Skip to content

Commit 6ab660d

Browse files
authored
Optimizations for diagonal HermOrSym (#48189)
1 parent 45b7e7a commit 6ab660d

File tree

7 files changed

+55
-10
lines changed

7 files changed

+55
-10
lines changed

stdlib/LinearAlgebra/src/hessenberg.jl

+5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ real(H::UpperHessenberg{<:Real}) = H
7070
real(H::UpperHessenberg{<:Complex}) = UpperHessenberg(triu!(real(H.data),-1))
7171
imag(H::UpperHessenberg) = UpperHessenberg(triu!(imag(H.data),-1))
7272

73+
function istriu(A::UpperHessenberg, k::Integer=0)
74+
k <= -1 && return true
75+
return _istriu(A, k)
76+
end
77+
7378
function Matrix{T}(H::UpperHessenberg) where T
7479
m,n = size(H)
7580
return triu!(copyto!(Matrix{T}(undef, m, n), H.data), -1)

stdlib/LinearAlgebra/src/special.jl

+4
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ end
324324
zero(D::Diagonal) = Diagonal(zero.(D.diag))
325325
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))
326326

327+
isdiag(A::HermOrSym{<:Any,<:Diagonal}) = isdiag(parent(A))
328+
dot(x::AbstractVector, A::RealHermSymComplexSym{<:Real,<:Diagonal}, y::AbstractVector) =
329+
dot(x, A.data, y)
330+
327331
# equals and approx equals methods for structured matrices
328332
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl
329333

stdlib/LinearAlgebra/src/symmetric.jl

+11-4
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ end
241241
diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
242242
diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))
243243

244+
isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data))
245+
244246
# For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same
245247
# symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases.
246248
similar(A::Symmetric, ::Type{T}) where {T} = Symmetric(similar(parent(A), T), ifelse(A.uplo == 'U', :U, :L))
@@ -316,6 +318,7 @@ function fillstored!(A::HermOrSym{T}, x) where T
316318
return A
317319
end
318320

321+
Base.isreal(A::HermOrSym{<:Real}) = true
319322
function Base.isreal(A::HermOrSym)
320323
n = size(A, 1)
321324
@inbounds if A.uplo == 'U'
@@ -578,9 +581,11 @@ end
578581

579582
function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
580583
require_one_based_indexing(x, y)
581-
(length(x) == length(y) == size(A, 1)) || throw(DimensionMismatch())
584+
n = length(x)
585+
(n == length(y) == size(A, 1)) || throw(DimensionMismatch())
582586
data = A.data
583-
r = zero(eltype(x)) * zero(eltype(A)) * zero(eltype(y))
587+
r = dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
588+
iszero(n) && return r
584589
if A.uplo == 'U'
585590
@inbounds for j = 1:length(y)
586591
r += dot(x[j], real(data[j,j]), y[j])
@@ -612,7 +617,9 @@ end
612617
factorize(A::HermOrSym) = _factorize(A)
613618
function _factorize(A::HermOrSym{T}; check::Bool=true) where T
614619
TT = typeof(sqrt(oneunit(T)))
615-
if TT <: BlasFloat
620+
if isdiag(A)
621+
return Diagonal(A)
622+
elseif TT <: BlasFloat
616623
return bunchkaufman(A; check=check)
617624
else # fallback
618625
return lu(A; check=check)
@@ -626,7 +633,7 @@ det(A::Symmetric) = det(_factorize(A; check=false))
626633
\(A::HermOrSym, B::AbstractVector) = \(factorize(A), B)
627634
# Bunch-Kaufman solves can not utilize BLAS-3 for multiple right hand sides
628635
# so using LU is faster for AbstractMatrix right hand side
629-
\(A::HermOrSym, B::AbstractMatrix) = \(lu(A), B)
636+
\(A::HermOrSym, B::AbstractMatrix) = \(isdiag(A) ? Diagonal(A) : lu(A), B)
630637

631638
function _inv(A::HermOrSym)
632639
n = checksquare(A)

stdlib/LinearAlgebra/src/triangular.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,10 @@ function istriu(A::Union{UpperTriangular,UnitUpperTriangular}, k::Integer=0)
297297
k <= 0 && return true
298298
return _istriu(A, k)
299299
end
300-
istril(A::Adjoint) = istriu(A.parent)
301-
istril(A::Transpose) = istriu(A.parent)
302-
istriu(A::Adjoint) = istril(A.parent)
303-
istriu(A::Transpose) = istril(A.parent)
300+
istril(A::Adjoint, k::Integer=0) = istriu(A.parent, -k)
301+
istril(A::Transpose, k::Integer=0) = istriu(A.parent, -k)
302+
istriu(A::Adjoint, k::Integer=0) = istril(A.parent, -k)
303+
istriu(A::Transpose, k::Integer=0) = istril(A.parent, -k)
304304

