Skip to content

Commit 9fd4087

Browse files
authored
Generalize 2-arg eigen towards AbstractMatrix (#46797)
1 parent 9118373 commit 9fd4087

File tree

5 files changed

+128
-147
lines changed

5 files changed

+128
-147
lines changed

stdlib/LinearAlgebra/src/bunchkaufman.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ julia> S.L*S.D*S.L' - A[S.p, S.p]
197197
```
198198
"""
199199
bunchkaufman(A::AbstractMatrix{T}, rook::Bool=false; check::Bool = true) where {T} =
200-
bunchkaufman!(copymutable_oftype(A, typeof(sqrt(oneunit(T)))), rook; check = check)
200+
bunchkaufman!(eigencopy_oftype(A, typeof(sqrt(oneunit(T)))), rook; check = check)
201201

202202
BunchKaufman{T}(B::BunchKaufman) where {T} =
203203
BunchKaufman(convert(Matrix{T}, B.LD), B.ipiv, B.uplo, B.symmetric, B.rook, B.info)

stdlib/LinearAlgebra/src/diagonal.jl

+24-1
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ function eigen(D::Diagonal; permute::Bool=true, scale::Bool=true, sortby::Union{
752752
λ = eigvals(D)
753753
if !isnothing(sortby)
754754
p = sortperm(λ; alg=QuickSort, by=sortby)
755-
λ = λ[p] # make a copy, otherwise this permutes D.diag
755+
λ = λ[p]
756756
evecs = zeros(Td, size(D))
757757
@inbounds for i in eachindex(p)
758758
evecs[p[i],i] = one(Td)
@@ -762,6 +762,29 @@ function eigen(D::Diagonal; permute::Bool=true, scale::Bool=true, sortby::Union{
762762
end
763763
Eigen(λ, evecs)
764764
end
765+
function eigen(Da::Diagonal, Db::Diagonal; sortby::Union{Function,Nothing}=nothing)
766+
if any(!isfinite, Da.diag) || any(!isfinite, Db.diag)
767+
throw(ArgumentError("matrices contain Infs or NaNs"))
768+
end
769+
if any(iszero, Db.diag)
770+
throw(ArgumentError("right-hand side diagonal matrix is singular"))
771+
end
772+
return GeneralizedEigen(eigen(Db \ Da; sortby)...)
773+
end
774+
function eigen(A::AbstractMatrix, D::Diagonal; sortby::Union{Function,Nothing}=nothing)
775+
if any(iszero, D.diag)
776+
throw(ArgumentError("right-hand side diagonal matrix is singular"))
777+
end
778+
if size(A, 1) == size(A, 2) && isdiag(A)
779+
return eigen(Diagonal(A), D; sortby)
780+
elseif ishermitian(A)
781+
S = promote_type(eigtype(eltype(A)), eltype(D))
782+
return eigen!(eigencopy_oftype(Hermitian(A), S), Diagonal{S}(D); sortby)
783+
else
784+
S = promote_type(eigtype(eltype(A)), eltype(D))
785+
return eigen!(eigencopy_oftype(A, S), Diagonal{S}(D); sortby)
786+
end
787+
end
765788

766789
#Singular system
767790
svdvals(D::Diagonal{<:Number}) = sort!(abs.(D.diag), rev = true)

stdlib/LinearAlgebra/src/eigen.jl

+26-14
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,20 @@ true
233233
```
234234
"""
235235
function eigen(A::AbstractMatrix{T}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=eigsortby) where T
236-
AA = copymutable_oftype(A, eigtype(T))
237-
isdiag(AA) && return eigen(Diagonal(AA); permute=permute, scale=scale, sortby=sortby)
238-
return eigen!(AA; permute=permute, scale=scale, sortby=sortby)
236+
isdiag(A) && return eigen(Diagonal{eigtype(T)}(diag(A)); sortby)
237+
ishermitian(A) && return eigen!(eigencopy_oftype(Hermitian(A), eigtype(T)); sortby)
238+
AA = eigencopy_oftype(A, eigtype(T))
239+
return eigen!(AA; permute, scale, sortby)
239240
end
240241
function eigen(A::AbstractMatrix{T}; permute::Bool=true, scale::Bool=true, sortby::Union{Function,Nothing}=eigsortby) where {T <: Union{Float16,Complex{Float16}}}
241-
AA = copymutable_oftype(A, eigtype(T))
242-
isdiag(AA) && return eigen(Diagonal(AA); permute=permute, scale=scale, sortby=sortby)
243-
A = eigen!(AA; permute, scale, sortby)
244-
values = convert(AbstractVector{isreal(A.values) ? Float16 : Complex{Float16}}, A.values)
245-
vectors = convert(AbstractMatrix{isreal(A.vectors) ? Float16 : Complex{Float16}}, A.vectors)
242+
isdiag(A) && return eigen(Diagonal{eigtype(T)}(diag(A)); sortby)
243+
E = if ishermitian(A)
244+
eigen!(eigencopy_oftype(Hermitian(A), eigtype(T)); sortby)
245+
else
246+
eigen!(eigencopy_oftype(A, eigtype(T)); permute, scale, sortby)
247+
end
248+
values = convert(AbstractVector{isreal(E.values) ? Float16 : Complex{Float16}}, E.values)
249+
vectors = convert(AbstractMatrix{isreal(E.vectors) ? Float16 : Complex{Float16}}, E.vectors)
246250
return Eigen(values, vectors)
247251
end
248252
eigen(x::Number) = Eigen([x], fill(one(x), 1, 1))
@@ -333,7 +337,7 @@ julia> eigvals(diag_matrix)
333337
```
334338
"""
335339
eigvals(A::AbstractMatrix{T}; kws...) where T =
336-
eigvals!(copymutable_oftype(A, eigtype(T)); kws...)
340+
eigvals!(eigencopy_oftype(A, eigtype(T)); kws...)
337341

338342
"""
339343
For a scalar input, `eigvals` will return a scalar.
@@ -507,12 +511,20 @@ true
507511
```
508512
"""
509513
function eigen(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}; kws...) where {TA,TB}
510-
S = promote_type(eigtype(TA),TB)
511-
eigen!(copymutable_oftype(A, S), copymutable_oftype(B, S); kws...)
514+
S = promote_type(eigtype(TA), TB)
515+
eigen!(eigencopy_oftype(A, S), eigencopy_oftype(B, S); kws...)
512516
end
513-
514517
eigen(A::Number, B::Number) = eigen(fill(A,1,1), fill(B,1,1))
515518

