Skip to content

Commit

Permalink
unbreak l2 test
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 17, 2024
1 parent aeb421b commit 4c109ac
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,14 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)

@test begin
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end

diff2 = model.weight .- init_weight
@test diff1 diff2

true
end broken = VERSION >= v"1.11"
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end

diff2 = model.weight .- init_weight
@test diff1 diff2
end

# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
Expand Down

0 comments on commit 4c109ac

Please sign in to comment.