You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The QR-factorization returns a QRCompactWY, which describes two matrices. When testing a custom pullback for the qr-factorization, test_rrule calls length on that type which is not well defined.
using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences
using ChainRulesTestUtils
ChainRulesCore.debug_mode() = true
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T}
QR = qr(A)
m, n = size(A)
function qr_pullback(Ȳ::Tangent)
# For square (m=n) or tall and skinny (m >= n), use the rule derived by
# Seeger et al. (2019)
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
# This code is re-used in the wide case and we put it in a separate function.
function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
M = R̄*R' - Q'*Q̄
# M <- copyltu(M)
M = triu(M) + transpose(triu(M,1))
Ā = (Q̄ + Q * M) / R'
# For the wide (m < n) case, we implement the rule derived by
# Liao et al. (2019)
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
# R = (U,V). Both X and U are full rank square matrices.
# See also the discussion in
# And
Q̄ = Ȳ.factors
R̄ = Ȳ.T
Q = QR.Q
R = QR.R
if m ≥ n
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)]
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
# partition A = [X | Y]
# X = A[1:m, 1:m]
Y = A[1:m, m + 1:end]
# partition R = [U | V], and we don't need V
U = R[1:m, 1:m]
if R̄ isa ChainRules.AbstractZero
V̄ = zeros(size(Y))
Q̄_prime = zeros(size(Q))
Ū = R̄
# partition R̄ = [Ū | V̄]
Ū = R̄[1:m, 1:m]
V̄ = R̄[1:m, m + 1:end]
Q̄_prime = Y * V̄'
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
Ȳ = Q * V̄
# partition Ā = [X̄ | Ȳ]
Ā = [X̄ Ȳ]
return (NoTangent(), Ā)
return QR, qr_pullback
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol)
function getproperty_qr_pullback(Ȳ)
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
# in the Tangent object.
∂factors = if d === :Q
∂T = if d === :R
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
return (NoTangent(), ∂F)
return getproperty(F, d), getproperty_qr_pullback
V = randn((4,4))
test_rrule(qr, V)
Fails with:
Got exception outside of a @test
MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
This is because length is not defined for ::LinearAlgebra.QRCompactWY:
julia> typeof(qr(V))
LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}}
julia> length(qr(V))
ERROR: MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
[1] top-level scope
@ REPL[9]:1
The text was updated successfully, but these errors were encountered:
It is thrown in the call to collect
That bit should be probably be guarded by a isiterable or similar.
And if it isn't iterable then it should just @test false since the thing is not equal.
(Since otherwise it would have been caught be _can_pass_early, or it would interate into parts and we could check that the parts are all equal)
However the underlying problem as to why this didn't _can_pass_early is because QR equality doesn't workm because of bug with QR's equality JuliaLang/julia#41363
As a hack one could monkey-patch Base.== or Base.isapprox, or ChainRulesTestUtils.test_approx
In particular
function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...)
ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg *" Q:"; kwargs...)
ChainRulesTestUtils.test_approx(actual.R, expected.R, msg *" R:"; kwargs...)
The QR-factorization returns a
, which describes two matrices. When testing a custom pullback for the qr-factorization,test_rrule
on that type which is not well defined.MWE:
Fails with:
This is because
is not defined for::LinearAlgebra.QRCompactWY
:The text was updated successfully, but these errors were encountered: