Skip to content

Commit

Permalink
narrower non_differentiable params
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 21, 2022
1 parent 8d948e8 commit 9cc8e25
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ ProgressLogging = "0.1"
Reexport = "0.2, 1.0"
SpecialFunctions = "1.8.2, 2.1.2"
StatsBase = "0.33"
Zygote = "0.6.34"
Zygote = "0.6.49"
julia = "1.6"

[extras]
Expand Down
5 changes: 4 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ function params(m...)
end

# Allows caching of the parameters when params is called within gradient() to fix #2040.
@non_differentiable params(m...)
# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054
# That speeds up implicit use, and silently breaks explicit use.
# From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248
Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing

struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
Expand Down
14 changes: 14 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,20 @@ end
@test size.(Flux.params(m)) == [(2,), (1, 2)]
end

@testset "params gradient" begin
m = (x=[1,2.0], y=[3.0]);

# Explicit -- was broken by #2054
gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1]
@test gnew.x [0.4472135954999579, 0.8944271909999159]
@test gnew.y [1.0]

# Implicit
gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m))
@test gold[m.x] [0.4472135954999579, 0.8944271909999159]
@test gold[m.y] [1.0]
end

@testset "Precision" begin
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x64 = rand(Float64, 10)
Expand Down

0 comments on commit 9cc8e25

Please sign in to comment.