Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tangent Interface Docs #434

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.70"
version = "0.4.71"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ makedocs(;
"Developer Documentation" => [
joinpath("developer_documentation", "running_tests_locally.md"),
joinpath("developer_documentation", "developer_tools.md"),
joinpath("developer_documentation", "tangents.md"),
joinpath("developer_documentation", "forwards_mode_design.md"),
joinpath("developer_documentation", "internal_docstrings.md"),
],
Expand Down
47 changes: 47 additions & 0 deletions docs/src/developer_documentation/tangents.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Tangents

As discussed in [Representing Gradients](@ref), Mooncake requires that each "primal" type be associated to a unique "tangent" type, given by the function [tangent_type](@ref).
Moreover, we must be able to "split" a given tangent into its _fdata_ ("forwards-data") and _rdata_ ("reverse-data"), whose types are given by [fdata_type](@ref) and `rdata_type` respectively.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe miss ref for rdata_type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. For some reason, it didn't like me doing that. I'm not 100% sure why.


Very occassionally it may be necessary to specify your own tangent type.
This is not an entirely trivial undertaking, as there is quite a lot of functionality that must be added to make it work properly.
So, before diving in to add your own custom type, seriously consider whether it is worth the effort, and whether the default definition given by Mooncake are really inadequate for your use-case.

## Testing Functionality

The interface is given in the form of three functions:
```@docs
Mooncake.TestUtils.test_tangent_interface
Mooncake.TestUtils.test_tangent_splitting
Mooncake.TestUtils.test_rule_and_type_interactions
```

You can call all three of these functions at once using
```@docs
Mooncake.TestUtils.test_data
```

If all the tests in these functions pass, then you have satisfied the interface.

## Interface

Below are the docstrings for each function tested by [`Mooncake.TestUtils.test_tangent_interface`](@ref) and [`Mooncake.TestUtils.test_tangent_splitting`](@ref).

```@docs
Mooncake.tangent_type
Mooncake.zero_tangent
Mooncake.randn_tangent
Mooncake.TestUtils.has_equal_data
Mooncake.increment!!
Mooncake.set_to_zero!!
Mooncake._add_to_primal
Mooncake._diff
Mooncake._dot
Mooncake._scale
Mooncake.TestUtils.populate_address_map
Mooncake.fdata_type
Mooncake.rdata_type
Mooncake.fdata
Mooncake.rdata
Mooncake.uninit_fdata
```
6 changes: 5 additions & 1 deletion docs/src/understanding_mooncake/rule_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ Conversely, we refer to both inputs and outputs to the adjoint of this derivativ
Note, however, that the sets involved are the same whether dealing with a derivative or its adjoint.
Consequently, we use the same type to represent both.


_**Representing Gradients**_

This package assigns to each type in Julia a unique `tangent_type`, the purpose of which is to contain the gradients computed during reverse mode AD.
Expand All @@ -299,6 +298,11 @@ The following docstring provides the best in-depth explanation.
Mooncake.fdata_type(T)
```

_**More Info**_

See [Tangents](@ref) for complete information on what you must do if you wish to implement your own tangent type.
(In the vast majority of cases this is unnecessary).

_**CoDuals**_

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context.
Expand Down
5 changes: 5 additions & 0 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ end
return :(tuple_map(fdata, t))
end

"""
uninit_fdata(p)

Equivalent to `fdata(uninit_tangent(p))`.
"""
uninit_fdata(p) = fdata(uninit_tangent(p))

"""
Expand Down
160 changes: 42 additions & 118 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,125 +874,49 @@
N_large = 33
_names = Tuple(map(n -> Symbol("x$n"), 1:N_large))

abs_test_cases = vcat(
[
(sin, NoTangent(), NoTangent(), NoTangent()),
(map(Float16, (5.0, 4.0, 3.1, 7.1))...),
(5.0f0, 4.0f0, 3.0f0, 7.0f0),
(5.1, 4.0, 3.0, 7.0),
(svec(5.0), Any[4.0], Any[3.0], Any[7.0]),
([3.0, 2.0], [1.0, 2.0], [2.0, 3.0], [3.0, 5.0]),
(Float64[], Float64[], Float64[], Float64[]),
(
[1, 2],
[NoTangent(), NoTangent()],
[NoTangent(), NoTangent()],
[NoTangent(), NoTangent()],
),
(
[[1.0], [1.0, 2.0]],
[[2.0], [2.0, 3.0]],
[[3.0], [4.0, 5.0]],
[[5.0], [6.0, 8.0]],
),
(
setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 1),
setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 1),
setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 1),
setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 1),
),
(
setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 2),
setindex!(Vector{Vector{Float64}}(undef, 2), [2.0], 2),
setindex!(Vector{Vector{Float64}}(undef, 2), [3.0], 2),
setindex!(Vector{Vector{Float64}}(undef, 2), [5.0], 2),
),
((6.0, [1.0, 2.0]), (5.0, [3.0, 4.0]), (4.0, [4.0, 3.0]), (9.0, [7.0, 7.0])),
((), NoTangent(), NoTangent(), NoTangent()),
((1,), NoTangent(), NoTangent(), NoTangent()),
((2, 3), NoTangent(), NoTangent(), NoTangent()),
(
Mooncake.tuple_fill(5.0, Val(N_large)),
Mooncake.tuple_fill(6.0, Val(N_large)),
Mooncake.tuple_fill(7.0, Val(N_large)),
Mooncake.tuple_fill(13.0, Val(N_large)),
),
(
(a=6.0, b=[1.0, 2.0]),
(a=5.0, b=[3.0, 4.0]),
(a=4.0, b=[4.0, 3.0]),
(a=9.0, b=[7.0, 7.0]),
),
((;), NoTangent(), NoTangent(), NoTangent()),
(
NamedTuple{_names}(Mooncake.tuple_fill(5.0, Val(N_large))),
NamedTuple{_names}(Mooncake.tuple_fill(6.0, Val(N_large))),
NamedTuple{_names}(Mooncake.tuple_fill(7.0, Val(N_large))),
NamedTuple{_names}(Mooncake.tuple_fill(13.0, Val(N_large))),
),
(
TestResources.TypeStableMutableStruct{Float64}(5.0, 3.0),
build_tangent(TestResources.TypeStableMutableStruct{Float64}, 5.0, 4.0),
build_tangent(TestResources.TypeStableMutableStruct{Float64}, 3.0, 3.0),
build_tangent(TestResources.TypeStableMutableStruct{Float64}, 8.0, 7.0),
),
( # complete init
TestResources.StructFoo(6.0, [1.0, 2.0]),
build_tangent(TestResources.StructFoo, 5.0, [3.0, 4.0]),
build_tangent(TestResources.StructFoo, 3.0, [2.0, 1.0]),
build_tangent(TestResources.StructFoo, 8.0, [5.0, 5.0]),
),
( # partial init
TestResources.StructFoo(6.0),
build_tangent(TestResources.StructFoo, 5.0),
build_tangent(TestResources.StructFoo, 4.0),
build_tangent(TestResources.StructFoo, 9.0),
),
( # complete init
TestResources.MutableFoo(6.0, [1.0, 2.0]),
build_tangent(TestResources.MutableFoo, 5.0, [3.0, 4.0]),
build_tangent(TestResources.MutableFoo, 3.0, [2.0, 1.0]),
build_tangent(TestResources.MutableFoo, 8.0, [5.0, 5.0]),
),
( # partial init
TestResources.MutableFoo(6.0),
build_tangent(TestResources.MutableFoo, 5.0),
build_tangent(TestResources.MutableFoo, 4.0),
build_tangent(TestResources.MutableFoo, 9.0),
),
(
TestResources.StructNoFwds(5.0),
build_tangent(TestResources.StructNoFwds, 5.0),
build_tangent(TestResources.StructNoFwds, 4.0),
build_tangent(TestResources.StructNoFwds, 9.0),
),
(
TestResources.StructNoRvs([5.0]),
build_tangent(TestResources.StructNoRvs, [5.0]),
build_tangent(TestResources.StructNoRvs, [4.0]),
build_tangent(TestResources.StructNoRvs, [9.0]),
),
(UnitRange{Int}(5, 7), NoTangent(), NoTangent(), NoTangent()),
],
map([
LowerTriangular{Float64,Matrix{Float64}},
UpperTriangular{Float64,Matrix{Float64}},
UnitLowerTriangular{Float64,Matrix{Float64}},
UnitUpperTriangular{Float64,Matrix{Float64}},
]) do T
return (
T(randn(2, 2)),
build_tangent(T, [1.0 2.0; 3.0 4.0]),
build_tangent(T, [2.0 1.0; 5.0 4.0]),
build_tangent(T, [3.0 3.0; 8.0 8.0]),
)
end,
[
(p, NoTangent(), NoTangent(), NoTangent()) for
p in [Array, Float64, Union{Float64,Float32}, Union, UnionAll, typeof(<:)]
],
)
abs_test_cases = [

Check warning on line 877 in src/tangents.jl

View check run for this annotation

Codecov / codecov/patch

src/tangents.jl#L877

Added line #L877 was not covered by tests
(sin, NoTangent),
(Float16(5.0), Float16),
(5.0f0, Float32),
(5.1, Float64),
(svec(5.0), Vector{Any}),
([3.0, 2.0], Vector{Float64}),
(Float64[], Vector{Float64}),
([1, 2], Vector{NoTangent}),
([[1.0], [1.0, 2.0]], Vector{Vector{Float64}}),
(setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 1), Vector{Vector{Float64}}),
(setindex!(Vector{Vector{Float64}}(undef, 2), [1.0], 2), Vector{Vector{Float64}}),
((6.0, [1.0, 2.0]), Tuple{Float64,Vector{Float64}}),
((), NoTangent),
((1,), NoTangent),
((2, 3), NoTangent),
(Mooncake.tuple_fill(5.0, Val(N_large)), NTuple{N_large,Float64}),
((a=6.0, b=[1.0, 2.0]), @NamedTuple{a::Float64, b::Vector{Float64}}),

Check warning on line 894 in src/tangents.jl

View check run for this annotation

Codecov / codecov/patch

src/tangents.jl#L894

Added line #L894 was not covered by tests
((;), NoTangent),
(
NamedTuple{_names}(Mooncake.tuple_fill(5.0, Val(N_large))),
NamedTuple{_names,NTuple{N_large,Float64}},
),
(UnitRange{Int}(5, 7), NoTangent),
(Array, NoTangent),
(Float64, NoTangent),
(Union{Float64,Float32}, NoTangent),
(Union, NoTangent),
(UnionAll, NoTangent),
(typeof(<:), NoTangent),
]
rel_test_cases = Any[
TestResources.StructFoo(6.0, [1.0, 2.0]),
TestResources.StructFoo(6.0),
TestResources.MutableFoo(6.0, [1.0, 2.0]),
TestResources.MutableFoo(6.0),
TestResources.StructNoFwds(5.0),
TestResources.StructNoRvs([5.0]),
TestResources.TypeStableMutableStruct{Float64}(5.0, 3.0),
LowerTriangular{Float64,Matrix{Float64}}(randn(2, 2)),
UpperTriangular{Float64,Matrix{Float64}}(randn(2, 2)),
UnitLowerTriangular{Float64,Matrix{Float64}}(randn(2, 2)),
UnitUpperTriangular{Float64,Matrix{Float64}}(randn(2, 2)),
(2.0, 3),
(3, 2.0),
(2.0, 1.0),
Expand Down
Loading