Skip to content

Commit

Permalink
Merge pull request #251 from FluxML/refactor-regularization
Browse files Browse the repository at this point in the history
Omnibus PR, including switch to explicit style differentiation
  • Loading branch information
ablaom authored Jun 10, 2024
2 parents 01ad08e + 1bd58dd commit f38e0cf
Show file tree
Hide file tree
Showing 21 changed files with 554 additions and 433 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -22,6 +23,7 @@ ComputationalResources = "0.3.2"
Flux = "0.14"
MLJModelInterface = "1.1.1"
Metalhead = "0.9.3"
Optimisers = "0.3.2"
ProgressMeter = "1.7.1"
StatisticalMeasures = "0.1"
Statistics = "<0.0.1, 1"
Expand All @@ -30,14 +32,14 @@ julia = "1.9"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,6 @@ to builders for the purposes of weight initialization. This can be
any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that
will be reset on every cold restart of model (machine) training.
Until there is a [mechanism for
doing so](https://github.com/FluxML/Flux.jl/issues/1617) `rng` is *not*
passed to dropout layers and one must manually seed the `GLOBAL_RNG`
for reproducibility purposes, when using a builder that includes
`Dropout` (such as `MLJFlux.Short`). If training models on a
GPU (i.e., `acceleration isa CUDALibs`) one must additionally call
`CUDA.seed!(...)`.
### Built-in builders
The following builders are provided out-of-the-box. Query their
Expand Down
5 changes: 3 additions & 2 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJFlux
module MLJFlux

export CUDALibs, CPU1

Expand All @@ -14,11 +14,11 @@ using ColorTypes
using ComputationalResources
using Random
import Metalhead
import Optimisers

include("utilities.jl")
const MMI=MLJModelInterface

include("penalizers.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
Expand All @@ -32,6 +32,7 @@ export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, ImageClassifier
export CUDALibs, CPU1

include("deprecated.jl")


end #module
39 changes: 18 additions & 21 deletions src/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
abstract type Builder <: MLJModelInterface.MLJType end

"""
Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG)
Linear(; σ=Flux.relu)
MLJFlux builder that constructs a fully connected two layer network
with activation function `σ`. The number of input and output nodes is
determined from the data. The bias and coefficients are initialized
using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is
instead used as the seed for a `MersenneTwister`.
MLJFlux builder that constructs a fully connected two layer network with activation
function `σ`. The number of input and output nodes is determined from the data. Weights
are initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from the `rng`
field of the MLJFlux model.
"""
mutable struct Linear <: Builder
Expand All @@ -31,7 +30,7 @@ build(builder::Linear, rng, n::Integer, m::Integer) =
Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng)))

"""
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG)
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid)
MLJFlux builder that constructs a full-connected three-layer network
using `n_hidden` nodes in the hidden layer and the specified `dropout`
Expand All @@ -40,9 +39,8 @@ hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
is the geometric mean of the number of input and output nodes. The
number of input and output nodes is determined from the data.
The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from
the `rng` field of the MLJFlux model.
"""
mutable struct Short <: Builder
Expand All @@ -57,22 +55,19 @@ function build(builder::Short, rng, n, m)
init=Flux.glorot_uniform(rng)
Flux.Chain(
Flux.Dense(n, n_hidden, builder.σ, init=init),
# TODO: fix next after https://github.com/FluxML/Flux.jl/issues/1617
Flux.Dropout(builder.dropout),
Flux.Dropout(builder.dropout; rng),
Flux.Dense(n_hidden, m, init=init))
end

