Skip to content

Commit

Permalink
indent by two spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 8, 2024
1 parent 5ea9e76 commit ab89ed4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions ext/OptimisersEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)

_setup(rule, x::Duplicated; cache) = throw(ArgumentError(
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
they may not appear deep inside other objects."""
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
they may not appear deep inside other objects."""
))

"""
Expand Down
38 changes: 19 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,25 +536,25 @@ end
end

@testset "Enzyme Duplicated" begin
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
@test st isa Optimisers.Leaf
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val Float16[0.8887, 2.0, 3.445]

shared = [1.0]
model = (x=shared, y=shared)
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
dup = Duplicated(model, model)
st2 = Optimisers.setup(Descent(0.1), model)
Optimisers.update!(st2, dup)
@test model.x [0.9]
shared .= 1
Optimisers.update!(st2, model, grad)
model.x [0.8] # This is wrong, but don't make it a test.
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
@test st isa Optimisers.Leaf
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val Float16[0.8887, 2.0, 3.445]

shared = [1.0]
model = (x=shared, y=shared)
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
dup = Duplicated(model, model)
st2 = Optimisers.setup(Descent(0.1), model)
Optimisers.update!(st2, dup)
@test model.x [0.9]
shared .= 1
Optimisers.update!(st2, model, grad)
model.x [0.8] # This is wrong, but don't make it a test.
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?

@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
end
end
@testset verbose=true "Destructure" begin
Expand Down

0 comments on commit ab89ed4

Please sign in to comment.