Skip to content

Commit

Permalink
fix: update reactant version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent d0fb186 commit 438fede
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 19 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.4.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -110,7 +110,7 @@ NNlib = "0.9.24"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Optimisers = "0.4.1"
Pkg = "1.10"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
StableRNGs = "1"
StaticArrays = "1"
WeightInitializers = "1"
Expand Down
2 changes: 1 addition & 1 deletion examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
2 changes: 1 addition & 1 deletion examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader)
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8,
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Utils.vec(x::AnyTracedRArray) = Reactant.materialize_traced_array(vec(x))
Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
13 changes: 2 additions & 11 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,14 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
@show 1213

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, st)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
Expand All @@ -59,16 +53,13 @@ for inplace in ("!", "")
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

@show Lux.Functors.fmap(typeof, ts.states)

return grads, loss, stats, ts
end

# XXX: Should we add a check to ensure the inputs to this function is same as the one
# used in the compiled function? We can re-trigger the compilation with a warning
@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@show Lux.Functors.fmap(typeof, ts.parameters)
@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

Expand Down
2 changes: 1 addition & 1 deletion lib/LuxCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ EnzymeCore = "0.8.6"
Functors = "0.5"
MLDataDevices = "1.6"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.11"
ReverseDiff = "1.15"
Setfield = "1"
Tracker = "0.2.36"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.6"
Reactant = "0.2.11"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down

0 comments on commit 438fede

Please sign in to comment.