Skip to content

Commit

Permalink
update docstrings for train and train_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 29, 2024
1 parent 4af84e5 commit 943965d
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,19 @@ end

"""
train_epoch(
model::MLJFlux.MLJFluxModel,
model,
chain,
optimiser,
optimiser_state,
X,
y,
) -> updated_chain, updated_optimiser_state, training_loss
A private method that can be overloaded for custom models.
Update the parameters of a Flux `chain`, where:
- `model` is typically an `MLJFluxModel` instance, but could be any object such that
`model.loss` is a Flux.jl loss function.
- the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the
`model`
Expand All @@ -36,7 +37,7 @@ Update the parameters of a Flux `chain`, where:
"""
function train_epoch(
model::MLJFlux.MLJFluxModel,
model,
chain,
optimiser,
optimiser_state,
Expand Down Expand Up @@ -66,7 +67,7 @@ end

"""
train(
model::MLJFlux.MLJFluxModel,
model,
chain,
optimiser,
optimiser_state,
Expand All @@ -76,10 +77,9 @@ end
y,
) -> (updated_chain, updated_optimiser_state, history)
A private method that can be overloaded for custom models.
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function
inferred from the `model`.
inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it
could be any object such that `model.loss` is a Flux.jl loss function.
Here `chain` is a `Flux.Chain` object, or other Flux model such that
`Flux.params(chain)` returns the parameters to be optimized.
Expand Down Expand Up @@ -111,7 +111,7 @@ initial (no-train) loss.
"""
function train(
model::MLJFlux.MLJFluxModel,
model,
chain,
optimiser,
optimiser_state,
Expand Down

0 comments on commit 943965d

Please sign in to comment.