305305
function tril!(A::UpperTriangular, k::Integer=0)
306306
n = size(A,1)

stdlib/LinearAlgebra/test/hessenberg.jl

+5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ let n = 10
2424
A = Areal
2525
H = UpperHessenberg(A)
2626
AH = triu(A,-1)
27+
for k in -2:2
28+
@test istril(H, k) == istril(AH, k)
29+
@test istriu(H, k) == istriu(AH, k)
30+
@test (k <= -1 ? istriu(H, k) : !istriu(H, k))
31+
end
2732
@test UpperHessenberg(H) === H
2833
@test parent(H) === A
2934
@test Matrix(H) == Array(H) == H == AH

stdlib/LinearAlgebra/test/symmetric.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,16 @@ end
7676
end
7777
@testset "diag" begin
7878
D = Diagonal(x)
79-
@test diag(Symmetric(D, :U))::Vector == x
80-
@test diag(Hermitian(D, :U))::Vector == real(x)
79+
DM = Matrix(D)
80+
B = diagm(-1 => x, 1 => x)
81+
for uplo in (:U, :L)
82+
@test diag(Symmetric(D, uplo))::Vector == x
83+
@test diag(Hermitian(D, uplo))::Vector == real(x)
84+
@test isdiag(Symmetric(DM, uplo))
85+
@test isdiag(Hermitian(DM, uplo))
86+
@test !isdiag(Symmetric(B, uplo))
87+
@test !isdiag(Hermitian(B, uplo))
88+
end
8189
end
8290
@testset "similar" begin
8391
@test isa(similar(Symmetric(asym)), Symmetric{eltya})
@@ -394,13 +402,19 @@ end
394402
@test Hermitian(aherm)\b aherm\b
395403
@test Symmetric(asym)\x asym\x
396404
@test Symmetric(asym)\b asym\b
405+
@test Hermitian(Diagonal(aherm))\x Diagonal(aherm)\x
406+
@test Hermitian(Matrix(Diagonal(aherm)))\b Diagonal(aherm)\b
407+
@test Symmetric(Diagonal(asym))\x Diagonal(asym)\x
408+
@test Symmetric(Matrix(Diagonal(asym)))\b Diagonal(asym)\b
397409
end
398410
end
399411
@testset "generalized dot product" begin
400412
for uplo in (:U, :L)
401413
@test dot(x, Hermitian(aherm, uplo), y) dot(x, Hermitian(aherm, uplo)*y) dot(x, Matrix(Hermitian(aherm, uplo)), y)
402414
@test dot(x, Hermitian(aherm, uplo), x) dot(x, Hermitian(aherm, uplo)*x) dot(x, Matrix(Hermitian(aherm, uplo)), x)
403415
end
416+
@test dot(x, Hermitian(Diagonal(a)), y) dot(x, Hermitian(Diagonal(a))*y) dot(x, Matrix(Hermitian(Diagonal(a))), y)
417+
@test dot(x, Hermitian(Diagonal(a)), x) dot(x, Hermitian(Diagonal(a))*x) dot(x, Matrix(Hermitian(Diagonal(a))), x)
404418
if eltya <: Real
405419
for uplo in (:U, :L)
406420
@test dot(x, Symmetric(aherm, uplo), y) dot(x, Symmetric(aherm, uplo)*y) dot(x, Matrix(Symmetric(aherm, uplo)), y)

stdlib/LinearAlgebra/test/triangular.jl

+10
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ for elty1 in (Float32, Float64, BigFloat, ComplexF32, ComplexF64, Complex{BigFlo
119119
@test !istriu(A1')
120120
@test !istriu(transpose(A1))
121121
end
122+
M = copy(parent(A1))
123+
for trans in (adjoint, transpose), k in -1:1
124+
triu!(M, k)
125+
@test istril(trans(M), -k) == istril(copy(trans(M)), -k) == true
126+
end
127+
M = copy(parent(A1))
128+
for trans in (adjoint, transpose), k in 1:-1:-1
129+
tril!(M, k)
130+
@test istriu(trans(M), -k) == istriu(copy(trans(M)), -k) == true
131+
end
122132

123133
#tril/triu
124134
if uplo1 === :L

0 commit comments

Comments
 (0)