From 943965d523d58785d980b1001a36fe4ba79fb902 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Thu, 30 May 2024 10:41:14 +1200 Subject: [PATCH] update docstrings for `train` and `train_epoch` --- src/core.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/core.jl b/src/core.jl index 66202e04..bd49933b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -16,7 +16,7 @@ end """ train_epoch( - model::MLJFlux.MLJFluxModel, + model, chain, optimiser, optimiser_state, @@ -24,10 +24,11 @@ end y, ) -> updated_chain, updated_optimiser_state, training_loss -A private method that can be overloaded for custom models. - Update the parameters of a Flux `chain`, where: +- `model` is typically an `MLJFluxModel` instance, but could be any object such that + `model.loss` is a Flux.jl loss function. + - the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the `model` @@ -36,7 +37,7 @@ Update the parameters of a Flux `chain`, where: """ function train_epoch( - model::MLJFlux.MLJFluxModel, + model, chain, optimiser, optimiser_state, @@ -66,7 +67,7 @@ end """ train( - model::MLJFlux.MLJFluxModel, + model, chain, optimiser, optimiser_state, @@ -76,10 +77,9 @@ end y, ) -> (updated_chain, updated_optimiser_state, history) -A private method that can be overloaded for custom models. - Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function -inferred from the `model`. +inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it +could be any object such that `model.loss` is a Flux.jl loss function. Here `chain` is a `Flux.Chain` object, or other Flux model such that `Flux.params(chain)` returns the parameters to be optimized. @@ -111,7 +111,7 @@ initial (no-train) loss. """ function train( - model::MLJFlux.MLJFluxModel, + model, chain, optimiser, optimiser_state,