Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omnibus PR, including switch to explicit style differentiation #251

Merged
merged 18 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -22,6 +23,7 @@ ComputationalResources = "0.3.2"
Flux = "0.14"
MLJModelInterface = "1.1.1"
Metalhead = "0.9.3"
Optimisers = "0.3.2"
ProgressMeter = "1.7.1"
StatisticalMeasures = "0.1"
Statistics = "<0.0.1, 1"
Expand All @@ -30,14 +32,14 @@ julia = "1.9"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["CUDA", "cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "StatisticalMeasures", "StatsBase", "Test"]
9 changes: 0 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,6 @@ to builders for the purposes of weight initialization. This can be
any `AbstractRNG` or the seed (integer) for a `MersenneTwister` that
will be reset on every cold restart of model (machine) training.

Until there is a [mechanism for
doing so](https://github.com/FluxML/Flux.jl/issues/1617) `rng` is *not*
passed to dropout layers and one must manually seed the `GLOBAL_RNG`
for reproducibility purposes, when using a builder that includes
`Dropout` (such as `MLJFlux.Short`). If training models on a
GPU (i.e., `acceleration isa CUDALibs`) one must additionally call
`CUDA.seed!(...)`.


### Built-in builders

The following builders are provided out-of-the-box. Query their
Expand Down
4 changes: 2 additions & 2 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module MLJFlux
module MLJFlux

export CUDALibs, CPU1

Expand All @@ -14,11 +14,11 @@ using ColorTypes
using ComputationalResources
using Random
import Metalhead
import Optimisers

include("utilities.jl")
const MMI=MLJModelInterface

include("penalizers.jl")
include("builders.jl")
include("metalhead.jl")
include("types.jl")
Expand Down
39 changes: 18 additions & 21 deletions src/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@
abstract type Builder <: MLJModelInterface.MLJType end

"""
Linear(; σ=Flux.relu, rng=Random.GLOBAL_RNG)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth deprecating this?

Linear(; σ=Flux.relu)

MLJFlux builder that constructs a fully connected two layer network
with activation function `σ`. The number of input and output nodes is
determined from the data. The bias and coefficients are initialized
using `Flux.glorot_uniform(rng)`. If `rng` is an integer, it is
instead used as the seed for a `MersenneTwister`.
MLJFlux builder that constructs a fully connected two layer network with activation
function `σ`. The number of input and output nodes is determined from the data. Weights
are initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from the `rng`
field of the MLJFlux model.

"""
mutable struct Linear <: Builder
Expand All @@ -31,7 +30,7 @@ build(builder::Linear, rng, n::Integer, m::Integer) =
Flux.Chain(Flux.Dense(n, m, builder.σ, init=Flux.glorot_uniform(rng)))

"""
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid, rng=GLOBAL_RNG)
Short(; n_hidden=0, dropout=0.5, σ=Flux.sigmoid)

MLJFlux builder that constructs a full-connected three-layer network
using `n_hidden` nodes in the hidden layer and the specified `dropout`
Expand All @@ -40,9 +39,8 @@ hidden and final layers. If `n_hidden=0` (the default) then `n_hidden`
is the geometric mean of the number of input and output nodes. The
number of input and output nodes is determined from the data.

The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from
the `rng` field of the MLJFlux model.

"""
mutable struct Short <: Builder
Expand All @@ -57,22 +55,19 @@ function build(builder::Short, rng, n, m)
init=Flux.glorot_uniform(rng)
Flux.Chain(
Flux.Dense(n, n_hidden, builder.σ, init=init),
# TODO: fix next after https://github.com/FluxML/Flux.jl/issues/1617
Flux.Dropout(builder.dropout),
Flux.Dropout(builder.dropout; rng),
Flux.Dense(n_hidden, m, init=init))
end

"""
MLP(; hidden=(100,), σ=Flux.relu, rng=GLOBAL_RNG)
MLP(; hidden=(100,), σ=Flux.relu)

MLJFlux builder that constructs a Multi-layer perceptron network. The
ith element of `hidden` represents the number of neurons in the ith
hidden layer. An activation function `σ` is applied between each
layer.
MLJFlux builder that constructs a Multi-layer perceptron network. The ith element of
`hidden` represents the number of neurons in the ith hidden layer. An activation function
`σ` is applied between each layer.

The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
Each layer is initialized using `Flux.glorot_uniform(rng)`, where `rng` is inferred from
the `rng` field of the MLJFlux model.

"""
mutable struct MLP{N} <: MLJFlux.Builder
Expand Down Expand Up @@ -110,6 +105,7 @@ Creates a builder for `neural_net`. The variables `rng`, `n_in`, `n_out` and
input and output sizes `n_in` and `n_out` and number of input channels `n_channels`.

# Examples

```jldoctest
julia> import MLJFlux: @builder;

Expand All @@ -132,4 +128,5 @@ macro builder(ex)
end)
end

build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) = b.apply(rng, n_in, n_out, n_channels)
build(b::GenericBuilder, rng, n_in, n_out, n_channels = 1) =
b.apply(rng, n_in, n_out, n_channels)
110 changes: 75 additions & 35 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# make the optimiser structs "transparent" so that their field values
# are exposed by calls to MLJ.params:
MLJModelInterface.istransparent(m::Flux.Optimise.AbstractOptimiser) = true
MLJModelInterface.istransparent(m::Optimisers.AbstractRule) = true


## GENERAL METHOD TO OPTIMIZE A CHAIN
Expand All @@ -15,47 +15,71 @@ end
(::Mover{<:CUDALibs})(data) = Flux.gpu(data)

"""
train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)

A private method that can be overloaded for custom models.
train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
) -> updated_chain, updated_optimiser_state, training_loss

Update the parameters of a Flux `chain`, where:

- `model` is typically an `MLJFluxModel` instance, but could be any object such that
`model.loss` is a Flux.jl loss function.

- the loss function `(yhat, y) -> loss(yhat, y)` is inferred from the
`model`

- `params -> penalty(params)` is a regularization penalty function

- `X` and `y` are vectors of batches of the training data, as detailed
in the [`MLJFlux.fit!`](@ref) document string.
in the [`MLJFlux.train`](@ref) document string.

"""
function train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
function train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
)

loss = model.loss
n_batches = length(y)
training_loss = zero(Float32)

for i in 1:n_batches
parameters = Flux.params(chain)
gs = Flux.gradient(parameters) do
yhat = chain(X[i])
batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches
training_loss += batch_loss
return batch_loss
batch_loss, gs = Flux.withgradient(chain) do m
yhat = m(X[i])
loss(yhat, y[i])
end
Flux.update!(optimiser, parameters, gs)
training_loss += batch_loss
# The `do` syntax above means `gs` is a tuple of length one we need to unwrap to
# get the actual gradient:
∇ = first(gs)
optimiser_state, chain = Optimisers.update(optimiser_state, chain, ∇)
end
return training_loss / n_batches

return chain, optimiser_state, training_loss / n_batches
end


"""
fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this, too, might be worth deprecating. If I understand this correctly, existing extensions that overload MLJFlux.fit! like here won't work anymore? As in, they should now be overloading the train function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to double-check @ablaom, fit! will become train and train! will become train_epoch?


A private method that can be overloaded for custom models.

Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is
the loss function inferred from the `model`, and `parameters -> penalty(parameters)` is the
regularization penalty function.
train(
model,
chain,
optimiser,
optimiser_state,
epochs,
verbosity,
X,
y,
) -> (updated_chain, updated_optimiser_state, history)

Optimize a Flux model `chain`, where `(yhat, y) -> loss(yhat, y)` is the loss function
inferred from the `model`. Typically, `model` will be an `MLJFluxModel` instance, but it
could be any object such that `model.loss` is a Flux.jl loss function.

Here `chain` is a `Flux.Chain` object, or other Flux model such that
`Flux.params(chain)` returns the parameters to be optimized.
Expand All @@ -76,17 +100,26 @@ batches. Specifically, it is expected that:
total number of training batches.

Both the `chain` and the data `(X, y)` must both live on a CPU or both
live on a GPU. This `fit!` method takes no responsibility for data
live on a GPU. This `train` method takes no responsibility for data
movement.

### Return value
# Return value

`(chain_trained, history)`, where `chain_trained` is a trained version
of `chain` and `history` is a vector of penalized losses - one initial
loss, and one loss per epoch.
Returns `(updated_chain, updated_optimiser_state, history)`, where `updated_chain` is a
trained version of `chain` and `history` is a vector of losses, including the
initial (no-train) loss.

"""
function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, verbosity, X, y)
function train(
model,
chain,
optimiser,
optimiser_state,
epochs,
verbosity,
X,
y,
)

loss = model.loss

Expand All @@ -98,20 +131,25 @@ function fit!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, epochs, ve
# initiate history:
n_batches = length(y)

parameters = Flux.params(chain)
losses = (loss(chain(X[i]), y[i]) +
penalty(parameters) / n_batches for i in 1:n_batches)
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
history = [mean(losses),]

for i in 1:epochs
current_loss = train!(model::MLJFlux.MLJFluxModel, penalty, chain, optimiser, X, y)
chain, optimiser_state, current_loss = train_epoch(
model,
chain,
optimiser,
optimiser_state,
X,
y,
)
verbosity < 2 ||
@info "Loss is $(round(current_loss; sigdigits=4))"
verbosity != 1 || next!(meter)
push!(history, current_loss)
end

return chain, history
return chain, optimiser_state, history

end

Expand Down Expand Up @@ -221,7 +259,9 @@ _get(X::AbstractArray{<:Any,4}, b) = X[:, :, :, b]
"""
collate(model, X, y)

Return the Flux-friendly data object required by `MLJFlux.fit!`, given
**Private method**

Return the Flux-friendly data object required by `MLJFlux.train`, given
input `X` and target `y` in the form required by
`MLJModelInterface.input_scitype(X)` and
`MLJModelInterface.target_scitype(y)`. (The batch size used is given
Expand Down
2 changes: 1 addition & 1 deletion src/metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function VGGHack(
depth in keys(Metalhead.VGG_CONFIGS),
"depth must be from one in $(sort(collect(keys(Metalhead.VGG_CONFIGS))))"
)
model = Metalhead.VGG(imsize;
model = Metalhead.vgg(imsize;
config = Metalhead.VGG_CONFIGS[depth],
inchannels,
batchnorm,
Expand Down
Loading
Loading