519+
"""
520+
LinearAlgebra.eigencopy_oftype(A::AbstractMatrix, ::Type{S})
521+
522+
Creates a dense copy of `A` with eltype `S` by calling `copy_similar(A, S)`.
523+
In the case of `Hermitian` or `Symmetric` matrices additionally retains the wrapper,
524+
together with the `uplo` field.
525+
"""
526+
eigencopy_oftype(A, S) = copy_similar(A, S)
527+
516528
"""
517529
eigvals!(A, B; sortby) -> values
518530
@@ -586,8 +598,8 @@ julia> eigvals(A,B)
586598
```
587599
"""
588600
function eigvals(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}; kws...) where {TA,TB}
589-
S = promote_type(eigtype(TA),TB)
590-
return eigvals!(copymutable_oftype(A, S), copymutable_oftype(B, S); kws...)
601+
S = promote_type(eigtype(TA), TB)
602+
return eigvals!(eigencopy_oftype(A, S), eigencopy_oftype(B, S); kws...)
591603
end
592604

593605
"""

stdlib/LinearAlgebra/src/symmetriceigen.jl

+27-111
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3+
# preserve HermOrSym wrapper
4+
eigencopy_oftype(A::Hermitian, S) = Hermitian(copy_similar(A, S), sym_uplo(A.uplo))
5+
eigencopy_oftype(A::Symmetric, S) = Symmetric(copy_similar(A, S), sym_uplo(A.uplo))
6+
37
# Eigensolvers for symmetric and Hermitian matrices
48
eigen!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing) =
59
Eigen(sorteig!(LAPACK.syevr!('V', 'A', A.uplo, A.data, 0.0, 0.0, 0, 0, -1.0)..., sortby)...)
610

711
function eigen(A::RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing)
8-
T = eltype(A)
9-
S = eigtype(T)
10-
eigen!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), sortby=sortby)
12+
S = eigtype(eltype(A))
13+
eigen!(eigencopy_oftype(A, S), sortby=sortby)
1114
end
1215

1316
eigen!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}, irange::UnitRange) =
@@ -31,9 +34,8 @@ The [`UnitRange`](@ref) `irange` specifies indices of the sorted eigenvalues to
3134
will be a *truncated* factorization.
3235
"""
3336
function eigen(A::RealHermSymComplexHerm, irange::UnitRange)
34-
T = eltype(A)
35-
S = eigtype(T)
36-
eigen!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), irange)
37+
S = eigtype(eltype(A))
38+
eigen!(eigencopy_oftype(A, S), irange)
3739
end
3840

