Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 8, 2024
1 parent fac710f commit 3cf0579
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
10 changes: 7 additions & 3 deletions ext/OptimisersEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
module OptimisersEnzymeCoreExt

import Optimisers: trainable, setup, update!, isnumeric, AbstractRule
import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup
import EnzymeCore: Duplicated, Const

using Functors: fmapstructure

println("loaded!")

trainable(x::Duplicated) = (; val = x.val)
trainable(x::Const) = (;)

"""
setup(rule::AbstractRule, model_grad::Duplicated)
Expand All @@ -16,6 +15,11 @@ For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
"""
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)

_setup(rule, x::Duplicated; cache) = throw(ArgumentError(
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
they may not appear deep inside other objects."""
))

"""
update!(opt_state, model_grad::Duplicated)
Expand Down
12 changes: 11 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote
using ChainRulesCore, Functors, StaticArrays, Zygote, EnzymeCore
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
Expand Down Expand Up @@ -534,6 +534,16 @@ end
@test Optimisers._norm(bc2, p) isa Float64
end
end

@testset "Enzyme Duplicated" begin
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
@test st isa Optimisers.Leaf
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val Float16[0.8887, 2.0, 3.445]

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx))
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down

0 comments on commit 3cf0579

Please sign in to comment.