From 642e6e93355f589304ebcdce91039c5ec9c5c1cb Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Mon, 14 Jun 2021 21:40:14 +0200 Subject: [PATCH] fix some issues with equality of factorizations - `hash` did not respect the type of a factorization, so completely different factorizations with the same underlying data would result in same `hash` leading to inconsistencies with `isequal`. This likely doesn't occur very often in practice, but definitely seems worth fixing. - `==` and `isequal` only returned true if two factorizations are of exactly the same type, which is inconsistent with their implementation for other objects and with the definition of `hash` for factorizations. - Equality for `QRCompactWY` did not ignore the subdiagonal entries of `T` leading to nondeterministic behavior. Perhaps `T` should be directly stored as `UpperTriangular` in `QRCompactWY`, but that seems potentially breaking. Relying on implementation details of `DataType` here is certainly less than ideal, but I could not come up with a nicer solution. --- stdlib/LinearAlgebra/src/factorization.jl | 14 ++++-- stdlib/LinearAlgebra/src/qr.jl | 10 +++++ stdlib/LinearAlgebra/test/factorization.jl | 50 ++++++++++++++++++++++ stdlib/LinearAlgebra/test/testgroups | 1 + 4 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 stdlib/LinearAlgebra/test/factorization.jl diff --git a/stdlib/LinearAlgebra/src/factorization.jl b/stdlib/LinearAlgebra/src/factorization.jl index b651e85512f6d..641965b0e8d44 100644 --- a/stdlib/LinearAlgebra/src/factorization.jl +++ b/stdlib/LinearAlgebra/src/factorization.jl @@ -64,9 +64,17 @@ Factorization{T}(A::Adjoint{<:Any,<:Factorization}) where {T} = adjoint(Factorization{T}(parent(A))) inv(F::Factorization{T}) where {T} = (n = size(F, 1); ldiv!(F, Matrix{T}(I, n, n))) -Base.hash(F::Factorization, h::UInt) = mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=h) -Base.:(==)( F::T, G::T) where {T<:Factorization} = all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F)) -Base.isequal(F::T, G::T) where {T<:Factorization} = all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool +function Base.hash(F::Factorization, h::UInt) + return mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=hash(typeof(F).name.wrapper, h)) +end +function Base.:(==)(F::Factorization, G::Factorization) + typeof(F).name.wrapper == typeof(G).name.wrapper || return false + return all(f -> getfield(F, f) == getfield(G, f), 1:nfields(F)) +end +function Base.isequal(F::Factorization, G::Factorization) + typeof(F).name.wrapper == typeof(G).name.wrapper || return false + return all(f -> isequal(getfield(F, f), getfield(G, f)), 1:nfields(F))::Bool +end function Base.show(io::IO, x::Adjoint{<:Any,<:Factorization}) print(io, "Adjoint of ") diff --git a/stdlib/LinearAlgebra/src/qr.jl b/stdlib/LinearAlgebra/src/qr.jl index d0ec430347193..32b21ac0045c9 100644 --- a/stdlib/LinearAlgebra/src/qr.jl +++ b/stdlib/LinearAlgebra/src/qr.jl @@ -127,6 +127,16 @@ Base.iterate(S::QRCompactWY) = (S.Q, Val(:R)) Base.iterate(S::QRCompactWY, ::Val{:R}) = (S.R, Val(:done)) Base.iterate(S::QRCompactWY, ::Val{:done}) = nothing +function Base.hash(F::QRCompactWY, h::UInt) + return hash(F.factors, hash(UpperTriangular(F.T), hash(QRCompactWY, h))) +end +function Base.:(==)(A::QRCompactWY, B::QRCompactWY) + return A.factors == B.factors && UpperTriangular(A.T) == UpperTriangular(B.T) +end +function Base.isequal(A::QRCompactWY, B::QRCompactWY) + return isequal(A.factors, B.factors) && isequal(UpperTriangular(A.T), UpperTriangular(B.T)) +end + """ QRPivoted <: Factorization diff --git a/stdlib/LinearAlgebra/test/factorization.jl b/stdlib/LinearAlgebra/test/factorization.jl new file mode 100644 index 0000000000000..fda9b316ca5ff --- /dev/null +++ b/stdlib/LinearAlgebra/test/factorization.jl @@ -0,0 +1,50 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module TestFactorization +using Test, LinearAlgebra + +@testset "equality for factorizations - $f" for f in Any[ + bunchkaufman, + cholesky, + x -> cholesky(x, Val(true)), + eigen, + hessenberg, + lq, + lu, + qr, + x -> qr(x, ColumnNorm()), + svd, + schur, +] + A = randn(3, 3) + A = A * A' # ensure A is pos. def. and symmetric + F, G = f(A), f(A) + + @test F == G + @test isequal(F, G) + @test hash(F) == hash(G) + + f === hessenberg && continue + + F = typeof(F).name.wrapper(Base.mapany(1:nfields(F)) do i + x = getfield(F, i) + return x isa AbstractArray{Float64} ? Float32.(x) : x + end...) + G = typeof(G).name.wrapper(Base.mapany(1:nfields(G)) do i + x = getfield(G, i) + return x isa AbstractArray{Float64} ? Float64.(Float32.(x)) : x + end...) + + @test F == G + @test isequal(F, G) + @test hash(F) == hash(G) +end + +@testset "hash collisions" begin + A, v = randn(2, 2), randn(2) + F, G = LQ(A, v), QR(A, v) + @test !isequal(F, G) + @test hash(F) != hash(G) +end + +end diff --git a/stdlib/LinearAlgebra/test/testgroups b/stdlib/LinearAlgebra/test/testgroups index b33dfecaa82ee..de082d8e7dce0 100644 --- a/stdlib/LinearAlgebra/test/testgroups +++ b/stdlib/LinearAlgebra/test/testgroups @@ -25,3 +25,4 @@ givens structuredbroadcast addmul ldlt +factorization