From 914bd92700b609b40d13eff98f51df423c8073cd Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 24 Feb 2021 22:57:10 +0000 Subject: [PATCH 01/43] Sketch project implementation --- src/ChainRulesCore.jl | 1 + src/projection.jl | 55 +++++++++++++++++++++++++++++++++++++++++++ test/projection.jl | 3 +++ test/runtests.jl | 1 + 4 files changed, 60 insertions(+) create mode 100644 src/projection.jl create mode 100644 test/projection.jl diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index e3bd53deb..5156a83fd 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") +include("projection.jl") include("config.jl") include("rules.jl") diff --git a/src/projection.jl b/src/projection.jl new file mode 100644 index 000000000..51cce00c3 --- /dev/null +++ b/src/projection.jl @@ -0,0 +1,55 @@ +using LinearAlgebra: Diagonal, diag + +""" + project(T::Type, x, dx) + +"project" `dx` onto type `T` such that it is the same size as `x`. + +It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s +onto `Array`s -- this wouldn't be possible with type information alone because the neither +`AbstractZero`s nor `T` know what size of `Array` to produce. +""" +function project end + +# Number-types +project(::Type{T}, x::T, dx::T) where {T<:Real} = dx + +project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) + +project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) + + + +# Arrays +project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx + +project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) + +function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} + return project(T, x, collect(dx)) +end + +function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} + return project.(Ref(T), x, Ref(dx)) +end + + + +# Diagonal +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} + return Diagonal(project(V, diag(x), diag(dx))) +end + +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V} + return Diagonal(project(V, diag(x), dx.diag)) +end + +function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal) + return Composite{typeof(x)}(diag=diag(dx)) +end + + + +# One use for this functionality is to make it easy to define addition between two different +# representations of the same tangent. This also makes it clear that the +Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) diff --git a/test/projection.jl b/test/projection.jl new file mode 100644 index 000000000..25b38fc1d --- /dev/null +++ b/test/projection.jl @@ -0,0 +1,3 @@ +@testset "projection" begin + +end diff --git a/test/runtests.jl b/test/runtests.jl index 090a85828..f98e971b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ using Test end include("accumulation.jl") + include("projection.jl") include("rules.jl") include("rule_definition_tools.jl") From 06678a4f64e222b5ca352e2600bae6fde439b253 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:54:38 +0100 Subject: [PATCH 02/43] change Composite to Tangent --- src/projection.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 51cce00c3..ab77978ef 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -40,16 +40,16 @@ function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) return Diagonal(project(V, diag(x), diag(dx))) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Composite) where {V} +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} return Diagonal(project(V, diag(x), dx.diag)) end -function project(::Type{<:Composite}, x::Diagonal, dx::Diagonal) - return Composite{typeof(x)}(diag=diag(dx)) +function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) + return Tangent{typeof(x)}(diag=diag(dx)) end # One use for this functionality is to make it easy to define addition between two different # representations of the same tangent. This also makes it clear that the -Base.:(+)(x::Composite{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) +Base.:(+)(x::Tangent{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) From c58f974909386311fa16d2dbc8b3eb84717504e9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:54:57 +0100 Subject: [PATCH 03/43] export project --- src/ChainRulesCore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 5156a83fd..7ba0570ab 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk # differential operations +export canonicalize, extern, unthunk, project # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk From 00020e3d00afd1b9654552164d7e023893e6f25b Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 11:55:24 +0100 Subject: [PATCH 04/43] make T optional --- src/projection.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index ab77978ef..c9fa6b14f 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,9 +1,10 @@ using LinearAlgebra: Diagonal, diag """ - project(T::Type, x, dx) + project([T::Type], x, dx) -"project" `dx` onto type `T` such that it is the same size as `x`. +"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +it is assumed to be the type of `x`. It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither @@ -11,6 +12,8 @@ onto `Array`s -- this wouldn't be possible with type information alone because t """ function project end +project(x, dx) = project(typeof(x), x, dx) + # Number-types project(::Type{T}, x::T, dx::T) where {T<:Real} = dx From 37f9253a767909e441750415e4c818f806dc8d07 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 12:05:01 +0100 Subject: [PATCH 05/43] add tests and Complex --- src/projection.jl | 2 ++ test/projection.jl | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index c9fa6b14f..e8088b8c4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -17,6 +17,8 @@ project(x, dx) = project(typeof(x), x, dx) # Number-types project(::Type{T}, x::T, dx::T) where {T<:Real} = dx +project(::Type{T}, x::T, dx::Complex) where {T<:Real} = real(dx) + project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) diff --git a/test/projection.jl b/test/projection.jl index 25b38fc1d..1fbfeeccc 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,3 +1,10 @@ @testset "projection" begin + @testset "Number types" begin + @test 3.2 == project(1.0, 3.2) + @test 3.2 == project(1.0, 3.2 + 3im) + @test 3.2f0 == project(Float32, 1.0f0, 3.2 - 3im) + @test 0.0 == project(1.1, ZeroTangent()) + @test 3.2 == project(1.0, @thunk(3.2)) + end end From 4e1b79d19d8eae25ab3245bc745139bfc8c5e26f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 15:50:24 +0100 Subject: [PATCH 06/43] workout the edge cases --- src/projection.jl | 30 +++++++++-------- test/projection.jl | 81 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 19 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index e8088b8c4..ff5a8d101 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -14,47 +14,49 @@ function project end project(x, dx) = project(typeof(x), x, dx) -# Number-types -project(::Type{T}, x::T, dx::T) where {T<:Real} = dx +# identity +project(::Type{T}, x::T, dx::T) where T = dx -project(::Type{T}, x::T, dx::Complex) where {T<:Real} = real(dx) +### AbstractZero +project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x) -project(::Type{T}, x::T, dx::AbstractZero) where {T<:Real} = zero(x) +### AbstractThunk +project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx)) -project(::Type{T}, x::T, dx::AbstractThunk) where {T<:Real} = project(x, unthunk(dx)) +### Number-types +project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx) +project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx)) # Arrays project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx +# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) +# for project(rand(2, 2), Diagonal(rand(2))) function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} return project(T, x, collect(dx)) end +# for project([Foo(0.0), Foo(0.0)], ZeroTangent()) function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} return project.(Ref(T), x, Ref(dx)) end - -# Diagonal +## Diagonal function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} return Diagonal(project(V, diag(x), diag(dx))) end - function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} return Diagonal(project(V, diag(x), dx.diag)) end +function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} + return Diagonal(project(V, diag(x), dx)) +end function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) return Tangent{typeof(x)}(diag=diag(dx)) end - - - -# One use for this functionality is to make it easy to define addition between two different -# representations of the same tangent. This also makes it clear that the -Base.:(+)(x::Tangent{<:Diagonal}, y::Diagonal) = x + project(typeof(x), x, y) diff --git a/test/projection.jl b/test/projection.jl index 1fbfeeccc..68308c396 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,10 +1,81 @@ +struct Foo + a::Float64 +end + +Base.zero(::Foo) = Foo(0.0) +Base.zero(::Type{Foo}) = "F0" + @testset "projection" begin - @testset "Number types" begin - @test 3.2 == project(1.0, 3.2) - @test 3.2 == project(1.0, 3.2 + 3im) - @test 3.2f0 == project(Float32, 1.0f0, 3.2 - 3im) - @test 0.0 == project(1.1, ZeroTangent()) + #identity + @test Foo(1.2) == project(Foo(-0.2), Foo(1.2)) + @test 3.2 == project(1.0, 3.2) + @test 2.0 + 0.0im == project(1.0im, 2.0) + + @testset "From AbstractZero" begin + @testset "to numbers" begin + @test 0.0 == project(1.1, ZeroTangent()) + @test 0.0f0 == project(1.1f0, ZeroTangent()) + end + + @testset "to arrays (dense and structured)" begin + @test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent()) + @test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent()) + @test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent()) + end + + @testset "to structs" begin + @test Foo(0.0) == project(Foo(3.2), ZeroTangent()) + end + + @testset "to arrays of structs" begin + @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent()) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent()) + end + end + + @testset "From AbstractThunk" begin @test 3.2 == project(1.0, @thunk(3.2)) + @test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2))) + @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + end + + @testset "To number types" begin + @testset "to subset" begin + @test 3.2 == project(1.0, 3.2 + 3im) + @test 3.2f0 == project(1.0f0, 3.2) + @test 3.2f0 == project(1.0f0, 3.2 - 3im) + end + + @testset "to superset" begin + @test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0) + @test 2.0 == project(2.0, 2.0f0) + end + end + + @testset "To Arrays" begin + # change eltype + @test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0]) + @test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4]) + + # from a structured array + @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) + + # from an array of specials + @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) end + + @testset "Diagonal" begin + d = Diagonal([1.0, 4.0]) + t = Tangent{Diagonal}(;diag=[1.0, 4.0]) + @test d == project(d, [1.0 2; 3 4]) + @test d == project(d, t) + @test project(Tangent, d, d) isa Tangent + + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) + @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + end + + # how to project to Upper/Lower Symmetric end From 7dc58ee3457830552c6adbadcc47deb269e5615e Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 22 Jun 2021 18:10:27 +0100 Subject: [PATCH 07/43] rename dummy struct --- test/projection.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/projection.jl b/test/projection.jl index 68308c396..030808195 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,14 +1,14 @@ -struct Foo +struct Fred a::Float64 end -Base.zero(::Foo) = Foo(0.0) -Base.zero(::Type{Foo}) = "F0" +Base.zero(::Fred) = Fred(0.0) +Base.zero(::Type{Fred}) = "F0" @testset "projection" begin #identity - @test Foo(1.2) == project(Foo(-0.2), Foo(1.2)) + @test Fred(1.2) == project(Fred(-0.2), Fred(1.2)) @test 3.2 == project(1.0, 3.2) @test 2.0 + 0.0im == project(1.0im, 2.0) @@ -25,20 +25,20 @@ Base.zero(::Type{Foo}) = "F0" end @testset "to structs" begin - @test Foo(0.0) == project(Foo(3.2), ZeroTangent()) + @test Fred(0.0) == project(Fred(3.2), ZeroTangent()) end @testset "to arrays of structs" begin - @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], ZeroTangent()) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), ZeroTangent()) + @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], ZeroTangent()) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), ZeroTangent()) end end @testset "From AbstractThunk" begin @test 3.2 == project(1.0, @thunk(3.2)) - @test Foo(3.2) == project(Foo(-0.2), @thunk(Foo(3.2))) + @test Fred(3.2) == project(Fred(-0.2), @thunk(Fred(3.2))) @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) end @testset "To number types" begin @@ -63,7 +63,7 @@ Base.zero(::Type{Foo}) = "F0" @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) # from an array of specials - @test [Foo(0.0), Foo(0.0)] == project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) + @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], [ZeroTangent(), ZeroTangent()]) end @testset "Diagonal" begin @@ -73,8 +73,8 @@ Base.zero(::Type{Foo}) = "F0" @test d == project(d, t) @test project(Tangent, d, d) isa Tangent - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) - @test Diagonal([Foo(0.0), Foo(0.0)]) == project(Diagonal([Foo(3.2,), Foo(4.2)]), @thunk(ZeroTangent())) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) + @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) end # how to project to Upper/Lower Symmetric From 3345ba91778286d9f560f6fbb7ec472bcb7767f4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 23 Jun 2021 18:23:50 +0100 Subject: [PATCH 08/43] rename project to projector --- src/projection.jl | 60 +++++++++++++++++++++++++---------------------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index ff5a8d101..340f5ca46 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,62 +1,66 @@ using LinearAlgebra: Diagonal, diag """ - project([T::Type], x, dx) + projector([T::Type], x, dx) -"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +"projector" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. -It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s +It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither `AbstractZero`s nor `T` know what size of `Array` to produce. -""" -function project end +""" # TODO change docstring to reflect projecor returns a closure +function projector end -project(x, dx) = project(typeof(x), x, dx) +projector(x, dx) = projector(typeof(x), x, dx) # identity -project(::Type{T}, x::T, dx::T) where T = dx +projector(::Type{T}, x::T, dx::T) where T = identity ### AbstractZero -project(::Type{T}, x::T, dx::AbstractZero) where T = zero(x) +projector(::Type{T}, x::T, dx::AbstractZero) where T = _ -> zero(x) ### AbstractThunk -project(::Type{T}, x::T, dx::AbstractThunk) where T = project(x, unthunk(dx)) +projector(::Type{T}, x::T, dx::AbstractThunk) where T = projector(x, unthunk(dx)) ### Number-types -project(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = T(dx) -project(::Type{T}, x::T, dx::Complex) where {T<:Real} = T(real(dx)) +projector(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = dx -> T(dx) +projector(::Type{T}, x::T, dx::Complex) where {T<:Real} = dx -> T(real(dx)) # Arrays -project(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = dx +projector(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = identity -# for project([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) -project(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = project.(Ref(T), x, dx) +# for projector([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) +projector(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = projector.(Ref(T), x, dx) # TODO -# for project(rand(2, 2), Diagonal(rand(2))) -function project(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} - return project(T, x, collect(dx)) +# for projector(rand(2, 2), Diagonal(rand(2))) +function projector(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} + return projector(T, x, collect(dx)) end -# for project([Foo(0.0), Foo(0.0)], ZeroTangent()) -function project(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} - return project.(Ref(T), x, Ref(dx)) +# for projector([Foo(0.0), Foo(0.0)], ZeroTangent()) +function projector(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} + return projector.(Ref(T), x, Ref(dx)) # TODO end ## Diagonal -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} - return Diagonal(project(V, diag(x), diag(dx))) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, diag(dx))) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} - return Diagonal(project(V, diag(x), dx.diag)) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, dx.diag)) end -function project(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} - return Diagonal(project(V, diag(x), dx)) +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} + d = diag(x) + return dx -> Diagonal(projector(V, d, dx)) end -function project(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) - return Tangent{typeof(x)}(diag=diag(dx)) +function projector(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) + T = typeof(x) + return dx -> Tangent{T}(diag=diag(dx)) end From 31d81edc819ce0c1967eafcd124b5a0389487395 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 10:16:09 +0100 Subject: [PATCH 09/43] move to projector --- src/ChainRulesCore.jl | 2 +- src/differentials/abstract_zero.jl | 1 + src/differentials/thunks.jl | 2 + src/projection.jl | 93 +++++++++--------- test/projection.jl | 145 +++++++++++++++++------------ 5 files changed, 139 insertions(+), 104 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 7ba0570ab..9cf6b7c31 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk, project # differential operations +export canonicalize, extern, unthunk, projector # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index 01dbfc8f3..fb5342dd2 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -27,6 +27,7 @@ Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) Base.getindex(z::AbstractZero, k) = z +Base.getproperty(z::AbstractZero, f::Symbol) = z """ ZeroTangent() <: AbstractZero diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index 781fff60c..ed6a3d35c 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -37,6 +37,8 @@ Base.imag(a::AbstractThunk) = imag(unthunk(a)) Base.Complex(a::AbstractThunk) = Complex(unthunk(a)) Base.Complex(a::AbstractThunk, b::AbstractThunk) = Complex(unthunk(a), unthunk(b)) +Base.getproperty(a::AbstractThunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) + Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); kws...) function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...) return mapreduce(f, op, itr, unthunk(a); kws...) diff --git a/src/projection.jl b/src/projection.jl index 340f5ca46..daedb0175 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,9 +1,9 @@ using LinearAlgebra: Diagonal, diag """ - projector([T::Type], x, dx) + projector([T::Type], x) -"projector" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, +"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s @@ -12,55 +12,62 @@ onto `Array`s -- this wouldn't be possible with type information alone because t """ # TODO change docstring to reflect projecor returns a closure function projector end -projector(x, dx) = projector(typeof(x), x, dx) +projector(x) = projector(typeof(x), x) -# identity -projector(::Type{T}, x::T, dx::T) where T = identity - -### AbstractZero -projector(::Type{T}, x::T, dx::AbstractZero) where T = _ -> zero(x) - -### AbstractThunk -projector(::Type{T}, x::T, dx::AbstractThunk) where T = projector(x, unthunk(dx)) - - -### Number-types -projector(::Type{T}, x::T, dx::T2) where {T<:Number, T2<:Number} = dx -> T(dx) -projector(::Type{T}, x::T, dx::Complex) where {T<:Real} = dx -> T(real(dx)) +# fallback +function projector(::Type{T}, x::T) where T + println("to Any") + project(dx::T) = dx + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end +# Numbers +function projector(::Type{T}, x::T) where {T<:Real} + println("to Real") + project(dx::Real) = T(dx) + project(dx::Number) = T(real(dx)) # to avoid InexactError + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end +function projector(::Type{T}, x::T) where {T<:Number} + println("to Number") + project(dx::Number) = T(dx) + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end # Arrays -projector(::Type{Array{T, N}}, x::Array{T, N}, dx::Array{T, N}) where {T<:Real, N} = identity - -# for projector([Foo(0.0), Foo(0.0)], [ZeroTangent(), ZeroTangent()]) -projector(::Type{<:Array{T}}, x::Array, dx::Array) where {T} = projector.(Ref(T), x, dx) # TODO - -# for projector(rand(2, 2), Diagonal(rand(2))) -function projector(::Type{T}, x::Array, dx::AbstractArray) where {T<:Array} - return projector(T, x, collect(dx)) +function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} + println("to Array") + element = zero(eltype(x)) + project(dx::Array{T, N}) = dx # identity + project(dx::AbstractArray) = project(collect(dx)) # from Diagonal + project(dx::Array) = projector(element).(dx) # from different element type + project(dx::AbstractZero) = zero(x) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project end -# for projector([Foo(0.0), Foo(0.0)], ZeroTangent()) -function projector(::Type{<:Array{T}}, x::Array, dx::AbstractZero) where {T} - return projector.(Ref(T), x, Ref(dx)) # TODO +# Tangent +function projector(::Type{<:Tangent}, x) + println("to Tangent") + keys = fieldnames(typeof(x)) + project(dx) = Tangent{typeof(x)}(; ((k, getproperty(dx, k)) for k in keys)...) + return project end - -## Diagonal -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractMatrix) where {V} - d = diag(x) - return dx -> Diagonal(projector(V, d, diag(dx))) -end -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::Tangent) where {V} +# Diagonal +function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} + println("to Diagonal") d = diag(x) - return dx -> Diagonal(projector(V, d, dx.diag)) -end -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal, dx::AbstractZero) where {V} - d = diag(x) - return dx -> Diagonal(projector(V, d, dx)) + project(dx::AbstractMatrix) = Diagonal(projector(V, d)(diag(dx))) + project(dx::Tangent) = Diagonal(projector(V, d)(dx.diag)) + project(dx::AbstractZero) = Diagonal(projector(V, d)(dx)) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project end -function projector(::Type{<:Tangent}, x::Diagonal, dx::Diagonal) - T = typeof(x) - return dx -> Tangent{T}(diag=diag(dx)) -end diff --git a/test/projection.jl b/test/projection.jl index 030808195..d581e7100 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -3,78 +3,103 @@ struct Fred end Base.zero(::Fred) = Fred(0.0) -Base.zero(::Type{Fred}) = "F0" +Base.zero(::Type{Fred}) = Fred(0.0) @testset "projection" begin - - #identity - @test Fred(1.2) == project(Fred(-0.2), Fred(1.2)) - @test 3.2 == project(1.0, 3.2) - @test 2.0 + 0.0im == project(1.0im, 2.0) - - @testset "From AbstractZero" begin - @testset "to numbers" begin - @test 0.0 == project(1.1, ZeroTangent()) - @test 0.0f0 == project(1.1f0, ZeroTangent()) - end - - @testset "to arrays (dense and structured)" begin - @test zeros(2, 2) == project([1.0 2; 3 4], ZeroTangent()) - @test Diagonal(zeros(2)) == project(Diagonal([1.0, 4]), ZeroTangent()) - @test Diagonal(zeros(ComplexF64, 2)) == project(Diagonal([1.0 + 0im, 4]), ZeroTangent()) - end - - @testset "to structs" begin - @test Fred(0.0) == project(Fred(3.2), ZeroTangent()) - end - - @testset "to arrays of structs" begin - @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], ZeroTangent()) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), ZeroTangent()) - end + @testset "fallback" begin + @test Fred(1.2) == projector(Fred(3.2))(Fred(1.2)) + @test Fred(0.0) == projector(Fred(3.2))(ZeroTangent()) + @test Fred(3.2) == projector(Fred(-0.2))(@thunk(Fred(3.2))) end - @testset "From AbstractThunk" begin - @test 3.2 == project(1.0, @thunk(3.2)) - @test Fred(3.2) == project(Fred(-0.2), @thunk(Fred(3.2))) - @test zeros(2) == project([1.0, 2.0], @thunk(ZeroTangent())) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) - end + @testset "to Real" begin + # Float64 + @test 3.2 == projector(1.0)(3.2) + @test 0.0 == projector(1.1)(ZeroTangent()) + @test 3.2 == projector(1.0)(@thunk(3.2)) - @testset "To number types" begin - @testset "to subset" begin - @test 3.2 == project(1.0, 3.2 + 3im) - @test 3.2f0 == project(1.0f0, 3.2) - @test 3.2f0 == project(1.0f0, 3.2 - 3im) - end - - @testset "to superset" begin - @test 2.0 + 0.0im == project(2.0 + 1.0im, 2.0) - @test 2.0 == project(2.0, 2.0f0) - end + # down + @test 3.2 == projector(1.0)(3.2 + 3im) + @test 3.2f0 == projector(1.0f0)(3.2) + @test 3.2f0 == projector(1.0f0)(3.2 - 3im) + + # up + @test 2.0 == projector(2.0)(2.0f0) end - @testset "To Arrays" begin - # change eltype - @test [1.0 2.0; 3.0 4.0] == project(zeros(2, 2), [1.0 2.0; 3.0 4.0]) - @test [1.0f0 2; 3 4] == project(zeros(Float32, 2, 2), [1.0 2; 3 4]) + @testset "to Number" begin + # Complex + @test 2.0 + 0.0im == projector(1.0im)(2.0 + 0.0im) + + # down + @test 2.0 + 0.0im == projector(1.0im)(2.0) + @test 0.0 + 0.0im == projector(1.0im)(ZeroTangent()) + @test 0.0 + 0.0im == projector(1.0im)(@thunk(ZeroTangent())) - # from a structured array - @test [1.0 0; 0 4] == project(zeros(2, 2), Diagonal([1.0, 4])) + # up + @test 2.0 + 0.0im == projector(2.0 + 1.0im)(2.0) + end + + @testset "to Array" begin + # to an array of numbers + @test [1.0 2.0; 3.0 4.0] == projector(zeros(2, 2))([1.0 2.0; 3.0 4.0]) + @test zeros(2, 2) == projector([1.0 2; 3 4])(ZeroTangent()) + @test zeros(2) == projector([1.0, 2.0])(@thunk(ZeroTangent())) + @test [1.0f0 2; 3 4] == projector(zeros(Float32, 2, 2))([1.0 2; 3 4]) + @test [1.0 0; 0 4] == projector(zeros(2, 2))(Diagonal([1.0, 4])) + + # to a array of structs + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([Fred(0.0), Fred(0.0)]) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), ZeroTangent()]) + @test [Fred(0.0), Fred(3.2)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), @thunk(Fred(3.2))]) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(1.0), Fred(2.0)])(ZeroTangent()) + @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])(@thunk(ZeroTangent())) + diagfreds = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] + @test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) + end - # from an array of specials - @test [Fred(0.0), Fred(0.0)] == project([Fred(0.0), Fred(0.0)], [ZeroTangent(), ZeroTangent()]) + @testset "to Diagonal" begin + d_F64 = Diagonal([0.0, 0.0]) + d_F32 = Diagonal([0.0f0, 0.0f0]) + d_C64 = Diagonal([0.0 + 0im, 0.0]) + d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) + + # from Matrix + @test d_F64 == projector(d_F64)(zeros(2, 2)) + @test d_F64 == projector(d_F64)(zeros(Float32, 2, 2)) + @test d_F64 == projector(d_F64)(zeros(ComplexF64, 2, 2)) + + # from Diagonal of Numbers + @test d_F64 == projector(d_F64)(d_F64) + @test d_F64 == projector(d_F64)(d_F32) + @test d_F64 == projector(d_F64)(d_C64) + + # from Diagonal of AbstractTangent + @test d_F64 == projector(d_F64)(ZeroTangent()) + @test d_C64 == projector(d_C64)(ZeroTangent()) + @test d_F64 == projector(d_F64)(@thunk(ZeroTangent())) + @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) + @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) + + # from Diagonal of structs + @test d_Fred == projector(d_Fred)(ZeroTangent()) + @test d_Fred == projector(d_Fred)(@thunk(ZeroTangent())) + @test d_Fred == projector(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) + + # from Tangent + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) + @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end - @testset "Diagonal" begin - d = Diagonal([1.0, 4.0]) - t = Tangent{Diagonal}(;diag=[1.0, 4.0]) - @test d == project(d, [1.0 2; 3 4]) - @test d == project(d, t) - @test project(Tangent, d, d) isa Tangent + @testset "to Tangent" begin + @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), Diagonal([ZeroTangent(), ZeroTangent()])) - @test Diagonal([Fred(0.0), Fred(0.0)]) == project(Diagonal([Fred(3.2,), Fred(4.2)]), @thunk(ZeroTangent())) + @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent end # how to project to Upper/Lower Symmetric From 2ea4845df1702c89b9d454e05eaeb4f6bd3d5b9f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 15:32:27 +0100 Subject: [PATCH 10/43] do not close over x (other than in the general case) --- src/projection.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index daedb0175..407d80cc0 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -44,29 +44,29 @@ end function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} println("to Array") element = zero(eltype(x)) + sizex = size(x) project(dx::Array{T, N}) = dx # identity project(dx::AbstractArray) = project(collect(dx)) # from Diagonal project(dx::Array) = projector(element).(dx) # from different element type - project(dx::AbstractZero) = zero(x) + project(dx::AbstractZero) = zeros(T, sizex...) project(dx::AbstractThunk) = project(unthunk(dx)) return project end # Tangent -function projector(::Type{<:Tangent}, x) +function projector(::Type{<:Tangent}, x::T) where {T} println("to Tangent") - keys = fieldnames(typeof(x)) - project(dx) = Tangent{typeof(x)}(; ((k, getproperty(dx, k)) for k in keys)...) + project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) return project end # Diagonal function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} println("to Diagonal") - d = diag(x) - project(dx::AbstractMatrix) = Diagonal(projector(V, d)(diag(dx))) - project(dx::Tangent) = Diagonal(projector(V, d)(dx.diag)) - project(dx::AbstractZero) = Diagonal(projector(V, d)(dx)) + projV = projector(V, diag(x)) + project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) + project(dx::Tangent) = Diagonal(projV(dx.diag)) + project(dx::AbstractZero) = Diagonal(projV(dx)) project(dx::AbstractThunk) = project(unthunk(dx)) return project end From 465e1d7bb7b2ba7e2e0e2b954e7ccd30a6f26c49 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 15:34:16 +0100 Subject: [PATCH 11/43] update docstring --- src/projection.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 407d80cc0..308f38276 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -3,13 +3,13 @@ using LinearAlgebra: Diagonal, diag """ projector([T::Type], x) -"project" `dx` onto type `T` such that it is the same size as `x`. If `T` is not provided, -it is assumed to be the type of `x`. +Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the +same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. -It's necessary to have `x` to ensure that it's possible to projector e.g. `AbstractZero`s +It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither `AbstractZero`s nor `T` know what size of `Array` to produce. -""" # TODO change docstring to reflect projecor returns a closure +""" function projector end projector(x) = projector(typeof(x), x) From 0a06dce733176c7e06c3b9ec4cdfb190092b1b5f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:26:11 +0100 Subject: [PATCH 12/43] fix getproperty --- src/differentials/thunks.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index ed6a3d35c..00f8d18c7 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -37,8 +37,6 @@ Base.imag(a::AbstractThunk) = imag(unthunk(a)) Base.Complex(a::AbstractThunk) = Complex(unthunk(a)) Base.Complex(a::AbstractThunk, b::AbstractThunk) = Complex(unthunk(a), unthunk(b)) -Base.getproperty(a::AbstractThunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) - Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); kws...) function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...) return mapreduce(f, op, itr, unthunk(a); kws...) @@ -190,6 +188,8 @@ end @inline unthunk(x::Thunk) = x.f() +Base.getproperty(a::Thunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) + Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") """ @@ -211,6 +211,8 @@ end unthunk(x::InplaceableThunk) = unthunk(x.val) +Base.getproperty(a::InplaceableThunk, f::Symbol) = f in (:val, :add!) ? getfield(a, f) : getproperty(unthunk(a), f) + function Base.show(io::IO, x::InplaceableThunk) return print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end From d822b020a7c31ce57d03df6b603dc2c18538f410 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:26:43 +0100 Subject: [PATCH 13/43] add to Tangent and to Symmetric --- src/projection.jl | 15 +++++++++++++-- test/projection.jl | 26 +++++++++++++++++--------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 308f38276..86dd7ec81 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -14,9 +14,9 @@ function projector end projector(x) = projector(typeof(x), x) -# fallback +# fallback (structs) function projector(::Type{T}, x::T) where T - println("to Any") + println("to Any, T=$T") project(dx::T) = dx project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -71,3 +71,14 @@ function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} return project end +# Symmetric +function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} + println("to Symetric") + projM = projector(M, parent(x)) + uplo = Symbol(x.uplo) + project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) + project(dx::Tangent) = Symmetric(projM(dx.data), uplo) + project(dx::AbstractZero) = Symmetric(projM(dx), uplo) + project(dx::AbstractThunk) = project(unthunk(dx)) + return project +end diff --git a/test/projection.jl b/test/projection.jl index d581e7100..4a5e57217 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -58,6 +58,16 @@ Base.zero(::Type{Fred}) = Fred(0.0) @test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) end + @testset "to Tangent" begin + @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) + @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) + + @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent + @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent + end + @testset "to Diagonal" begin d_F64 = Diagonal([0.0, 0.0]) d_F32 = Diagonal([0.0f0, 0.0f0]) @@ -92,15 +102,13 @@ Base.zero(::Type{Fred}) = Fred(0.0) @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end - @testset "to Tangent" begin - @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) + @testset "to Symmetric" begin + data = [1.0 2; 3 4] + @test Symmetric(data) == projector(Symmetric(data))(data) + @test Symmetric(data, :L) == projector(Symmetric(data, :L))(data) + @test Symmetric(Diagonal(data)) == projector(Symmetric(data))(Diagonal(diag(data))) - @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent + @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(ZeroTangent()) + @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(@thunk(ZeroTangent())) end - - # how to project to Upper/Lower Symmetric end From 25a7ceeea230438656a41c3ea55080f6949ac4aa Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:27:20 +0100 Subject: [PATCH 14/43] remove debug strings --- src/projection.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 86dd7ec81..203342387 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -16,7 +16,6 @@ projector(x) = projector(typeof(x), x) # fallback (structs) function projector(::Type{T}, x::T) where T - println("to Any, T=$T") project(dx::T) = dx project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -25,7 +24,6 @@ end # Numbers function projector(::Type{T}, x::T) where {T<:Real} - println("to Real") project(dx::Real) = T(dx) project(dx::Number) = T(real(dx)) # to avoid InexactError project(dx::AbstractZero) = zero(x) @@ -33,7 +31,6 @@ function projector(::Type{T}, x::T) where {T<:Real} return project end function projector(::Type{T}, x::T) where {T<:Number} - println("to Number") project(dx::Number) = T(dx) project(dx::AbstractZero) = zero(x) project(dx::AbstractThunk) = project(unthunk(dx)) @@ -42,7 +39,6 @@ end # Arrays function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} - println("to Array") element = zero(eltype(x)) sizex = size(x) project(dx::Array{T, N}) = dx # identity @@ -55,14 +51,12 @@ end # Tangent function projector(::Type{<:Tangent}, x::T) where {T} - println("to Tangent") project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) return project end # Diagonal function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} - println("to Diagonal") projV = projector(V, diag(x)) project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) project(dx::Tangent) = Diagonal(projV(dx.diag)) @@ -73,7 +67,6 @@ end # Symmetric function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} - println("to Symetric") projM = projector(M, parent(x)) uplo = Symbol(x.uplo) project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) From 7801e19ca8590b58da9a3700e4081000947ee0f9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 24 Jun 2021 17:41:10 +0100 Subject: [PATCH 15/43] separate out the projector --- src/projection.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 203342387..171cca1d9 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -39,11 +39,11 @@ end # Arrays function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} - element = zero(eltype(x)) sizex = size(x) + projT = projector(zero(T)) project(dx::Array{T, N}) = dx # identity project(dx::AbstractArray) = project(collect(dx)) # from Diagonal - project(dx::Array) = projector(element).(dx) # from different element type + project(dx::Array) = projT.(dx) # from different element type project(dx::AbstractZero) = zeros(T, sizex...) project(dx::AbstractThunk) = project(unthunk(dx)) return project From 9147fadddc1fe72df767b6a29c8799edb8502216 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 25 Jun 2021 15:47:51 +0100 Subject: [PATCH 16/43] implement preproject --- src/ChainRulesCore.jl | 2 +- src/projection.jl | 103 +++++++++++++++------------------ test/projection.jl | 131 ++++++++++++++++++++++++------------------ 3 files changed, 122 insertions(+), 114 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 9cf6b7c31..df10a5484 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk, projector # differential operations +export canonicalize, extern, unthunk, project, preproject # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/projection.jl b/src/projection.jl index 171cca1d9..6e333cc87 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,7 +1,7 @@ using LinearAlgebra: Diagonal, diag """ - projector([T::Type], x) + project([T::Type], dx; info) Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. @@ -9,69 +9,60 @@ same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s onto `Array`s -- this wouldn't be possible with type information alone because the neither `AbstractZero`s nor `T` know what size of `Array` to produce. -""" -function projector end +""" # TODO docstring +function project end -projector(x) = projector(typeof(x), x) +""" +""" # TODO add docstring +function preproject end # fallback (structs) -function projector(::Type{T}, x::T) where T - project(dx::T) = dx - project(dx::AbstractZero) = zero(x) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end +project(::Type{T}, dx::T) where T = dx +project(::Type{T}, dx::AbstractZero) where T = zero(T) +project(::Type{T}, dx::AbstractThunk) where T = project(T, unthunk(dx)) +function project(::Type{T}, dx::Tangent{<:T}) where T + fnames = fieldnames(T) + values = [getproperty(dx, fn) for fn in fnames] + return T((; zip(fnames, values)...)...) +end # TODO: make Tangent work recursively -# Numbers -function projector(::Type{T}, x::T) where {T<:Real} - project(dx::Real) = T(dx) - project(dx::Number) = T(real(dx)) # to avoid InexactError - project(dx::AbstractZero) = zero(x) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end -function projector(::Type{T}, x::T) where {T<:Number} - project(dx::Number) = T(dx) - project(dx::AbstractZero) = zero(x) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end +# Real +project(::Type{T}, dx::Real) where {T<:Real} = T(dx) +project(::Type{T}, dx::Number) where {T<:Real} = T(real(dx)) +project(::Type{T}, dx::AbstractZero) where {T<:Real} = zero(T) +project(::Type{T}, dx::AbstractThunk) where {T<:Real} = project(T, unthunk(dx)) +# Number +project(::Type{T}, dx::Number) where {T<:Number} = T(dx) +project(::Type{T}, dx::AbstractZero) where {T<:Number} = zero(T) +project(::Type{T}, dx::AbstractThunk) where {T<:Number} = project(T, unthunk(dx)) # Arrays -function projector(::Type{Array{T, N}}, x::Array{T, N}) where {T, N} - sizex = size(x) - projT = projector(zero(T)) - project(dx::Array{T, N}) = dx # identity - project(dx::AbstractArray) = project(collect(dx)) # from Diagonal - project(dx::Array) = projT.(dx) # from different element type - project(dx::AbstractZero) = zeros(T, sizex...) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end +preproject(x::Array) = (; size=size(x), eltype=eltype(x)) + +project(AT::Type{Array{T, N}}, dx::Array{T, N}; info) where {T, N} = dx +project(AT::Type{Array{T, N}}, dx::AbstractArray; info) where {T, N} = project(AT, collect(dx); info=info) +project(AT::Type{Array{T, N}}, dx::Array; info) where {T, N} = project.(T, dx) +project(AT::Type{Array{T, N}}, dx::AbstractZero; info) where {T, N} = zeros(T, info.size...) +project(AT::Type{Array{T, N}}, dx::AbstractThunk; info) where {T, N} = project(AT, unthunk(dx); info=info) -# Tangent -function projector(::Type{<:Tangent}, x::T) where {T} - project(dx) = Tangent{T}(; ((k, getproperty(dx, k)) for k in fieldnames(T))...) - return project -end +# Tangent # TODO: do we need this? +#function projector(::Type{<:Tangent}, x::T) where {T} +# project(dx) = Tangent{T}(; ((k, getfield(dx, k)) for k in fieldnames(T))...) +# return project +#end # Diagonal -function projector(::Type{<:Diagonal{<:Any, V}}, x::Diagonal) where {V} - projV = projector(V, diag(x)) - project(dx::AbstractMatrix) = Diagonal(projV(diag(dx))) - project(dx::Tangent) = Diagonal(projV(dx.diag)) - project(dx::AbstractZero) = Diagonal(projV(dx)) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end +preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) + +project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractMatrix; info) where {V} = Diagonal(project(V, diag(dx); info=info.Vinfo)) +project(DT::Type{<:Diagonal{<:Any, V}}, dx::Tangent; info) where {V} = Diagonal(project(V, dx.diag; info=info.Vinfo)) +project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractZero; info) where {V} = Diagonal(project(V, dx; info=info.Vinfo)) +project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractThunk; info) where {V} = project(DT, unthunk(dx); info=info) # Symmetric -function projector(::Type{<:Symmetric{<:Any, M}}, x::Symmetric) where {M} - projM = projector(M, parent(x)) - uplo = Symbol(x.uplo) - project(dx::AbstractMatrix) = Symmetric(projM(dx), uplo) - project(dx::Tangent) = Symmetric(projM(dx.data), uplo) - project(dx::AbstractZero) = Symmetric(projM(dx), uplo) - project(dx::AbstractThunk) = project(unthunk(dx)) - return project -end +preproject(x::Symmetric{<:Any, M}) where {M} = (; uplo=Symbol(x.uplo), Minfo=preproject(parent(x))) + +project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractMatrix; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) +project(ST::Type{<:Symmetric{<:Any, M}}, dx::Tangent; info) where {M} = Symmetric(project(M, dx.data; info=info.Minfo), info.uplo) +project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractZero; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) +project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractThunk; info) where {M} = project(ST, unthunk(dx); info=info) diff --git a/test/projection.jl b/test/projection.jl index 4a5e57217..bb6774059 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -7,66 +7,75 @@ Base.zero(::Type{Fred}) = Fred(0.0) @testset "projection" begin @testset "fallback" begin - @test Fred(1.2) == projector(Fred(3.2))(Fred(1.2)) - @test Fred(0.0) == projector(Fred(3.2))(ZeroTangent()) - @test Fred(3.2) == projector(Fred(-0.2))(@thunk(Fred(3.2))) + @test Fred(1.2) == project(Fred, Fred(1.2)) + @test Fred(0.0) == project(Fred, ZeroTangent()) + @test Fred(3.2) == project(Fred, @thunk(Fred(3.2))) end @testset "to Real" begin # Float64 - @test 3.2 == projector(1.0)(3.2) - @test 0.0 == projector(1.1)(ZeroTangent()) - @test 3.2 == projector(1.0)(@thunk(3.2)) + @test 3.2 == project(Float64, 3.2) + @test 0.0 == project(Float64, ZeroTangent()) + @test 3.2 == project(Float64, @thunk(3.2)) # down - @test 3.2 == projector(1.0)(3.2 + 3im) - @test 3.2f0 == projector(1.0f0)(3.2) - @test 3.2f0 == projector(1.0f0)(3.2 - 3im) + @test 3.2 == project(Float64, 3.2 + 3im) + @test 3.2f0 == project(Float32, 3.2) + @test 3.2f0 == project(Float32, 3.2 - 3im) # up - @test 2.0 == projector(2.0)(2.0f0) + @test 2.0 == project(Float64, 2.0f0) end @testset "to Number" begin # Complex - @test 2.0 + 0.0im == projector(1.0im)(2.0 + 0.0im) + @test 2.0 + 0.0im == project(ComplexF64, 2.0 + 0.0im) # down - @test 2.0 + 0.0im == projector(1.0im)(2.0) - @test 0.0 + 0.0im == projector(1.0im)(ZeroTangent()) - @test 0.0 + 0.0im == projector(1.0im)(@thunk(ZeroTangent())) + @test 2.0 + 0.0im == project(ComplexF64, 2.0) + @test 0.0 + 0.0im == project(ComplexF64, ZeroTangent()) + @test 0.0 + 0.0im == project(ComplexF64, @thunk(ZeroTangent())) # up - @test 2.0 + 0.0im == projector(2.0 + 1.0im)(2.0) + @test 2.0 + 0.0im == project(ComplexF64, 2.0) end @testset "to Array" begin # to an array of numbers - @test [1.0 2.0; 3.0 4.0] == projector(zeros(2, 2))([1.0 2.0; 3.0 4.0]) - @test zeros(2, 2) == projector([1.0 2; 3 4])(ZeroTangent()) - @test zeros(2) == projector([1.0, 2.0])(@thunk(ZeroTangent())) - @test [1.0f0 2; 3 4] == projector(zeros(Float32, 2, 2))([1.0 2; 3 4]) - @test [1.0 0; 0 4] == projector(zeros(2, 2))(Diagonal([1.0, 4])) + x = zeros(2, 2) + @test [1.0 2.0; 3.0 4.0] == project(typeof(x), [1.0 2.0; 3.0 4.0]; info=preproject(x)) + @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) + + x = zeros(2) + @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) + + x = zeros(Float32, 2, 2) + @test x == project(typeof(x), [0.0 0; 0 0]; info=preproject(x)) + + x = [1.0 0; 0 4] + @test x == project(typeof(x), Diagonal([1.0, 4]); info=preproject(x)) # to a array of structs - @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([Fred(0.0), Fred(0.0)]) - @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), ZeroTangent()]) - @test [Fred(0.0), Fred(3.2)] == projector([Fred(0.0), Fred(0.0)])([ZeroTangent(), @thunk(Fred(3.2))]) - @test [Fred(0.0), Fred(0.0)] == projector([Fred(1.0), Fred(2.0)])(ZeroTangent()) - @test [Fred(0.0), Fred(0.0)] == projector([Fred(0.0), Fred(0.0)])(@thunk(ZeroTangent())) - diagfreds = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] - @test diagfreds == projector(diagfreds)(Diagonal([Fred(1.0), Fred(4.0)])) + x = [Fred(0.0), Fred(0.0)] + @test x == project(typeof(x), [Fred(0.0), Fred(0.0)]; info=preproject(x)) + @test x == project(typeof(x), [ZeroTangent(), ZeroTangent()]; info=preproject(x)) + @test x == project(typeof(x), [ZeroTangent(), @thunk(Fred(0.0))]; info=preproject(x)) + @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) + @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) + + x = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] + @test x == project(typeof(x), Diagonal([Fred(1.0), Fred(4.0)]); info=preproject(x)) end - @testset "to Tangent" begin - @test Tangent{Fred}(; a = 3.2,) == projector(Tangent, Fred(3.2))(Fred(3.2)) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(ZeroTangent()) - @test Tangent{Fred}(; a = ZeroTangent(),) == projector(Tangent, Fred(3.2))(@thunk(ZeroTangent())) + #@testset "to Tangent" begin + # @test Tangent{Fred}(; a = 3.2,) == project(Tangent, Fred(3.2), Fred(3.2)) + # @test Tangent{Fred}(; a = ZeroTangent(),) == project(Tangent, Fred(3.2), ZeroTangent()) + # @test Tangent{Fred}(; a = ZeroTangent(),) == project(Tangent, Fred(3.2), @thunk(ZeroTangent())) - @test projector(Tangent, Diagonal(zeros(2)))(Diagonal([1.0f0, 2.0f0])) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(ZeroTangent()) isa Tangent - @test projector(Tangent, Diagonal(zeros(2)))(@thunk(ZeroTangent())) isa Tangent - end + # @test project(Tangent, Diagonal(zeros(2)), Diagonal([1.0f0, 2.0f0])) isa Tangent + # @test project(Tangent, Diagonal(zeros(2)), ZeroTangent()) isa Tangent + # @test project(Tangent, Diagonal(zeros(2)), @thunk(ZeroTangent())) isa Tangent + #end @testset "to Diagonal" begin d_F64 = Diagonal([0.0, 0.0]) @@ -75,40 +84,48 @@ Base.zero(::Type{Fred}) = Fred(0.0) d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) # from Matrix - @test d_F64 == projector(d_F64)(zeros(2, 2)) - @test d_F64 == projector(d_F64)(zeros(Float32, 2, 2)) - @test d_F64 == projector(d_F64)(zeros(ComplexF64, 2, 2)) + @test d_F64 == project(typeof(d_F64), zeros(2, 2); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), zeros(Float32, 2, 2); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), zeros(ComplexF64, 2, 2); info=preproject(d_F64)) # from Diagonal of Numbers - @test d_F64 == projector(d_F64)(d_F64) - @test d_F64 == projector(d_F64)(d_F32) - @test d_F64 == projector(d_F64)(d_C64) + @test d_F64 == project(typeof(d_F64), d_F64; info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), d_F32; info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), d_C64; info=preproject(d_F64)) # from Diagonal of AbstractTangent - @test d_F64 == projector(d_F64)(ZeroTangent()) - @test d_C64 == projector(d_C64)(ZeroTangent()) - @test d_F64 == projector(d_F64)(@thunk(ZeroTangent())) - @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) - @test d_F64 == projector(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) + @test d_F64 == project(typeof(d_F64), ZeroTangent(); info=preproject(d_F64)) + @test d_C64 == project(typeof(d_C64), ZeroTangent(); info=preproject(d_C64)) + @test d_F64 == project(typeof(d_F64), @thunk(ZeroTangent()); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), Diagonal([ZeroTangent(), ZeroTangent()]); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), Diagonal([ZeroTangent(), @thunk(ZeroTangent())]); info=preproject(d_F64)) # from Diagonal of structs - @test d_Fred == projector(d_Fred)(ZeroTangent()) - @test d_Fred == projector(d_Fred)(@thunk(ZeroTangent())) - @test d_Fred == projector(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) + @test d_Fred == project(typeof(d_Fred), ZeroTangent(); info=preproject(d_Fred)) + @test d_Fred == project(typeof(d_Fred), @thunk(ZeroTangent()); info=preproject(d_Fred)) + @test d_Fred == project(typeof(d_Fred), Diagonal([ZeroTangent(), ZeroTangent()]); info=preproject(d_Fred)) # from Tangent - @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) - @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) - @test d_F64 == projector(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) + @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[0.0, 0.0]); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[0.0f0, 0.0f0]); info=preproject(d_F64)) + @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())]); info=preproject(d_F64)) end @testset "to Symmetric" begin data = [1.0 2; 3 4] - @test Symmetric(data) == projector(Symmetric(data))(data) - @test Symmetric(data, :L) == projector(Symmetric(data, :L))(data) - @test Symmetric(Diagonal(data)) == projector(Symmetric(data))(Diagonal(diag(data))) + x = Symmetric(data) + @test x == project(typeof(x), data; info=preproject(x)) + + x = Symmetric(data, :L) + @test x == project(typeof(x), data; info=preproject(x)) + + data = [1.0 0; 0 4] + x = Symmetric(data) + @test x == project(typeof(x), Diagonal([1.0, 4.0]); info=preproject(x)) - @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(ZeroTangent()) - @test Symmetric(zeros(2, 2)) == projector(Symmetric(data))(@thunk(ZeroTangent())) + data = [0.0 0; 0 0] + x = Symmetric(data) + @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) + @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) end end From cc2f199a210cdc73bc69e9cff99dae7261fea99f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 25 Jun 2021 15:50:05 +0100 Subject: [PATCH 17/43] remove getproperty for thunks --- src/differentials/thunks.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/differentials/thunks.jl b/src/differentials/thunks.jl index 00f8d18c7..781fff60c 100644 --- a/src/differentials/thunks.jl +++ b/src/differentials/thunks.jl @@ -188,8 +188,6 @@ end @inline unthunk(x::Thunk) = x.f() -Base.getproperty(a::Thunk, f::Symbol) = f === :f ? getfield(a, f) : getproperty(unthunk(a), f) - Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))") """ @@ -211,8 +209,6 @@ end unthunk(x::InplaceableThunk) = unthunk(x.val) -Base.getproperty(a::InplaceableThunk, f::Symbol) = f in (:val, :add!) ? getfield(a, f) : getproperty(unthunk(a), f) - function Base.show(io::IO, x::InplaceableThunk) return print(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") end From 2aa385903fa7810b22ffc779d9798db2c80a40ce Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 25 Jun 2021 16:53:07 +0100 Subject: [PATCH 18/43] remove to Tangent --- src/differentials/abstract_zero.jl | 1 - src/projection.jl | 6 ------ test/projection.jl | 11 +---------- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/differentials/abstract_zero.jl b/src/differentials/abstract_zero.jl index fb5342dd2..01dbfc8f3 100644 --- a/src/differentials/abstract_zero.jl +++ b/src/differentials/abstract_zero.jl @@ -27,7 +27,6 @@ Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where T <: Number = zero(T) Base.getindex(z::AbstractZero, k) = z -Base.getproperty(z::AbstractZero, f::Symbol) = z """ ZeroTangent() <: AbstractZero diff --git a/src/projection.jl b/src/projection.jl index 6e333cc87..a0fc0ad66 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -45,12 +45,6 @@ project(AT::Type{Array{T, N}}, dx::Array; info) where {T, N} = project.(T, dx) project(AT::Type{Array{T, N}}, dx::AbstractZero; info) where {T, N} = zeros(T, info.size...) project(AT::Type{Array{T, N}}, dx::AbstractThunk; info) where {T, N} = project(AT, unthunk(dx); info=info) -# Tangent # TODO: do we need this? -#function projector(::Type{<:Tangent}, x::T) where {T} -# project(dx) = Tangent{T}(; ((k, getfield(dx, k)) for k in fieldnames(T))...) -# return project -#end - # Diagonal preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) diff --git a/test/projection.jl b/test/projection.jl index bb6774059..eeb623ca2 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -67,16 +67,6 @@ Base.zero(::Type{Fred}) = Fred(0.0) @test x == project(typeof(x), Diagonal([Fred(1.0), Fred(4.0)]); info=preproject(x)) end - #@testset "to Tangent" begin - # @test Tangent{Fred}(; a = 3.2,) == project(Tangent, Fred(3.2), Fred(3.2)) - # @test Tangent{Fred}(; a = ZeroTangent(),) == project(Tangent, Fred(3.2), ZeroTangent()) - # @test Tangent{Fred}(; a = ZeroTangent(),) == project(Tangent, Fred(3.2), @thunk(ZeroTangent())) - - # @test project(Tangent, Diagonal(zeros(2)), Diagonal([1.0f0, 2.0f0])) isa Tangent - # @test project(Tangent, Diagonal(zeros(2)), ZeroTangent()) isa Tangent - # @test project(Tangent, Diagonal(zeros(2)), @thunk(ZeroTangent())) isa Tangent - #end - @testset "to Diagonal" begin d_F64 = Diagonal([0.0, 0.0]) d_F32 = Diagonal([0.0f0, 0.0f0]) @@ -113,6 +103,7 @@ Base.zero(::Type{Fred}) = Fred(0.0) @testset "to Symmetric" begin data = [1.0 2; 3 4] + x = Symmetric(data) @test x == project(typeof(x), data; info=preproject(x)) From 44ef266ad733bbef394d2194acf176fe86608adc Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 25 Jun 2021 17:03:20 +0100 Subject: [PATCH 19/43] fix docstrings --- src/projection.jl | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index a0fc0ad66..f34f0e6d0 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,20 +1,29 @@ using LinearAlgebra: Diagonal, diag + """ - project([T::Type], dx; info) + preproject(x) -Returns a `project(dx)` closure which maps `dx` onto type `T`, such that it is the -same size as `x`. If `T` is not provided, it is assumed to be the type of `x`. +Returns a NamedTuple containing information needed to [`project`](@ref) a differential `dx` +onto the type `T` for a primal `x`. For example, when projecting `dx=ZeroTangent()` on an +array `T=Array{T, N}`, the size of `x` is not available from `T`. +""" +function preproject end -It's necessary to have `x` to ensure that it's possible to project e.g. `AbstractZero`s -onto `Array`s -- this wouldn't be possible with type information alone because the neither -`AbstractZero`s nor `T` know what size of `Array` to produce. -""" # TODO docstring -function project end +preproject(x::Array) = (; size=size(x), eltype=eltype(x)) + +preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) + +preproject(x::Symmetric{<:Any, M}) where {M} = (; uplo=Symbol(x.uplo), Minfo=preproject(parent(x))) """ -""" # TODO add docstring -function preproject end + project(T::Type, dx; [info]) + +Projects the differential `dx` for primal `x` onto type `T`. The kwarg `info` contains +information about the primal `x` that is needed to project onto `T`, e.g. the size of an +array. It is obtained from `preproject(x)`. +""" +function project end # fallback (structs) project(::Type{T}, dx::T) where T = dx @@ -37,8 +46,6 @@ project(::Type{T}, dx::AbstractZero) where {T<:Number} = zero(T) project(::Type{T}, dx::AbstractThunk) where {T<:Number} = project(T, unthunk(dx)) # Arrays -preproject(x::Array) = (; size=size(x), eltype=eltype(x)) - project(AT::Type{Array{T, N}}, dx::Array{T, N}; info) where {T, N} = dx project(AT::Type{Array{T, N}}, dx::AbstractArray; info) where {T, N} = project(AT, collect(dx); info=info) project(AT::Type{Array{T, N}}, dx::Array; info) where {T, N} = project.(T, dx) @@ -46,16 +53,12 @@ project(AT::Type{Array{T, N}}, dx::AbstractZero; info) where {T, N} = zeros(T, i project(AT::Type{Array{T, N}}, dx::AbstractThunk; info) where {T, N} = project(AT, unthunk(dx); info=info) # Diagonal -preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) - project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractMatrix; info) where {V} = Diagonal(project(V, diag(dx); info=info.Vinfo)) project(DT::Type{<:Diagonal{<:Any, V}}, dx::Tangent; info) where {V} = Diagonal(project(V, dx.diag; info=info.Vinfo)) project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractZero; info) where {V} = Diagonal(project(V, dx; info=info.Vinfo)) project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractThunk; info) where {V} = project(DT, unthunk(dx); info=info) # Symmetric -preproject(x::Symmetric{<:Any, M}) where {M} = (; uplo=Symbol(x.uplo), Minfo=preproject(parent(x))) - project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractMatrix; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) project(ST::Type{<:Symmetric{<:Any, M}}, dx::Tangent; info) where {M} = Symmetric(project(M, dx.data; info=info.Minfo), info.uplo) project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractZero; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) From d8848f5d288c27d809cb7b4dd790473c5b0afd41 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 25 Jun 2021 18:28:02 +0100 Subject: [PATCH 20/43] project nested structs --- src/projection.jl | 45 ++++++++++++++++++++++++++++----------------- test/projection.jl | 25 ++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index f34f0e6d0..2ebf0601a 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -10,6 +10,14 @@ array `T=Array{T, N}`, the size of `x` is not available from `T`. """ function preproject end +function preproject(x::T) where {T} + fnames = fieldnames(T) + values = [getproperty(x, fn) for fn in fnames] + types = typeof.(values) + infos = preproject.(values) + return (; zip(fnames, collect(zip(types, infos)))...) +end + preproject(x::Array) = (; size=size(x), eltype=eltype(x)) preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) @@ -17,7 +25,7 @@ preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) preproject(x::Symmetric{<:Any, M}) where {M} = (; uplo=Symbol(x.uplo), Minfo=preproject(parent(x))) """ - project(T::Type, dx; [info]) + project(T::Type, dx; info) Projects the differential `dx` for primal `x` onto type `T`. The kwarg `info` contains information about the primal `x` that is needed to project onto `T`, e.g. the size of an @@ -26,29 +34,32 @@ array. It is obtained from `preproject(x)`. function project end # fallback (structs) -project(::Type{T}, dx::T) where T = dx -project(::Type{T}, dx::AbstractZero) where T = zero(T) -project(::Type{T}, dx::AbstractThunk) where T = project(T, unthunk(dx)) -function project(::Type{T}, dx::Tangent{<:T}) where T +project(::Type{T}, dx::T; info=nothing) where T = dx +project(::Type{T}, dx::AbstractZero; info=nothing) where T = zero(T) +project(::Type{T}, dx::AbstractThunk; info=nothing) where T = project(T, unthunk(dx)) +function project(::Type{T}, dx::Tangent; info) where {T} fnames = fieldnames(T) - values = [getproperty(dx, fn) for fn in fnames] - return T((; zip(fnames, values)...)...) -end # TODO: make Tangent work recursively + fdxs = [getproperty(dx, fn) for fn in fnames] + ftypes = [first(i) for i in info] + finfos = [last(i) for i in info] + proj_values = [project(t, dx; info=i) for (t, dx, i) in zip(ftypes, fdxs, finfos)] + return T((; zip(fnames, proj_values)...)...) +end # Real -project(::Type{T}, dx::Real) where {T<:Real} = T(dx) -project(::Type{T}, dx::Number) where {T<:Real} = T(real(dx)) -project(::Type{T}, dx::AbstractZero) where {T<:Real} = zero(T) -project(::Type{T}, dx::AbstractThunk) where {T<:Real} = project(T, unthunk(dx)) +project(::Type{T}, dx::Real; info=nothing) where {T<:Real} = T(dx) +project(::Type{T}, dx::Number; info=nothing) where {T<:Real} = T(real(dx)) +project(::Type{T}, dx::AbstractZero; info=nothing) where {T<:Real} = zero(T) +project(::Type{T}, dx::AbstractThunk; info=nothing) where {T<:Real} = project(T, unthunk(dx)) # Number -project(::Type{T}, dx::Number) where {T<:Number} = T(dx) -project(::Type{T}, dx::AbstractZero) where {T<:Number} = zero(T) -project(::Type{T}, dx::AbstractThunk) where {T<:Number} = project(T, unthunk(dx)) +project(::Type{T}, dx::Number; info=nothing) where {T<:Number} = T(dx) +project(::Type{T}, dx::AbstractZero; info=nothing) where {T<:Number} = zero(T) +project(::Type{T}, dx::AbstractThunk; info=nothing) where {T<:Number} = project(T, unthunk(dx)) # Arrays -project(AT::Type{Array{T, N}}, dx::Array{T, N}; info) where {T, N} = dx +project(AT::Type{Array{T, N}}, dx::Array{T, N}; info=nothing) where {T, N} = dx project(AT::Type{Array{T, N}}, dx::AbstractArray; info) where {T, N} = project(AT, collect(dx); info=info) -project(AT::Type{Array{T, N}}, dx::Array; info) where {T, N} = project.(T, dx) +project(AT::Type{Array{T, N}}, dx::Array; info=nothing) where {T, N} = project.(T, dx) project(AT::Type{Array{T, N}}, dx::AbstractZero; info) where {T, N} = zeros(T, info.size...) project(AT::Type{Array{T, N}}, dx::AbstractThunk; info) where {T, N} = project(AT, unthunk(dx); info=info) diff --git a/test/projection.jl b/test/projection.jl index eeb623ca2..6af3a02eb 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -1,15 +1,38 @@ struct Fred a::Float64 end - Base.zero(::Fred) = Fred(0.0) Base.zero(::Type{Fred}) = Fred(0.0) +struct Freddy{T, N} + a::Array{T, N} +end +Base.:(==)(a::Freddy, b::Freddy) = a.a == b.a + +struct Mary + a::Fred +end +#Base.zero(::Mary) = Mary(zero(Fred)) +#Base.zero(::Type{Mary}) = Mary(zero(Fred)) + @testset "projection" begin @testset "fallback" begin @test Fred(1.2) == project(Fred, Fred(1.2)) @test Fred(0.0) == project(Fred, ZeroTangent()) @test Fred(3.2) == project(Fred, @thunk(Fred(3.2))) + @test Fred(1.2) == project(Fred, Tangent{Fred}(;a=1.2); info=preproject(Fred(1.0))) + + # struct with complicated field + x = Freddy(zeros(2,2)) + dx = Tangent{Freddy}(; a=ZeroTangent()) + @test x == project(typeof(x), dx; info=preproject(x)) + + # nested structs + f = Fred(0.0) + tf = Tangent{Fred}(;a=ZeroTangent()) + m = Mary(f) + dm = Tangent{Mary}(;a=tf) + @test m == project(typeof(m), dm; info=preproject(m)) end @testset "to Real" begin From 88da9c67d5dc21735a37d773ce4164514b617bd1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 29 Jun 2021 16:46:29 +0100 Subject: [PATCH 21/43] Change from preproject to ProjectTo functor --- src/ChainRulesCore.jl | 2 +- src/differentials/composite.jl | 1 + src/projection.jl | 126 +++++++++++++++++++-------------- test/projection.jl | 113 ++++++++++++++++------------- 4 files changed, 135 insertions(+), 107 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index df10a5484..1582e7004 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk, project, preproject # differential operations +export ProjectTo, canonicalize, extern, unthunk # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index af1a05d1f..48c1b7c44 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -129,6 +129,7 @@ backing(x::NamedTuple) = x backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) +# For generic structs function backing(x::T)::NamedTuple where T # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. diff --git a/src/projection.jl b/src/projection.jl index 2ebf0601a..a995529f8 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,76 +1,92 @@ using LinearAlgebra: Diagonal, diag +struct ProjectTo{P, D<:NamedTuple} + info::D +end +ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info) +ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs)) -""" - preproject(x) +backing(project::ProjectTo) = getfield(project, :info) +Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) +Base.propertynames(p::ProjectTo) = propertynames(backing(p)) -Returns a NamedTuple containing information needed to [`project`](@ref) a differential `dx` -onto the type `T` for a primal `x`. For example, when projecting `dx=ZeroTangent()` on an -array `T=Array{T, N}`, the size of `x` is not available from `T`. -""" -function preproject end - -function preproject(x::T) where {T} - fnames = fieldnames(T) - values = [getproperty(x, fn) for fn in fnames] - types = typeof.(values) - infos = preproject.(values) - return (; zip(fnames, collect(zip(types, infos)))...) +function Base.show(io::IO, project::ProjectTo{T}) where T + print(io, "ProjectTo{") + show(io, T) + print(io, "}") + if isempty(backing(project)) + print(io, "()") + else + show(io, backing(project)) + end end -preproject(x::Array) = (; size=size(x), eltype=eltype(x)) -preproject(x::Diagonal{<:Any, V}) where {V} = (; Vinfo=preproject(diag(x))) +""" + ProjectTo(x) -preproject(x::Symmetric{<:Any, M}) where {M} = (; uplo=Symbol(x.uplo), Minfo=preproject(parent(x))) +Returns a `ProjectTo{P,...}` functor able to project a differential `dx` onto the type `T` +for a primal `x`. +This functor encloses over what ever is needed to be able to be able to do that projection. +For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` +is not available from `P`, so it is stored in the functor. +""" +function ProjectTo end """ - project(T::Type, dx; info) + (::ProjectTo{T})(dx) -Projects the differential `dx` for primal `x` onto type `T`. The kwarg `info` contains -information about the primal `x` that is needed to project onto `T`, e.g. the size of an -array. It is obtained from `preproject(x)`. +Projects the differential `dx` on the onto type `T`. +`ProjectTo{T}` is a functor that knows how to perform this projection. """ -function project end +function (::ProjectTo) end + +# Generic +(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) +(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't +(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) # fallback (structs) -project(::Type{T}, dx::T; info=nothing) where T = dx -project(::Type{T}, dx::AbstractZero; info=nothing) where T = zero(T) -project(::Type{T}, dx::AbstractThunk; info=nothing) where T = project(T, unthunk(dx)) -function project(::Type{T}, dx::Tangent; info) where {T} - fnames = fieldnames(T) - fdxs = [getproperty(dx, fn) for fn in fnames] - ftypes = [first(i) for i in info] - finfos = [last(i) for i in info] - proj_values = [project(t, dx; info=i) for (t, dx, i) in zip(ftypes, fdxs, finfos)] - return T((; zip(fnames, proj_values)...)...) +function ProjectTo(x::T) where {T} + # Generic fallback for structs, recursively make `ProjectTo`s all their fields + fields_nt::NamedTuple = backing(x) + return ProjectTo{T}(map(ProjectTo, fields_nt)) +end +function (project::ProjectTo{T})(dx::Tangent) where {T} + sub_projects = backing(project) + sub_dxs = backing(dx) + _call(f, x) = f(x) + return construct(T, map(_call, sub_projects, sub_dxs)) end -# Real -project(::Type{T}, dx::Real; info=nothing) where {T<:Real} = T(dx) -project(::Type{T}, dx::Number; info=nothing) where {T<:Real} = T(real(dx)) -project(::Type{T}, dx::AbstractZero; info=nothing) where {T<:Real} = zero(T) -project(::Type{T}, dx::AbstractThunk; info=nothing) where {T<:Real} = project(T, unthunk(dx)) -# Number -project(::Type{T}, dx::Number; info=nothing) where {T<:Number} = T(dx) -project(::Type{T}, dx::AbstractZero; info=nothing) where {T<:Number} = zero(T) -project(::Type{T}, dx::AbstractThunk; info=nothing) where {T<:Number} = project(T, unthunk(dx)) +# Arrays +ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) +function (project::ProjectTo{T})(dx::Array) where {T<:Array} + _call(f, x) = f(x) + return T(map(_call, project.elements, dx)) +end +(project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T} = zeros(T, size(project.elements)) +(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) -# Arrays -project(AT::Type{Array{T, N}}, dx::Array{T, N}; info=nothing) where {T, N} = dx -project(AT::Type{Array{T, N}}, dx::AbstractArray; info) where {T, N} = project(AT, collect(dx); info=info) -project(AT::Type{Array{T, N}}, dx::Array; info=nothing) where {T, N} = project.(T, dx) -project(AT::Type{Array{T, N}}, dx::AbstractZero; info) where {T, N} = zeros(T, info.size...) -project(AT::Type{Array{T, N}}, dx::AbstractThunk; info) where {T, N} = project(AT, unthunk(dx); info=info) +# Arrays{<:Number}: optimized case so we don't need a projector per element +ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; size=size(x)) +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = ProjectTo(T).(dx) +(project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) # Diagonal -project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractMatrix; info) where {V} = Diagonal(project(V, diag(dx); info=info.Vinfo)) -project(DT::Type{<:Diagonal{<:Any, V}}, dx::Tangent; info) where {V} = Diagonal(project(V, dx.diag; info=info.Vinfo)) -project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractZero; info) where {V} = Diagonal(project(V, dx; info=info.Vinfo)) -project(DT::Type{<:Diagonal{<:Any, V}}, dx::AbstractThunk; info) where {V} = project(DT, unthunk(dx); info=info) +ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) +(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) +(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) # Symmetric -project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractMatrix; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) -project(ST::Type{<:Symmetric{<:Any, M}}, dx::Tangent; info) where {M} = Symmetric(project(M, dx.data; info=info.Minfo), info.uplo) -project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractZero; info) where {M} = Symmetric(project(M, dx; info=info.Minfo), info.uplo) -project(ST::Type{<:Symmetric{<:Any, M}}, dx::AbstractThunk; info) where {M} = project(ST, unthunk(dx); info=info) +ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.parent(dx), project.uplo) +(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.parent(dx), project.uplo) + +# Number +ProjectTo(::T) where {T<:Number} = ProjectTo(T) +# As a special convience for `Number` subtypes we allow `ProjectTo` to be constructed from +# the type only. TODO: do we really want to allow this? +ProjectTo(::Type{T}) where {T<:Number} = ProjectTo{T}() +(::ProjectTo{<:T})(dx::Number) where {T<:Number} = convert(T, dx) +(::ProjectTo{<:T})(dx::Number) where {T<:Real} = convert(T, real(dx)) diff --git a/test/projection.jl b/test/projection.jl index 6af3a02eb..a119464c5 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -16,78 +16,89 @@ end #Base.zero(::Type{Mary}) = Mary(zero(Fred)) @testset "projection" begin + @testset "display" begin + @test startswith(repr(ProjectTo(Fred(1.1))), "ProjectTo{Fred}(") + @test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()" + end + @testset "fallback" begin - @test Fred(1.2) == project(Fred, Fred(1.2)) - @test Fred(0.0) == project(Fred, ZeroTangent()) - @test Fred(3.2) == project(Fred, @thunk(Fred(3.2))) - @test Fred(1.2) == project(Fred, Tangent{Fred}(;a=1.2); info=preproject(Fred(1.0))) + @test Fred(1.2) == ProjectTo(Fred(1.1))(Fred(1.2)) + @test Fred(0.0) == ProjectTo(Fred(1.1))(ZeroTangent()) + @test Fred(3.2) == ProjectTo(Fred(1.1))(@thunk(Fred(3.2))) + @test Fred(1.2) == ProjectTo(Fred(1.1))(Tangent{Fred}(;a=1.2)) # struct with complicated field x = Freddy(zeros(2,2)) dx = Tangent{Freddy}(; a=ZeroTangent()) - @test x == project(typeof(x), dx; info=preproject(x)) + @test x == ProjectTo(x)(dx) # nested structs f = Fred(0.0) tf = Tangent{Fred}(;a=ZeroTangent()) m = Mary(f) dm = Tangent{Mary}(;a=tf) - @test m == project(typeof(m), dm; info=preproject(m)) + @test m == ProjectTo(m)(dm) end @testset "to Real" begin + # to type as shorthand for passing primal + @test ProjectTo(Float64) == ProjectTo(1.2) + # Float64 - @test 3.2 == project(Float64, 3.2) - @test 0.0 == project(Float64, ZeroTangent()) - @test 3.2 == project(Float64, @thunk(3.2)) + @test 3.2 == ProjectTo(Float64)(3.2) + @test 0.0 == ProjectTo(Float64)(ZeroTangent()) + @test 3.2 == ProjectTo(Float64)(@thunk(3.2)) # down - @test 3.2 == project(Float64, 3.2 + 3im) - @test 3.2f0 == project(Float32, 3.2) - @test 3.2f0 == project(Float32, 3.2 - 3im) + @test 3.2 == ProjectTo(Float64)(3.2 + 3im) + @test 3.2f0 == ProjectTo(Float32)(3.2) + @test 3.2f0 == ProjectTo(Float32)(3.2 - 3im) # up - @test 2.0 == project(Float64, 2.0f0) + @test 2.0 == ProjectTo(Float64)(2.0f0) end @testset "to Number" begin + # To type, as short-hand for passing primal + @test ProjectTo(ComplexF64) == ProjectTo(1.0 + 2.0im) + # Complex - @test 2.0 + 0.0im == project(ComplexF64, 2.0 + 0.0im) + @test 2.0 + 4.0im == ProjectTo(ComplexF64)(2.0 + 4.0im) # down - @test 2.0 + 0.0im == project(ComplexF64, 2.0) - @test 0.0 + 0.0im == project(ComplexF64, ZeroTangent()) - @test 0.0 + 0.0im == project(ComplexF64, @thunk(ZeroTangent())) + @test 2.0 + 0.0im == ProjectTo(ComplexF64)(2.0) + @test 0.0 + 0.0im == ProjectTo(ComplexF64)(ZeroTangent()) + @test 0.0 + 0.0im == ProjectTo(ComplexF64)(@thunk(ZeroTangent())) # up - @test 2.0 + 0.0im == project(ComplexF64, 2.0) + @test 2.0 + 0.0im == ProjectTo(ComplexF64)(2.0) end @testset "to Array" begin # to an array of numbers x = zeros(2, 2) - @test [1.0 2.0; 3.0 4.0] == project(typeof(x), [1.0 2.0; 3.0 4.0]; info=preproject(x)) - @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) + @test [1.0 2.0; 3.0 4.0] == ProjectTo(x)([1.0 2.0; 3.0 4.0]) + @test x == ProjectTo(x)(ZeroTangent()) x = zeros(2) - @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) x = zeros(Float32, 2, 2) - @test x == project(typeof(x), [0.0 0; 0 0]; info=preproject(x)) + @test x == ProjectTo(x)([0.0 0; 0 0]) x = [1.0 0; 0 4] - @test x == project(typeof(x), Diagonal([1.0, 4]); info=preproject(x)) + @test x == ProjectTo(x)(Diagonal([1.0, 4])) # to a array of structs x = [Fred(0.0), Fred(0.0)] - @test x == project(typeof(x), [Fred(0.0), Fred(0.0)]; info=preproject(x)) - @test x == project(typeof(x), [ZeroTangent(), ZeroTangent()]; info=preproject(x)) - @test x == project(typeof(x), [ZeroTangent(), @thunk(Fred(0.0))]; info=preproject(x)) - @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) - @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) + @test x == ProjectTo(x)([Fred(0.0), Fred(0.0)]) + @test x == ProjectTo(x)([ZeroTangent(), ZeroTangent()]) + @test x == ProjectTo(x)([ZeroTangent(), @thunk(Fred(0.0))]) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) x = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] - @test x == project(typeof(x), Diagonal([Fred(1.0), Fred(4.0)]); info=preproject(x)) + @test x == ProjectTo(x)(Diagonal([Fred(1.0), Fred(4.0)])) end @testset "to Diagonal" begin @@ -97,49 +108,49 @@ end d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) # from Matrix - @test d_F64 == project(typeof(d_F64), zeros(2, 2); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), zeros(Float32, 2, 2); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), zeros(ComplexF64, 2, 2); info=preproject(d_F64)) + @test d_F64 == ProjectTo(d_F64)(zeros(2, 2)) + @test d_F64 == ProjectTo(d_F64)(zeros(Float32, 2, 2)) + @test d_F64 == ProjectTo(d_F64)(zeros(ComplexF64, 2, 2)) # from Diagonal of Numbers - @test d_F64 == project(typeof(d_F64), d_F64; info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), d_F32; info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), d_C64; info=preproject(d_F64)) + @test d_F64 == ProjectTo(d_F64)(d_F64) + @test d_F64 == ProjectTo(d_F64)(d_F32) + @test d_F64 == ProjectTo(d_F64)(d_C64) # from Diagonal of AbstractTangent - @test d_F64 == project(typeof(d_F64), ZeroTangent(); info=preproject(d_F64)) - @test d_C64 == project(typeof(d_C64), ZeroTangent(); info=preproject(d_C64)) - @test d_F64 == project(typeof(d_F64), @thunk(ZeroTangent()); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), Diagonal([ZeroTangent(), ZeroTangent()]); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), Diagonal([ZeroTangent(), @thunk(ZeroTangent())]); info=preproject(d_F64)) + @test d_F64 == ProjectTo(d_F64)(ZeroTangent()) + @test d_C64 == ProjectTo(d_C64)(ZeroTangent()) + @test d_F64 == ProjectTo(d_F64)(@thunk(ZeroTangent())) + @test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) + @test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) # from Diagonal of structs - @test d_Fred == project(typeof(d_Fred), ZeroTangent(); info=preproject(d_Fred)) - @test d_Fred == project(typeof(d_Fred), @thunk(ZeroTangent()); info=preproject(d_Fred)) - @test d_Fred == project(typeof(d_Fred), Diagonal([ZeroTangent(), ZeroTangent()]); info=preproject(d_Fred)) + @test d_Fred == ProjectTo(d_Fred)(ZeroTangent()) + @test d_Fred == ProjectTo(d_Fred)(@thunk(ZeroTangent())) + @test d_Fred == ProjectTo(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) # from Tangent - @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[0.0, 0.0]); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[0.0f0, 0.0f0]); info=preproject(d_F64)) - @test d_F64 == project(typeof(d_F64), Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())]); info=preproject(d_F64)) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end @testset "to Symmetric" begin data = [1.0 2; 3 4] x = Symmetric(data) - @test x == project(typeof(x), data; info=preproject(x)) + @test x == ProjectTo(x)(data) x = Symmetric(data, :L) - @test x == project(typeof(x), data; info=preproject(x)) + @test x == ProjectTo(x)(data) data = [1.0 0; 0 4] x = Symmetric(data) - @test x == project(typeof(x), Diagonal([1.0, 4.0]); info=preproject(x)) + @test x == ProjectTo(x)(Diagonal([1.0, 4.0])) data = [0.0 0; 0 0] x = Symmetric(data) - @test x == project(typeof(x), ZeroTangent(); info=preproject(x)) - @test x == project(typeof(x), @thunk(ZeroTangent()); info=preproject(x)) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) end end From e0318b37ae17dcaed92d35b088d984783416ef34 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 29 Jun 2021 17:11:09 +0100 Subject: [PATCH 22/43] Make sure Arrays of Arrays etc work --- src/projection.jl | 4 +++- test/projection.jl | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index a995529f8..3ef9cea93 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -65,7 +65,9 @@ function (project::ProjectTo{T})(dx::Array) where {T<:Array} _call(f, x) = f(x) return T(map(_call, project.elements, dx)) end -(project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T} = zeros(T, size(project.elements)) +function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} + return T(map(proj->proj(dx), project.elements)) +end (project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) # Arrays{<:Number}: optimized case so we don't need a projector per element diff --git a/test/projection.jl b/test/projection.jl index a119464c5..5941f812c 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -101,6 +101,27 @@ end @test x == ProjectTo(x)(Diagonal([Fred(1.0), Fred(4.0)])) end + @testset "To Array of Arrays" begin + # inner arrays have same type but different sizes + x = [[1.0, 2.0, 3.0], [4.0, 5.0]] + @test x == ProjectTo(x)(x) + @test x == ProjectTo(x)([[1.0 + 2im, 2.0, 3.0], [4.0 + 2im, 5.0]]) + + # This makes sure we don't fall for https://github.com/JuliaLang/julia/issues/38064 + @test [[0.0, 0.0, 0.0], [0.0, 0.0]] == ProjectTo(x)(ZeroTangent()) + end + + @testset "Array{Any} with really messy contents" begin + # inner arrays have same type but different sizes + x = [[1.0, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]] + @test x == ProjectTo(x)(x) + @test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]]) + # using a different type for the 2nd element (Adjoint) + @test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0-im, 5.0]', [[[Fred(1)]]]]) + + @test [[0.0, 0.0, 0.0], [0.0im 0.0], [[[Fred(0)]]]] == ProjectTo(x)(ZeroTangent()) + end + @testset "to Diagonal" begin d_F64 = Diagonal([0.0, 0.0]) d_F32 = Diagonal([0.0f0, 0.0f0]) From ce5d64698b4779fab0199e2a2315058257942dd3 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 12:31:28 +0100 Subject: [PATCH 23/43] remove the special case ProjectTo(::Type{<:Number}) --- src/projection.jl | 24 +++++++++++------------- test/projection.jl | 30 ++++++++++++------------------ 2 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 3ef9cea93..fdb3dab03 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -41,11 +41,6 @@ Projects the differential `dx` on the onto type `T`. """ function (::ProjectTo) end -# Generic -(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) -(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't -(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) - # fallback (structs) function ProjectTo(x::T) where {T} # Generic fallback for structs, recursively make `ProjectTo`s all their fields @@ -59,6 +54,16 @@ function (project::ProjectTo{T})(dx::Tangent) where {T} return construct(T, map(_call, sub_projects, sub_dxs)) end +# Generic +(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) +(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't +(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) + +# Number +ProjectTo(::T) where {T<:Number} = ProjectTo{T}() +(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx) +(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx)) + # Arrays ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) function (project::ProjectTo{T})(dx::Array) where {T<:Array} @@ -72,7 +77,7 @@ end # Arrays{<:Number}: optimized case so we don't need a projector per element ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; size=size(x)) -(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = ProjectTo(T).(dx) +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = ProjectTo(zero(T)).(dx) (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) # Diagonal @@ -85,10 +90,3 @@ ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), paren (project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.parent(dx), project.uplo) (project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.parent(dx), project.uplo) -# Number -ProjectTo(::T) where {T<:Number} = ProjectTo(T) -# As a special convience for `Number` subtypes we allow `ProjectTo` to be constructed from -# the type only. TODO: do we really want to allow this? -ProjectTo(::Type{T}) where {T<:Number} = ProjectTo{T}() -(::ProjectTo{<:T})(dx::Number) where {T<:Number} = convert(T, dx) -(::ProjectTo{<:T})(dx::Number) where {T<:Real} = convert(T, real(dx)) diff --git a/test/projection.jl b/test/projection.jl index 5941f812c..f76f05b5a 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -41,37 +41,31 @@ end end @testset "to Real" begin - # to type as shorthand for passing primal - @test ProjectTo(Float64) == ProjectTo(1.2) - # Float64 - @test 3.2 == ProjectTo(Float64)(3.2) - @test 0.0 == ProjectTo(Float64)(ZeroTangent()) - @test 3.2 == ProjectTo(Float64)(@thunk(3.2)) + @test 3.2 == ProjectTo(1.0)(3.2) + @test 0.0 == ProjectTo(1.0)(ZeroTangent()) + @test 3.2 == ProjectTo(1.0)(@thunk(3.2)) # down - @test 3.2 == ProjectTo(Float64)(3.2 + 3im) - @test 3.2f0 == ProjectTo(Float32)(3.2) - @test 3.2f0 == ProjectTo(Float32)(3.2 - 3im) + @test 3.2 == ProjectTo(1.0)(3.2 + 3im) + @test 3.2f0 == ProjectTo(1.0f0)(3.2) + @test 3.2f0 == ProjectTo(1.0f0)(3.2 - 3im) # up - @test 2.0 == ProjectTo(Float64)(2.0f0) + @test 2.0 == ProjectTo(1.0)(2.0f0) end @testset "to Number" begin - # To type, as short-hand for passing primal - @test ProjectTo(ComplexF64) == ProjectTo(1.0 + 2.0im) - # Complex - @test 2.0 + 4.0im == ProjectTo(ComplexF64)(2.0 + 4.0im) + @test 2.0 + 4.0im == ProjectTo(1.0im)(2.0 + 4.0im) # down - @test 2.0 + 0.0im == ProjectTo(ComplexF64)(2.0) - @test 0.0 + 0.0im == ProjectTo(ComplexF64)(ZeroTangent()) - @test 0.0 + 0.0im == ProjectTo(ComplexF64)(@thunk(ZeroTangent())) + @test 2.0 + 0.0im == ProjectTo(1.0im)(2.0) + @test 0.0 + 0.0im == ProjectTo(1.0im)(ZeroTangent()) + @test 0.0 + 0.0im == ProjectTo(1.0im)(@thunk(ZeroTangent())) # up - @test 2.0 + 0.0im == ProjectTo(ComplexF64)(2.0) + @test 2.0 + 0.0im == ProjectTo(1.0im)(2.0) end @testset "to Array" begin From f1a626085e93bb1ebb6104eb91ee085fefe5b203 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 14:59:08 +0100 Subject: [PATCH 24/43] add to_ prefix, add Transpose/Adjoint/SubArray --- src/projection.jl | 51 ++++++++++++++++++++++++++++++++++++---------- test/projection.jl | 8 ++++++++ 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index fdb3dab03..66d33c286 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -65,28 +65,57 @@ ProjectTo(::T) where {T<:Number} = ProjectTo{T}() (::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx)) # Arrays -ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) +ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; to_elements=map(ProjectTo, xs)) function (project::ProjectTo{T})(dx::Array) where {T<:Array} _call(f, x) = f(x) - return T(map(_call, project.elements, dx)) + return T(map(_call, project.to_elements, dx)) end function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} - return T(map(proj->proj(dx), project.elements)) + return T(map(proj->proj(dx), project.to_elements)) end (project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) # Arrays{<:Number}: optimized case so we don't need a projector per element -ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; size=size(x)) -(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = ProjectTo(zero(T)).(dx) +ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; to_element=ProjectTo(zero(eltype(x))), size=size(x)) # TODO: how to do nested where? +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.to_element.(dx) (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) # Diagonal -ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) -(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) -(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) +ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; to_diag=ProjectTo(diag(x))) +(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.to_diag(diag(dx))) +(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.to_diag(dx)) +(project::ProjectTo{T})(dx::Tangent) where {T<:Diagonal} = T(project.to_diag(dx.diag)) # Symmetric -ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.parent(dx), project.uplo) -(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.parent(dx), project.uplo) +ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), to_parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.to_parent(dx), project.uplo) +(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.to_parent(dx), project.uplo) +(project::ProjectTo{<:Symmetric})(dx::Tangent) = Symmetric(project.to_parent(dx.data), project.uplo) + +# Transpose +ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; to_parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.to_parent(transpose(dx))) +(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.to_parent(parent(dx))) + +# Adjoint +ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; to_parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.to_parent(adjoint(dx))) + +# SubArray +ProjectTo(x::T) where {T<:SubArray} = ProjectTo(collect(x)) # TODO: is this what we want? + +# TODO: ProjectTo Tuple and NamedTuple. Does this even make sense? How about the structs +# with Tuple or NamedTuple fields? + + + + + + + + + + + + diff --git a/test/projection.jl b/test/projection.jl index f76f05b5a..1efacf881 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -155,6 +155,7 @@ end x = Symmetric(data) @test x == ProjectTo(x)(data) + @test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data, uplo=NoTangent())) x = Symmetric(data, :L) @test x == ProjectTo(x)(data) @@ -168,4 +169,11 @@ end @test x == ProjectTo(x)(ZeroTangent()) @test x == ProjectTo(x)(@thunk(ZeroTangent())) end + + @testset "to Transpose" begin # TODO: this one, plus Adjoint, and SubArray + x = rand(3, 4) + t = transpose(x) + + #@test x == ProjectTo(t)(rand(4, 3)) + end end From 06268a37abcf63de6454262da62f0b37c0e5547d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 16:09:56 +0100 Subject: [PATCH 25/43] add Adjoint and Transpose test --- src/projection.jl | 39 +++++++++++++++++++++------------------ test/projection.jl | 32 +++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 21 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 66d33c286..c38bbecf7 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -49,7 +49,9 @@ function ProjectTo(x::T) where {T} end function (project::ProjectTo{T})(dx::Tangent) where {T} sub_projects = backing(project) + #@show sub_projects sub_dxs = backing(dx) + #@show sub_dxs _call(f, x) = f(x) return construct(T, map(_call, sub_projects, sub_dxs)) end @@ -65,41 +67,42 @@ ProjectTo(::T) where {T<:Number} = ProjectTo{T}() (::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx)) # Arrays -ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; to_elements=map(ProjectTo, xs)) +ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) function (project::ProjectTo{T})(dx::Array) where {T<:Array} _call(f, x) = f(x) - return T(map(_call, project.to_elements, dx)) + return T(map(_call, project.elements, dx)) end function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} - return T(map(proj->proj(dx), project.to_elements)) + return T(map(proj->proj(dx), project.elements)) end (project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) # Arrays{<:Number}: optimized case so we don't need a projector per element -ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; to_element=ProjectTo(zero(eltype(x))), size=size(x)) # TODO: how to do nested where? -(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.to_element.(dx) +ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; element=ProjectTo(zero(eltype(x))), size=size(x)) # TODO: how to do nested where? +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) # Diagonal -ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; to_diag=ProjectTo(diag(x))) -(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.to_diag(diag(dx))) -(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.to_diag(dx)) -(project::ProjectTo{T})(dx::Tangent) where {T<:Diagonal} = T(project.to_diag(dx.diag)) +ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) +(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) +(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) # Symmetric -ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), to_parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.to_parent(dx), project.uplo) -(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.to_parent(dx), project.uplo) -(project::ProjectTo{<:Symmetric})(dx::Tangent) = Symmetric(project.to_parent(dx.data), project.uplo) +ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.parent(dx), project.uplo) +(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.parent(dx), project.uplo) +(project::ProjectTo{<:Symmetric})(dx::Tangent) = Symmetric(project.parent(dx.data), project.uplo) # Transpose -ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; to_parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.to_parent(transpose(dx))) -(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.to_parent(parent(dx))) +ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx))) +(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(parent(dx))) +(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) # Adjoint -ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; to_parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.to_parent(adjoint(dx))) +ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx))) +(project::ProjectTo{<:Adjoint})(dx::ZeroTangent) = adjoint(project.parent(dx)) # SubArray ProjectTo(x::T) where {T<:SubArray} = ProjectTo(collect(x)) # TODO: is this what we want? diff --git a/test/projection.jl b/test/projection.jl index 1efacf881..1ae0e13e7 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -170,10 +170,36 @@ end @test x == ProjectTo(x)(@thunk(ZeroTangent())) end - @testset "to Transpose" begin # TODO: this one, plus Adjoint, and SubArray + @testset "to Transpose" begin x = rand(3, 4) t = transpose(x) - - #@test x == ProjectTo(t)(rand(4, 3)) + mt = collect(t) + a = adjoint(x) + ma = collect(a) + + @test t == ProjectTo(t)(mt) + @test t == ProjectTo(t)(ma) + @test zeros(4, 3) == ProjectTo(t)(ZeroTangent()) + @test zeros(4, 3) == ProjectTo(t)(Tangent{Transpose}(; parent=ZeroTangent())) + end + + @testset "to Adjoint" begin + x = rand(3, 4) + a = adjoint(x) + ma = collect(a) + + @test a == ProjectTo(a)(ma) + @test zeros(4, 3) == ProjectTo(a)(ZeroTangent()) + @test zeros(4, 3) == ProjectTo(a)(Tangent{Adjoint}(; parent=ZeroTangent())) + end + + @testset "to SubArray" begin + x = rand(3, 4) + sa = view(x, :, 1:2) + m = collect(sa) + + @test m == ProjectTo(sa)(m) + @test zeros(3, 2) == ProjectTo(sa)(ZeroTangent()) + @test_broken zeros(3, 2) == ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) # what do we want to do with SubArray? end end From a9812795798163e376e66b2200e74b4939de8cc1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 16:28:02 +0100 Subject: [PATCH 26/43] test Tangents with implicit zeros --- src/projection.jl | 10 +++++++++- test/projection.jl | 11 +++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index c38bbecf7..046f08dbf 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -44,18 +44,26 @@ function (::ProjectTo) end # fallback (structs) function ProjectTo(x::T) where {T} # Generic fallback for structs, recursively make `ProjectTo`s all their fields + #println() + #@show x + #@show T fields_nt::NamedTuple = backing(x) + #@show fields_nt return ProjectTo{T}(map(ProjectTo, fields_nt)) end function (project::ProjectTo{T})(dx::Tangent) where {T} sub_projects = backing(project) #@show sub_projects - sub_dxs = backing(dx) + sub_dxs = backing(canonicalize(dx)) #@show sub_dxs _call(f, x) = f(x) return construct(T, map(_call, sub_projects, sub_dxs)) end +# Tuple +ProjectTo(x::T) where {T<:Tuple} = ProjectTo{T}() +(::ProjectTo{T})(dx::T) where {T<:Tuple} = dx + # Generic (project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) (::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't diff --git a/test/projection.jl b/test/projection.jl index 1ae0e13e7..d54b22138 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -12,8 +12,11 @@ Base.:(==)(a::Freddy, b::Freddy) = a.a == b.a struct Mary a::Fred end -#Base.zero(::Mary) = Mary(zero(Fred)) -#Base.zero(::Type{Mary}) = Mary(zero(Fred)) + +struct TwoFields + a::Float64 + c::Float64 +end @testset "projection" begin @testset "display" begin @@ -38,6 +41,10 @@ end m = Mary(f) dm = Tangent{Mary}(;a=tf) @test m == ProjectTo(m)(dm) + + # two fields + tf = TwoFields(3.0, 0.0) + @test tf == ProjectTo(tf)(Tangent{TwoFields}(; a=3.0)) end @testset "to Real" begin From eefd84f8cf5e5243f19fc4323e41b18b147a1ae4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 16:40:15 +0100 Subject: [PATCH 27/43] throw error when ProjectTo to Tuple or NamedTuple --- src/projection.jl | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 046f08dbf..a624b7efa 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,5 +1,3 @@ -using LinearAlgebra: Diagonal, diag - struct ProjectTo{P, D<:NamedTuple} info::D end @@ -60,9 +58,10 @@ function (project::ProjectTo{T})(dx::Tangent) where {T} return construct(T, map(_call, sub_projects, sub_dxs)) end -# Tuple -ProjectTo(x::T) where {T<:Tuple} = ProjectTo{T}() -(::ProjectTo{T})(dx::T) where {T<:Tuple} = dx +# does not work for Tuples and NamedTuples +function ProjectTo(x::T) where {T<:Union{<:Tuple, NamedTuple}} + throw(ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x")) +end # Generic (project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) @@ -115,10 +114,6 @@ ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) # SubArray ProjectTo(x::T) where {T<:SubArray} = ProjectTo(collect(x)) # TODO: is this what we want? -# TODO: ProjectTo Tuple and NamedTuple. Does this even make sense? How about the structs -# with Tuple or NamedTuple fields? - - From 2facaea94e39fd74c5c0887794b8bda22465e1b9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 30 Jun 2021 17:04:00 +0100 Subject: [PATCH 28/43] fix transpose bug --- src/projection.jl | 2 +- test/projection.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index a624b7efa..832dc71e4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -103,7 +103,7 @@ ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), paren # Transpose ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) (project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx))) -(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(parent(dx))) +(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(conj(parent(dx)))) (project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) # Adjoint diff --git a/test/projection.jl b/test/projection.jl index d54b22138..8410806ad 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -207,6 +207,6 @@ end @test m == ProjectTo(sa)(m) @test zeros(3, 2) == ProjectTo(sa)(ZeroTangent()) - @test_broken zeros(3, 2) == ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) # what do we want to do with SubArray? + @test_broken zeros(3, 2) == ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) # TODO: what do we want to do with SubArray? end end From 9787b1bb5872c7e503950a3c7637e47190d59949 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 1 Jul 2021 10:13:29 +0100 Subject: [PATCH 29/43] add test for TwoFields --- test/projection.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/projection.jl b/test/projection.jl index 8410806ad..0f1e2d837 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -15,7 +15,7 @@ end struct TwoFields a::Float64 - c::Float64 + b::Float64 end @testset "projection" begin @@ -43,8 +43,10 @@ end @test m == ProjectTo(m)(dm) # two fields - tf = TwoFields(3.0, 0.0) - @test tf == ProjectTo(tf)(Tangent{TwoFields}(; a=3.0)) + tfa = TwoFields(3.0, 0.0) + tfb = TwoFields(0.0, 3.0) + @test tfa == ProjectTo(tfa)(Tangent{TwoFields}(; a=3.0)) + @test tfb == ProjectTo(tfb)(Tangent{TwoFields}(; b=3.0)) end @testset "to Real" begin From 93c74896c367f1c5f952c2f5c6db64c580e01ad2 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 1 Jul 2021 10:27:47 +0100 Subject: [PATCH 30/43] test complex numbers too --- src/projection.jl | 2 +- test/projection.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 832dc71e4..bf5bac9f4 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -58,7 +58,7 @@ function (project::ProjectTo{T})(dx::Tangent) where {T} return construct(T, map(_call, sub_projects, sub_dxs)) end -# does not work for Tuples and NamedTuples +# should not work for Tuples and NamedTuples, as not valid tangent types function ProjectTo(x::T) where {T<:Union{<:Tuple, NamedTuple}} throw(ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x")) end diff --git a/test/projection.jl b/test/projection.jl index 0f1e2d837..78c8db180 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -180,20 +180,20 @@ end end @testset "to Transpose" begin - x = rand(3, 4) + x = rand(ComplexF64, 3, 4) t = transpose(x) mt = collect(t) a = adjoint(x) ma = collect(a) @test t == ProjectTo(t)(mt) - @test t == ProjectTo(t)(ma) + @test conj(t) == ProjectTo(t)(ma) @test zeros(4, 3) == ProjectTo(t)(ZeroTangent()) @test zeros(4, 3) == ProjectTo(t)(Tangent{Transpose}(; parent=ZeroTangent())) end @testset "to Adjoint" begin - x = rand(3, 4) + x = rand(ComplexF64, 3, 4) a = adjoint(x) ma = collect(a) From e7190b2805593c2a590fb5b691649cf43fb6aa84 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 1 Jul 2021 13:29:43 +0100 Subject: [PATCH 31/43] nested where --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index bf5bac9f4..467163757 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -85,7 +85,7 @@ end (project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) # Arrays{<:Number}: optimized case so we don't need a projector per element -ProjectTo(x::T) where {T<:Array{<:Number}} = ProjectTo{T}(; element=ProjectTo(zero(eltype(x))), size=size(x)) # TODO: how to do nested where? +ProjectTo(x::T) where {E<:Number, T<:Array{E}} = ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) (project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) From 233d292425eedd9b34cd8f1b6d5d5bebaeec1f81 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Thu, 1 Jul 2021 13:36:39 +0100 Subject: [PATCH 32/43] fix SubArray --- src/projection.jl | 3 ++- test/projection.jl | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 467163757..b0b795112 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -88,6 +88,7 @@ end ProjectTo(x::T) where {E<:Number, T<:Array{E}} = ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) (project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) +(project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} = project(dx.parent) # Diagonal ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) @@ -112,7 +113,7 @@ ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) (project::ProjectTo{<:Adjoint})(dx::ZeroTangent) = adjoint(project.parent(dx)) # SubArray -ProjectTo(x::T) where {T<:SubArray} = ProjectTo(collect(x)) # TODO: is this what we want? +ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy diff --git a/test/projection.jl b/test/projection.jl index 78c8db180..dd45ca3fc 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -207,8 +207,9 @@ end sa = view(x, :, 1:2) m = collect(sa) - @test m == ProjectTo(sa)(m) - @test zeros(3, 2) == ProjectTo(sa)(ZeroTangent()) - @test_broken zeros(3, 2) == ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) # TODO: what do we want to do with SubArray? + # make sure it converts the view to the parent type + @test ProjectTo(sa)(m) isa Matrix + @test ProjectTo(sa)(ZeroTangent()) + @test ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) isa Matrix end end From 4c25f32dcaf491c35219d733ca189ec5e8df9b3c Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 2 Jul 2021 12:28:55 +0100 Subject: [PATCH 33/43] add Hermitian --- src/projection.jl | 14 +++++++++----- test/projection.jl | 18 +++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index b0b795112..258af99b8 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -95,11 +95,15 @@ ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) (project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) (project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) -# Symmetric -ProjectTo(x::T) where {T<:Symmetric} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Symmetric})(dx::AbstractMatrix) = Symmetric(project.parent(dx), project.uplo) -(project::ProjectTo{<:Symmetric})(dx::AbstractZero) = Symmetric(project.parent(dx), project.uplo) -(project::ProjectTo{<:Symmetric})(dx::Tangent) = Symmetric(project.parent(dx.data), project.uplo) +# Symmetric and Hermitian +for SymHerm = (:Symmetric, :Hermitian) + @eval begin + ProjectTo(x::T) where {T<:$SymHerm} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) + (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) = $SymHerm(project.parent(dx), project.uplo) + (project::ProjectTo{<:$SymHerm})(dx::AbstractZero) = $SymHerm(project.parent(dx), project.uplo) + (project::ProjectTo{<:$SymHerm})(dx::Tangent) = $SymHerm(project.parent(dx.data), project.uplo) + end +end # Transpose ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) diff --git a/test/projection.jl b/test/projection.jl index dd45ca3fc..bd6870cff 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -159,22 +159,22 @@ end @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) end - @testset "to Symmetric" begin - data = [1.0 2; 3 4] + @testset "to $SymHerm" for SymHerm in (Symmetric, Hermitian) + data = [1.0+1im 2-2im; 3 4] - x = Symmetric(data) + x = SymHerm(data) @test x == ProjectTo(x)(data) @test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data, uplo=NoTangent())) - x = Symmetric(data, :L) + x = SymHerm(data, :L) @test x == ProjectTo(x)(data) - data = [1.0 0; 0 4] - x = Symmetric(data) - @test x == ProjectTo(x)(Diagonal([1.0, 4.0])) + data = [1.0-2im 0; 0 4] + x = SymHerm(data) + @test x == ProjectTo(x)(Diagonal([1.0-2im, 4.0])) - data = [0.0 0; 0 0] - x = Symmetric(data) + data = [0.0+0im 0; 0 0] + x = SymHerm(data) @test x == ProjectTo(x)(ZeroTangent()) @test x == ProjectTo(x)(@thunk(ZeroTangent())) end From 029cb6909782fa0e4d66c9d1f7af0fe264571855 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 2 Jul 2021 17:32:06 +0100 Subject: [PATCH 34/43] remove debug statements --- src/projection.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 258af99b8..29549a01e 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -42,18 +42,12 @@ function (::ProjectTo) end # fallback (structs) function ProjectTo(x::T) where {T} # Generic fallback for structs, recursively make `ProjectTo`s all their fields - #println() - #@show x - #@show T fields_nt::NamedTuple = backing(x) - #@show fields_nt return ProjectTo{T}(map(ProjectTo, fields_nt)) end function (project::ProjectTo{T})(dx::Tangent) where {T} sub_projects = backing(project) - #@show sub_projects sub_dxs = backing(canonicalize(dx)) - #@show sub_dxs _call(f, x) = f(x) return construct(T, map(_call, sub_projects, sub_dxs)) end From b73e2464e2c439b5ba165daf39e1d6f3c8362683 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 2 Jul 2021 17:56:41 +0100 Subject: [PATCH 35/43] add Upper and LowerTriangular --- src/projection.jl | 8 ++++++++ test/projection.jl | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index 29549a01e..c181743bc 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -98,6 +98,14 @@ for SymHerm = (:Symmetric, :Hermitian) (project::ProjectTo{<:$SymHerm})(dx::Tangent) = $SymHerm(project.parent(dx.data), project.uplo) end end +for UL = (:UpperTriangular, :LowerTriangular) + @eval begin + ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x))) + (project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx)) + (project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx)) + (project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data)) + end +end # Transpose ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) diff --git a/test/projection.jl b/test/projection.jl index bd6870cff..5d7ddb576 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -179,6 +179,20 @@ end @test x == ProjectTo(x)(@thunk(ZeroTangent())) end + @testset "to $UL" for UL in (UpperTriangular, LowerTriangular) + data = [1.0+1im 2-2im; 3 4] + + x = UL(data) + @test x == ProjectTo(x)(data) + @test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data)) + + data = [0.0+0im 0; 0 0] + x = UL(data) + @test x == ProjectTo(x)(Diagonal(zeros(2))) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) + end + @testset "to Transpose" begin x = rand(ComplexF64, 3, 4) t = transpose(x) From 9d665c0b2f23bcc4bb9bf933e9c846bd0e51c5ab Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 2 Jul 2021 19:49:37 +0100 Subject: [PATCH 36/43] PermutedDimsArray --- src/projection.jl | 12 +++++++++--- test/projection.jl | 9 +++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index c181743bc..f8738f800 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -89,7 +89,7 @@ ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) (project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) (project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) -# Symmetric and Hermitian +# :data, :uplo fields for SymHerm = (:Symmetric, :Hermitian) @eval begin ProjectTo(x::T) where {T<:$SymHerm} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) @@ -98,6 +98,8 @@ for SymHerm = (:Symmetric, :Hermitian) (project::ProjectTo{<:$SymHerm})(dx::Tangent) = $SymHerm(project.parent(dx.data), project.uplo) end end + +# :data field for UL = (:UpperTriangular, :LowerTriangular) @eval begin ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x))) @@ -110,13 +112,17 @@ end # Transpose ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) (project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx))) -(project::ProjectTo{<:Transpose})(dx::Adjoint) = transpose(project.parent(conj(parent(dx)))) (project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) # Adjoint ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) (project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx))) -(project::ProjectTo{<:Adjoint})(dx::ZeroTangent) = adjoint(project.parent(dx)) +(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx)) + +# PermutedDimsArray +ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x))) +(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractArray) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) +(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractZero) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(project.parent(dx)) # SubArray ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy diff --git a/test/projection.jl b/test/projection.jl index 5d7ddb576..b9bdaebc7 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -216,6 +216,15 @@ end @test zeros(4, 3) == ProjectTo(a)(Tangent{Adjoint}(; parent=ZeroTangent())) end + @testset "to PermutedDimsArray" begin + a = zeros(3, 5, 4) + b = PermutedDimsArray(a, (2, 1, 3)) + bc = collect(b) + + @test b == ProjectTo(b)(bc) + @test b == ProjectTo(b)(ZeroTangent()) + end + @testset "to SubArray" begin x = rand(3, 4) sa = view(x, :, 1:2) From 030d636b84b5d9a84a58f08abd4e4c0f763c749a Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 2 Jul 2021 19:54:50 +0100 Subject: [PATCH 37/43] Update test/projection.jl Co-authored-by: Lyndon White --- test/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/projection.jl b/test/projection.jl index b9bdaebc7..199662e2e 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -232,7 +232,7 @@ end # make sure it converts the view to the parent type @test ProjectTo(sa)(m) isa Matrix - @test ProjectTo(sa)(ZeroTangent()) + @test zeros(3, 2) == ProjectTo(sa)(ZeroTangent()) @test ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) isa Matrix end end From b87368f6a223be36395c683bd462c27bf8188864 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 5 Jul 2021 12:51:55 +0100 Subject: [PATCH 38/43] fix docs --- docs/Manifest.toml | 2 +- docs/src/api.md | 5 +++++ src/projection.jl | 33 +++++++++++++-------------------- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index be1de955b..36e47a930 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" deps = ["Compat", "LinearAlgebra", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.7" +version = "0.10.9" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] diff --git a/docs/src/api.md b/docs/src/api.md index 5fcd53400..6ba5363ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -41,6 +41,11 @@ Pages = ["config.jl"] Private = false ``` +## ProjectTo +```@docs +ProjectTo +``` + ## Internal ```@docs ChainRulesCore.AbstractTangent diff --git a/src/projection.jl b/src/projection.jl index f8738f800..6ad5908ca 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,3 +1,16 @@ +""" + ProjectTo(x) + +Returns a `ProjectTo{P,...}` functor able to project a differential `dx` onto the type `P` +for a primal `x`. +This functor encloses over what ever is needed to be able to be able to do that projection. +For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` +is not available from `P`, so it is stored in the functor. + + (::ProjectTo{P})(dx) + +Projects the differential `dx` on the onto type `P`. +""" struct ProjectTo{P, D<:NamedTuple} info::D end @@ -19,26 +32,6 @@ function Base.show(io::IO, project::ProjectTo{T}) where T end end - -""" - ProjectTo(x) - -Returns a `ProjectTo{P,...}` functor able to project a differential `dx` onto the type `T` -for a primal `x`. -This functor encloses over what ever is needed to be able to be able to do that projection. -For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` -is not available from `P`, so it is stored in the functor. -""" -function ProjectTo end - -""" - (::ProjectTo{T})(dx) - -Projects the differential `dx` on the onto type `T`. -`ProjectTo{T}` is a functor that knows how to perform this projection. -""" -function (::ProjectTo) end - # fallback (structs) function ProjectTo(x::T) where {T} # Generic fallback for structs, recursively make `ProjectTo`s all their fields From 0f09ab93924c26cf065e7aa930c1463ef215b7ed Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 5 Jul 2021 13:00:09 +0100 Subject: [PATCH 39/43] JuliaFormatter --- src/projection.jl | 74 +++++++++++++++++++++++++++------------------- test/projection.jl | 28 ++++++++++-------- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 6ad5908ca..aae01a662 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -11,7 +11,7 @@ is not available from `P`, so it is stored in the functor. Projects the differential `dx` on the onto type `P`. """ -struct ProjectTo{P, D<:NamedTuple} +struct ProjectTo{P,D<:NamedTuple} info::D end ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info) @@ -21,7 +21,7 @@ backing(project::ProjectTo) = getfield(project, :info) Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) Base.propertynames(p::ProjectTo) = propertynames(backing(p)) -function Base.show(io::IO, project::ProjectTo{T}) where T +function Base.show(io::IO, project::ProjectTo{T}) where {T} print(io, "ProjectTo{") show(io, T) print(io, "}") @@ -46,13 +46,15 @@ function (project::ProjectTo{T})(dx::Tangent) where {T} end # should not work for Tuples and NamedTuples, as not valid tangent types -function ProjectTo(x::T) where {T<:Union{<:Tuple, NamedTuple}} - throw(ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x")) +function ProjectTo(x::T) where {T<:Union{<:Tuple,NamedTuple}} + return throw( + ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x") + ) end # Generic (project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) -(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't +(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't (::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) # Number @@ -67,15 +69,21 @@ function (project::ProjectTo{T})(dx::Array) where {T<:Array} return T(map(_call, project.elements, dx)) end function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} - return T(map(proj->proj(dx), project.elements)) + return T(map(proj -> proj(dx), project.elements)) end (project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) # Arrays{<:Number}: optimized case so we don't need a projector per element -ProjectTo(x::T) where {E<:Number, T<:Array{E}} = ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) +function ProjectTo(x::T) where {E<:Number,T<:Array{E}} + return ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) +end (project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) -(project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} = zeros(T, project.size) -(project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} = project(dx.parent) +function (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} + return zeros(T, project.size) +end +function (project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} + return project(dx.parent) +end # Diagonal ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) @@ -83,17 +91,25 @@ ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) (project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) # :data, :uplo fields -for SymHerm = (:Symmetric, :Hermitian) +for SymHerm in (:Symmetric, :Hermitian) @eval begin - ProjectTo(x::T) where {T<:$SymHerm} = ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) - (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) = $SymHerm(project.parent(dx), project.uplo) - (project::ProjectTo{<:$SymHerm})(dx::AbstractZero) = $SymHerm(project.parent(dx), project.uplo) - (project::ProjectTo{<:$SymHerm})(dx::Tangent) = $SymHerm(project.parent(dx.data), project.uplo) + function ProjectTo(x::T) where {T<:$SymHerm} + return ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) + end + function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) + return $SymHerm(project.parent(dx), project.uplo) + end + function (project::ProjectTo{<:$SymHerm})(dx::AbstractZero) + return $SymHerm(project.parent(dx), project.uplo) + end + function (project::ProjectTo{<:$SymHerm})(dx::Tangent) + return $SymHerm(project.parent(dx.data), project.uplo) + end end end # :data field -for UL = (:UpperTriangular, :LowerTriangular) +for UL in (:UpperTriangular, :LowerTriangular) @eval begin ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x))) (project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx)) @@ -104,7 +120,9 @@ end # Transpose ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) -(project::ProjectTo{<:Transpose})(dx::AbstractMatrix) = transpose(project.parent(transpose(dx))) +function (project::ProjectTo{<:Transpose})(dx::AbstractMatrix) + return transpose(project.parent(transpose(dx))) +end (project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) # Adjoint @@ -114,20 +132,16 @@ ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) # PermutedDimsArray ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x))) -(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractArray) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) -(project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(dx::AbstractZero) where {T, N, perm, iperm, AA} = PermutedDimsArray{T,N,perm,iperm,AA}(project.parent(dx)) +function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( + dx::AbstractArray +) where {T,N,perm,iperm,AA} + return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) +end +function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( + dx::AbstractZero +) where {T,N,perm,iperm,AA} + return PermutedDimsArray{T,N,perm,iperm,AA}(project.parent(dx)) +end # SubArray ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy - - - - - - - - - - - - diff --git a/test/projection.jl b/test/projection.jl index 199662e2e..7bd8d266a 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -4,8 +4,8 @@ end Base.zero(::Fred) = Fred(0.0) Base.zero(::Type{Fred}) = Fred(0.0) -struct Freddy{T, N} - a::Array{T, N} +struct Freddy{T,N} + a::Array{T,N} end Base.:(==)(a::Freddy, b::Freddy) = a.a == b.a @@ -28,18 +28,18 @@ end @test Fred(1.2) == ProjectTo(Fred(1.1))(Fred(1.2)) @test Fred(0.0) == ProjectTo(Fred(1.1))(ZeroTangent()) @test Fred(3.2) == ProjectTo(Fred(1.1))(@thunk(Fred(3.2))) - @test Fred(1.2) == ProjectTo(Fred(1.1))(Tangent{Fred}(;a=1.2)) + @test Fred(1.2) == ProjectTo(Fred(1.1))(Tangent{Fred}(; a=1.2)) # struct with complicated field - x = Freddy(zeros(2,2)) + x = Freddy(zeros(2, 2)) dx = Tangent{Freddy}(; a=ZeroTangent()) @test x == ProjectTo(x)(dx) # nested structs f = Fred(0.0) - tf = Tangent{Fred}(;a=ZeroTangent()) + tf = Tangent{Fred}(; a=ZeroTangent()) m = Mary(f) - dm = Tangent{Mary}(;a=tf) + dm = Tangent{Mary}(; a=tf) @test m == ProjectTo(m)(dm) # two fields @@ -116,11 +116,11 @@ end @testset "Array{Any} with really messy contents" begin # inner arrays have same type but different sizes - x = [[1.0, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]] + x = [[1.0, 2.0, 3.0], [4.0 + im 5.0], [[[Fred(1)]]]] @test x == ProjectTo(x)(x) - @test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0+im 5.0], [[[Fred(1)]]]]) + @test x == ProjectTo(x)([[1.0 + im, 2.0, 3.0], [4.0 + im 5.0], [[[Fred(1)]]]]) # using a different type for the 2nd element (Adjoint) - @test x == ProjectTo(x)([[1.0+im, 2.0, 3.0], [4.0-im, 5.0]', [[[Fred(1)]]]]) + @test x == ProjectTo(x)([[1.0 + im, 2.0, 3.0], [4.0 - im, 5.0]', [[[Fred(1)]]]]) @test [[0.0, 0.0, 0.0], [0.0im 0.0], [[[Fred(0)]]]] == ProjectTo(x)(ZeroTangent()) end @@ -154,9 +154,11 @@ end @test d_Fred == ProjectTo(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) # from Tangent - @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0, 0.0])) - @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[0.0f0, 0.0f0])) - @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(;diag=[ZeroTangent(), @thunk(ZeroTangent())])) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(; diag=[0.0, 0.0])) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(; diag=[0.0f0, 0.0f0])) + @test d_F64 == ProjectTo(d_F64)( + Tangent{Diagonal}(; diag=[ZeroTangent(), @thunk(ZeroTangent())]) + ) end @testset "to $SymHerm" for SymHerm in (Symmetric, Hermitian) @@ -171,7 +173,7 @@ end data = [1.0-2im 0; 0 4] x = SymHerm(data) - @test x == ProjectTo(x)(Diagonal([1.0-2im, 4.0])) + @test x == ProjectTo(x)(Diagonal([1.0 - 2im, 4.0])) data = [0.0+0im 0; 0 0] x = SymHerm(data) From 3a47f6fd3fd0029dcffd3afa216c3b272c81adc2 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 5 Jul 2021 13:33:12 +0100 Subject: [PATCH 40/43] simplify one of the PermutedDimsArray --- src/projection.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index aae01a662..c2368498f 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -137,10 +137,8 @@ function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( ) where {T,N,perm,iperm,AA} return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) end -function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( - dx::AbstractZero -) where {T,N,perm,iperm,AA} - return PermutedDimsArray{T,N,perm,iperm,AA}(project.parent(dx)) +function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray} + return P(project.parent(dx)) end # SubArray From ce022d5257435517d27f64f008ba1e5cb689c8a2 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Mon, 5 Jul 2021 18:08:26 +0100 Subject: [PATCH 41/43] document when to use ProjectTo --- docs/src/writing_good_rules.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 872a160b4..12ef1b799 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -63,6 +63,39 @@ Examples being: - There is only one derivative being returned, so from the fact that the user called `frule`/`rrule` they clearly will want to use that one. +## Use `ProjectTo` appropriately + +Rules with abstractly-typed arguments may return incorrect answers when called with certain concrete types. +A classic example is the matrix-matrix multiplication rule +``` +function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + function times_pullback(ȳ) + dA = ȳ * B' + dB = A' * ȳ + return NoTangent(), dA, dB + end + return A * B, times_pullback +end +``` +When computing `*(A, B)`, where `A isa Diagonal` and `B isa Matrix`, the output will be a `Matrix`. +As a result, `ȳ` in the pullback will be a `Matrix`, and consequently `dA` for a `A isa Diagonal` will be a `Matrix`, which is wrong. +While a specialised rules can indeed be written for the `Diagonal` case, there are many other types and we don't want to be forced to write a rule for each of them. +Instead, `project_A = ProjectTo(A)` can be used (outside the pullback) to extract an object that knows how to project onto the type of `A` (e.g. also knows the size of the array). +This object can be called with a tangent, `project_A(ȳ * B')`, to project it on the type `A`. +The correct rule then looks like +``` +function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function times_pullback(ȳ) + dA = ȳ * B' + dB = A' * ȳ + return NoTangent(), project_A(dA), project_B(dB) + end + return A * B, times_pullback +end +``` + ## Structs: constructors and functors To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. From 4106232dc85f26070bdd3de0ba199e0253281c82 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 6 Jul 2021 10:16:34 +0100 Subject: [PATCH 42/43] Apply suggestions from code review Co-authored-by: Lyndon White --- docs/src/writing_good_rules.md | 11 ++++++----- src/projection.jl | 11 +++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 12ef1b799..6b982ae9b 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -63,11 +63,11 @@ Examples being: - There is only one derivative being returned, so from the fact that the user called `frule`/`rrule` they clearly will want to use that one. -## Use `ProjectTo` appropriately +## Ensure you remain in the primal's subspace (i.e. use `ProjectTo` appropriately) Rules with abstractly-typed arguments may return incorrect answers when called with certain concrete types. -A classic example is the matrix-matrix multiplication rule -``` +A classic example is the matrix-matrix multiplication rule, a naive definition of which follows: +```julia function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) function times_pullback(ȳ) dA = ȳ * B' @@ -79,11 +79,12 @@ end ``` When computing `*(A, B)`, where `A isa Diagonal` and `B isa Matrix`, the output will be a `Matrix`. As a result, `ȳ` in the pullback will be a `Matrix`, and consequently `dA` for a `A isa Diagonal` will be a `Matrix`, which is wrong. +Not only is it the wrong type, but it can contain non-zeros off the diagonal, which is not possible, it is outside of the subspace. While a specialised rules can indeed be written for the `Diagonal` case, there are many other types and we don't want to be forced to write a rule for each of them. Instead, `project_A = ProjectTo(A)` can be used (outside the pullback) to extract an object that knows how to project onto the type of `A` (e.g. also knows the size of the array). -This object can be called with a tangent, `project_A(ȳ * B')`, to project it on the type `A`. +This object can be called with a tangent `ȳ * B'`, by doing `project_A(ȳ * B')`, to project it on the tangent space of `A`. The correct rule then looks like -``` +```julia function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) project_A = ProjectTo(A) project_B = ProjectTo(B) diff --git a/src/projection.jl b/src/projection.jl index c2368498f..786ec170d 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,15 +1,14 @@ """ - ProjectTo(x) + ProjectTo(x::T) -Returns a `ProjectTo{P,...}` functor able to project a differential `dx` onto the type `P` -for a primal `x`. +Returns a `ProjectTo{T,...}` functor able to project a differential `dx` onto the same tangent space as `x`. This functor encloses over what ever is needed to be able to be able to do that projection. For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` is not available from `P`, so it is stored in the functor. - (::ProjectTo{P})(dx) + (::ProjectTo{T})(dx) -Projects the differential `dx` on the onto type `P`. +Projects the differential `dx` on the onto the tangent space used to create the `ProjectTo`. """ struct ProjectTo{P,D<:NamedTuple} info::D @@ -142,4 +141,4 @@ function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray} end # SubArray -ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy +ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy From 04a4e87df1772d9160a335e06de67008d20cdd86 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 6 Jul 2021 10:17:21 +0100 Subject: [PATCH 43/43] Update docs/Manifest.toml --- docs/Manifest.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 36e47a930..8c7b26a87 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" deps = ["Compat", "LinearAlgebra", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.9" +version = "0.10.11" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]