Skip to content

Commit

Permalink
fix: remove unnecessary patches for now
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 20, 2024
1 parent 431f3d7 commit d0fb186
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 323 deletions.
29 changes: 0 additions & 29 deletions examples/ConditionalVAE/Project.toml

This file was deleted.

291 changes: 0 additions & 291 deletions examples/ConditionalVAE/main.jl

This file was deleted.

5 changes: 2 additions & 3 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ function accuracy(model, ps, st, dataloader)
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="gpu_if_available")
backend::String="reactant")
rng = StableRNG(seed)

if backend == "gpu_if_available"
Expand Down Expand Up @@ -130,7 +130,6 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
(_, _, _, train_state) = Training.single_train_step!(
adtype, loss, (x, y), train_state
)
@show i, time() - stime
end
ttime = time() - stime

Expand Down
11 changes: 11 additions & 0 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
@show 1213

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, st)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
Expand All @@ -53,11 +59,16 @@ for inplace in ("!", "")
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

@show Lux.Functors.fmap(typeof, ts.states)

return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@show Lux.Functors.fmap(typeof, ts.parameters)
@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

Expand Down

0 comments on commit d0fb186

Please sign in to comment.