-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56e65fb
commit 3946a68
Showing
6 changed files
with
92 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
[deps] | ||
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" | ||
DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,73 @@ | ||
# MLJFlux.jl | ||
|
||
Documentation for MLJFlux.jl | ||
A Julia package integrating deep learning Flux models with MLJ. | ||
|
||
### Objectives | ||
|
||
- Provide a user-friendly and high-level interface to fundamental [Flux](https://fluxml.ai/Flux.jl/stable/) deep learning models while still being extensible by supporting custom models written with Flux | ||
|
||
- Make building deep learning models more convenient to users already familiar with the MLJ workflow | ||
|
||
- Make it easier to apply machine learning techniques provided by MLJ, including: out-of-sample performance evaluation, hyper-parameter optimization, iteration control, and more, to deep learning models | ||
|
||
!!! note "MLJFlux Coverage" | ||
MLJFlux support is focused on fundamental and widely used deep learning models; sophisticated architectures or techniques such as online learning, reinforcement learning, and adversarial networks are currently beyond its scope. | ||
|
||
Also note that MLJFlux is limited to training models only when all training data fits into memory, though it still supports automatic batching of data. | ||
|
||
### Installation | ||
|
||
```julia | ||
import Pkg | ||
Pkg.activate("my_environment", shared=true) | ||
Pkg.add(["MLJ", "MLJFlux", "Flux"]) | ||
``` | ||
You only need `Flux` if you need to build a custom architecture or experiment with different optimizers, loss functions and activations. | ||
|
||
### Quick Start | ||
First load and instantiate mode: | ||
```@example | ||
using MLJ, Flux, MLJFlux | ||
import RDatasets | ||
# 1. Load Data | ||
iris = RDatasets.dataset("datasets", "iris"); | ||
y, X = unpack(iris, ==(:Species), colname -> true, rng=123); | ||
# 2. Load and instantiate model | ||
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg="MLJFlux" | ||
clf = NeuralNetworkClassifier( | ||
builder=MLJFlux.MLP(; hidden=(5,4), σ=Flux.relu), | ||
optimiser=Flux.ADAM(0.01), | ||
batch_size=8, | ||
epochs=100, | ||
acceleration=CUDALibs() | ||
) | ||
# 3. Wrap it in a machine in fit | ||
mach = machine(clf, X, y) | ||
fit!(mach) | ||
# 4. Evaluate the model | ||
cv=CV(nfolds=5) | ||
evaluate!(mach, resampling=cv, measure=accuracy) | ||
``` | ||
As you can see we were able to use MLJ functionality (i.e., cross validation) with a Flux deep learning model. All arguments provided also have defaults. | ||
|
||
Notice that we were also able to define the neural network in a high-level fashion by only specifying the number of neurons per each hidden layer and the activation function. Meanwhile, `MLJFlux` was able to infer the input and output layer as well as use a suitable default for the loss function and output activation given the classification task. | ||
|
||
### Flux or MLJFlux? | ||
[Flux](https://fluxml.ai/Flux.jl/stable/) is a deep learning framework in Julia that comes with everything you need to build deep learning models (i.e., GPU support, automatic differentiation, layers, activations, losses, optimizers, etc.). [MLJFlux](https://github.com/FluxML/MLJFlux.jl) wraps models built with Flux which provides a more high-level interface for building and training such models. More importantly, it empowers Flux models by extending their support to many common machine learning workflows that are possible via MLJ such as: | ||
|
||
- **Estimating performance** of your model using a holdout set or other resampling strategy (e.g., cross-validation) as measured by one or more metrics (e.g., loss functions) that may not have been used in training | ||
|
||
- **Optimizing hyper-parameters** such as a regularization parameter (e.g., dropout) or a width/height/nchannnels of convolution layer | ||
|
||
- **Compose with other models** such as introducing data pre-processing steps (e.g., missing data imputation) into a pipeline. It might make sense to include non-deep learning models in this pipeline. Other kinds of model composition could include blending predictions of a deep learner with some other kind of model (as in “model stacking”). Models composed with MLJ can be also tuned as a single unit. | ||
|
||
- **Controlling iteration** by adding an early stopping criterion based on an out-of-sample estimate of the loss, dynamically changing the learning rate (eg, cyclic learning rates), periodically save snapshots of the model, generate live plots of sample weights to judge training progress (as in tensor board) | ||
|
||
|
||
- **Comparing** your model with a non-deep learning models | ||
|
||
Thus, for model that could be implemented in both `Flux` and `MLJFlux`, one could choose working with `MLJFlux` instead of `Flux` if they are interested in any of the functionality above, while not willing to implement it from scratch and/or when they would prefer working with a more high-level interface equivalent to that of MLJ for their task. |