Skip to content

Commit

Permalink
Merge pull request #189 from FluxML/dev
Browse files Browse the repository at this point in the history
For a 0.2.5 release
  • Loading branch information
ablaom authored Sep 28, 2021
2 parents edd7047 + deb23ad commit cb267a5
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 31 deletions.
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ steps:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
if: build.pull_request.base_branch == "master" || build.pull_request.base_branch == null

env:
JULIA_PKG_SERVER: "" # it often struggles with our large artifacts
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/ci_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ name: CI (Julia nightly)
on:
pull_request:
branches:
- master
- dev
push:
branches:
- master
- dev
tags: '*'
jobs:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJFlux"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
authors = ["Anthony D. Blaom <[email protected]>", "Ayush Shridhar <[email protected]>"]
version = "0.2.4"
version = "0.2.5"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
47 changes: 20 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class of Flux models that can used, at least in the medium term. For
example, online learning, re-enforcement learning, and adversarial
networks are currently out of scope.

Currently MLJFlux is also limited to training models in the case that all
training data fits into memory.


### Basic idea

Each MLJFlux model has a *builder* hyperparameter, an object encoding
Expand Down Expand Up @@ -168,7 +172,7 @@ Flux "models" used in MLJFLux are `Flux.Chain` objects, we call them
MLJFlux provides four model types, for use with input features `X` and
targets `y` of the [scientific
type](https://alan-turing-institute.github.io/MLJScientificTypes.jl/dev/)
indicated in the table below. The parameters `n_in` and `n_out`
indicated in the table below. The parameters `n_in`, `n_out` and `n_channels`
refer to information passed to the builder, as described under
[Defining a new builder](defining-a-new-builder) below.
Expand Down Expand Up @@ -257,32 +261,14 @@ GPU (i.e., `acceleration isa CUDALibs`) one must additionally call
### Built-in builders
MLJ provides two simple builders out of the box. In all cases weights
are intitialized using `glorot_uniform(rng)` where `rng` is the RNG
(or `MersenneTwister` seed) specified by the MLJFlux model.
- `MLJFlux.Linear(σ=...)` builds a fully connected two layer network
with `n_in` inputs and `n_out` outputs, with activation function
`σ`, defaulting to a `MLJFlux.relu`.
The following builders are provided out-of-the-box. Query their
doc-strings for advanced options and further details.
- `MLJFlux.Short(n_hidden=..., dropout=..., σ=...)` builds a
full-connected three-layer network with `n_in` inputs and `n_out`
outputs using `n_hidden` nodes in the hidden layer and the specified
`dropout` (defaulting to 0.5). An activation function `σ` is applied
between the hidden and final layers. If `n_hidden=0` (the default)
then `n_hidden` is the geometric mean of the number of input and
output nodes.
See Table 1 above to see how `n_in` and `n_out` relate to the data.
Alternatively, use `MLJFlux.@builder(neural_net)` to automatically create a builder for
any valid Flux chain expression `neural_net`, where the symbols `n_in`, `n_out`,
`n_channels` and `rng` can appear literally, with the interpretations explained above. For
example,
```
builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh))
```
|builder | description |
|:-------------------------|:-----------------------------------------------------|
| `MLJFlux.Linear(σ=relu)` | vanilla linear network with activation function `σ` |
| `MLJFlux.Short(n_hidden=0, dropout=0.5, σ=sigmoid)` | fully connected network with one hidden layer and dropout|
| `MLJFlux.MLP(hidden=(10,))` | general multi-layer perceptron |
### Model hyperparameters.
Expand Down Expand Up @@ -387,7 +373,14 @@ following conditions:
- The object returned by `chain(x)` must be an `AbstractFloat` vector
of length `n_out`.
See also `MLJFlux.@builder` for an automated way to create generic builders.
Alternatively, use `MLJFlux.@builder(neural_net)` to automatically create a builder for
any valid Flux chain expression `neural_net`, where the symbols `n_in`, `n_out`,
`n_channels` and `rng` can appear literally, with the interpretations explained above. For
example,
```
builder = MLJFlux.@builder Chain(Dense(n_in, 128), Dense(128, n_out, tanh))
```
### Loss functions
Expand Down
34 changes: 34 additions & 0 deletions src/builders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,40 @@ function build(builder::Short, rng, n, m)
Flux.Dense(n_hidden, m, init=init))
end

"""
MLP(; hidden=(100,), σ=Flux.relu, rng=GLOBAL_RNG)
MLJFlux builder that constructs a Multi-layer perceptron network. The
ith element of `hidden` represents the number of neurons in the ith
hidden layer. An activation function `σ` is applied between each
layer.
The each layer is initialized using `Flux.glorot_uniform(rng)`. If
`rng` is an integer, it is instead used as the seed for a
`MersenneTwister`.
"""
mutable struct MLP{N} <: MLJFlux.Builder
hidden::NTuple{N, Int} # count first and last layer
σ
end
MLP(; hidden=(100,), σ=Flux.relu) = MLP(hidden, σ)
function MLJFlux.build(mlp::MLP, rng, n_in, n_out)
init=Flux.glorot_uniform(rng)

hidden = [Flux.Dense(n_in, mlp.hidden[1], mlp.σ, init=init)]
for i 2:length(mlp.hidden)
push!(hidden, Flux.Dense(mlp.hidden[i-1],
mlp.hidden[i],
mlp.σ,
init=init))
end
push!(hidden, Flux.Dense(mlp.hidden[end], n_out, init=init))

return Flux.Chain(hidden... )
end


struct GenericBuilder{F} <: Builder
apply::F
end
Expand Down

0 comments on commit cb267a5

Please sign in to comment.