From 281f8c9b50b3ad727ec660124c763399d4e58944 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Wed, 4 Dec 2024 15:06:16 -0400 Subject: [PATCH] Remove usage of global variables in linear and logistic regression tutorial training functions (#2537) --- docs/src/tutorials/linear_regression.md | 46 +++++++++++------------ docs/src/tutorials/logistic_regression.md | 46 +++++++++++------------ src/gradient.jl | 4 +- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/docs/src/tutorials/linear_regression.md b/docs/src/tutorials/linear_regression.md index 1852ad7441..f56a688575 100644 --- a/docs/src/tutorials/linear_regression.md +++ b/docs/src/tutorials/linear_regression.md @@ -6,7 +6,7 @@ Flux is a pure Julia ML stack that allows you to build predictive models. Here a - Build a model with configurable parameters to make predictions - Iteratively train the model by tweaking the parameters to improve predictions - Verify your model - + Under the hood, Flux uses a technique called automatic differentiation to take gradients that help improve predictions. Flux is also fully written in Julia so you can easily replace any layer of Flux with your own code to improve your understanding or satisfy special requirements. The following page contains a step-by-step walkthrough of the linear regression algorithm in `Julia` using `Flux`! We will start by creating a simple linear regression model for dummy data and then move on to a real dataset. The first part would involve writing some parts of the model on our own, which will later be replaced by `Flux`. @@ -104,9 +104,9 @@ julia> custom_model(W, b, x)[1], y[1] It does! But the predictions are way off. We need to train the model to improve the predictions, but before training the model we need to define the loss function. The loss function would ideally output a quantity that we will try to minimize during the entire training process. Here we will use the mean sum squared error loss function. ```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function custom_loss(W, b, x, y) - ŷ = custom_model(W, b, x) - sum((y .- ŷ).^2) / length(x) +julia> function custom_loss(weights, biases, features, labels) + ŷ = custom_model(weights, biases, features) + sum((labels .- ŷ).^2) / length(weights) end; julia> custom_loss(W, b, x, y) @@ -115,7 +115,7 @@ julia> custom_loss(W, b, x, y) Calling the loss function on our `x`s and `y`s shows how far our predictions (`ŷ`) are from the real labels. More precisely, it calculates the sum of the squares of residuals and divides it by the total number of data points. -We have successfully defined our model and the loss function, but surprisingly, we haven't used `Flux` anywhere till now. Let's see how we can write the same code using `Flux`. +We have successfully defined our model and the loss function, but surprisingly, we haven't used `Flux` anywhere till now. Let's see how we can write the same code using `Flux`. ```jldoctest linear_regression_simple julia> flux_model = Dense(1 => 1) @@ -142,9 +142,9 @@ julia> flux_model(x)[1], y[1] It is! The next step would be defining the loss function using `Flux`'s functions - ```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function flux_loss(flux_model, x, y) - ŷ = flux_model(x) - Flux.mse(ŷ, y) +julia> function flux_loss(flux_model, features, labels) + ŷ = flux_model(features) + Flux.mse(ŷ, labels) end; julia> flux_loss(flux_model, x, y) @@ -214,13 +214,13 @@ The loss went down! This means that we successfully trained our model for one ep Let's plug our super training logic inside a function and test it again - ```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function train_custom_model() - dLdW, dLdb, _, _ = gradient(custom_loss, W, b, x, y) - @. W = W - 0.1 * dLdW - @. b = b - 0.1 * dLdb +julia> function train_custom_model!(f_loss, weights, biases, features, labels) + dLdW, dLdb, _, _ = gradient(f_loss, weights, biases, features, labels) + @. weights = weights - 0.1 * dLdW + @. biases = biases - 0.1 * dLdb end; -julia> train_custom_model(); +julia> train_custom_model!(custom_loss, W, b, x, y); julia> W, b, custom_loss(W, b, x, y) (Float32[2.340657], Float32[0.7516814], 13.64972f0) @@ -230,7 +230,7 @@ It works, and the loss went down again! This was the second epoch of our trainin ```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> for i = 1:40 - train_custom_model() + train_custom_model!(custom_loss, W, b, x, y) end julia> W, b, custom_loss(W, b, x, y) @@ -266,7 +266,7 @@ julia> using Flux, Statistics, MLDatasets, DataFrames ``` ## Gathering real data -Let's start by initializing our dataset. We will be using the [`BostonHousing`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/misc/#MLDatasets.BostonHousing) dataset consisting of `506` data points. Each of these data points has `13` features and a corresponding label, the house's price. The `x`s are still mapped to a single `y`, but now, a single `x` data point has 13 features. +Let's start by initializing our dataset. We will be using the [`BostonHousing`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/misc/#MLDatasets.BostonHousing) dataset consisting of `506` data points. Each of these data points has `13` features and a corresponding label, the house's price. The `x`s are still mapped to a single `y`, but now, a single `x` data point has 13 features. ```jldoctest linear_regression_complex julia> dataset = BostonHousing(); @@ -314,9 +314,9 @@ Dense(13 => 1) # 14 parameters Same as before, our next step would be to define a loss function to quantify our accuracy somehow. The lower the loss, the better the model! ```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function loss(model, x, y) - ŷ = model(x) - Flux.mse(ŷ, y) +julia> function loss(model, features, labels) + ŷ = model(features) + Flux.mse(ŷ, labels) end; julia> loss(model, x_train_n, y_train) @@ -330,8 +330,8 @@ We can now proceed to the training phase! The training procedure would make use of the same mathematics, but now we can pass in the model inside the `gradient` call and let `Flux` and `Zygote` handle the derivatives! ```jldoctest linear_regression_complex -julia> function train_model() - dLdm, _, _ = gradient(loss, model, x_train_n, y_train) +julia> function train_model!(f_loss, model, features, labels) + dLdm, _, _ = gradient(f_loss, model, features, labels) @. model.weight = model.weight - 0.000001 * dLdm.weight @. model.bias = model.bias - 0.000001 * dLdm.bias end; @@ -344,7 +344,7 @@ We can write such custom training loops effortlessly using `Flux` and plain `Jul julia> loss_init = Inf; julia> while true - train_model() + train_model!(loss, model, x_train_n, y_train) if loss_init == Inf loss_init = loss(model, x_train_n, y_train) continue @@ -385,9 +385,9 @@ The loss is not as small as the loss of the training data, but it looks good! Th --- -Summarising this tutorial, we started by generating a random yet correlated dataset for our `custom model`. We then saw how a simple linear regression model could be built with and without `Flux`, and how they were almost identical. +Summarising this tutorial, we started by generating a random yet correlated dataset for our `custom model`. We then saw how a simple linear regression model could be built with and without `Flux`, and how they were almost identical. -Next, we trained the model by manually writing down the Gradient Descent algorithm and optimising the loss. We also saw how `Flux` provides various wrapper functionalities and keeps the API extremely intuitive and simple for the users. +Next, we trained the model by manually writing down the Gradient Descent algorithm and optimising the loss. We also saw how `Flux` provides various wrapper functionalities and keeps the API extremely intuitive and simple for the users. After getting familiar with the basics of `Flux` and `Julia`, we moved ahead to build a machine learning model for a real dataset. We repeated the exact same steps, but this time with a lot more features and data points, and by harnessing `Flux`'s full capabilities. In the end, we developed a training loop that was smarter than the hardcoded one and ran the model on our normalised dataset to conclude the tutorial. diff --git a/docs/src/tutorials/logistic_regression.md b/docs/src/tutorials/logistic_regression.md index 01f45a5fcd..51302a5cbf 100644 --- a/docs/src/tutorials/logistic_regression.md +++ b/docs/src/tutorials/logistic_regression.md @@ -1,6 +1,6 @@ # Logistic Regression -The following page contains a step-by-step walkthrough of the logistic regression algorithm in Julia using Flux. We will then create a simple logistic regression model without any usage of Flux and compare the different working parts with Flux's implementation. +The following page contains a step-by-step walkthrough of the logistic regression algorithm in Julia using Flux. We will then create a simple logistic regression model without any usage of Flux and compare the different working parts with Flux's implementation. Let's start by importing the required Julia packages. @@ -9,7 +9,7 @@ julia> using Flux, Statistics, MLDatasets, DataFrames, OneHotArrays ``` ## Dataset -Let's start by importing a dataset from MLDatasets.jl. We will use the `Iris` dataset that contains the data of three different `Iris` species. The data consists of 150 data points (`x`s), each having four features. Each of these `x` is mapped to `y`, the name of a particular `Iris` specie. The following code will download the `Iris` dataset when run for the first time. +Let's start by importing a dataset from MLDatasets.jl. We will use the `Iris` dataset that contains the data of three different `Iris` species. The data consists of 150 data points (`x`s), each having four features. Each of these `x` is mapped to a label (or target) `y`, the name of a particular `Iris` species. The following code will download the `Iris` dataset when run for the first time. ```jldoctest logistic_regression julia> Iris() @@ -141,7 +141,7 @@ julia> flux_model = Chain(Dense(4 => 3), softmax) Chain( Dense(4 => 3), # 15 parameters softmax, -) +) ``` A [`Dense(4 => 3)`](@ref Dense) layer denotes a layer with four inputs (four features in every data point) and three outputs (three classes or labels). This layer is the same as the mathematical model defined by us above. Under the hood, Flux too calculates the output using the same expression, but we don't have to initialize the parameters ourselves this time, instead Flux does it for us. @@ -170,9 +170,9 @@ julia> custom_logitcrossentropy(ŷ, y) = mean(.-sum(y .* logsoftmax(ŷ; dims = 1 Now we can wrap the `custom_logitcrossentropy` inside a function that takes in the model parameters, `x`s, and `y`s, and returns the loss value. ```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function custom_loss(W, b, x, y) - ŷ = custom_model(W, b, x) - custom_logitcrossentropy(ŷ, y) +julia> function custom_loss(weights, biases, features, labels_onehot) + ŷ = custom_model(weights, biases, features) + custom_logitcrossentropy(ŷ, labels_onehot) end; julia> custom_loss(W, b, x, custom_y_onehot) @@ -184,9 +184,9 @@ The loss function works! Flux provides us with many minimal yet elegant loss functions. In fact, the `custom_logitcrossentropy` defined above has been taken directly from Flux. The functions present in Flux includes sanity checks, ensures efficient performance, and behaves well with the overall FluxML ecosystem. ```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" -julia> function flux_loss(flux_model, x, y) - ŷ = flux_model(x) - Flux.logitcrossentropy(ŷ, y) +julia> function flux_loss(flux_model, features, labels_onehot) + ŷ = flux_model(features) + Flux.logitcrossentropy(ŷ, labels_onehot) end; julia> flux_loss(flux_model, x, flux_y_onehot) @@ -214,9 +214,9 @@ julia> max_idx = [x[1] for x in argmax(custom_y_onehot; dims=1)] Now we can write a function that calculates the indices of the maximum element in each column, and maps them to a class name. ```jldoctest logistic_regression -julia> function custom_onecold(custom_y_onehot) - max_idx = [x[1] for x in argmax(custom_y_onehot; dims=1)] - vec(classes[max_idx]) +julia> function custom_onecold(labels_onehot) + max_idx = [x[1] for x in argmax(labels_onehot; dims=1)] + return vec(classes[max_idx]) end; julia> custom_onecold(custom_y_onehot) @@ -313,10 +313,10 @@ julia> custom_loss(W, b, x, custom_y_onehot) The loss went down! Let's plug our super training logic inside a function. ```jldoctest logistic_regression -julia> function train_custom_model() - dLdW, dLdb, _, _ = gradient(custom_loss, W, b, x, custom_y_onehot) - W .= W .- 0.1 .* dLdW - b .= b .- 0.1 .* dLdb +julia> function train_custom_model!(f_loss, weights, biases, features, labels_onehot) + dLdW, dLdb, _, _ = gradient(f_loss, weights, biases, features, labels_onehot) + weights .= weights .- 0.1 .* dLdW + biases .= biases .- 0.1 .* dLdb end; ``` @@ -324,10 +324,10 @@ We can plug the training function inside a loop and train the model for more epo ```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" julia> for i = 1:500 - train_custom_model(); + train_custom_model!(custom_loss, W, b, x, custom_y_onehot); custom_accuracy(W, b, x, y) >= 0.98 && break end - + julia> @show custom_accuracy(W, b, x, y); custom_accuracy(W, b, x, y) = 0.98 ``` @@ -347,14 +347,14 @@ We can write a similar-looking training loop for our `flux_model` and train it s julia> flux_loss(flux_model, x, flux_y_onehot) 1.215731131385928 -julia> function train_flux_model() - dLdm, _, _ = gradient(flux_loss, flux_model, x, flux_y_onehot) - @. flux_model[1].weight = flux_model[1].weight - 0.1 * dLdm[:layers][1][:weight] - @. flux_model[1].bias = flux_model[1].bias - 0.1 * dLdm[:layers][1][:bias] +julia> function train_flux_model!(f_loss, model, features, labels_onehot) + dLdm, _, _ = gradient(f_loss, model, features, labels_onehot) + @. model[1].weight = model[1].weight - 0.1 * dLdm[:layers][1][:weight] + @. model[1].bias = model[1].bias - 0.1 * dLdm[:layers][1][:bias] end; julia> for i = 1:500 - train_flux_model(); + train_flux_model!(flux_loss, flux_model, x, flux_y_onehot); flux_accuracy(x, y) >= 0.98 && break end ``` diff --git a/src/gradient.jl b/src/gradient.jl index 57b08416a0..40d58cf933 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -37,7 +37,7 @@ function gradient(f, args...; zero::Bool=true) end if Zygote.isderiving() error("""`Flux.gradient` does not support use within a Zygote gradient. - If what you are doing worked on Flux < 0.14, then calling `Zygote.gradiet` directly should still work. + If what you are doing worked on Flux < 0.14, then calling `Zygote.gradient` directly should still work. If you are writing new code, then Zygote over Zygote is heavily discouraged. """) end @@ -175,7 +175,7 @@ function withgradient(f, args...; zero::Bool=true) end if Zygote.isderiving() error("""`Flux.withgradient` does not support use within a Zygote gradient. - If what you are doing worked on Flux < 0.14, then calling `Zygote.gradiet` directly should still work. + If what you are doing worked on Flux < 0.14, then calling `Zygote.withgradient` directly should still work. If you are writing new code, then Zygote over Zygote is heavily discouraged. """) end