Skip to content

Commit

Permalink
teach Flux.state about Duplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 24, 2024
1 parent 126e7bd commit 55d24eb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,14 @@ const STATE_TYPES = Union{AbstractArray, Number, Nothing, AbstractString, Symbol

_state(x::STATE_TYPES) = x
_state(x) = ()

#=
Starting with `gradient(f, m) == gradient(f, Duplicated(m))`,
we choose to regard `Duplicated` as some kind of label, not part of the model tree,
and avoid outer NamedTuples like `(; val=..., dval=...)`.
We certainly don't want to save model gradients alongside parameters/settings:
=#
state(x::EnzymeCore.Duplicated) = state(x.val)

loadmodel!(dst::EnzymeCore.Duplicated, src::EnzymeCore.Duplicated; kw...) = @invoke loadmodel!(dst::Any, src::Any; kw...)
loadmodel!(dst::EnzymeCore.Duplicated, src; kw...) = (loadmodel!(dst.val, src; kw...); dst)
9 changes: 9 additions & 0 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ end
# setup understands Duplicated:
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)

# state, loadmodel do too -- all ignore the dval branch, no outer (; val, dval) namedtuple
@test Flux.state(m1) == Flux.state(m1.val)
oldmodel = deepcopy(m1)
oldpar = deepcopy(Flux.state(m1))
m1.val.weight .= 0
@test Flux.loadmodel!(m1, oldmodel).val.weight oldpar.weight
m1.val.weight .= 0
@test Flux.loadmodel!(m1, oldpar).val.weight oldpar.weight

# At least one Duplicated is required:
@test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val))
@test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
Expand Down

0 comments on commit 55d24eb

Please sign in to comment.