3941
eigen!(A::RealHermSymComplexHerm{T,<:StridedMatrix}, vl::Real, vh::Real) where {T<:BlasReal} =
@@ -57,9 +59,8 @@ The following functions are available for `Eigen` objects: [`inv`](@ref), [`det`
5759
will be a *truncated* factorization.
5860
"""
5961
function eigen(A::RealHermSymComplexHerm, vl::Real, vh::Real)
60-
T = eltype(A)
61-
S = eigtype(T)
62-
eigen!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), vl, vh)
62+
S = eigtype(eltype(A))
63+
eigen!(eigencopy_oftype(A, S), vl, vh)
6364
end
6465

6566
function eigvals!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing)
@@ -69,9 +70,8 @@ function eigvals!(A::RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; sortby:
6970
end
7071

7172
function eigvals(A::RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing)
72-
T = eltype(A)
73-
S = eigtype(T)
74-
eigvals!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), sortby=sortby)
73+
S = eigtype(eltype(A))
74+
eigvals!(eigencopy_oftype(A, S), sortby=sortby)
7575
end
7676

7777
"""
@@ -110,9 +110,8 @@ julia> eigvals(A)
110110
```
111111
"""
112112
function eigvals(A::RealHermSymComplexHerm, irange::UnitRange)
113-
T = eltype(A)
114-
S = eigtype(T)
115-
eigvals!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), irange)
113+
S = eigtype(eltype(A))
114+
eigvals!(eigencopy_oftype(A, S), irange)
116115
end
117116

118117
"""
@@ -150,9 +149,8 @@ julia> eigvals(A)
150149
```
151150
"""
152151
function eigvals(A::RealHermSymComplexHerm, vl::Real, vh::Real)
153-
T = eltype(A)
154-
S = eigtype(T)
155-
eigvals!(S != T ? convert(AbstractMatrix{S}, A) : copy(A), vl, vh)
152+
S = eigtype(eltype(A))
153+
eigvals!(eigencopy_oftype(A, S), vl, vh)
156154
end
157155

158156
eigmax(A::RealHermSymComplexHerm{<:Real,<:StridedMatrix}) = eigvals(A, size(A, 1):size(A, 1))[1]
@@ -166,107 +164,25 @@ function eigen!(A::Hermitian{T,S}, B::Hermitian{T,S}; sortby::Union{Function,Not
166164
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.data, B.uplo == A.uplo ? B.data : copy(B.data'))
167165
GeneralizedEigen(sorteig!(vals, vecs, sortby)...)
168166
end
169-
170167
function eigen!(A::RealHermSymComplexHerm{T,S}, B::AbstractMatrix{T}; sortby::Union{Function,Nothing}=nothing) where {T<:Number,S<:StridedMatrix}
168+
return _choleigen!(A, B, sortby)
169+
end
170+
function eigen!(A::StridedMatrix{T}, B::Union{RealHermSymComplexHerm{T},Diagonal{T}}; sortby::Union{Function,Nothing}=nothing) where {T<:Number}
171+
return _choleigen!(A, B, sortby)
172+
end
173+
function _choleigen!(A, B, sortby)
171174
U = cholesky(B).U
172175
vals, w = eigen!(UtiAUi!(A, U))
173176
vecs = U \ w
174177
GeneralizedEigen(sorteig!(vals, vecs, sortby)...)
175178
end
176179

177-
# Perform U' \ A / U in-place.
178-
UtiAUi!(As::Symmetric, Utr::UpperTriangular) = Symmetric(_UtiAsymUi!(As.uplo, parent(As), parent(Utr)), sym_uplo(As.uplo))
179-
UtiAUi!(As::Hermitian, Utr::UpperTriangular) = Hermitian(_UtiAsymUi!(As.uplo, parent(As), parent(Utr)), sym_uplo(As.uplo))
180-
UtiAUi!(As::Symmetric, Udi::Diagonal) = Symmetric(_UtiAsymUi_diag!(As.uplo, parent(As), Udi), sym_uplo(As.uplo))
181-
UtiAUi!(As::Hermitian, Udi::Diagonal) = Hermitian(_UtiAsymUi_diag!(As.uplo, parent(As), Udi), sym_uplo(As.uplo))
182-
183-
# U is upper triangular
184-
function _UtiAsymUi!(uplo, A, U)
185-
n = size(A, 1)
186-
μ⁻¹ = 1 / U[1, 1]
187-
αμ⁻² = A[1, 1] * μ⁻¹' * μ⁻¹
188-
189-
# Update (1, 1) element
190-
A[1, 1] = αμ⁻²
191-
if n > 1
192-
Unext = view(U, 2:n, 2:n)
193-
194-
if uplo === 'U'
195-
# Update submatrix
196-
for j in 2:n, i in 2:j
197-
A[i, j] = (
198-
A[i, j]
199-
- μ⁻¹' * U[1, j] * A[1, i]'
200-
- μ⁻¹ * A[1, j] * U[1, i]'
201-
+ αμ⁻² * U[1, j] * U[1, i]'
202-
)
203-
end
204-
205-
# Update vector
206-
for j in 2:n
207-
A[1, j] = A[1, j] * μ⁻¹' - U[1, j] * αμ⁻²
208-
end
209-
ldiv!(view(A', 2:n, 1), UpperTriangular(Unext)', view(A', 2:n, 1))
210-
else
211-
# Update submatrix
212-
for j in 2:n, i in 2:j
213-
A[j, i] = (
214-
A[j, i]
215-
- μ⁻¹ * A[i, 1]' * U[1, j]'
216-
- μ⁻¹' * U[1, i] * A[j, 1]
217-
+ αμ⁻² * U[1, i] * U[1, j]'
218-
)
219-
end
220-
221-
# Update vector
222-
for j in 2:n
223-
A[j, 1] = A[j, 1] * μ⁻¹ - U[1, j]' * αμ⁻²
224-
end
225-
ldiv!(view(A, 2:n, 1), UpperTriangular(Unext)', view(A, 2:n, 1))
226-
end
227-
228-
# Recurse
229-
_UtiAsymUi!(uplo, view(A, 2:n, 2:n), Unext)
230-
end
231-
232-
return A
233-
end
180+
# Perform U' \ A / U in-place, where U::Union{UpperTriangular,Diagonal}
181+
UtiAUi!(A::StridedMatrix, U) = _UtiAUi!(A, U)
182+
UtiAUi!(A::Symmetric, U) = Symmetric(_UtiAUi!(copytri!(parent(A), A.uplo), U), sym_uplo(A.uplo))
183+
UtiAUi!(A::Hermitian, U) = Hermitian(_UtiAUi!(copytri!(parent(A), A.uplo, true), U), sym_uplo(A.uplo))
234184

235-
# U is diagonal
236-
function _UtiAsymUi_diag!(uplo, A, U)
237-
n = size(A, 1)
238-
μ⁻¹ = 1 / U[1, 1]
239-
αμ⁻² = A[1, 1] * μ⁻¹' * μ⁻¹
240-
241-
# Update (1, 1) element
242-
A[1, 1] = αμ⁻²
243-
if n > 1
244-
Unext = view(U, 2:n, 2:n)
245-
246-
if uplo === 'U'
247-
# No need to update any submatrix when U is diagonal
248-
249-
# Update vector
250-
for j in 2:n
251-
A[1, j] = A[1, j] * μ⁻¹'
252-
end
253-
ldiv!(view(A', 2:n, 1), Diagonal(Unext)', view(A', 2:n, 1))
254-
else
255-
# No need to update any submatrix when U is diagonal
256-
257-
# Update vector
258-
for j in 2:n
259-
A[j, 1] = A[j, 1] * μ⁻¹
260-
end
261-
ldiv!(view(A, 2:n, 1), Diagonal(Unext)', view(A, 2:n, 1))
262-
end
263-
264-
# Recurse
265-
_UtiAsymUi!(uplo, view(A, 2:n, 2:n), Unext)
266-
end
267-
268-
return A
269-
end
185+
_UtiAUi!(A, U) = rdiv!(ldiv!(U', A), U)
270186

271187
function eigvals!(A::HermOrSym{T,S}, B::HermOrSym{T,S}; sortby::Union{Function,Nothing}=nothing) where {T<:BlasReal,S<:StridedMatrix}
272188
vals = LAPACK.sygvd!(1, 'N', A.uplo, A.data, B.uplo == A.uplo ? B.data : copy(B.data'))[1]

0 commit comments

Comments
 (0)