"""
MLP(; hidden=(100,), σ=Flux.relu, rng=GLOBAL_RNG)
MLP(; hidden=(100,), σ=Flux.relu)
MLJFlux builder that constructs a Multi-layer perceptron network. The
ith element of `hidden` represents the number of neurons in the ith
hidden layer. An activation function `σ` is applied between each
layer.
MLJFlux builder that constructs a Multi-layer perceptron network. The ith element of
`hidden` represents the number of neurons in the ith hidden layer. An activation function
`σ` is applied between each layer.
The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from
the `rng` field of the MLJFlux model.
"""
mutable struct MLP{N} <: MLJFlux.Builder
Expand Down Expand Up @@ -110,6 +105,7 @@ Creates a builder for `neural_net`. The variables `rng`, `n_in`, `n_out` and
input and output sizes `n_in` and `n_out` and number of input channels `n_channels`.
# Examples
```jldoctest
julia> import MLJFlux: @builder;
Expand All @@ -132,4 +128,5 @@ macro builder(ex)
end)
end

build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) = b.apply(rng, n_in, n_out, n_channels)
build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) =
b.apply(rng, n_in, n_out, n_channels)
110 changes: 75 additions & 35 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# make the optimiser structs "transparent" so that their field values
# are exposed by calls to MLJ.params:
MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true
MLJModelInterface.istransparent(m::Optimisers.AbstractRule) = true


## GENERAL METHOD TO OPTIMIZE A CHAIN
Expand All @@ -15,47 +15,71 @@ end
(::Mover{<:CUDALibs})(data) = Flux.gpu(data)

"""
train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
A private method that can be overloaded for custom models.
train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
) -> updated_chain, updated_optimiser_state, training_loss
Update the parameters of a Flux `chain`, where:
- `model` is typically an `MLJFluxModel` instance, but could be any object such that
`model.loss` is a Flux.jl loss function.
- the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the
`model`
- `params -> penalty(params)` is a regularization penalty function
- `X` and `y` are vectors of batches of the training data, as detailed
in the [`MLJFlux.fit!`](@ref) document string.
in the [`MLJFlux.train`](@ref) document string.
"""
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
function train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
)

loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)

for i in 1:n_batches
parameters = Flux.params(chain)
gs = Flux.gradient(parameters) do
yhat = chain(X[i])
batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches
training_loss += batch_loss
return batch_loss
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
loss(yhat, y[i])
end
Flux.update!(optimiser, parameters, gs)
training_loss += batch_loss
# The `do` syntax above means `gs` is a tuple of length one we need to unwrap to
# get the actual gradient:
= first(gs)
optimiser_state, chain = Optimisers.update(optimiser_state, chain, ∇)
end
return training_loss / n_batches

return chain, optimiser_state, training_loss / n_batches
end


"""
fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
A private method that can be overloaded for custom models.
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is
the loss function inferred from the `model`, and `parameters -> penalty(parameters)` is the
regularization penalty function.
train(
model,
chain,
optimiser,
optimiser_state,
epochs,
verbosity,
X,
y,
) -> (updated_chain, updated_optimiser_state, history)
Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function
inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it
could be any object such that `model.loss` is a Flux.jl loss function.
Here `chain` is a `Flux.Chain` object, or other Flux model such that
`Flux.params(chain)` returns the parameters to be optimized.
Expand All @@ -76,17 +100,26 @@ batches. Specifically, it is expected that:
total number of training batches.
Both the `chain` and the data `(X, y)` must both live on a CPU or both
live on a GPU. This `fit!` method takes no responsibility for data
live on a GPU. This `train` method takes no responsibility for data
movement.
### Return value
# Return value
`(chain_trained, history)`, where `chain_trained` is a trained version
of `chain` and `history` is a vector of penalized losses - one initial
loss, and one loss per epoch.
Returns `(updated_chain, updated_optimiser_state, history)`, where `updated_chain` is a
trained version of `chain` and `history` is a vector of losses, including the
initial (no-train) loss.
"""
function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
function train(
model,
chain,
optimiser,
optimiser_state,
epochs,
verbosity,
X,
y,
)

loss = model.loss

Expand All @@ -98,20 +131,25 @@ function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, ve
# initiate history:
n_batches = length(y)

parameters = Flux.params(chain)
losses = (loss(chain(X[i]), y[i]) +
penalty(parameters) / n_batches for i in 1:n_batches)
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
history = [mean(losses),]

for i in 1:epochs
current_loss = train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
chain, optimiser_state, current_loss = train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
)
verbosity < 2 ||
@info "Loss is $(round(current_loss; sigdigits=4))"
verbosity != 1 || next!(meter)
push!(history, current_loss)
end

return chain, history
return chain, optimiser_state, history

end

Expand Down Expand Up @@ -221,7 +259,9 @@ _get(X::AbstractArray{<:Any,4}, b) = X[:, :, :, b]
"""
collate(model, X, y)
Return the Flux-friendly data object required by `MLJFlux.fit!`, given
**Private method**
Return the Flux-friendly data object required by `MLJFlux.train`, given
input `X` and target `y` in the form required by
`MLJModelInterface.input_scitype(X)` and
`MLJModelInterface.target_scitype(y)`. (The batch size used is given
Expand Down
19 changes: 19 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Base.@deprecate fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y) train(
model::MLJFlux.MLJFluxModel,
chain,
optimiser,
optimiser_state,
epochs,
verbosity,
X,
y,
) false

Base.@deprecate train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y) train_epoch(
model::MLJFlux.MLJFluxModel,
chain,
optimiser,
optimiser_state,
X,
y,
) false
2 changes: 1 addition & 1 deletion src/metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function VGGHack(
depth in keys(Metalhead.VGG_CONFIGS),
"depth must be from one in $(sort(collect(keys(Metalhead.VGG_CONFIGS))))"
)
model = Metalhead.VGG(imsize;
model = Metalhead.vgg(imsize;
config = Metalhead.VGG_CONFIGS[depth],
inchannels,
batchnorm,
Expand Down
Loading

0 comments on commit f38e0cf

Please sign in to comment.