Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use clamptype mechanism to project onto cotangent space #965

Closed
wants to merge 11 commits into from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ NaNMath = "0.3"
Requires = "1.1"
SpecialFunctions = "0.10, 1.0"
StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
ZygoteRules = "0.2.2"
julia = "1.3"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ include("lib/broadcast.jl")
include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
include("lib/clamp.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
@init @require LogExpFunctions="2ab3a3ac-af41-5b50-aa03-7779005ae688" include("lib/logexpfunctions.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
end

_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T, size(xs)), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)

_droplike(dy, dxv) = dy
Expand Down
64 changes: 64 additions & 0 deletions src/lib/clamp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

import ZygoteRules: clamptype
# This sees a tuple of argument types, and can modify the resulting tuple of tangents

clamptype(Ts::Tuple{}, dxs::Tuple{}) = ()
clamptype(Ts::Tuple, dxs::Tuple) =
first(Ts) === GlobalRef ? clamptype(Base.tail(Ts), dxs) :
(clamptype(first(Ts), first(dxs)), clamptype(Base.tail(Ts), Base.tail(dxs))...)

clamptype(Ts::Tuple{}, dxs::Tuple) = (@error "mismatch!" dxs; dxs)
clamptype(Ts::Tuple, dxs::Tuple{}) = (@error "mismatch!" Ts; ())

# Bool, Real, Complex

# clamptype(::Type{Bool}, dx) = nothing
# clamptype(::Type{Bool}, dx::Complex) = nothing # ambiguity
# clamptype(::Type{<:AbstractArray{Bool}}, dx::AbstractArray) = nothing
# clamptype(::Type{<:AbstractArray{Bool}}, dx::AbstractArray) = (@info "bool array" summary(dx); nothing)

# clamptype(::Type{Bool}, dx) = (@info "bool, dropping" typeof(dx); nothing)
# clamptype(::Type{Bool}, dx::Complex) = (@info "bool, dropping" typeof(dx); nothing)
# clamptype(::Type{<:AbstractArray{Bool}}, dx::AbstractArray) = (@info "bool array, disabled" summary(dx); dx)

clamptype(::Type{<:Real}, dx::Complex) = real(dx)
clamptype(::Type{<:AbstractArray{<:Real}}, dx::AbstractArray) = real(dx)

using LinearAlgebra: Diagonal, UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec

# LinearAlgebra's matrix types

for Wrap in [:Diagonal, :UpperTriangular, :LowerTriangular]
@eval begin
clamptype(::Type{<:$Wrap{T,PT}}, dx::$Wrap) where {T,PT} =
clamptype(PT, dx)
clamptype(::Type{<:$Wrap{T,PT}}, dx::AbstractMatrix) where {T,PT} =
clamptype(PT, $Wrap(dx))
# not right for :UnitUpperTriangular, :UnitLowerTriangular
end
end

for (trans, Wrap) in [(transpose, :Symmetric), (Base.adjoint, :Hermitian)]
@eval begin
clamptype(::Type{<:$Wrap{T,PT}}, dx::$Wrap) where {T,PT} =
clamptype(PT, dx)
clamptype(::Type{<:$Wrap{T,PT}}, dx::AbstractMatrix) where {T,PT} =
clamptype(PT, $Wrap(_twofold($trans, dx)))
end
end

_twofold(trans, dx) = (dx .+ trans(dx)) ./ 2
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice this:

using Zygote # this PR
using LinearAlgebra, FiniteDifferences, ForwardDiff, ChainRules

gradient(x -> Symmetric([x 3x; 5x 7x])[1,2], pi)   # (3,)
gradient(A -> A[1,2], Symmetric(rand(2,2)))        # ([0.0 0.5; 0.5 0.0],)
gradient(A -> Symmetric(A)[1,2], rand(2,2))        # ([0.0 1.0; 0.0 0.0],)

grad(central_fdm(5, 1), A -> A[1,2], Symmetric(rand(2,2))) # ([0.0 1.0; 1.0 0.0],) # weird? 
grad(central_fdm(5, 1), A -> Symmetric(A)[1,2], rand(2,2)) # ([0.0 1.0; 0.0 0.0],) # fine

If you use _twofold(trans, dx) = dx .+ trans(dx) .- Diagonal(dx), then the Zygote results will double, and the 1st & 3rd will be quite obviously wrong. And this isn't a projection operator.

Is the FiniteDifferences result the right thing for a tangent not cotangent? Or is it a bug? As far as I can tell it's doing something like this:

ve, re = to_vec(Symmetric([1 3; 5 7])) # ve == [1,3,3,7]
re([1,-42,3,7]) == [1 3; 3 7]

re2(v) = Symmetric(reshape(v,2,2))
v2 = rand(1:999, 4); re(v2) == re2(v2)
re2(ForwardDiff.gradient(v -> re2(v)[1,2], ve)) == [0 1; 1 0]

Looking in ChainRules, there seem to be few tests using this behaviour of FiniteDifferences. There is a rule for Matrix, which applies (dx .+ dx' .- Diagonal(dx) to match:

rrule(Matrix, Symmetric(rand(2,2)))[2]([1 3; 5 7])[2]    == [1 8; 8 7] 
pullback(Matrix, Symmetric(rand(3,3)))[2]([1 3; 5 7])[1] == [1 4; 4 7] # with this PR

Defined here https://github.com/JuliaDiff/ChainRules.jl/blob/4e3164a3a48d4da35e0112d30be7ea9dbdaf3920/src/rulesets/LinearAlgebra/symmetric.jl#L71 where _symmetric_back is also the gradient of Symmetric where it makes more sense IMO. (Originally from Zygote, I think.)

function _twofold(trans, dx::Array{<:AbstractFloat})
@inbounds for i in axes(dx,1)
for j in i+1:lastindex(dx,2)
dx[i,j] = (dx[i,j] + trans(dx[j,i])) / 2
end
end
dx
end

clamptype(::Type{<:AdjOrTransAbsVec{T,PT}}, dx::AdjOrTransAbsVec) where {T,PT} =
clamptype(PT, dx)
clamptype(::Type{<:AdjOrTransAbsVec{T,PT}}, dx::AbstractMatrix) where {T,PT} =
clamptype(PT, transpose(vec(dx))) # sometimes wrong wrapper but avoids conjugation
89 changes: 89 additions & 0 deletions test/lib/clamp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using ZygoteRules: clamptype
using LinearAlgebra

@info "----- starting type clamp tests"

@testset "clamptype" begin

# Real & Complex
@test clamptype(Float32, 1+im) === 1
@test clamptype(ComplexF64, 1+im) === 1+im

TA = typeof(rand(3))
@test clamptype(TA, 1:3) === 1:3
@test clamptype(TA, (1:3) .+ im) isa Vector{Int}

# Boolean
# @test clamptype(Bool, 1+im) === nothing
# TB = typeof(rand(3) .> 0.5)
# @test clamptype(TB, rand(3)) === nothing
# @test clamptype(TB, Diagonal(1:3)) === nothing

# Structured, I
TD = typeof(Diagonal(1:3))
@test clamptype(TD, reshape(1:9, 3, 3)) isa Diagonal{Int,<:Vector}
@test clamptype(TD, Diagonal((1:3) .+ im)) == Diagonal(1:3)

# Structured, II
TH = typeof(Hermitian(rand(3,3) .+ im))
TS = typeof(Symmetric(rand(3,3)))
@test clamptype(TS, reshape(1:4,2,2) .+ im) == [1 2.5; 2.5 4]
AH = clamptype(TH, reshape(1:4,2,2) .+ im)
@test AH == [1 2.5; 2.5 4]
@test AH isa Hermitian{ComplexF64}
@test clamptype(TH, reshape(1:4,2,2)) isa Hermitian{Float64}

# Row vectors
TA = typeof((1:3)')
TT = typeof(transpose(1:3))
TC = typeof(adjoint(rand(3) .+ im))
@test clamptype(TA, permutedims(1:3)) isa LinearAlgebra.AdjOrTransAbsVec
@test clamptype(TA, ones(1,3) .+ im) isa LinearAlgebra.AdjOrTrans{Float64,<:Vector}
@test clamptype(TC, ones(1,3) .+ im) == [1+im 1+im 1+im]

# Tricky
# TDB = typeof(Diagonal(rand(3) .> 0.5))
# @test clamptype(TDB, rand(3,3)) === nothing
# @test clamptype(TDB, rand(ComplexF32, 3,3)) === nothing
# TAB = typeof(transpose([true, false]))
# @test clamptype(TAB, rand(3)') === nothing
end

@testset "clamped gradients" begin # only the marked tests pass on master

# Real & Complex
@test gradient(x -> abs2(x+im), 2) == (4,)
@test gradient(x -> abs2(x+im), 2+0im) == (4 + 2im,) # as before

@test gradient(x -> abs2(sum(x .+ im)), [1, 2])[1] == [6, 6]
@test gradient(x -> abs2(sum(x .+ im)), Any[1, 2])[1] == [6, 6]
@test gradient(x -> abs2(sum(x .+ im)), [1, 2+0im])[1] == [6 + 4im, 6 + 4im] # as before

# Structured, some zeros
# (if rules improve, these will end up testing them not the projection)
@test gradient(x -> sum(x .+ 1), Diagonal(rand(3)))[1] == Diagonal([1,1,1])
@test gradient(x -> sum(sqrt.(x .+ 1)./2), Diagonal(rand(3)))[1] isa Diagonal
@test gradient(x -> sum(x .+ 1), UpperTriangular(rand(3,3)))[1] == UpperTriangular(ones(3,3))

@test gradient(x -> x[1,2], LowerTriangular(rand(3,3)))[1] == zeros(3,3)
@test_broken gradient(x -> x[1,2], UnitLowerTriangular(rand(3,3)))[1] == zeros(3,3)

ld = gradient((x,y) -> sum(x * y), LowerTriangular(ones(3,3)), Diagonal(ones(3,3)))
@test ld[1] isa LowerTriangular
@test_broken ld[2] isa Diagonal

# Structured, some symmetry
@test gradient(x -> sum(x .+ 1), Symmetric(rand(3,3)))[1] isa Symmetric
@test gradient(x -> x[1,2], Symmetric(rand(3,3)))[1] == [0 1/2 0; 1/2 0 0; 0 0 0]

@test_broken gradient(x -> sum(x * x'), Symmetric(ones(3,3)))[1] isa Symmetric

# Row vector restoration
@test pullback(x -> x.+1, rand(3)')[2](ones(1,3))[1] isa LinearAlgebra.AdjOrTransAbsVec
@test pullback(x -> x.+1, rand(3)')[2]([1 2 3+im])[1] == [1 2 3]
@test pullback(x -> x.+1, rand(ComplexF64, 3)')[2]([1 2 3+im])[1] == [1 2 3+im] # as before

@test gradient(x -> x[1,2], rand(3)')[1] isa LinearAlgebra.AdjOrTransAbsVec # worked, broken by _zero change
end

@info "----- done type clamp tests"
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ end
include("lib/number.jl")
include("lib/lib.jl")
include("lib/array.jl")
include("lib/clamp.jl")
end

@testset "Features" begin
Expand Down