From 3acd962ad79e3679398b3d162180dd107036a75b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 6 Jul 2021 11:18:57 +0100 Subject: [PATCH] use a functor for projection (#385) --- docs/Manifest.toml | 2 +- docs/src/api.md | 5 + docs/src/writing_good_rules.md | 34 +++++ src/ChainRulesCore.jl | 3 +- src/differentials/composite.jl | 1 + src/projection.jl | 144 ++++++++++++++++++++ test/projection.jl | 240 +++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 8 files changed, 428 insertions(+), 2 deletions(-) create mode 100644 src/projection.jl create mode 100644 test/projection.jl diff --git a/docs/Manifest.toml b/docs/Manifest.toml index be1de955b..8c7b26a87 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" deps = ["Compat", "LinearAlgebra", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.10.7" +version = "0.10.11" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] diff --git a/docs/src/api.md b/docs/src/api.md index 5fcd53400..6ba5363ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -41,6 +41,11 @@ Pages = ["config.jl"] Private = false ``` +## ProjectTo +```@docs +ProjectTo +``` + ## Internal ```@docs ChainRulesCore.AbstractTangent diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 872a160b4..6b982ae9b 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -63,6 +63,40 @@ Examples being: - There is only one derivative being returned, so from the fact that the user called `frule`/`rrule` they clearly will want to use that one. +## Ensure you remain in the primal's subspace (i.e. use `ProjectTo` appropriately) + +Rules with abstractly-typed arguments may return incorrect answers when called with certain concrete types. +A classic example is the matrix-matrix multiplication rule, a naive definition of which follows: +```julia +function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + function times_pullback(ȳ) + dA = ȳ * B' + dB = A' * ȳ + return NoTangent(), dA, dB + end + return A * B, times_pullback +end +``` +When computing `*(A, B)`, where `A isa Diagonal` and `B isa Matrix`, the output will be a `Matrix`. +As a result, `ȳ` in the pullback will be a `Matrix`, and consequently `dA` for a `A isa Diagonal` will be a `Matrix`, which is wrong. +Not only is it the wrong type, but it can contain non-zeros off the diagonal, which is not possible, it is outside of the subspace. +While a specialised rules can indeed be written for the `Diagonal` case, there are many other types and we don't want to be forced to write a rule for each of them. +Instead, `project_A = ProjectTo(A)` can be used (outside the pullback) to extract an object that knows how to project onto the type of `A` (e.g. also knows the size of the array). +This object can be called with a tangent `ȳ * B'`, by doing `project_A(ȳ * B')`, to project it on the tangent space of `A`. +The correct rule then looks like +```julia +function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix) + project_A = ProjectTo(A) + project_B = ProjectTo(B) + function times_pullback(ȳ) + dA = ȳ * B' + dB = A' * ȳ + return NoTangent(), project_A(dA), project_B(dB) + end + return A * B, times_pullback +end +``` + ## Structs: constructors and functors To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index e3bd53deb..1582e7004 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @scalar_rule, @thunk, @not_implemented -export canonicalize, extern, unthunk # differential operations +export ProjectTo, canonicalize, extern, unthunk # differential operations export add!! # gradient accumulation operations # differentials export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl") include("differential_arithmetic.jl") include("accumulation.jl") +include("projection.jl") include("config.jl") include("rules.jl") diff --git a/src/differentials/composite.jl b/src/differentials/composite.jl index 2e47691cd..9a8b9870a 100644 --- a/src/differentials/composite.jl +++ b/src/differentials/composite.jl @@ -133,6 +133,7 @@ backing(x::NamedTuple) = x backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) +# For generic structs function backing(x::T)::NamedTuple where T # note: all computation outside the if @generated happens at runtime. # so the first 4 lines of the branchs look the same, but can not be moved out. diff --git a/src/projection.jl b/src/projection.jl new file mode 100644 index 000000000..786ec170d --- /dev/null +++ b/src/projection.jl @@ -0,0 +1,144 @@ +""" + ProjectTo(x::T) + +Returns a `ProjectTo{T,...}` functor able to project a differential `dx` onto the same tangent space as `x`. +This functor encloses over what ever is needed to be able to be able to do that projection. +For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x` +is not available from `P`, so it is stored in the functor. + + (::ProjectTo{T})(dx) + +Projects the differential `dx` on the onto the tangent space used to create the `ProjectTo`. +""" +struct ProjectTo{P,D<:NamedTuple} + info::D +end +ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info) +ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs)) + +backing(project::ProjectTo) = getfield(project, :info) +Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name) +Base.propertynames(p::ProjectTo) = propertynames(backing(p)) + +function Base.show(io::IO, project::ProjectTo{T}) where {T} + print(io, "ProjectTo{") + show(io, T) + print(io, "}") + if isempty(backing(project)) + print(io, "()") + else + show(io, backing(project)) + end +end + +# fallback (structs) +function ProjectTo(x::T) where {T} + # Generic fallback for structs, recursively make `ProjectTo`s all their fields + fields_nt::NamedTuple = backing(x) + return ProjectTo{T}(map(ProjectTo, fields_nt)) +end +function (project::ProjectTo{T})(dx::Tangent) where {T} + sub_projects = backing(project) + sub_dxs = backing(canonicalize(dx)) + _call(f, x) = f(x) + return construct(T, map(_call, sub_projects, sub_dxs)) +end + +# should not work for Tuples and NamedTuples, as not valid tangent types +function ProjectTo(x::T) where {T<:Union{<:Tuple,NamedTuple}} + return throw( + ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x") + ) +end + +# Generic +(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx)) +(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't +(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T) + +# Number +ProjectTo(::T) where {T<:Number} = ProjectTo{T}() +(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx) +(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx)) + +# Arrays +ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs)) +function (project::ProjectTo{T})(dx::Array) where {T<:Array} + _call(f, x) = f(x) + return T(map(_call, project.elements, dx)) +end +function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array} + return T(map(proj -> proj(dx), project.elements)) +end +(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx)) + +# Arrays{<:Number}: optimized case so we don't need a projector per element +function ProjectTo(x::T) where {E<:Number,T<:Array{E}} + return ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x)) +end +(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx) +function (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number} + return zeros(T, project.size) +end +function (project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number} + return project(dx.parent) +end + +# Diagonal +ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x))) +(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx))) +(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx)) + +# :data, :uplo fields +for SymHerm in (:Symmetric, :Hermitian) + @eval begin + function ProjectTo(x::T) where {T<:$SymHerm} + return ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x))) + end + function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix) + return $SymHerm(project.parent(dx), project.uplo) + end + function (project::ProjectTo{<:$SymHerm})(dx::AbstractZero) + return $SymHerm(project.parent(dx), project.uplo) + end + function (project::ProjectTo{<:$SymHerm})(dx::Tangent) + return $SymHerm(project.parent(dx.data), project.uplo) + end + end +end + +# :data field +for UL in (:UpperTriangular, :LowerTriangular) + @eval begin + ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x))) + (project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx)) + (project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx)) + (project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data)) + end +end + +# Transpose +ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x))) +function (project::ProjectTo{<:Transpose})(dx::AbstractMatrix) + return transpose(project.parent(transpose(dx))) +end +(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx)) + +# Adjoint +ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x))) +(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx))) +(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx)) + +# PermutedDimsArray +ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x))) +function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})( + dx::AbstractArray +) where {T,N,perm,iperm,AA} + return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm)) +end +function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray} + return P(project.parent(dx)) +end + +# SubArray +ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy diff --git a/test/projection.jl b/test/projection.jl new file mode 100644 index 000000000..7bd8d266a --- /dev/null +++ b/test/projection.jl @@ -0,0 +1,240 @@ +struct Fred + a::Float64 +end +Base.zero(::Fred) = Fred(0.0) +Base.zero(::Type{Fred}) = Fred(0.0) + +struct Freddy{T,N} + a::Array{T,N} +end +Base.:(==)(a::Freddy, b::Freddy) = a.a == b.a + +struct Mary + a::Fred +end + +struct TwoFields + a::Float64 + b::Float64 +end + +@testset "projection" begin + @testset "display" begin + @test startswith(repr(ProjectTo(Fred(1.1))), "ProjectTo{Fred}(") + @test repr(ProjectTo(1.1)) == "ProjectTo{Float64}()" + end + + @testset "fallback" begin + @test Fred(1.2) == ProjectTo(Fred(1.1))(Fred(1.2)) + @test Fred(0.0) == ProjectTo(Fred(1.1))(ZeroTangent()) + @test Fred(3.2) == ProjectTo(Fred(1.1))(@thunk(Fred(3.2))) + @test Fred(1.2) == ProjectTo(Fred(1.1))(Tangent{Fred}(; a=1.2)) + + # struct with complicated field + x = Freddy(zeros(2, 2)) + dx = Tangent{Freddy}(; a=ZeroTangent()) + @test x == ProjectTo(x)(dx) + + # nested structs + f = Fred(0.0) + tf = Tangent{Fred}(; a=ZeroTangent()) + m = Mary(f) + dm = Tangent{Mary}(; a=tf) + @test m == ProjectTo(m)(dm) + + # two fields + tfa = TwoFields(3.0, 0.0) + tfb = TwoFields(0.0, 3.0) + @test tfa == ProjectTo(tfa)(Tangent{TwoFields}(; a=3.0)) + @test tfb == ProjectTo(tfb)(Tangent{TwoFields}(; b=3.0)) + end + + @testset "to Real" begin + # Float64 + @test 3.2 == ProjectTo(1.0)(3.2) + @test 0.0 == ProjectTo(1.0)(ZeroTangent()) + @test 3.2 == ProjectTo(1.0)(@thunk(3.2)) + + # down + @test 3.2 == ProjectTo(1.0)(3.2 + 3im) + @test 3.2f0 == ProjectTo(1.0f0)(3.2) + @test 3.2f0 == ProjectTo(1.0f0)(3.2 - 3im) + + # up + @test 2.0 == ProjectTo(1.0)(2.0f0) + end + + @testset "to Number" begin + # Complex + @test 2.0 + 4.0im == ProjectTo(1.0im)(2.0 + 4.0im) + + # down + @test 2.0 + 0.0im == ProjectTo(1.0im)(2.0) + @test 0.0 + 0.0im == ProjectTo(1.0im)(ZeroTangent()) + @test 0.0 + 0.0im == ProjectTo(1.0im)(@thunk(ZeroTangent())) + + # up + @test 2.0 + 0.0im == ProjectTo(1.0im)(2.0) + end + + @testset "to Array" begin + # to an array of numbers + x = zeros(2, 2) + @test [1.0 2.0; 3.0 4.0] == ProjectTo(x)([1.0 2.0; 3.0 4.0]) + @test x == ProjectTo(x)(ZeroTangent()) + + x = zeros(2) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) + + x = zeros(Float32, 2, 2) + @test x == ProjectTo(x)([0.0 0; 0 0]) + + x = [1.0 0; 0 4] + @test x == ProjectTo(x)(Diagonal([1.0, 4])) + + # to a array of structs + x = [Fred(0.0), Fred(0.0)] + @test x == ProjectTo(x)([Fred(0.0), Fred(0.0)]) + @test x == ProjectTo(x)([ZeroTangent(), ZeroTangent()]) + @test x == ProjectTo(x)([ZeroTangent(), @thunk(Fred(0.0))]) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) + + x = [Fred(1.0) Fred(0.0); Fred(0.0) Fred(4.0)] + @test x == ProjectTo(x)(Diagonal([Fred(1.0), Fred(4.0)])) + end + + @testset "To Array of Arrays" begin + # inner arrays have same type but different sizes + x = [[1.0, 2.0, 3.0], [4.0, 5.0]] + @test x == ProjectTo(x)(x) + @test x == ProjectTo(x)([[1.0 + 2im, 2.0, 3.0], [4.0 + 2im, 5.0]]) + + # This makes sure we don't fall for https://github.com/JuliaLang/julia/issues/38064 + @test [[0.0, 0.0, 0.0], [0.0, 0.0]] == ProjectTo(x)(ZeroTangent()) + end + + @testset "Array{Any} with really messy contents" begin + # inner arrays have same type but different sizes + x = [[1.0, 2.0, 3.0], [4.0 + im 5.0], [[[Fred(1)]]]] + @test x == ProjectTo(x)(x) + @test x == ProjectTo(x)([[1.0 + im, 2.0, 3.0], [4.0 + im 5.0], [[[Fred(1)]]]]) + # using a different type for the 2nd element (Adjoint) + @test x == ProjectTo(x)([[1.0 + im, 2.0, 3.0], [4.0 - im, 5.0]', [[[Fred(1)]]]]) + + @test [[0.0, 0.0, 0.0], [0.0im 0.0], [[[Fred(0)]]]] == ProjectTo(x)(ZeroTangent()) + end + + @testset "to Diagonal" begin + d_F64 = Diagonal([0.0, 0.0]) + d_F32 = Diagonal([0.0f0, 0.0f0]) + d_C64 = Diagonal([0.0 + 0im, 0.0]) + d_Fred = Diagonal([Fred(0.0), Fred(0.0)]) + + # from Matrix + @test d_F64 == ProjectTo(d_F64)(zeros(2, 2)) + @test d_F64 == ProjectTo(d_F64)(zeros(Float32, 2, 2)) + @test d_F64 == ProjectTo(d_F64)(zeros(ComplexF64, 2, 2)) + + # from Diagonal of Numbers + @test d_F64 == ProjectTo(d_F64)(d_F64) + @test d_F64 == ProjectTo(d_F64)(d_F32) + @test d_F64 == ProjectTo(d_F64)(d_C64) + + # from Diagonal of AbstractTangent + @test d_F64 == ProjectTo(d_F64)(ZeroTangent()) + @test d_C64 == ProjectTo(d_C64)(ZeroTangent()) + @test d_F64 == ProjectTo(d_F64)(@thunk(ZeroTangent())) + @test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), ZeroTangent()])) + @test d_F64 == ProjectTo(d_F64)(Diagonal([ZeroTangent(), @thunk(ZeroTangent())])) + + # from Diagonal of structs + @test d_Fred == ProjectTo(d_Fred)(ZeroTangent()) + @test d_Fred == ProjectTo(d_Fred)(@thunk(ZeroTangent())) + @test d_Fred == ProjectTo(d_Fred)(Diagonal([ZeroTangent(), ZeroTangent()])) + + # from Tangent + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(; diag=[0.0, 0.0])) + @test d_F64 == ProjectTo(d_F64)(Tangent{Diagonal}(; diag=[0.0f0, 0.0f0])) + @test d_F64 == ProjectTo(d_F64)( + Tangent{Diagonal}(; diag=[ZeroTangent(), @thunk(ZeroTangent())]) + ) + end + + @testset "to $SymHerm" for SymHerm in (Symmetric, Hermitian) + data = [1.0+1im 2-2im; 3 4] + + x = SymHerm(data) + @test x == ProjectTo(x)(data) + @test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data, uplo=NoTangent())) + + x = SymHerm(data, :L) + @test x == ProjectTo(x)(data) + + data = [1.0-2im 0; 0 4] + x = SymHerm(data) + @test x == ProjectTo(x)(Diagonal([1.0 - 2im, 4.0])) + + data = [0.0+0im 0; 0 0] + x = SymHerm(data) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) + end + + @testset "to $UL" for UL in (UpperTriangular, LowerTriangular) + data = [1.0+1im 2-2im; 3 4] + + x = UL(data) + @test x == ProjectTo(x)(data) + @test x == ProjectTo(x)(Tangent{typeof(x)}(; data=data)) + + data = [0.0+0im 0; 0 0] + x = UL(data) + @test x == ProjectTo(x)(Diagonal(zeros(2))) + @test x == ProjectTo(x)(ZeroTangent()) + @test x == ProjectTo(x)(@thunk(ZeroTangent())) + end + + @testset "to Transpose" begin + x = rand(ComplexF64, 3, 4) + t = transpose(x) + mt = collect(t) + a = adjoint(x) + ma = collect(a) + + @test t == ProjectTo(t)(mt) + @test conj(t) == ProjectTo(t)(ma) + @test zeros(4, 3) == ProjectTo(t)(ZeroTangent()) + @test zeros(4, 3) == ProjectTo(t)(Tangent{Transpose}(; parent=ZeroTangent())) + end + + @testset "to Adjoint" begin + x = rand(ComplexF64, 3, 4) + a = adjoint(x) + ma = collect(a) + + @test a == ProjectTo(a)(ma) + @test zeros(4, 3) == ProjectTo(a)(ZeroTangent()) + @test zeros(4, 3) == ProjectTo(a)(Tangent{Adjoint}(; parent=ZeroTangent())) + end + + @testset "to PermutedDimsArray" begin + a = zeros(3, 5, 4) + b = PermutedDimsArray(a, (2, 1, 3)) + bc = collect(b) + + @test b == ProjectTo(b)(bc) + @test b == ProjectTo(b)(ZeroTangent()) + end + + @testset "to SubArray" begin + x = rand(3, 4) + sa = view(x, :, 1:2) + m = collect(sa) + + # make sure it converts the view to the parent type + @test ProjectTo(sa)(m) isa Matrix + @test zeros(3, 2) == ProjectTo(sa)(ZeroTangent()) + @test ProjectTo(sa)(Tangent{SubArray}(; parent=ZeroTangent())) isa Matrix + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 090a85828..f98e971b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ using Test end include("accumulation.jl") + include("projection.jl") include("rules.jl") include("rule_definition_tools.jl")