Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: handle zero_tangent in presence of cyclic structures (v1) #654

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7f15d11
rename files
oxinabox Aug 4, 2023
c9b4938
move functionality up to StructuralTangent
oxinabox Aug 4, 2023
a51e51e
Formatting
oxinabox Aug 4, 2023
93af90b
WIP mutable Tangent (squash me)
oxinabox Aug 21, 2023
418b5ce
wip
oxinabox Sep 15, 2023
724ba1b
First pass at something that maybe works
oxinabox Sep 15, 2023
c9a65df
accept int index
oxinabox Sep 18, 2023
06b51f5
add == and hash for MutableTangent
oxinabox Sep 26, 2023
ed3aa1d
add and test zero_tangent
oxinabox Sep 26, 2023
f45fbc7
export StructuralTangent
oxinabox Sep 28, 2023
b2bdb26
Style
oxinabox Oct 2, 2023
0438217
handle unassigned a bit more
oxinabox Oct 4, 2023
4852c91
add some more test cases to zero_tangent
oxinabox Oct 4, 2023
0f82019
style
oxinabox Oct 4, 2023
5574691
Handle Structs with undef fields
oxinabox Oct 6, 2023
e9cc221
overhaul zero_tangent and MutableTangent for type stability
oxinabox Dec 22, 2023
baea9d3
set MutableTangent setproperty! on index
oxinabox Dec 27, 2023
a27f1b6
formatting
oxinabox Dec 27, 2023
4cfce0b
handle abstract fields right in mutable tangents outside of zero tangent
oxinabox Dec 28, 2023
ad9a5af
formatting
oxinabox Dec 28, 2023
8b3d525
Add docs for forward mutation support
oxinabox Dec 28, 2023
c09ff91
use ismutabletype from Compat
oxinabox Dec 29, 2023
59fc470
wrap structural tangent tests in a common testset
oxinabox Dec 29, 2023
ade0c3d
Support types that have no tangent space in zero_tangent
oxinabox Dec 29, 2023
e068cb6
define zero_tangent for Tangent
oxinabox Jan 16, 2024
45de6a7
Add structural zero tangent code for higher order
oxinabox Jan 17, 2024
780ed05
Formatting
oxinabox Jan 17, 2024
2795872
overload show for mutable tangent
oxinabox Jan 17, 2024
7e9e778
formatting
oxinabox Jan 19, 2024
e912e46
move show code to `Common` area
oxinabox Jan 23, 2024
e478e7f
docs more consistent
oxinabox Jan 23, 2024
7d95866
Update src/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
fe63c33
Update test/tangent_types/structural_tangent.jl
oxinabox Jan 23, 2024
5fbbe5b
Handle circular references with-in mutable structs
oxinabox Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle abstract fields right in mutable tangents outside of zero tangent
  • Loading branch information
oxinabox committed Jan 23, 2024
commit 4cfce0bc4346648f66933d8104b27b71f2726ed8
10 changes: 6 additions & 4 deletions src/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ It itself is also mutable.
struct MutableTangent{P,F} <: StructuralTangent{P}
backing::F

function MutableTangent{P}(fieldvals) where P
backing = map(Ref, fieldvals)
return new{P, typeof(backing)}(backing)
end
function MutableTangent{P}(
any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names}
) where {names, P}
Expand All @@ -91,8 +87,14 @@ struct MutableTangent{P,F} <: StructuralTangent{P}
end
return new{P, typeof(backing)}(backing)
end

function MutableTangent{P}(fvals) where P
any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P)))
return MutableTangent{P}(any_mask, fvals)
end
end


####################################################################
# StructuralTangent Common

Expand Down
41 changes: 36 additions & 5 deletions test/tangent_types/structural_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ struct Foo
y::Float64
end

mutable struct MFoo
x::Float64
y
end

# For testing Primal + Tangent performance
struct Bar
x::Float64
Expand Down Expand Up @@ -452,14 +457,40 @@ end
end

@testset "== and hash" begin
@test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0)
@test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0)
@test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0)
@test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0)
@test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0)

nt = (; x=1.0)
@test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0)

@test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0))
@test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0))
end

@testset "Mutation" begin
v = MutableTangent{MFoo}(x=1.5, y=2.4)
v.x = 1.6
@test v == MutableTangent{MFoo}(x=1.6, y=2.4)
v.y = [1.0, 2.0] # change type, because primal can change type
@test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0])
end
end

@testset "map" begin
@testset "Tangent" begin
∂foo = Tangent{Foo}(x=1.5, y=2.4)
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8)

∂foo = Tangent{Foo}(x=1.5)
@test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0)
end
@testset "MutableTangent" begin
∂foo = MutableTangent{MFoo}(x=1.5, y=2.4)
∂foo2 = map(v->2*v, ∂foo)
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8)
# Check can still be mutated to new typ
∂foo2.y=[1.0, 2.0]
@test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0])
end
end