Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_rrule fails to handle QRCompactWY return types #184

Open
rkube opened this issue Jun 25, 2021 · 1 comment
Open

test_rrule fails to handle QRCompactWY return types #184

rkube opened this issue Jun 25, 2021 · 1 comment

Comments

@rkube
Copy link

rkube commented Jun 25, 2021

The QR-factorization returns a QRCompactWY, which describes two matrices. When testing a custom pullback for the qr-factorization, test_rrule calls length on that type which is not well defined.

MWE:

using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences

using ChainRulesTestUtils

ChainRulesCore.debug_mode() = true
Random.seed!(1234);

function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} 
    QR = qr(A)
    m, n = size(A)
    function qr_pullback(Ȳ::Tangent)
        # For square (m=n) or tall and skinny (m >= n), use the rule derived by 
        # Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
        #   
        # Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
        #   
        # where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
        # copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
        #   
        # This code is re-used in the wide case and we put it in a separate function.

        function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
            M = R̄*R' - Q'*Q̄
            # M <- copyltu(M)
            M = triu(M) + transpose(triu(M,1))
            Ā = (Q̄ + Q * M) / R'
        end 

        # For the wide (m < n) case, we implement the rule derived by
        # Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
        #   
        # Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
        # where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
        #  R = (U,V). Both X and U are full rank square matrices.
        #   
        # See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
        # And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp 
        Q̄ = Ȳ.factors
        R̄ = Ȳ.T 
        Q = QR.Q
        R = QR.R
        if m ≥ n 
            Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)] 
            Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
        else
            # partition A = [X | Y]
            # X = A[1:m, 1:m]
            Y = A[1:m, m + 1:end]
    
            # partition R = [U | V], and we don't need V
            U = R[1:m, 1:m]
            if R̄ isa ChainRules.AbstractZero
                V̄ = zeros(size(Y))
                Q̄_prime = zeros(size(Q))
                Ū = R̄ 
            else
                # partition R̄ = [Ū | V̄]
                Ū = R̄[1:m, 1:m]
                V̄ = R̄[1:m, m + 1:end]
                Q̄_prime = Y * V̄'
            end 

            Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ 

            X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
            Ȳ = Q * V̄ 
            # partition Ā = [X̄ | Ȳ]
            Ā = [X̄ Ȳ]
        end 
        return (NoTangent(), Ā)
    end 
    return QR, qr_pullback
end


function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) 
    function getproperty_qr_pullback(Ȳ)
        # The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see 
        # R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
        # Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
        # in the Tangent object.
        ∂factors = if d === :Q
            Ȳ
        else
            nothing
        end

        ∂T = if d === :R
            Ȳ
        else
            nothing
        end

        ∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
        return (NoTangent(), ∂F)
    end

    return getproperty(F, d), getproperty_qr_pullback
end

V = randn((4,4))
test_rrule(qr, V)

Fails with:

  Got exception outside of a @test
  MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
  Closest candidates are:
    length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
    length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
    length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195

This is because length is not defined for ::LinearAlgebra.QRCompactWY:

julia> typeof(qr(V))
LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}}
julia> length(qr(V))
ERROR: MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
  length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[9]:1
@oxinabox
Copy link
Member

oxinabox commented Jun 25, 2021

The error is here

if _can_pass_early(actual, expected)
@test true
else
c_actual = collect(actual)

It is thrown in the call to collect
That bit should be probably be guarded by a isiterable or similar.
And if it isn't iterable then it should just @test false since the thing is not equal.
(Since otherwise it would have been caught be _can_pass_early, or it would interate into parts and we could check that the parts are all equal)

However the underlying problem as to why this didn't _can_pass_early is because QR equality doesn't workm because of bug with QR's equality
JuliaLang/julia#41363

As a hack one could monkey-patch Base.== or Base.isapprox, or ChainRulesTestUtils.test_approx
In particular

function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...)
     ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg * " Q:"; kwargs...)
     ChainRulesTestUtils.test_approx(actual.R, expected.R, msg * " R:"; kwargs...)
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants