-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add basic docs skeleton and README integration
- Loading branch information
1 parent
3946a68
commit ee3e0b8
Showing
20 changed files
with
842 additions
and
50 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
### Adding new models to MLJFlux (advanced) | ||
|
||
This section is mainly for MLJFlux developers. It assumes familiarity | ||
with the [MLJ model | ||
API](https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/) | ||
|
||
If one subtypes a new model type as either | ||
`MLJFlux.MLJFluxProbabilistic` or `MLJFlux.MLJFluxDeterministic`, then | ||
instead of defining new methods for `MLJModelInterface.fit` and | ||
`MLJModelInterface.update` one can make use of fallbacks by | ||
implementing the lower level methods `shape`, `build`, and | ||
`fitresult`. See the [classifier source code](/src/classifier.jl) for | ||
an example. | ||
|
||
One still needs to implement a new `predict` method. |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
## Image Classification Example | ||
An expanded version of this example, with early stopping and | ||
snapshots, is available [here](/examples/mnist). | ||
|
||
We define a builder that builds a chain with six alternating | ||
convolution and max-pool layers, and a final dense layer, which we | ||
apply to the MNIST image dataset. | ||
|
||
First we define a generic builder (working for any image size, color | ||
or gray): | ||
|
||
```julia | ||
using MLJ | ||
using Flux | ||
using MLDatasets | ||
|
||
# helper function | ||
function flatten(x::AbstractArray) | ||
return reshape(x, :, size(x)[end]) | ||
end | ||
|
||
import MLJFlux | ||
mutable struct MyConvBuilder | ||
filter_size::Int | ||
channels1::Int | ||
channels2::Int | ||
channels3::Int | ||
end | ||
|
||
function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) | ||
|
||
k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 | ||
|
||
mod(k, 2) == 1 || error("`filter_size` must be odd. ") | ||
|
||
# padding to preserve image size on convolution: | ||
p = div(k - 1, 2) | ||
|
||
front = Chain( | ||
Conv((k, k), n_channels => c1, pad=(p, p), relu), | ||
MaxPool((2, 2)), | ||
Conv((k, k), c1 => c2, pad=(p, p), relu), | ||
MaxPool((2, 2)), | ||
Conv((k, k), c2 => c3, pad=(p, p), relu), | ||
MaxPool((2 ,2)), | ||
flatten) | ||
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first | ||
return Chain(front, Dense(d, n_out)) | ||
end | ||
``` | ||
Next, we load some of the MNIST data and check scientific types | ||
conform to those is the table above: | ||
|
||
```julia | ||
N = 500 | ||
Xraw, yraw = MNIST.traindata(); | ||
Xraw = Xraw[:,:,1:N]; | ||
yraw = yraw[1:N]; | ||
|
||
scitype(Xraw) | ||
``` | ||
```julia | ||
scitype(yraw) | ||
``` | ||
|
||
Inputs should have element scitype `GrayImage`: | ||
|
||
```julia | ||
X = coerce(Xraw, GrayImage); | ||
``` | ||
|
||
For classifiers, target must have element scitype `<: Finite`: | ||
|
||
```julia | ||
y = coerce(yraw, Multiclass); | ||
``` | ||
|
||
Instantiating an image classifier model: | ||
|
||
```julia | ||
ImageClassifier = @load ImageClassifier | ||
clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), | ||
epochs=10, | ||
loss=Flux.crossentropy) | ||
``` | ||
|
||
And evaluating the accuracy of the model on a 30% holdout set: | ||
|
||
```julia | ||
mach = machine(clf, X, y) | ||
|
||
evaluate!(mach, | ||
resampling=Holdout(rng=123, fraction_train=0.7), | ||
operation=predict_mode, | ||
measure=misclassification_rate) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
|
||
```@docs | ||
MLJFlux.Linear | ||
``` | ||
|
||
```@docs | ||
MLJFlux.Short | ||
``` | ||
|
||
```@docs | ||
MLJFlux.MLP | ||
``` | ||
|
||
```@docs | ||
MLJFlux.@builder | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
```@docs | ||
MLJFlux.NeuralNetworkClassifier | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
### Defining Custom Builders | ||
|
||
Following is an example defining a new builder for creating a simple | ||
fully-connected neural network with two hidden layers, with `n1` nodes | ||
in the first hidden layer, and `n2` nodes in the second, for use in | ||
any of the first three models in Table 1. The definition includes one | ||
mutable struct and one method: | ||
|
||
```julia | ||
mutable struct MyBuilder <: MLJFlux.Builder | ||
n1 :: Int | ||
n2 :: Int | ||
end | ||
|
||
function MLJFlux.build(nn::MyBuilder, rng, n_in, n_out) | ||
init = Flux.glorot_uniform(rng) | ||
return Chain(Dense(n_in, nn.n1, init=init), | ||
Dense(nn.n1, nn.n2, init=init), | ||
Dense(nn.n2, n_out, init=init)) | ||
end | ||
``` | ||
|
||
Note here that `n_in` and `n_out` depend on the size of the data (see | ||
Table 1). | ||
|
||
For a concrete image classification example, see | ||
[examples/mnist](examples/mnist). | ||
|
||
More generally, defining a new builder means defining a new struct | ||
sub-typing `MLJFlux.Builder` and defining a new `MLJFlux.build` method | ||
with one of these signatures: | ||
|
||
```julia | ||
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out) | ||
MLJFlux.build(builder::MyBuilder, rng, n_in, n_out, n_channels) # for use with `ImageClassifier` | ||
``` | ||
|
||
This method must return a `Flux.Chain` instance, `chain`, subject to the | ||
following conditions: | ||
|
||
- `chain(x)` must make sense: | ||
|
||
- for any `x <: Array{<:AbstractFloat, 2}` of size `(n_in, | ||
batch_size)` where `batch_size` is any integer (for use with one | ||
of the first three model types); or | ||
|
||
- for any `x <: Array{<:Float32, 4}` of size `(W, H, n_channels, | ||
batch_size)`, where `(W, H) = n_in`, `n_channels` is 1 or 3, and | ||
`batch_size` is any integer (for use with `ImageClassifier`) | ||
|
||
- The object returned by `chain(x)` must be an `AbstractFloat` vector | ||
of length `n_out`. | ||
|
||
Alternatively, use `MLJFlux.@builder(neural_net)` to automatically create a builder for | ||
any valid Flux chain expression `neural_net`, where the symbols `n_in`, `n_out`, | ||
`n_channels` and `rng` can appear literally, with the interpretations explained above. For | ||
example, | ||
|
||
``` | ||
builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh)) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
```@docs | ||
MLJFlux.ImageClassifier | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
```@docs | ||
MLJFlux.MultitargetNeuralNetworkRegressor | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
```@docs | ||
MLJFlux.NeuralNetworkRegressor | ||
``` |
Oops, something went wrong.