Skip to content

Commit

Permalink
test for shared params + minimal docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 8, 2024
1 parent 3cf0579 commit cc5cc4e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
9 changes: 9 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g);
julia> opt_state # the state in `a` and `b` differ
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
```

## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)

Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl.
It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`.

Optimisers.jl now has some methods to handle this:
* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and
* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`.
14 changes: 13 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,19 @@ end
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val Float16[0.8887, 2.0, 3.445]

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx))
shared = [1.0]
model = (x=shared, y=shared)
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
dup = Duplicated(model, model)
st2 = Optimisers.setup(Descent(0.1), model)
Optimisers.update!(st2, dup)
@test model.x [0.9]
shared .= 1
Optimisers.update!(st2, model, grad)
model.x [0.8] # This is wrong, but don't make it a test.
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
end
end
@testset verbose=true "Destructure" begin
Expand Down

0 comments on commit cc5cc4e

Please sign in to comment.