-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Omnibus PR, including switch to explicit style differentiation #251
Changes from 17 commits
8ed2d15
8924a61
52c0078
1ed8b49
14615f9
bf8f461
bd5aa8b
c806dc5
99100b8
f60be3f
3c39a26
bf66132
14b1993
ffb20bd
b3b41ac
4af84e5
943965d
1bd58dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
|
||
# make the optimiser structs "transparent" so that their field values | ||
# are exposed by calls to MLJ.params: | ||
MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true | ||
MLJModelInterface.istransparent(m::Optimisers.AbstractRule) = true | ||
|
||
|
||
## GENERAL METHOD TO OPTIMIZE A CHAIN | ||
|
@@ -15,47 +15,71 @@ end | |
(::Mover{<:CUDALibs})(data) = Flux.gpu(data) | ||
|
||
""" | ||
train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) | ||
|
||
A private method that can be overloaded for custom models. | ||
train_epoch( | ||
model, | ||
chain, | ||
optimiser, | ||
optimiser_state, | ||
X, | ||
y, | ||
) -> updated_chain, updated_optimiser_state, training_loss | ||
|
||
Update the parameters of a Flux `chain`, where: | ||
|
||
- `model` is typically an `MLJFluxModel` instance, but could be any object such that | ||
`model.loss` is a Flux.jl loss function. | ||
|
||
- the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the | ||
`model` | ||
|
||
- `params -> penalty(params)` is a regularization penalty function | ||
|
||
- `X` and `y` are vectors of batches of the training data, as detailed | ||
in the [`MLJFlux.fit!`](@ref) document string. | ||
in the [`MLJFlux.train`](@ref) document string. | ||
|
||
""" | ||
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) | ||
function train_epoch( | ||
model, | ||
chain, | ||
optimiser, | ||
optimiser_state, | ||
X, | ||
y, | ||
) | ||
|
||
loss = model.loss | ||
n_batches = length(y) | ||
training_loss = zero(Float32) | ||
|
||
for i in 1:n_batches | ||
parameters = Flux.params(chain) | ||
gs = Flux.gradient(parameters) do | ||
yhat = chain(X[i]) | ||
batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches | ||
training_loss += batch_loss | ||
return batch_loss | ||
batch_loss, gs = Flux.withgradient(chain) do m | ||
yhat = m(X[i]) | ||
loss(yhat, y[i]) | ||
end | ||
Flux.update!(optimiser, parameters, gs) | ||
training_loss += batch_loss | ||
# The `do` syntax above means `gs` is a tuple of length one we need to unwrap to | ||
# get the actual gradient: | ||
∇ = first(gs) | ||
optimiser_state, chain = Optimisers.update(optimiser_state, chain, ∇) | ||
end | ||
return training_loss / n_batches | ||
|
||
return chain, optimiser_state, training_loss / n_batches | ||
end | ||
|
||
|
||
""" | ||
fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this, too, might be worth deprecating. If I understand this correctly, existing extensions that overload There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So just to double-check @ablaom, |
||
|
||
A private method that can be overloaded for custom models. | ||
|
||
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is | ||
the loss function inferred from the `model`, and `parameters -> penalty(parameters)` is the | ||
regularization penalty function. | ||
train( | ||
model, | ||
chain, | ||
optimiser, | ||
optimiser_state, | ||
epochs, | ||
verbosity, | ||
X, | ||
y, | ||
) -> (updated_chain, updated_optimiser_state, history) | ||
|
||
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function | ||
inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it | ||
could be any object such that `model.loss` is a Flux.jl loss function. | ||
|
||
Here `chain` is a `Flux.Chain` object, or other Flux model such that | ||
`Flux.params(chain)` returns the parameters to be optimized. | ||
|
@@ -76,17 +100,26 @@ batches. Specifically, it is expected that: | |
total number of training batches. | ||
|
||
Both the `chain` and the data `(X, y)` must both live on a CPU or both | ||
live on a GPU. This `fit!` method takes no responsibility for data | ||
live on a GPU. This `train` method takes no responsibility for data | ||
movement. | ||
|
||
### Return value | ||
# Return value | ||
|
||
`(chain_trained, history)`, where `chain_trained` is a trained version | ||
of `chain` and `history` is a vector of penalized losses - one initial | ||
loss, and one loss per epoch. | ||
Returns `(updated_chain, updated_optimiser_state, history)`, where `updated_chain` is a | ||
trained version of `chain` and `history` is a vector of losses, including the | ||
initial (no-train) loss. | ||
|
||
""" | ||
function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) | ||
function train( | ||
model, | ||
chain, | ||
optimiser, | ||
optimiser_state, | ||
epochs, | ||
verbosity, | ||
X, | ||
y, | ||
) | ||
|
||
loss = model.loss | ||
|
||
|
@@ -98,20 +131,25 @@ function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, ve | |
# initiate history: | ||
n_batches = length(y) | ||
|
||
parameters = Flux.params(chain) | ||
losses = (loss(chain(X[i]), y[i]) + | ||
penalty(parameters) / n_batches for i in 1:n_batches) | ||
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches) | ||
history = [mean(losses),] | ||
|
||
for i in 1:epochs | ||
current_loss = train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) | ||
chain, optimiser_state, current_loss = train_epoch( | ||
model, | ||
chain, | ||
optimiser, | ||
optimiser_state, | ||
X, | ||
y, | ||
) | ||
verbosity < 2 || | ||
@info "Loss is $(round(current_loss; sigdigits=4))" | ||
verbosity != 1 || next!(meter) | ||
push!(history, current_loss) | ||
end | ||
|
||
return chain, history | ||
return chain, optimiser_state, history | ||
|
||
end | ||
|
||
|
@@ -221,7 +259,9 @@ _get(X::AbstractArray{<:Any,4}, b) = X[:, :, :, b] | |
""" | ||
collate(model, X, y) | ||
|
||
Return the Flux-friendly data object required by `MLJFlux.fit!`, given | ||
**Private method** | ||
|
||
Return the Flux-friendly data object required by `MLJFlux.train`, given | ||
input `X` and target `y` in the form required by | ||
`MLJModelInterface.input_scitype(X)` and | ||
`MLJModelInterface.target_scitype(y)`. (The batch size used is given | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth deprecating this?