Skip to content

Commit

Permalink
fix ambiguity
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 14, 2024
1 parent 25eeee0 commit f91a4dd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
6 changes: 2 additions & 4 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,8 @@ export Chain, Dense, Embedding, EmbeddingBag,
))

include("optimise/Optimise.jl")
using .Optimise: Optimise, update!
export Optimise, ClipValue, Optimiser

export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser
using .Optimise: Optimise
export ClipValue # this is const defined in deprecations, for ClipGrad

include("train.jl")
using .Train
Expand Down
5 changes: 5 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ function update!(opt::Optimisers.AbstractRule, model, grad)
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end
function update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple)
error("""Invalid input to `update!`.
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1]
# Can't catch every case, but can catch many simple Flux models:
Expand Down
2 changes: 1 addition & 1 deletion test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ end

@testset "Explicit Flux.update! features" begin
m = Chain(Dense(2=>3, tanh), Dense(3=>1), only)
x = rand(2)
x = rand(Float32, 2)
y1 = m(x) # before

# Explicit gradient
Expand Down

0 comments on commit f91a4dd

Please sign in to comment.