From 41d5f16271adde8b9a59b6f0fc9b8581b3e05650 Mon Sep 17 00:00:00 2001 From: JingYu Ning Date: Fri, 15 Jul 2022 22:15:12 +0800 Subject: [PATCH] Remove redundant code --- example/FlowOverCircle/src/FlowOverCircle.jl | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/example/FlowOverCircle/src/FlowOverCircle.jl b/example/FlowOverCircle/src/FlowOverCircle.jl index 920019af..915026ae 100644 --- a/example/FlowOverCircle/src/FlowOverCircle.jl +++ b/example/FlowOverCircle/src/FlowOverCircle.jl @@ -54,24 +54,6 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000), return loader_train, loader_test end -function FluxTraining.step!(learner, phase::FluxTraining.TrainingPhase, batch) - xs, ys = batch - FluxTraining.runstep(learner, phase, (; xs = xs, ys = ys)) do handle, state - state.grads = FluxTraining._gradient(learner.optimizer, learner.model, learner.params) do model - state.ŷs = model(state.xs) - handle(FluxTraining.LossBegin()) - state.loss = learner.lossfn(state.ŷs, state.ys) - handle(FluxTraining.BackwardBegin()) - - return state.loss - end - - handle(FluxTraining.BackwardEnd()) - learner.params, learner.model = FluxTraining._update!( - learner.optimizer, learner.params, learner.model, state.grads) - end -end - function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50) if cuda && CUDA.has_cuda() device = gpu