Skip to content

Commit

Permalink
use a functor for projection (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox authored Jul 6, 2021
1 parent b4f2cfa commit 3acd962
Show file tree
Hide file tree
Showing 8 changed files with 428 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
path = ".."
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.7"
version = "0.10.11"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand Down
5 changes: 5 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Pages = ["config.jl"]
Private = false
```

## ProjectTo
```@docs
ProjectTo
```

## Internal
```@docs
ChainRulesCore.AbstractTangent
Expand Down
34 changes: 34 additions & 0 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,40 @@ Examples being:
- There is only one derivative being returned, so from the fact that the user called
`frule`/`rrule` they clearly will want to use that one.

## Ensure you remain in the primal's subspace (i.e. use `ProjectTo` appropriately)

Rules with abstractly-typed arguments may return incorrect answers when called with certain concrete types.
A classic example is the matrix-matrix multiplication rule, a naive definition of which follows:
```julia
function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
function times_pullback(ȳ)
dA =* B'
dB = A' *
return NoTangent(), dA, dB
end
return A * B, times_pullback
end
```
When computing `*(A, B)`, where `A isa Diagonal` and `B isa Matrix`, the output will be a `Matrix`.
As a result, `` in the pullback will be a `Matrix`, and consequently `dA` for a `A isa Diagonal` will be a `Matrix`, which is wrong.
Not only is it the wrong type, but it can contain non-zeros off the diagonal, which is not possible, it is outside of the subspace.
While a specialised rules can indeed be written for the `Diagonal` case, there are many other types and we don't want to be forced to write a rule for each of them.
Instead, `project_A = ProjectTo(A)` can be used (outside the pullback) to extract an object that knows how to project onto the type of `A` (e.g. also knows the size of the array).
This object can be called with a tangent `ȳ * B'`, by doing `project_A(ȳ * B')`, to project it on the tangent space of `A`.
The correct rule then looks like
```julia
function rrule(::typeof(*), A::AbstractMatrix, B::AbstractMatrix)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
dA =* B'
dB = A' *
return NoTangent(), project_A(dA), project_B(dB)
end
return A * B, times_pullback
end
```

## Structs: constructors and functors

To define an `frule` or `rrule` for a _function_ `foo` we dispatch on the type of `foo`, which is `typeof(foo)`.
Expand Down
3 changes: 2 additions & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @scalar_rule, @thunk, @not_implemented
export canonicalize, extern, unthunk # differential operations
export ProjectTo, canonicalize, extern, unthunk # differential operations
export add!! # gradient accumulation operations
# differentials
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
Expand All @@ -26,6 +26,7 @@ include("differentials/notimplemented.jl")

include("differential_arithmetic.jl")
include("accumulation.jl")
include("projection.jl")

include("config.jl")
include("rules.jl")
Expand Down
1 change: 1 addition & 0 deletions src/differentials/composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ backing(x::NamedTuple) = x
backing(x::Dict) = x
backing(x::Tangent) = getfield(x, :backing)

# For generic structs
function backing(x::T)::NamedTuple where T
# note: all computation outside the if @generated happens at runtime.
# so the first 4 lines of the branchs look the same, but can not be moved out.
Expand Down
144 changes: 144 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
ProjectTo(x::T)
Returns a `ProjectTo{T,...}` functor able to project a differential `dx` onto the same tangent space as `x`.
This functor encloses over what ever is needed to be able to be able to do that projection.
For example, when projecting `dx=ZeroTangent()` on an array `P=Array{T, N}`, the size of `x`
is not available from `P`, so it is stored in the functor.
(::ProjectTo{T})(dx)
Projects the differential `dx` on the onto the tangent space used to create the `ProjectTo`.
"""
struct ProjectTo{P,D<:NamedTuple}
info::D
end
ProjectTo{P}(info::D) where {P,D<:NamedTuple} = ProjectTo{P,D}(info)
ProjectTo{P}(; kwargs...) where {P} = ProjectTo{P}(NamedTuple(kwargs))

backing(project::ProjectTo) = getfield(project, :info)
Base.getproperty(p::ProjectTo, name::Symbol) = getproperty(backing(p), name)
Base.propertynames(p::ProjectTo) = propertynames(backing(p))

function Base.show(io::IO, project::ProjectTo{T}) where {T}
print(io, "ProjectTo{")
show(io, T)
print(io, "}")
if isempty(backing(project))
print(io, "()")
else
show(io, backing(project))
end
end

# fallback (structs)
function ProjectTo(x::T) where {T}
# Generic fallback for structs, recursively make `ProjectTo`s all their fields
fields_nt::NamedTuple = backing(x)
return ProjectTo{T}(map(ProjectTo, fields_nt))
end
function (project::ProjectTo{T})(dx::Tangent) where {T}
sub_projects = backing(project)
sub_dxs = backing(canonicalize(dx))
_call(f, x) = f(x)
return construct(T, map(_call, sub_projects, sub_dxs))
end

# should not work for Tuples and NamedTuples, as not valid tangent types
function ProjectTo(x::T) where {T<:Union{<:Tuple,NamedTuple}}
return throw(
ArgumentError("The `x` in `ProjectTo(x)` must be a valid differential, not $x")
)
end

# Generic
(project::ProjectTo)(dx::AbstractThunk) = project(unthunk(dx))
(::ProjectTo{T})(dx::T) where {T} = dx # not always true, but we can special case for when it isn't
(::ProjectTo{T})(dx::AbstractZero) where {T} = zero(T)

# Number
ProjectTo(::T) where {T<:Number} = ProjectTo{T}()
(::ProjectTo{T})(dx::Number) where {T<:Number} = convert(T, dx)
(::ProjectTo{T})(dx::Number) where {T<:Real} = convert(T, real(dx))

# Arrays
ProjectTo(xs::T) where {T<:Array} = ProjectTo{T}(; elements=map(ProjectTo, xs))
function (project::ProjectTo{T})(dx::Array) where {T<:Array}
_call(f, x) = f(x)
return T(map(_call, project.elements, dx))
end
function (project::ProjectTo{T})(dx::AbstractZero) where {T<:Array}
return T(map(proj -> proj(dx), project.elements))
end
(project::ProjectTo{<:Array})(dx::AbstractArray) = project(collect(dx))

# Arrays{<:Number}: optimized case so we don't need a projector per element
function ProjectTo(x::T) where {E<:Number,T<:Array{E}}
return ProjectTo{T}(; element=ProjectTo(zero(E)), size=size(x))
end
(project::ProjectTo{<:Array{T}})(dx::Array) where {T<:Number} = project.element.(dx)
function (project::ProjectTo{<:Array{T}})(dx::AbstractZero) where {T<:Number}
return zeros(T, project.size)
end
function (project::ProjectTo{<:Array{T}})(dx::Tangent{<:SubArray}) where {T<:Number}
return project(dx.parent)
end

# Diagonal
ProjectTo(x::T) where {T<:Diagonal} = ProjectTo{T}(; diag=ProjectTo(diag(x)))
(project::ProjectTo{T})(dx::AbstractMatrix) where {T<:Diagonal} = T(project.diag(diag(dx)))
(project::ProjectTo{T})(dx::AbstractZero) where {T<:Diagonal} = T(project.diag(dx))

# :data, :uplo fields
for SymHerm in (:Symmetric, :Hermitian)
@eval begin
function ProjectTo(x::T) where {T<:$SymHerm}
return ProjectTo{T}(; uplo=Symbol(x.uplo), parent=ProjectTo(parent(x)))
end
function (project::ProjectTo{<:$SymHerm})(dx::AbstractMatrix)
return $SymHerm(project.parent(dx), project.uplo)
end
function (project::ProjectTo{<:$SymHerm})(dx::AbstractZero)
return $SymHerm(project.parent(dx), project.uplo)
end
function (project::ProjectTo{<:$SymHerm})(dx::Tangent)
return $SymHerm(project.parent(dx.data), project.uplo)
end
end
end

# :data field
for UL in (:UpperTriangular, :LowerTriangular)
@eval begin
ProjectTo(x::T) where {T<:$UL} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:$UL})(dx::AbstractMatrix) = $UL(project.parent(dx))
(project::ProjectTo{<:$UL})(dx::AbstractZero) = $UL(project.parent(dx))
(project::ProjectTo{<:$UL})(dx::Tangent) = $UL(project.parent(dx.data))
end
end

# Transpose
ProjectTo(x::T) where {T<:Transpose} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
function (project::ProjectTo{<:Transpose})(dx::AbstractMatrix)
return transpose(project.parent(transpose(dx)))
end
(project::ProjectTo{<:Transpose})(dx::AbstractZero) = transpose(project.parent(dx))

# Adjoint
ProjectTo(x::T) where {T<:Adjoint} = ProjectTo{T}(; parent=ProjectTo(parent(x)))
(project::ProjectTo{<:Adjoint})(dx::AbstractMatrix) = adjoint(project.parent(adjoint(dx)))
(project::ProjectTo{<:Adjoint})(dx::AbstractZero) = adjoint(project.parent(dx))

# PermutedDimsArray
ProjectTo(x::P) where {P<:PermutedDimsArray} = ProjectTo{P}(; parent=ProjectTo(parent(x)))
function (project::ProjectTo{<:PermutedDimsArray{T,N,perm,iperm,AA}})(
dx::AbstractArray
) where {T,N,perm,iperm,AA}
return PermutedDimsArray{T,N,perm,iperm,AA}(permutedims(project.parent(dx), perm))
end
function (project::ProjectTo{P})(dx::AbstractZero) where {P<:PermutedDimsArray}
return P(project.parent(dx))
end

# SubArray
ProjectTo(x::T) where {T<:SubArray} = ProjectTo(copy(x)) # don't project on to a view, but onto matching copy
Loading

2 comments on commit 3acd962

@mzgubic
Copy link
Member

@mzgubic mzgubic commented on 3acd962 Jul 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.10.10 already exists and points to a different commit"

Please sign in to comment.