Skip to content

Commit

Permalink
blitz update and model zoo page
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 10, 2024
1 parent fcf6236 commit 15727cb
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 66 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
6 changes: 3 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ makedocs(
# Or perhaps those should just be trashed, model zoo versions are newer & more useful.
"Linear Regression" => "tutorials/linear_regression.md",
"Logistic Regression" => "tutorials/logistic_regression.md",
"Multi-layer Perceptron" => "tutorials/mlp.md",
"Model Zoo" => "tutorials/model_zoo.md",
#=
"Julia & Flux: 60 Minute Blitz" => "tutorials/2020-09-15-deep-learning-flux.md",
# "Multi-layer Perceptron" => "tutorials/mlp.md",
# "Julia & Flux: 60 Minute Blitz" => "tutorials/blitz.md",
"Simple ConvNet" => "tutorials/2021-02-07-convnet.md",
"Generative Adversarial Net" => "tutorials/2021-10-14-vanilla-gan.md",
"Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md",
=#
# Not really sure where this belongs... some in Fluxperimental, aim to delete?
],
],
format = Documenter.HTML(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ We can see Julia tile the column vector `1:5` across all rows of the larger arra
zeros(5,5) .+ (1:5)
```

The x' syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns.
The `x'` syntax is used to transpose a column `1:5` into an equivalent row, and Julia will tile that across columns.

```julia
zeros(5,5) .+ (1:5)'
Expand Down Expand Up @@ -181,16 +181,19 @@ x = rand(Float32, 10)
We can easily get the parameters of any layer or model with `trainables`.

```julia
trainables(m)
Flux.trainables(m)
```

It very easy to calculate the gradient for all parameters in a network, even if it has many parameters.
The function `gradient` is not limited to array but can compute the gradient with respect to generic composite types.

```julia
using Flux
using Flux: logitcrossentropy, trainables, getkeypath

x = rand(Float32, 10)
model = Chain(Dense(10 => 5, relu), Dense(5 => 2))
loss(model, x) = Flux.logitcrossentropy(model(x), [0.5, 0.5])
loss(model, x) = logitcrossentropy(model(x), [0.5, 0.5])
grad = gradient(m -> loss(m, x), model)[1]
for (k, p) in trainables(model, path=true)
println("$k => $(getkeypath(grad, k))")
Expand All @@ -203,8 +206,8 @@ The next step is to update our weights and perform optimisation. As you might be

```julia
η = 0.1
for (k, p) in trainables(m)
p .+= -η * getkeypath(grads, p)
for (k, p) in trainables(model, path=true)
p .+= -η * getkeypath(grad, p)
end
```

Expand All @@ -220,22 +223,24 @@ Training a network reduces down to iterating on a dataset mulitple times, perfor

```julia
data, labels = rand(10, 100), fill(0.5, 2, 100)
loss(m, x, y) = Flux.logitcrossentropy(m(x), y)
Flux.train!(loss, model, [(data, labels)], opt)
loss(m, x, y) = logitcrossentropy(m(x), y)
Flux.train!(loss, model, [(data, labels)], opt_state)
```

You don't have to use `train!`. In cases where arbitrary logic might be better suited, you could open up this training loop like so:

```julia
for d in training_set # assuming d looks like (data, labels)
for d in training_set # assuming d looks like (data, labels)
# our super logic
g = gradient(model) do model
l = loss(model, d...)
end
l = loss(model, d...)
end[1]
Flux.update!(opt_state, model, g)
end
end
```

The `do` block is a closure, which is a way of defining a function inline. It's a very powerful feature of Julia, and you can learn more about it [here](https://docs.julialang.org/en/v1/manual/functions/#Do-Block-Syntax-for-Function-Arguments).

## Training a Classifier

Getting a real classifier to work might help cement the workflow a bit more. [CIFAR10](https://https://www.cs.toronto.edu/~kriz/cifar.html) is a dataset of 50k tiny training images split into 10 classes.
Expand All @@ -254,38 +259,41 @@ We will do the following steps in order:
using Statistics
using Flux
using MLDatasets: CIFAR10
using Images.ImageCore
using Flux: onehotbatch, onecold
using Base.Iterators: partition
using CUDA
using ImageCore: colorview, RGB
using Flux: onehotbatch, onecold, DataLoader
using Plots: plot
using MLUtils: splitobs, numobs

# using CUDA # Uncomment if you have CUDA installed. Can also use AMDGPU or Metal instead
# using AMDGPU
# using Metal
```

This image will give us an idea of what we are dealing with.

![title](https://pytorch.org/tutorials/_images/cifar10.png)

```julia
train_x, train_y = CIFAR10.traindata(Float32)
train_x, train_y = CIFAR10(:train)[:]
labels = onehotbatch(train_y, 0:9)
```

The `train_x` contains 50000 images converted to 32 X 32 X 3 arrays with the third dimension being the 3 channels (R,G,B). Let's take a look at a random image from the train_x. For this, we need to permute the dimensions to 3 X 32 X 32 and use `colorview` to convert it back to an image.

```julia
using Plots
image(x) = colorview(RGB, permutedims(x, (3, 2, 1)))
plot(image(train_x[:,:,:,rand(1:end)]))
```

We can now arrange the training data in batches of say, 1000 and keep a validation set to track our progress. This process is called minibatch learning, which is a popular method of training large neural networks. Rather that sending the entire dataset at once, we break it down into smaller chunks (called minibatches) that are typically chosen at random, and train only on them. It is shown to help with escaping [saddle points](https://en.wikipedia.org/wiki/Saddle_point).
We can now arrange the training data in batches of say, 256 and keep a validation set to track our progress. This process is called minibatch learning, which is a popular method of training large neural networks. Rather that sending the entire dataset at once, we break it down into smaller chunks (called minibatches) that are typically chosen at random, and train only on them. It is shown to help with escaping [saddle points](https://en.wikipedia.org/wiki/Saddle_point).

The first 49k images (in batches of 1000) will be our training set, and the rest is for validation. `partition` handily breaks down the set we give it in consecutive parts (1000 in this case).
The first 45k images (in batches of 256) will be our training set, and the rest is for validation.
The `DataLoader` function will help us load the data in batches.

```julia
train = ([(train_x[:,:,:,i], labels[:,i]) for i in partition(1:49000, 1000)]) |> gpu
valset = 49001:50000
valX = train_x[:,:,:,valset] |> gpu
valY = labels[:, valset] |> gpu
trainset, valset = splitobs((train_x, labels), at = 45000)
trainloader = DataLoader(trainset, batchsize = 1000, shuffle = true)
valloader = DataLoader(trainset, batchsize = 1000)
```

### Defining the Classifier
Expand All @@ -295,30 +303,40 @@ Now we can define our Convolutional Neural Network (CNN).
A convolutional neural network is one which defines a kernel and slides it across a matrix to create an intermediate representation to extract features from. It creates higher order features as it goes into deeper layers, making it suitable for images, where the strucure of the subject is what will help us determine which class it belongs to.

```julia
m = Chain(
Conv((5,5), 3=>16, relu),
MaxPool((2,2)),
Conv((5,5), 16=>8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(200 => 120),
Dense(120 => 84),
Dense(84 => 10)) |> gpu
model = Chain(
Conv((5,5), 3 => 16, relu),
MaxPool((2, 2)),
Conv((5, 5), 16 => 8, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(200 => 120),
Dense(120 => 84),
Dense(84 => 10)) |> gpu
```

We will use a crossentropy loss and an Momentum optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination.
We will use a crossentropy loss and an `Momentum` optimiser here. Crossentropy will be a good option when it comes to working with mulitple independent classes. Momentum gradually lowers the learning rate as we proceed with the training. It helps maintain a bit of adaptivity in our optimisation, preventing us from over shooting from our desired destination.

```julia
using Flux: logitcrossentropy, Momentum

loss(m, x, y) = logitcrossentropy(m(x), y)
opt = Momentum(0.01)
opt_state = Flux.setup(Momentum(0.01), model)
```

We can start writing our train loop where we will keep track of some basic accuracy numbers about our model. We can define an `accuracy` function for it like so.
We can start writing our train loop where we will keep track of some basic accuracy numbers about our model. We can define an `accuracy` function for it like so:

```julia
accuracy(x, y) = mean(onecold(m(x), 0:9) .== onecold(y, 0:9))
function accuracy(model, loader)
n = 0
acc = 0
for batch in loader
x, y = batch |> gpu
= model(x)
acc += sum(onecold(ŷ) .== onecold(y))
n += numobs(x)
end
return acc / n
end
```

### Training the Classifier
Expand All @@ -329,14 +347,15 @@ Training is where we do a bunch of the interesting operations we defined earlier
```julia
epochs = 10

for epoch = 1:epochs
for d in train
gs = gradient(params(m)) do
l = loss(d...)
for epoch in 1:epochs
for batch in trainloader
x, y = batch |> gpu
g = gradient(model) do m
loss(m, x, y)
end[1]
Flux.update!(opt_state, model, g)
end
update!(opt, params(m), gs)
end
@show accuracy(valX, valY)
@show accuracy(model, valloader)
end
```

Expand All @@ -355,10 +374,9 @@ We will check this by predicting the class label that the neural network outputs
Okay, first step. Let us perform the exact same preprocessing on this set, as we did on our training set.

```julia
test_x, test_y = CIFAR10.testdata(Float32)
test_x, test_y = CIFAR10(:test)[:]
test_labels = onehotbatch(test_y, 0:9)

test = gpu.([(test_x[:,:,:,i], test_labels[:,i]) for i in partition(1:10000, 1000)])
testloader = DataLoader((test_x, test_labels), batchsize = 1000, shuffle = true)
```

Next, display an image from the test set.
Expand All @@ -367,44 +385,42 @@ Next, display an image from the test set.
plot(image(test_x[:,:,:,rand(1:end)]))
```

The outputs are energies for the 10 classes. Higher the energy for a class, the more the network thinks that the image is of the particular class. Every column corresponds to the output of one image, with the 10 floats in the column being the energies.
The outputs of the networks are (log)likelihoods for the 10 classes. Higher the energy for a class, the more the network thinks that the image is of the particular class. Every column corresponds to the output of one image, with the 10 floats in the column being the energies.

Let's see how the model fared.

```julia
ids = rand(1:10000, 5)
rand_test = test_x[:,:,:,ids] |> gpu
rand_truth = test_y[ids]
m(rand_test)
model(rand_test)
```

This looks similar to how we would expect the results to be. At this point, it's a good idea to see how our net actually performs on new data, that we have prepared.

```julia
accuracy(test[1]...)
accuracy(model, testloader)
```

This is much better than random chance set at 10% (since we only have 10 classes), and not bad at all for a small hand written network like ours.

Let's take a look at how the net performed on all the classes performed individually.

```julia
class_correct = zeros(10)
class_total = zeros(10)
for i in 1:10
preds = m(test[i][1])
lab = test[i][2]
for j = 1:1000
pred_class = findmax(preds[:, j])[2]
actual_class = findmax(lab[:, j])[2]
if pred_class == actual_class
class_correct[pred_class] += 1
confusion_matrix = zeros(Int, 10, 10)
m = model |> cpu
for batch in testloader
@show numobs(batch)
x, y = batch
preds = m(x)
= onecold(preds)
y = onecold(y)
for (yi, ŷi) in zip(y, ŷ)
confusion_matrix[yi, ŷi] += 1
end
class_total[actual_class] += 1
end
end

class_correct ./ class_total
confusion_matrix
```

The spread seems pretty good, with certain classes performing significantly better than the others. Why should that be?
Expand Down
File renamed without changes.
7 changes: 7 additions & 0 deletions docs/old_tutorials/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
These tutorials are hard to mantain
and overlapping with model-zoo examples.

Some of the tutorials are outdated.

Mantainance would be simplified by moving them
to Literate.jl and CI testing them.
10 changes: 10 additions & 0 deletions docs/src/tutorials/model_zoo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Model Zoo

The [model zoo](https://github.com/FluxML/model-zoo) is a collection of examples that demonstrate how to build and train models using Flux. The examples are organised by domain and include vision, text, and audio. Each example includes a description of the model, the data used, and the training process.

Some of the examples are pedagogical, see for instance
- [Multilayer Perceptron](https://github.com/FluxML/model-zoo/tree/master/vision/mlp_mnist)
- [Simple Convolutional Neural Network](https://github.com/FluxML/model-zoo/tree/master/vision/conv_mnist)

Others are more advanced, see for instance
- [Variational Autoencoder](https://github.com/FluxML/model-zoo/blob/master/vision/vae_mnist)
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ include("functor.jl")
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Functors.jl
functor, @functor,
functor, @functor, KeyPath, haskeypath, getkeypath,
# from Optimise/Train/Optimisers.jl
setup, update!, destructure, freeze!, adjust!, params, trainable, trainables
))
Expand Down

0 comments on commit 15727cb

Please sign in to comment.