Skip to content

Commit

Permalink
fix some issues with equality of factorizations
Browse files Browse the repository at this point in the history
- `hash` did not respect the type of a factorization, so completely
  different factorizations with the same underlying data would result in
  same `hash` leading to inconsistencies with `isequal`. This likely
  doesn't occur very often in practice, but definitely seems worth
  fixing.
- `==` and `isequal` only returned true if two factorizations are of
  exactly the same type, which is inconsistent with their implementation
  for other objects and with the definition of `hash` for factorizations.
- Equality for `QRCompactWY` did not ignore the subdiagonal entries of
  `T` leading to nondeterministic behavior. Perhaps `T` should be
  directly stored as `UpperTriangular` in `QRCompactWY`, but that seems
  potentially breaking.

Relying on implementation details of `DataType` here is certainly less
than ideal, but I could not come up with a nicer solution.
  • Loading branch information
simeonschaub committed Jun 14, 2021
1 parent 4a81b08 commit 642e6e9
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 3 deletions.
14 changes: 11 additions & 3 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,17 @@ Factorization{T}(A::Adjoint{<:Any,<:Factorization}) where {T} =
adjoint(Factorization{T}(parent(A)))
inv(F::Factorization{T}) where {T} = (n = size(F, 1); ldiv!(F, Matrix{T}(I, n, n)))

Base.hash(F::Factorization, h::UInt) = mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=h)
Base.:(==)( F::T, G::T) where {T<:Factorization} = all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F))
Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool
function Base.hash(F::Factorization, h::UInt)
return mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=hash(typeof(F).name.wrapper, h))
end
function Base.:(==)(F::Factorization, G::Factorization)
typeof(F).name.wrapper == typeof(G).name.wrapper || return false
return all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F))
end
function Base.isequal(F::Factorization, G::Factorization)
typeof(F).name.wrapper == typeof(G).name.wrapper || return false
return all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool
end

function Base.show(io::IO, x::Adjoint{<:Any,<:Factorization})
print(io, "Adjoint of ")
Expand Down
10 changes: 10 additions & 0 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R))
Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done))
Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing

function Base.hash(F::QRCompactWY, h::UInt)
return hash(F.factors, hash(UpperTriangular(F.T), hash(QRCompactWY, h)))
end
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
return A.factors == B.factors && UpperTriangular(A.T) == UpperTriangular(B.T)
end
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
return isequal(A.factors, B.factors) && isequal(UpperTriangular(A.T), UpperTriangular(B.T))
end

"""
QRPivoted <: Factorization
Expand Down
50 changes: 50 additions & 0 deletions stdlib/LinearAlgebra/test/factorization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module TestFactorization
using Test, LinearAlgebra

@testset "equality for factorizations - $f" for f in Any[
bunchkaufman,
cholesky,
x -> cholesky(x, Val(true)),
eigen,
hessenberg,
lq,
lu,
qr,
x -> qr(x, ColumnNorm()),
svd,
schur,
]
A = randn(3, 3)
A = A * A' # ensure A is pos. def. and symmetric
F, G = f(A), f(A)

@test F == G
@test isequal(F, G)
@test hash(F) == hash(G)

f === hessenberg && continue

F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i
x = getfield(F, i)
return x isa AbstractArray{Float64} ? Float32.(x) : x
end...)
G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i
x = getfield(G, i)
return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x
end...)

@test F == G
@test isequal(F, G)
@test hash(F) == hash(G)
end

@testset "hash collisions" begin
A, v = randn(2, 2), randn(2)
F, G = LQ(A, v), QR(A, v)
@test !isequal(F, G)
@test hash(F) != hash(G)
end

end
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/test/testgroups
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ givens
structuredbroadcast
addmul
ldlt
factorization

0 comments on commit 642e6e9

Please sign in to comment.