diff --git a/docs/src/index.md b/docs/src/index.md index 3cb32f8..a595d70 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g); julia> opt_state # the state in `a` and `b` differ (a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001)))) ``` + +## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) + +Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl. +It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`. + +Optimisers.jl now has some methods to handle this: +* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and +* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`. diff --git a/test/runtests.jl b/test/runtests.jl index 509d9cb..e17553c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -542,7 +542,19 @@ end @test nothing === Optimisers.update!(st, x_dx) # mutates both arguments @test x_dx.val ≈ Float16[0.8887, 2.0, 3.445] - @test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) + shared = [1.0] + model = (x=shared, y=shared) + grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated. + dup = Duplicated(model, model) + st2 = Optimisers.setup(Descent(0.1), model) + Optimisers.update!(st2, dup) + @test model.x ≈ [0.9] + shared .= 1 + Optimisers.update!(st2, model, grad) + model.x ≈ [0.8] # This is wrong, but don't make it a test. + # Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case? + + @test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed end end @testset verbose=true "Destructure" begin