Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 20, 2024
1 parent 0360155 commit 9c2f285
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 21 deletions.
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ using Functors: fmapstructure_with_path

## Uncomment below to change the default test settings
# ENV["FLUX_TEST_AMDGPU"] = "true"
ENV["FLUX_TEST_CUDA"] = "true"
# ENV["FLUX_TEST_CUDA"] = "true"
# ENV["FLUX_TEST_METAL"] = "true"
ENV["FLUX_TEST_CPU"] = "false"
# ENV["FLUX_TEST_CPU"] = "false"
# ENV["FLUX_TEST_DISTRIBUTED_MPI"] = "true"
# ENV["FLUX_TEST_DISTRIBUTED_NCCL"] = "true"
ENV["FLUX_TEST_ENZYME"] = "false" # We temporarily disable Enzyme tests since they are failing
Expand Down
21 changes: 9 additions & 12 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,15 @@ for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)

@test begin
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end

diff2 = model.weight .- init_weight
@test diff1 diff2

true
end broken = VERSION >= v"1.11"
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end

diff2 = model.weight .- init_weight
@test diff1 diff2

end

# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
Expand Down
12 changes: 5 additions & 7 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,11 @@ end
@testset "params gradient" begin
m = (x=[1,2.0], y=[3.0]);

@test begin
# Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11
gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1]
@test gnew.x [0.4472135954999579, 0.8944271909999159]
@test gnew.y [1.0]
true
end broken = VERSION >= v"1.11"
# Explicit -- was broken by #2054 / then fixed / now broken again on julia v1.11
gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1]
@test gnew.x [0.4472135954999579, 0.8944271909999159]
@test gnew.y [1.0]


# Implicit
gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m))
Expand Down

0 comments on commit 9c2f285

Please sign in to comment.