Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong call on cholesky rrule #611

Closed
theogf opened this issue May 7, 2022 · 12 comments · Fixed by #631
Closed

Wrong call on cholesky rrule #611

theogf opened this issue May 7, 2022 · 12 comments · Fixed by #631

Comments

@theogf
Copy link

theogf commented May 7, 2022

Identified in JuliaStats/PDMats.jl#159

When calling the cholesky rrule, there is an error in the code at

Ū = ΔC.U will be NoTangent() since the Tangent of Cholesky object does not contain the field U.
I think it should be Ū = ΔC.factors

@sethaxen
Copy link
Member

sethaxen commented May 7, 2022

Ū = ΔC.U will be NoTangent() since the Tangent of Cholesky object does not contain the field U.

Cholesky does not have the field U, but it does have the property U. Then there's this rrule for getproperty:

function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
function getproperty_cholesky_pullback(Ȳ)
C = Tangent{T}
∂F = if x === :U
if F.uplo === 'U'
C(U=UpperTriangular(Ȳ),)
else
C(L=LowerTriangular'),)
end
elseif x === :L
if F.uplo === 'L'
C(L=LowerTriangular(Ȳ),)
else
C(U=UpperTriangular'),)
end
end
return NoTangent(), ∂F, NoTangent()
end
return getproperty(F, x), getproperty_cholesky_pullback
end

Note that the current rules won't compose well if any downstream code accesses factors, but everything should work well if they access U or L. This should probably be improved though to look more like the rules for lu:

function _lu_pullback(ΔF::Tangent, m, n, eltypeA, pivot, F)
Δfactors = ΔF.factors
Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent())
factors = F.factors
∂factors = eltypeA <: Real ? real(Δfactors) : Δfactors
∂A = similar(factors)
q = min(m, n)
if m == n # square A
# ∂A = P' * (L' \ (tril(L' * ∂L, -1) + triu(∂U * U')) / U')
L = UnitLowerTriangular(factors)
U = UpperTriangular(factors)
∂U = UpperTriangular(∂factors)
tril!(copyto!(∂A, ∂factors), -1)
lmul!(L', ∂A)
copyto!(UpperTriangular(∂A), UpperTriangular(∂U * U'))
rdiv!(∂A, U')
ldiv!(L', ∂A)
elseif m < n # wide A, system is [P*A1 P*A2] = [L*U1 L*U2]
triu!(copyto!(∂A, ∂factors))
@views begin
factors1 = factors[:, 1:q]
U2 = factors[:, (q + 1):end]
∂A1 = ∂A[:, 1:q]
∂A2 = ∂A[:, (q + 1):end]
∂L = tril(∂factors[:, 1:q], -1)
end
L = UnitLowerTriangular(factors1)
U1 = UpperTriangular(factors1)
triu!(rmul!(∂A1, U1'))
∂A1 .+= tril!(mul!(lmul!(L', ∂L), ∂A2, U2', -1, 1), -1)
rdiv!(∂A1, U1')
ldiv!(L', ∂A)
else # tall A, system is [P1*A; P2*A] = [L1*U; L2*U]
tril!(copyto!(∂A, ∂factors), -1)
@views begin
factors1 = factors[1:q, :]
L2 = factors[(q + 1):end, :]
∂A1 = ∂A[1:q, :]
∂A2 = ∂A[(q + 1):end, :]
∂U = triu(∂factors[1:q, :])
end
U = UpperTriangular(factors1)
L1 = UnitLowerTriangular(factors1)
tril!(lmul!(L1', ∂A1), -1)
∂A1 .+= triu!(mul!(rmul!(∂U, U'), L2', ∂A2, -1, 1))
ldiv!(L1', ∂A1)
rdiv!(∂A, U')
end
if pivot isa LU_RowMaximum
∂A = ∂A[invperm(F.p), :]
end
return NoTangent(), ∂A, NoTangent()
end
_lu_pullback(ΔF::AbstractThunk, m, n, eltypeA, pivot, F) = _lu_pullback(unthunk(ΔF), m, n, eltypeA, pivot, F)
function rrule(
::typeof(lu), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
)
m, n = size(A)
F = lu(A, pivot; kwargs...)
lu_pullback(ȳ) = _lu_pullback(ȳ, m, n, eltype(A), pivot, F)
return F, lu_pullback
end
#####
##### functions of `LU`
#####
# this rrule is necessary because the primal mutates
function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:StridedMatrix{T}}}
function getproperty_LU_pullback(ΔY)
∂factors = if x === :L
m, n = size(F.factors)
S = eltype(ΔY)
tril!([ΔY zeros(S, m, max(0, n - m))], -1)
elseif x === :U
m, n = size(F.factors)
S = eltype(ΔY)
triu!([ΔY; zeros(S, max(0, m - n), n)])
elseif x === :factors
Matrix(ΔY)
else
return (NoTangent(), NoTangent(), NoTangent())
end
∂F = Tangent{TF}(; factors=∂factors)
return NoTangent(), ∂F, NoTangent()
end
getproperty_LU_pullback(ΔY::AbstractThunk) = getproperty_LU_pullback(unthunk(ΔY))
return getproperty(F, x), getproperty_LU_pullback
end

@theogf
Copy link
Author

theogf commented May 7, 2022

So this assumes that ΔC here is a Cholesky object right?

function _cholesky_pullback_shared_code(C, ΔC)

Which is not possible it is obviously a tangent :
Ā, U = _cholesky_pullback_shared_code(C, ΔC)

@sethaxen
Copy link
Member

sethaxen commented May 8, 2022

So this assumes that ΔC here is a Cholesky object right?

I don't see anywhere it assumes that. The preceding line has the type annotation ΔC::Tangent. The (co)tangent for Cholesky will not be a Cholesky.

@theogf
Copy link
Author

theogf commented May 8, 2022

Right so ΔC is a Tangent! So ΔC.U will call getproperty here
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2d75b4be102bb41ba3ac6df6dec8bb9617b20f0f/src/tangent_types/tangent.jl#L104
And since hasfield(NamedTuple, :U) is false, it will return NoTangent()...

@sethaxen
Copy link
Member

sethaxen commented May 9, 2022

T refers to the type of the backing, which will have fields for whatever keywords were passed to the Tangent constructor, which can be anything. e.g.

julia> using ChainRulesCore

julia> struct Cholesky
           factors
       end

julia> F = Cholesky(randn(5, 5));

julia> ΔF = Tangent{typeof(F)}(U=randn(5,5))
Tangent{Cholesky}(U = [-0.7059268385030435 -0.06675693167812352  -0.2413203958757456 0.03638639429591188; -0.47788794725474393 -1.2764613865307353  1.2012293162255412 0.15961345415826841;  ; 0.6552274641328225 0.24813748236499192  1.3140512188252975 0.22398055650819915; 2.2445029073821035 -0.5210976005765643  1.6755604234385513 0.4181320724677388],)

julia> ΔF.factors
ZeroTangent()

julia> ΔF.U
5×5 Matrix{Float64}:
 -0.705927  -0.0667569  -1.42752    -0.24132  0.0363864
 -0.477888  -1.27646    -0.652307    1.20123  0.159613
 -0.449378  -0.696023   -1.21753    -1.20449  0.23459
  0.655227   0.248137    0.0819825   1.31405  0.223981
  2.2445    -0.521098   -0.145292    1.67556  0.418132

Note that if this wasn't the current behavior, the rule for cholesky would always produce a ZeroTangent(), which it doesn't. Is there a particular surprising behavior/error you encountered that led to this issue?

@devmotion
Copy link
Member

I think that

causes some of the test errors in FluxML/Zygote.jl#1114 (e.g., in https://github.com/FluxML/Zygote.jl/runs/6021203148?check_suite_focus=true#step:6:192).

Due to being a ZeroTangent, the call in

errors. Possibly it could be useful to define 3- and 5-argument mul! with AbstractZeroTangent arguments (even though I assume one might run into many method ambiguity issues) but I don't think that (all) the examples in the Zygote tests should use a ZeroTangent there, so the primary issue seems to be that is wrong.

@sethaxen
Copy link
Member

so the primary issue seems to be that is wrong.

Can you clarify what you mean here?

@devmotion
Copy link
Member

It is a ZeroTangent but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of name U in the tangent but e.g. of name L or factors (the main difference in the Zygote implementations is actually that it accesses factors instead).

It seems this problem could occur with any hardcoded field access here, so maybe some custom _get_U(tangent) would be needed - if field U exists, it's returned, otherwise it's based on factors, and if that does not exist we transpose L (if it doesn't exist either that should automatically return a ZeroTangent).

@sethaxen
Copy link
Member

It is a ZeroTangent but I think it shouldn't. The problem is exactly the one @theogf described above: there's no field of name U in the tangent but e.g. of name L or factors (the main difference in the Zygote implementations is actually that it accesses factors instead).

It would be helpful if we had an MWE to structure this conversation around. Is there one you can construct from the Zygote failures you mentioned?

It seems this problem could occur with any hardcoded field access here, so maybe some custom _get_U(tangent) would be needed - if field U exists, it's returned, otherwise it's based on factors, and if that does not exist we transpose L (if it doesn't exist either that should automatically return a ZeroTangent).

Perhaps, but it might be cleaner to rewrite the rrule for getproperty to accumulate cotangents of factors and then for the rrule of cholesky to work with the cotangent of factors like what we do with lu:

function _lu_pullback(ΔF::Tangent, m, n, eltypeA, pivot, F)
Δfactors = ΔF.factors
Δfactors isa AbstractZero && return (NoTangent(), Δfactors, NoTangent())
factors = F.factors
∂factors = eltypeA <: Real ? real(Δfactors) : Δfactors
∂A = similar(factors)
q = min(m, n)
if m == n # square A
# ∂A = P' * (L' \ (tril(L' * ∂L, -1) + triu(∂U * U')) / U')
L = UnitLowerTriangular(factors)
U = UpperTriangular(factors)
∂U = UpperTriangular(∂factors)
tril!(copyto!(∂A, ∂factors), -1)
lmul!(L', ∂A)
copyto!(UpperTriangular(∂A), UpperTriangular(∂U * U'))
rdiv!(∂A, U')
ldiv!(L', ∂A)
elseif m < n # wide A, system is [P*A1 P*A2] = [L*U1 L*U2]
triu!(copyto!(∂A, ∂factors))
@views begin
factors1 = factors[:, 1:q]
U2 = factors[:, (q + 1):end]
∂A1 = ∂A[:, 1:q]
∂A2 = ∂A[:, (q + 1):end]
∂L = tril(∂factors[:, 1:q], -1)
end
L = UnitLowerTriangular(factors1)
U1 = UpperTriangular(factors1)
triu!(rmul!(∂A1, U1'))
∂A1 .+= tril!(mul!(lmul!(L', ∂L), ∂A2, U2', -1, 1), -1)
rdiv!(∂A1, U1')
ldiv!(L', ∂A)
else # tall A, system is [P1*A; P2*A] = [L1*U; L2*U]
tril!(copyto!(∂A, ∂factors), -1)
@views begin
factors1 = factors[1:q, :]
L2 = factors[(q + 1):end, :]
∂A1 = ∂A[1:q, :]
∂A2 = ∂A[(q + 1):end, :]
∂U = triu(∂factors[1:q, :])
end
U = UpperTriangular(factors1)
L1 = UnitLowerTriangular(factors1)
tril!(lmul!(L1', ∂A1), -1)
∂A1 .+= triu!(mul!(rmul!(∂U, U'), L2', ∂A2, -1, 1))
ldiv!(L1', ∂A1)
rdiv!(∂A, U')
end
if pivot isa LU_RowMaximum
∂A = ∂A[invperm(F.p), :]
end
return NoTangent(), ∂A, NoTangent()
end
_lu_pullback(ΔF::AbstractThunk, m, n, eltypeA, pivot, F) = _lu_pullback(unthunk(ΔF), m, n, eltypeA, pivot, F)
function rrule(
::typeof(lu), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
)
m, n = size(A)
F = lu(A, pivot; kwargs...)
lu_pullback(ȳ) = _lu_pullback(ȳ, m, n, eltype(A), pivot, F)
return F, lu_pullback
end
#####
##### functions of `LU`
#####
# this rrule is necessary because the primal mutates
function rrule(::typeof(getproperty), F::TF, x::Symbol) where {T,TF<:LU{T,<:StridedMatrix{T}}}
function getproperty_LU_pullback(ΔY)
∂factors = if x === :L
m, n = size(F.factors)
S = eltype(ΔY)
tril!([ΔY zeros(S, m, max(0, n - m))], -1)
elseif x === :U
m, n = size(F.factors)
S = eltype(ΔY)
triu!([ΔY; zeros(S, max(0, m - n), n)])
elseif x === :factors
Matrix(ΔY)
else
return (NoTangent(), NoTangent(), NoTangent())
end
∂F = Tangent{TF}(; factors=∂factors)
return NoTangent(), ∂F, NoTangent()
end
getproperty_LU_pullback(ΔY::AbstractThunk) = getproperty_LU_pullback(unthunk(ΔY))
return getproperty(F, x), getproperty_LU_pullback
end

@devmotion
Copy link
Member

Perhaps, but it might be cleaner to rewrite the rrule for getproperty to accumulate cotangents of factors and then for the rrule of cholesky to work with the cotangent of factors like what we do with lu:

Yes, that sounds simpler.

@devmotion
Copy link
Member

Is there one you can construct from the Zygote failures you mentioned?

These Zygote failures happen once one removes the Zygote adjoint for cholesky. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.

@sethaxen
Copy link
Member

Is there one you can construct from the Zygote failures you mentioned?

These Zygote failures happen once one removes the Zygote adjoint for cholesky. Then basically even the simplest examples start to fail. E.g., the linked issue above is triggered by https://github.com/FluxML/Zygote.jl/blob/af8aee4d8acc94bdfa8b9a1c7e16ef0b6a3df32e/test/gradcheck.jl#L650.

The main issue there is that the function \ depends on the adjoints for cholesky and \(::Cholesky, ::AbstractVecOrMat), both of which are composeable in Zygote. The linked PR deletes only Zygote's adjoint for cholesky, and its adjoint for \ does not compose well with ChainRules's, since it's written in terms of factors. We should actually have the \ rrule here in ChainRules.

I'll open up PRs to both use factors in the cholesky-related rrules and to migrate the rrule for \ to here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants