From 4fea1f902d8d5f4075e270415e4ec06b760ff272 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Dec 2024 00:06:03 +0100 Subject: [PATCH] misc stuff for v0.15 release (#2534) Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- .buildkite/pipeline.yml | 35 +------ .github/workflows/ci.yml | 11 +-- .github/workflows/pr_comment.yml | 4 +- NEWS.md | 5 +- Project.toml | 2 +- docs/make.jl | 26 ++--- docs/src/guide/gpu.md | 29 ++++-- docs/src/guide/models/basics.md | 5 - docs/src/guide/models/custom_layers.md | 4 +- docs/src/guide/models/quickstart.md | 2 +- docs/src/guide/training/training.md | 2 +- docs/src/index.md | 4 +- docs/src/reference/data/mldatadevices.md | 7 +- docs/src/reference/data/mlutils.md | 6 ++ docs/src/reference/data/onehot.md | 6 +- docs/src/reference/destructure.md | 20 ++-- docs/src/reference/models/activation.md | 11 ++- docs/src/reference/models/functors.md | 7 +- docs/src/reference/models/layers.md | 16 +-- docs/src/reference/models/losses.md | 6 +- docs/src/reference/models/nnlib.md | 4 + docs/src/reference/outputsize.md | 4 + docs/src/reference/training/callbacks.md | 3 + docs/src/reference/training/enzyme.md | 8 +- docs/src/reference/training/reference.md | 4 + docs/src/reference/utilities.md | 5 + ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl | 16 --- src/Flux.jl | 5 +- src/deprecations.jl | 25 ++--- src/functor.jl | 34 ++----- src/gradient.jl | 4 +- src/layers/conv.jl | 120 +++++++++++------------ src/layers/normalise.jl | 4 +- src/layers/recurrent.jl | 15 +-- src/loading.jl | 2 +- test/Project.toml | 11 ++- test/deprecations.jl | 4 + test/ext_metal/runtests.jl | 1 - test/runtests.jl | 7 +- 39 files changed, 238 insertions(+), 246 deletions(-) create mode 100644 test/deprecations.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 70bb7951a9..9fd87b0e00 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ steps: - - label: "CUDA GPU with julia v1" + - label: "CUDA - Julia 1" plugins: - JuliaCI/julia#v1: version: "1" @@ -17,17 +17,7 @@ steps: FLUX_TEST_ENZYME: "false" timeout_in_minutes: 60 - # - label: "GPU nightly" - # plugins: - # - JuliaCI/julia#v1: - # version: "nightly" - # - JuliaCI/julia-test#v1: ~ - # agents: - # queue: "juliagpu" - # cuda: "*" - # timeout_in_minutes: 60 - - - label: "Metal with julia v1" + - label: "Metal - Julia 1" plugins: - JuliaCI/julia#v1: version: "1" @@ -41,32 +31,18 @@ steps: queue: "juliaecosystem" os: "macos" arch: "aarch64" - commands: | - julia --project -e ' - # make sure the 1.8-era Manifest works on this Julia version - using Pkg - Pkg.resolve()' - commands: | - printf "[Flux]\ngpu_backend = \"Metal\"\n" > LocalPreferences.toml - if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 env: FLUX_TEST_METAL: "true" FLUX_TEST_CPU: "false" FLUX_TEST_ENZYME: "false" - matrix: - setup: - julia: - # - "1.9" - - "1" - # - "nightly" - - label: "AMD GPU with Julia 1" + - label: "AMDGPU - Julia 1" plugins: - JuliaCI/julia#v1: version: "1" - - JuliaCI/julia-test#v1: + - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: dirs: - src @@ -75,8 +51,6 @@ steps: queue: "juliagpu" rocm: "*" rocmgpu: "*" - commands: | - printf "[Flux]\ngpu_backend = \"AMDGPU\"\n" > LocalPreferences.toml timeout_in_minutes: 60 env: JULIA_AMDGPU_CORE_MUST_LOAD: "1" @@ -86,5 +60,6 @@ steps: FLUX_TEST_CPU: "false" FLUX_TEST_ENZYME: "false" JULIA_NUM_THREADS: 4 + env: SECRET_CODECOV_TOKEN: "fAV/xwuaV0l5oaIYSAXRQIor8h7yHdlrpLUZFwNVnchn7rDk9UZoz0oORG9vlKLc1GK2HhaPRAy+fTkJ3GM/8Y0phHh3ANK8f5UsGm2DUTNsnf6u9izgnwnoRTcsWu+vSO0fyYrxBvBCoJwljL+yZbDFz3oE16DP7HPIzxfQagm+o/kMEszVuoUXhuLXXH0LxT6pXl214qjqs04HfMRmKIIiup48NB6fBLdhGlQz64MdMNHBfgDa/fafB7eNvn0X6pEOxysoy6bDQLUhKelOXgcDx1UsTo34Yiqr+QeJPAeKcO//PWurwQhPoUoHfLad2da9DN4uQk4YQLqAlcIuAA==;U2FsdGVkX1+mRXF2c9soCXT7DYymY3msM+vrpaifiTp8xA+gMpbQ0G63WY3tJ+6V/fJcVnxYoKZVXbjcg8fl4Q==" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 94df9ddec4..acd279d709 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,18 +52,9 @@ jobs: ${{ runner.os }}-test- ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - - name: "Run test without coverage report" - uses: julia-actions/julia-runtest@v1 - if: matrix.version != '1' || matrix.os != 'ubuntu-latest' - with: - coverage: false - - name: "Run test with coverage report" - uses: julia-actions/julia-runtest@v1 - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' + - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' - uses: codecov/codecov-action@v5 - if: matrix.version == '1' && matrix.os == 'ubuntu-latest' with: files: lcov.info diff --git a/.github/workflows/pr_comment.yml b/.github/workflows/pr_comment.yml index 0e1a03a545..cd28471b16 100644 --- a/.github/workflows/pr_comment.yml +++ b/.github/workflows/pr_comment.yml @@ -8,7 +8,7 @@ jobs: steps: - name: Create PR comment if: github.event_name == 'pull_request' && github.repository == github.event.pull_request.head.repo.full_name && github.event.label.name == 'documentation' # if this is a pull request build AND the pull request is NOT made from a fork - uses: thollander/actions-comment-pull-request@fabd468d3a1a0b97feee5f6b9e499eab0dd903f6 + uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b with: message: 'Once the build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/Flux.jl/previews/PR${{ github.event.number }}/ in ~20 minutes' - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/NEWS.md b/NEWS.md index 7439cb6ec8..eaee9cab83 100644 --- a/NEWS.md +++ b/NEWS.md @@ -13,10 +13,11 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl The module is still available for now, but will be removed in a future release. * Most Flux layers will [re-use memory via `NNlib.bias_act!`](https://github.com/FluxML/Flux.jl/pull/2327), when possible. * Further support for Enzyme.jl, via methods of `Flux.gradient(loss, Duplicated(model))`. - Flux now owns & exports `gradient`, but without `Duplicated` this still defaults to calling Zygote.jl. + Flux now owns & exports `gradient` and `withgradient`, but without `Duplicated` this still defaults to calling Zygote.jl. * `Flux.params` has been deprecated. Use Zygote's explicit differentiation instead, `gradient(m -> loss(m, x, y), model)`, or use `Flux.trainables(model)` to get the trainable parameters. -* Flux now requires Functors.jl v0.5. This new release of Functors assumes all types to be functors by default. Therefore, applying `@layer` or `@functor` to a type is no longer strictly necessary for Flux's models. However, it is still recommended to use `@layer Model` for additional functionality like pretty printing. +* Flux now requires Functors.jl v0.5. This new release of Functors assumes all types to be functors by default. Therefore, applying `Flux.@layer` or `Functors.@functor` to a type is no longer strictly necessary for Flux's models. However, it is still recommended to use `@layer Model` for additional functionality like pretty printing. +* `@layer Model`now behaves the same as `@layer :expand Model`, which means that the model is expanded into its sublayers (if there are any) when printed. To force compact printing, use `@layer :noexpand Model`. ## v0.14.22 * Data movement between devices is now provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl). diff --git a/Project.toml b/Project.toml index 56950af15e..926e19d4eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.15.0-DEV" +version = "0.15.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/make.jl b/docs/make.jl index 370772ce3c..6bdfcbb638 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,11 +2,25 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, DataFrames, JLD2, MLDataDevices + DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) makedocs( + ## This should be + ## modules = [Flux], checkdocs = :all, + ## but we get errors. modules = [Flux, NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore, MLDataDevices], sitename = "Flux", + doctest = false, # done later + checkdocs = :none, # :all, :exports, :none + # checkdocs_ignored_modules = [NNlib, Functors, MLUtils, Zygote, OneHotArrays, Optimisers, ChainRulesCore, MLDataDevices], + warnonly = [:cross_references], + format = Documenter.HTML( + sidebar_sitename = false, + analytics = "UA-36890222-9", + assets = ["assets/flux.css"], + prettyurls = get(ENV, "CI", nothing) == "true" + ), pages = [ "Welcome" => "index.md", "Guide" => [ @@ -58,17 +72,7 @@ makedocs( "Deep Convolutional GAN" => "tutorials/2021-10-08-dcgan-mnist.md", =# ], - ], - format = Documenter.HTML( - sidebar_sitename = false, - analytics = "UA-36890222-9", - assets = ["assets/flux.css"], - prettyurls = get(ENV, "CI", nothing) == "true" - ), - doctest = false, # done later - checkdocs = :none, # :exports # Do not check if all functions appear in the docs - # since it considers all packages - warnonly = [:cross_references] + ] ) doctest(Flux) # only test Flux modules diff --git a/docs/src/guide/gpu.md b/docs/src/guide/gpu.md index 6e0ed95b3b..7adaa7ebe0 100644 --- a/docs/src/guide/gpu.md +++ b/docs/src/guide/gpu.md @@ -3,7 +3,9 @@ Most work on neural networks involves the use of GPUs, as they can typically perform the required computation much faster. This page describes how Flux co-operates with various other packages, which talk to GPU hardware. -## Basic GPU use: from `Array` to `CuArray` with `cu` +For those in a hurry, see the [quickstart](@ref man-quickstart) page. Or do `using CUDA` and then call `gpu` on both the model and the data. + +## Basic GPU use: from `Array` to `CuArray` Julia's GPU packages work with special array types, in place of the built-in `Array`. The most used is `CuArray` provided by [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), for GPUs made by NVIDIA. @@ -119,7 +121,7 @@ model = Chain(...) |> device The reason they work on Flux models is that `Flux.@layer Layer` defines methods of `Adapt.adapt_structure(to, lay::Layer)`. -## Automatic GPU choice with `gpu` +## Automatic GPU choice with `gpu` and `gpu_device` Flux also provides a more automatic way of choosing which GPU (or none) to use. This is the function `gpu`: * By default it does nothing. @@ -131,19 +133,28 @@ Flux also provides a more automatic way of choosing which GPU (or none) to use. For the most part, this means that a script which says `model |> gpu` and `data |> gpu` will just work. It should always run, and if a GPU package is loaded (and finds the correct hardware) then that will be used. -The function `gpu` uses a lower-level function called `get_device()` from [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl), -which checks what to do & then returns some device object. In fact, the entire implementation is just this: +The function `gpu` uses a lower-level function called [`gpu_device`](@ref) from MLDataDevices.jl, +which checks what to do and then returns some device object. In fact, the entire implementation is just this: ```julia gpu(x) = gpu_device()(x) cpu(x) = cpu_device()(x) ``` +Automatic backend selection through `gpu` is not type-stable. That doesn't matter if you do it once, or once per large batch -- it costs a few microseconds. But it might matter if you do it within some loop. -## Manually selecting devices +To avoid this, you can first obtain a "device object" with `device = gpu_device()`, once, and then use this as the function to transfer data. Something like this: +```julia +to_device = gpu_device() +gpu_model = model |> to_device -I thought there was a whole `Flux.gpu_backend!` and Preferences.jl story we had to tell?? +for epoch in 1:num_epochs + for (x, y) in dataloader + x_gpu, y_gpu = (x, y) |> to_device + # training code... +``` +Finally, setting a backend prefence with [`gpu_backend!`](@ref) gives type stability to the whole pipeline. ## Transferring Training Data @@ -408,7 +419,7 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true) By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following: -```julia +```julia-repl julia> using CUDA julia> CUDA.functional() @@ -417,7 +428,7 @@ true For AMD GPU: -```julia +```julia-repl julia> using AMDGPU julia> AMDGPU.functional() @@ -429,7 +440,7 @@ true For Metal GPU: -```julia +```julia-repl julia> using Metal julia> Metal.functional() diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md index 3bb4358afe..85d54ee58b 100644 --- a/docs/src/guide/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -13,11 +13,6 @@ julia> df(x) = gradient(f, x)[1]; # df/dx = 6x + 2 julia> df(2) 14.0 - -julia> d2f(x) = gradient(df, x)[1]; # d²f/dx² = 6 - -julia> d2f(2) -6.0 ``` When a function has many parameters, we can get gradients of each one at the same time: diff --git a/docs/src/guide/models/custom_layers.md b/docs/src/guide/models/custom_layers.md index 723016dd00..bb5f44b9e5 100644 --- a/docs/src/guide/models/custom_layers.md +++ b/docs/src/guide/models/custom_layers.md @@ -109,7 +109,9 @@ Join(combine, paths...) = Join(combine, paths) ``` Notice again that we parameterized the type of the `combine` and `paths` fields. In addition to the performance considerations of concrete types, this allows either field to be `Vector`s, `Tuple`s, or one of each - we don't need to pay attention to which. -The next step is to use [`Flux.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `Flux.setup` on a `Join` maps over the underlying trainable arrays on each path. +The next step is to use [`Flux.@layer`](@ref) to make our struct behave like a Flux layer. +In Flux < v0.15 this used to be important so that calling `Flux.setup` on a `Join` maps over the underlying trainable arrays on each path. Since Flux v0.15, this is no longer necessary, since now Functors.jl automatically traverses custom types. However, [`Flux.@layer`](@ref) is still recommended for pretty printing and other niceties. + ```julia Flux.@layer Join ``` diff --git a/docs/src/guide/models/quickstart.md b/docs/src/guide/models/quickstart.md index a0c92e0ef3..1539eca412 100644 --- a/docs/src/guide/models/quickstart.md +++ b/docs/src/guide/models/quickstart.md @@ -67,7 +67,7 @@ plot(p_true, p_raw, p_done, layout=(1,3), size=(1000,330)) ``` ```@raw html - + ``` Here's the loss during training: diff --git a/docs/src/guide/training/training.md b/docs/src/guide/training/training.md index 02e17eee41..1b79a3e8f4 100644 --- a/docs/src/guide/training/training.md +++ b/docs/src/guide/training/training.md @@ -159,7 +159,7 @@ first(data) isa Tuple{AbstractMatrix, AbstractVector} # true Here each iteration will use one matrix `x` (an image, perhaps) and one vector `y`. It is very common to instead train on *batches* of such inputs (or *mini-batches*, the two words mean the same thing) both for efficiency and for better results. -This can be easily done using the [`DataLoader`](@ref Flux.Data.DataLoader): +This can be easily done using the [`DataLoader`](@ref Flux.DataLoader): ```julia data = Flux.DataLoader((X, Y), batchsize=32) diff --git a/docs/src/index.md b/docs/src/index.md index 3c17ff4c9d..16b18ed0fc 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,14 +8,14 @@ Flux is a library for machine learning. It comes "batteries-included" with many ### Installation -Download [Julia 1.9](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. +Download [Julia 1.10](https://julialang.org/downloads/) or later, preferably the current stable release. You can add Flux using Julia's package manager, by typing `] add Flux` in the Julia prompt. For Nvidia GPU support, you will also need to install the `CUDA` and the `cuDNN` packages. For AMD GPU support, install the `AMDGPU` package. For acceleration on Apple Silicon, install the `Metal` package. ### Learning Flux The **[quick start](@ref man-quickstart)** page trains a simple neural network. -This rest of the **guide** provides a from-scratch introduction to Flux's take on models and how they work, starting with [fitting a line](@ref man-overview). Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts. +The rest of the **guide** provides a from-scratch introduction to Flux's take on models and how they work, starting with [fitting a line](@ref man-overview). Once you understand these docs, congratulations, you also understand [Flux's source code](https://github.com/FluxML/Flux.jl), which is intended to be concise, legible and a good reference for more advanced concepts. There are some **tutorials** about building particular models. The **[model zoo](https://github.com/FluxML/model-zoo/)** has starting points for many other common ones. And finally, the **[ecosystem page](ecosystem.md)** lists packages which define Flux models. diff --git a/docs/src/reference/data/mldatadevices.md b/docs/src/reference/data/mldatadevices.md index 86b0e474f6..7f0261d15a 100644 --- a/docs/src/reference/data/mldatadevices.md +++ b/docs/src/reference/data/mldatadevices.md @@ -1,6 +1,11 @@ +```@meta +CurrentModule = MLDataDevices +CollapsedDocStrings = true +``` + # Transferring data across devices -Flux relies on the [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl/blob/main/src/public.jl) package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types. +Flux relies on the MLDataDevices.jl package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types. ```@docs MLDataDevices.cpu_device diff --git a/docs/src/reference/data/mlutils.md b/docs/src/reference/data/mlutils.md index 7ae288d837..042a53fe44 100644 --- a/docs/src/reference/data/mlutils.md +++ b/docs/src/reference/data/mlutils.md @@ -1,3 +1,8 @@ +```@meta +CurrentModule = Flux +CollapsedDocStrings = true +``` + # Working with Data, using MLUtils.jl Flux re-exports the `DataLoader` type and utility functions for working with @@ -25,6 +30,7 @@ MLUtils.chunk MLUtils.eachobs MLUtils.fill_like MLUtils.filterobs +Flux.flatten MLUtils.flatten MLUtils.getobs MLUtils.getobs! diff --git a/docs/src/reference/data/onehot.md b/docs/src/reference/data/onehot.md index ef0efcdc35..e96db5c79e 100644 --- a/docs/src/reference/data/onehot.md +++ b/docs/src/reference/data/onehot.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # One-Hot Encoding with OneHotArrays.jl It's common to encode categorical variables (like `true`, `false` or `cat`, `dog`) in "one-of-k" or ["one-hot"](https://en.wikipedia.org/wiki/One-hot) form. [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) provides the `onehot` function to make this easy. @@ -51,7 +55,7 @@ julia> onecold(ans, [:a, :b, :c]) Note that these operations returned `OneHotVector` and `OneHotMatrix` rather than `Array`s. `OneHotVector`s behave like normal vectors but avoid any unnecessary cost compared to using an integer index directly. For example, multiplying a matrix with a one-hot vector simply slices out the relevant row of the matrix under the hood. -### Function listing +## Function listing ```@docs OneHotArrays.onehot diff --git a/docs/src/reference/destructure.md b/docs/src/reference/destructure.md index 469a1465b1..d60452d968 100644 --- a/docs/src/reference/destructure.md +++ b/docs/src/reference/destructure.md @@ -1,9 +1,14 @@ +```@meta +CurrentModule = Flux +CollapsedDocStrings = true +``` + # [Flat vs. Nested Structures](@id man-destructure) A Flux model is a nested structure, with parameters stored within many layers. Sometimes you may want a flat representation of them, to interact with functions expecting just one vector. This is provided by `destructure`: -```julia +```julia-repl julia> model = Chain(Dense(2=>1, tanh), Dense(1=>1)) Chain( Dense(2 => 1, tanh), # 3 parameters @@ -22,7 +27,7 @@ Chain( Both `destructure` and the `Restructure` function can be used within gradient computations. For instance, this computes the Hessian `∂²L/∂θᵢ∂θⱼ` of some loss function, with respect to all parameters of the Flux model. The resulting matrix has off-diagonal entries, which cannot really be expressed in a nested structure: -```julia +```julia-repl julia> x = rand(Float32, 2, 16); julia> grad = gradient(m -> sum(abs2, m(x)), model) # nested gradient @@ -51,7 +56,7 @@ julia> Flux.destructure(grad) # acts on non-models, too In order to collect all parameters of a model into a list instead, you can use the `trainables` function: -```julia +```julia-repl julia> Flux.trainables(model) 5-element Vector{AbstractArray}: [0.863101 1.2454957] @@ -61,7 +66,7 @@ julia> Flux.trainables(model) ``` Any mutation of the elements of the resulting list will affect the model's parameters. -### All Parameters +## All Parameters The functions `destructure` and `trainables` live in [`Optimisers.jl`](https://github.com/FluxML/Optimisers.jl). @@ -71,9 +76,10 @@ Optimisers.destructure Optimisers.trainable Optimisers.trainables Optimisers.isnumeric +Flux.params ``` -### All Layers +## All Layers Another kind of flat view of a nested model is provided by the `modules` command. This extracts a list of all layers: @@ -81,14 +87,14 @@ Another kind of flat view of a nested model is provided by the `modules` command Flux.modules ``` -### Save and Load +## Save and Load ```@docs Flux.state Flux.loadmodel! ``` -### KeyPath +## KeyPath ```@docs Functors.KeyPath diff --git a/docs/src/reference/models/activation.md b/docs/src/reference/models/activation.md index 9d1d02f24d..11527a950c 100644 --- a/docs/src/reference/models/activation.md +++ b/docs/src/reference/models/activation.md @@ -1,3 +1,8 @@ +```@meta +CollapsedDocStrings = true +``` + + # [Activation Functions from NNlib.jl](@id man-activations) These non-linearities used between layers of your model are exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package. @@ -6,7 +11,7 @@ Note that, unless otherwise stated, activation functions operate on scalars. To Functions like [`softmax`](@ref) are sometimes described as activation functions, but not by Flux. They must see all the outputs, and hence cannot be broadcasted. See the next page for details. -### Alphabetical Listing +## Alphabetical Listing ```@docs celu @@ -35,13 +40,13 @@ tanh_fast trelu ``` -### One More +## One More Julia's `Base.Math` also provides `tanh`, which can be used as an activation function. Note that many Flux layers will automatically replace this with [`NNlib.tanh_fast`](@ref) when called, as Base's `tanh` is slow enough to sometimes be a bottleneck. -```julia +```julia-repl julia> using UnicodePlots julia> lineplot(tanh, -3, 3, height=7) diff --git a/docs/src/reference/models/functors.md b/docs/src/reference/models/functors.md index 1768c99f93..6b232734f0 100644 --- a/docs/src/reference/models/functors.md +++ b/docs/src/reference/models/functors.md @@ -6,19 +6,20 @@ CollapsedDocStrings = true Flux models are deeply nested structures, and [Functors.jl](https://github.com/FluxML/Functors.jl) provides tools needed to explore such objects, apply functions to the parameters they contain (e.g. for moving them to gpu), and re-build them. -!!! compat "Flux ≤ 0.14" +!!! compat "Flux ≤ v0.14" All layers were previously defined with the `Functors.@functor` macro. This still works, but it is recommended that you use the new [`Flux.@layer`](@ref Flux.@layer) macro instead. Both allow [`Flux.setup`](@ref Flux.setup) to see the parameters inside, and [`gpu`](@ref) to move them to the GPU, but [`Flux.@layer`](@ref Flux.@layer) also overloads printing, and offers a way to define `trainable` at the same time. -!!! compat "Functors 0.5" - With Functors.jl v0.5, which is required by Flux v0.15 and later, every custom type is a functor by default. This means that applying `Flux.@layer` to a type is no longer strictly necessary, but it is still recommended for addictional features like pretty-printing and `trainable`. +!!! compat "Functors v0.5" + With Functors.jl v0.5, which is required by Flux v0.15 and later, every custom type is a functor by default. This means that applying `Flux.@layer` to a type is no longer strictly necessary, but it is still recommended for addictional features like pretty-printing. `Functors.jl` has its own [notes on basic usage](https://fluxml.ai/Functors.jl/stable/#Basic-Usage-and-Implementation) for more details. Additionally, the [Advanced Model Building and Customisation](@ref man-advanced) page covers the use cases of `Functors` in greater details. ```@docs Flux.@layer +Functors.@leaf Functors.@functor Functors.fmap Functors.fmap_with_path diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index e2f680e500..b798a35291 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -1,3 +1,8 @@ +```@meta +CurrentModule = Flux +CollapsedDocStrings = true +``` + # Built-in Layer Types If you started at the beginning of the guide, then you have already met the @@ -40,14 +45,10 @@ To understand how strides and padding work, the article by [Dumoulin & Visin](ht ```@docs Conv -Conv(weight::AbstractArray) ConvTranspose -ConvTranspose(weight::AbstractArray) CrossCor -CrossCor(weight::AbstractArray) DepthwiseConv SamePad -Flux.flatten ``` ## MultiHeadAttention @@ -108,9 +109,13 @@ PairwiseFusion Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data). ```@docs +RNNCell RNN +LSTMCell LSTM +GRUCell GRU +GRUv3Cell GRUv3 ``` @@ -140,7 +145,6 @@ Several normalisation layers behave differently under training and inference (te The functions `Flux.trainmode!` and `Flux.testmode!` let you manually specify which behaviour you want. When called on a model, they will place all layers within the model into the specified mode. ```@docs -testmode!(::Any) -testmode!(::Any, ::Any) +testmode! trainmode! ``` diff --git a/docs/src/reference/models/losses.md b/docs/src/reference/models/losses.md index ae0efac186..b8f42cd0fd 100644 --- a/docs/src/reference/models/losses.md +++ b/docs/src/reference/models/losses.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # [Loss Functions](@id man-losses) Flux provides a large number of common loss functions used for training machine learning models. @@ -21,7 +25,7 @@ loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean loss(ŷ, y, agg=identity) # no aggregation. ``` -### Function listing +## Function listing ```@docs Flux.Losses.mae diff --git a/docs/src/reference/models/nnlib.md b/docs/src/reference/models/nnlib.md index e7739f0ebf..3e54d0ac95 100644 --- a/docs/src/reference/models/nnlib.md +++ b/docs/src/reference/models/nnlib.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # Neural Network primitives from NNlib.jl Flux re-exports all of the functions exported by the [NNlib](https://github.com/FluxML/NNlib.jl) package. This includes activation functions, described on [their own page](@ref man-activations). Many of the functions on this page exist primarily as the internal implementation of Flux layer, but can also be used independently. diff --git a/docs/src/reference/outputsize.md b/docs/src/reference/outputsize.md index 9376db9ab8..9e16900ffd 100644 --- a/docs/src/reference/outputsize.md +++ b/docs/src/reference/outputsize.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # Shape Inference Flux has some tools to help generate models in an automated fashion, by inferring the size diff --git a/docs/src/reference/training/callbacks.md b/docs/src/reference/training/callbacks.md index 148aa02128..1dedd4a943 100644 --- a/docs/src/reference/training/callbacks.md +++ b/docs/src/reference/training/callbacks.md @@ -1,3 +1,6 @@ +```@meta +CollapsedDocStrings = true +``` # [Callback Helpers](@id man-callback-helpers) ```@docs diff --git a/docs/src/reference/training/enzyme.md b/docs/src/reference/training/enzyme.md index e096e9fd78..77875c9da3 100644 --- a/docs/src/reference/training/enzyme.md +++ b/docs/src/reference/training/enzyme.md @@ -11,7 +11,7 @@ Calling `Duplicated` on any Flux model which was defined using `@layer` will all and passing that to `gradient` (or `withgradient`, or `train!`) will then use Enzyme instead of Zygote. The gradient functions still return the gradient as usual, which can then be passed to `update!`: -```julia +```julia-repl julia> using Flux, Enzyme julia> model = Chain(Dense(28^2 => 32, sigmoid), Dense(32 => 10), softmax); # from model zoo @@ -47,7 +47,7 @@ The gradient `grads_f[1]` can be passed to `update!` as usual. But for convenience, you may also use what is stored within `Duplicated`. These are equivalent ways to perform an update step: -```julia +```julia-repl julia> opt_state = Flux.setup(Adam(), model) julia> ans == Flux.setup(Adam(), dup_model) @@ -60,7 +60,7 @@ julia> Flux.update!(opt_state, dup_model) # equivlent new path, Enzyme only Instead of using these FLux functions, you can also use Enzyme's own functions directly. `Enzyme.gradient` works like this: -```julia +```julia-repl julia> grads_e = Enzyme.gradient(Reverse, (m,x,y) -> sum(abs2, m(x) .- y), model, Const(x1), Const(y1)) (Chain(Dense(784 => 32, σ), Dense(32 => 10), softmax), nothing, nothing) @@ -73,7 +73,7 @@ But its fields contain the same gradient. There is also a method of `train!` which similarly takes `Duplicated(model)`: -```julia +```julia-repl julia> opt_state = Flux.setup(Adam(0), model); julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_state) diff --git a/docs/src/reference/training/reference.md b/docs/src/reference/training/reference.md index aa55ec0927..a9ee35e4c0 100644 --- a/docs/src/reference/training/reference.md +++ b/docs/src/reference/training/reference.md @@ -1,3 +1,7 @@ +```@meta +CollapsedDocStrings = true +``` + # Training API Reference The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl). diff --git a/docs/src/reference/utilities.md b/docs/src/reference/utilities.md index 53d2d7b9ef..eac4f78f61 100644 --- a/docs/src/reference/utilities.md +++ b/docs/src/reference/utilities.md @@ -1,3 +1,8 @@ +```@meta +CurrentModule = Flux +CollapsedDocStrings = true +``` + # [Random Weight Initialisation](@id man-init-funcs) Flux initialises convolutional layers and recurrent cells with `glorot_uniform` by default. diff --git a/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl index b354e50b5d..5980d96a1e 100644 --- a/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl +++ b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl @@ -4,22 +4,6 @@ using Flux using CUDA, cuDNN using NNlib -const USE_CUDNN = Ref{Union{Nothing, Bool}}(nothing) - -function check_use_cudnn() - if !isnothing(USE_CUDNN[]) - return - end - - USE_CUDNN[] = cuDNN.has_cudnn() - if !USE_CUDNN[] - @warn """ - cuDNN.jl didn't found libcudnn, some Flux functionality will not be available. - """ maxlog=1 - end - return -end - function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache=nothing) where T<:Union{Float32, Float64} diff --git a/src/Flux.jl b/src/Flux.jl index db77891913..272284ef46 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -48,7 +48,7 @@ export Chain, Dense, Embedding, EmbeddingBag, fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32, testmode!, trainmode! -@compat(public, ( # mark unexported symbols as API, on Julia 1.11 +@compat(public, ( # unexported symbols marked as API, on Julia 1.11 # modules Losses, Train, # layers @@ -61,6 +61,9 @@ export Chain, Dense, Embedding, EmbeddingBag, setup, train!, # from Optimsers.jl destructure, freeze!, thaw!, adjust!, trainables, update!, trainable, + # from Zygote.jl + hessian, diaghessian, jacobian, withjacobian, pullback, + # AD functions withgradient, # init glorot_uniform, diff --git a/src/deprecations.jl b/src/deprecations.jl index b1e7ccec92..530301fbd1 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -67,7 +67,7 @@ function reset!(x) return x end -function params!(p::Zygote.Params, x, seen = IdSet()) +function params!(p::Zygote.Params, x, seen = Base.IdSet()) if x isa AbstractArray{<:Number} && Functors.isleaf(x) return push!(p, x) elseif x in seen @@ -85,10 +85,8 @@ end Returns a `Zygote.Params` object containing all parameter arrays from the model. This is deprecated! - This function was the cornerstone of how Flux used Zygote's implicit mode gradients, but since Flux 0.13 we use explicit mode `gradient(m -> loss(m, x, y), model)` instead. - To collect all the parameter arrays for other purposes, use `Flux.trainables(model)`. """ function params(m...) @@ -99,27 +97,16 @@ function params(m...) return ps end - -""" - @functor MyLayer - -Flux used to require the use of `Functors.@functor` to mark any new layer-like struct. -This allowed it to explore inside the struct, and update any trainable parameters within. -Flux@0.15 removes this requirement. This is because Functors@0.5 changed ist behaviour -to be opt-out instead of opt-in. Arbitrary structs will now be explored without special marking. -Hence calling `@functor` is no longer required. - -Calling `Flux.@layer MyLayer` is, however, still recommended. This adds various convenience methods -for your layer type, such as pretty printing, and use with Adapt.jl. -""" -macro functor(ex) +macro functor(args...) @warn """The use of `Flux.@functor` is deprecated. Most likely, you should write `Flux.@layer MyLayer` which will add various convenience methods for your type, - such as pretty-printing, and use with Adapt.jl. + such as pretty-printing and use with Adapt.jl. However, this is not required. Flux.jl v0.15 uses Functors.jl v0.5, which makes exploration of most nested `struct`s opt-out instead of opt-in... so Flux will automatically see inside any custom struct definitions. + If you really want to apply the `@functor` macro to a custom struct, use `Functors.@functor` instead. """ maxlog=1 - _layer_macro(ex) + + return Functors.functorm(args...) end # Allows caching of the parameters when params is called within gradient() to fix #2040. diff --git a/src/functor.jl b/src/functor.jl index af28ca4906..9919e973c0 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -29,33 +29,7 @@ Dropout(0.3) """ testmode!(m) = testmode!(m, true) -""" - trainmode!(model) -> model - -Set a layer, or all layers in a model, to training mode. -Opposite to [`testmode!`](@ref), see further details there. -""" -trainmode!(m) = testmode!(m, false) -trainmode!(m, mode::Symbol) = testmode!(m, mode) -trainmode!(m, ::Nothing) = testmode!(m, nothing) # why do we have so much API? - -""" - testmode!(model, inactive) -This two-argument method is largely internal. It recurses into the `model`, -and until a method like `testmode!(d::Dropout, inactive)` alters the activity of a layer. -Custom layers can support manual `testmode!` / `trainmode!` switching -by defining such a method. - -Possible values of `inactive` are: -- `true` for testing, i.e. `active=false` -- `false` for training, same as [`trainmode!`](@ref)`(m)` -- `:auto` or `nothing` for Flux to detect training automatically. - -!!! compat - This method may be removed in a future breaking change, to separate - the user-facing `testmode!` from the internal recursion. -""" function testmode!(m, mode) inactive = if mode isa Symbol mode === :auto || throw(ArgumentError(lazy"testmode! accepts only the symbol :auto, got :$mode")) @@ -69,7 +43,15 @@ function testmode!(m, mode) m end +""" + trainmode!(model) -> model +Set a layer, or all layers in a model, to training mode. +Opposite to [`testmode!`](@ref), see further details there. +""" +trainmode!(m) = testmode!(m, false) +trainmode!(m, mode::Symbol) = testmode!(m, mode) +trainmode!(m, ::Nothing) = testmode!(m, nothing) # why do we have so much API? diff --git a/src/gradient.jl b/src/gradient.jl index 40d58cf933..aa88828545 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -196,7 +196,7 @@ Only available when Enzyme is loaded! # Example -```julia +```julia-repl julia> using Flux, Enzyme julia> model = Chain(Embedding([1.1 2.2 3.3]), Dense([4.4;;]), only); @@ -215,7 +215,7 @@ The function `f` may return Tuple or NamedTuple, with the loss as the first elem The gradient is then `grad = gradient(first∘f, args...)` but the returned value is `val = f(args...)`: -```julia +```julia-repl julia> Flux.withgradient(m -> (m(3), "aux"), Duplicated(model)) (val = (14.52, "aux"), grad = ((layers = ((weight = [0.0 0.0 4.4],), (weight = [3.3;;], bias = [1.0], σ = nothing), nothing),),)) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a5a7734313..1f8a81ee16 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -59,6 +59,7 @@ end """ Conv(filter, in => out, σ = identity; stride = 1, pad = 0, dilation = 1, groups = 1, [bias, init]) + Conv(weight, [bias, activation; stride, pad, dilation]) Standard convolutional layer. `filter` is a tuple of integers specifying the size of the convolutional kernel; @@ -91,11 +92,15 @@ Keywords to control initialization of the layer: * `bias` - The initial bias vector is all zero by default. Trainable bias can be disabled entirely by setting this to `false`, or another vector can be provided such as `bias = randn(Float32, out)`. +The second form of the constructor allows you to pass in a pre-constructed weight matrix +and bias vector. This is useful when you want to initialize the weights yourself. + See also [`ConvTranspose`](@ref), [`DepthwiseConv`](@ref), [`CrossCor`](@ref). # Examples + ```jldoctest -julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images +julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> layer = Conv((5,5), 3 => 7, relu; bias = false) Conv((5, 5), 3 => 7, relu, bias=false) # 525 parameters @@ -115,39 +120,32 @@ julia> Conv((1,1), 3 => 7; pad = (20,10,0,0))(xs) |> size julia> Conv((5,5), 3 => 7; stride = 2, dilation = 4)(xs) |> size (42, 42, 7, 50) ``` -""" -struct Conv{N,M,F,A,V} - σ::F - weight::A - bias::V - stride::NTuple{N,Int} - pad::NTuple{M,Int} - dilation::NTuple{N,Int} - groups::Int -end - -""" - Conv(weight::AbstractArray, [bias, activation; stride, pad, dilation]) - -Constructs a convolutional layer with the given weight and bias. -Accepts the same keywords and has the same defaults as -[`Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ; ...)`](@ref Conv). ```jldoctest -julia> weight = rand(3, 4, 5); +julia> weight = rand(Float32, 3, 4, 5); -julia> bias = zeros(5); +julia> bias = zeros(Float32, 5); julia> layer = Conv(weight, bias, sigmoid) # expects 1 spatial dimension Conv((3,), 4 => 5, σ) # 65 parameters -julia> layer(randn(100, 4, 64)) |> size +julia> layer(randn(Float32, 100, 4, 64)) |> size (98, 5, 64) julia> Flux.trainables(layer) |> length 2 ``` """ +struct Conv{N,M,F,A,V} + σ::F + weight::A + bias::V + stride::NTuple{N,Int} + pad::NTuple{M,Int} + dilation::NTuple{N,Int} + groups::Int +end + function Conv(w::AbstractArray{T,N}, b = true, σ = identity; stride = 1, pad = 0, dilation = 1, groups = 1) where {T,N} @@ -223,6 +221,7 @@ end """ ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, outpad=0, dilation=1, [bias, init]) + ConvTranspose(weight, [bias, activation; stride, pad, outpad, dilation]) Standard convolutional transpose layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -237,11 +236,14 @@ of the output in the desired dimensions. Whereas `pad` is used to zero-pad the i Parameters are controlled by additional keywords, with defaults `init=glorot_uniform` and `bias=true`. +The second form of the constructor allows you to pass in a pre-constructed weight matrix +and bias vector. This is useful when you want to initialize the weights yourself. + See also [`Conv`](@ref) for more detailed description of keywords. # Examples ```jldoctest -julia> xs = rand32(100, 100, 3, 50); # a batch of 50 RGB images +julia> xs = rand(Float32, 100, 100, 3, 50); # a batch of 50 RGB images julia> layer = ConvTranspose((5,5), 3 => 7, relu) ConvTranspose((5, 5), 3 => 7, relu) # 532 parameters @@ -258,6 +260,21 @@ julia> ConvTranspose((5,5), 3 => 7, stride=2, outpad=1)(xs) |> size julia> ConvTranspose((5,5), 3 => 7, stride=3, pad=SamePad())(xs) |> size (300, 300, 7, 50) ``` + +```jldoctest +julia> weight = rand(Float32, 3, 4, 5); + +julia> bias = zeros(Float32, 4); + +julia> layer = ConvTranspose(weight, bias, sigmoid) +ConvTranspose((3,), 5 => 4, σ) # 64 parameters + +julia> layer(randn(Float32, 100, 5, 64)) |> size # transposed convolution will increase the dimension size (upsampling) +(102, 4, 64) + +julia> Flux.trainables(layer) |> length +2 +``` """ struct ConvTranspose{N,M,F,A,V} σ::F @@ -273,29 +290,6 @@ end _channels_in(l::ConvTranspose) = size(l.weight)[end] _channels_out(l::ConvTranspose) = size(l.weight)[end-1]*l.groups -""" - ConvTranspose(weight::AbstractArray, [bias, activation; stride, pad, outpad, dilation, groups]) - -Constructs a ConvTranspose layer with the given weight and bias. -Accepts the same keywords and has the same defaults as -[`ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ; ...)`](@ref ConvTranspose). - -# Examples -```jldoctest -julia> weight = rand(3, 4, 5); - -julia> bias = zeros(4); - -julia> layer = ConvTranspose(weight, bias, sigmoid) -ConvTranspose((3,), 5 => 4, σ) # 64 parameters - -julia> layer(randn(100, 5, 64)) |> size # transposed convolution will increase the dimension size (upsampling) -(102, 4, 64) - -julia> Flux.trainables(layer) |> length -2 -``` -""" function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, outpad = 0, dilation = 1, groups = 1) where {T,N} stride = expand(Val(N-2), stride) @@ -403,6 +397,7 @@ end """ CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1, [bias, init]) + CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation]) Standard cross correlation layer. `filter` is a tuple of integers specifying the size of the convolutional kernel; @@ -411,6 +406,9 @@ specifying the size of the convolutional kernel; Parameters are controlled by additional keywords, with defaults `init=glorot_uniform` and `bias=true`. +The second form of the constructor allows you to pass in a pre-constructed weight matrix +and bias vector. This is useful when you want to initialize the weights yourself + See also [`Conv`](@ref) for more detailed description of keywords. # Examples @@ -427,6 +425,18 @@ julia> layer(xs) |> size julia> CrossCor((5,5), 3 => 7, stride=3, pad=(2,0))(xs) |> size (34, 32, 7, 50) ``` + +```jldoctest +julia> weight = rand(Float32, 3, 4, 5); + +julia> bias = zeros(Float32, 5); + +julia> layer = CrossCor(weight, bias, relu) +CrossCor((3,), 4 => 5, relu) # 65 parameters + +julia> layer(randn(Float32, 100, 4, 64)) |> size +(98, 5, 64) +``` """ struct CrossCor{N,M,F,A,V} σ::F @@ -439,26 +449,6 @@ end _channels_in(l::CrossCor) = size(l.weight, ndims(l.weight)-1) -""" - CrossCor(weight::AbstractArray, [bias, activation; stride, pad, dilation]) - -Constructs a CrossCor layer with the given weight and bias. -Accepts the same keywords and has the same defaults as -[`CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ; ...)`](@ref CrossCor). - -# Examples -```jldoctest -julia> weight = rand(3, 4, 5); - -julia> bias = zeros(5); - -julia> layer = CrossCor(weight, bias, relu) -CrossCor((3,), 4 => 5, relu) # 65 parameters - -julia> layer(randn(100, 4, 64)) |> size -(98, 5, 64) -``` -""" function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; stride = 1, pad = 0, dilation = 1) where {T,N} stride = expand(Val(N-2), stride) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index dded9ab306..85cece9477 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -29,7 +29,7 @@ Keyword `rng` lets you specify a custom random number generator. (Only supported on the CPU.) # Examples -```julia +```julia-repl julia> m = Chain(Dense(ones(3,2)), Dropout(0.4)) Chain( Dense(2 => 3), # 9 parameters @@ -297,7 +297,7 @@ that will be used to renormalize the input in test phase. Use [`testmode!`](@ref) during inference. # Examples -```julia +```julia-repl julia> using Statistics julia> xs = rand(3, 3, 3, 2); # a batch of 2 images, each having 3 channels diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 25642f2187..750141db1b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -123,12 +123,14 @@ See [`RNNCell`](@ref) for a layer that processes a single time step. # Forward - rnn(x, h) + rnn(x, [h]) The arguments of the forward pass are: - `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`. -- `h`: The initial hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. +- `h`: The initial hidden state of the RNN. + If given, it is a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros. Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. @@ -506,8 +508,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step. # Forward - gru(x, h) - gru(x) + gru(x, [h]) The arguments of the forward pass are: @@ -584,8 +585,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. # Forward - gruv3cell(x, h) - gruv3cell(x) + gruv3cell(x, [h]) The arguments of the forward pass are: - `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. @@ -658,6 +658,9 @@ for all `len` steps `t` in the input sequence. See [`GRUv3Cell`](@ref) for a layer that processes a single time step. See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. +Notice that `GRUv3` is not a more advanced version of [`GRU`](@ref) +but only a less popular variant. + # Arguments - `in => out`: The input and output dimensions of the layer. diff --git a/src/loading.jl b/src/loading.jl index 3bc4f6e58e..14b13f2295 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -50,7 +50,7 @@ Zero bias vectors and `bias=false` are considered equivalent See also [`Flux.state`](@ref). # Examples -```julia +```julia-repl julia> dst = Chain(Dense(Flux.ones32(2, 5), Flux.ones32(2), tanh), Dense(2 => 1; bias = [1f0])) Chain( Dense(5 => 2, tanh), # 12 parameters diff --git a/test/Project.toml b/test/Project.toml index 99f1d7175a..eb2fe438d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,20 +5,25 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" -MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +Enzyme = "0.13" FiniteDifferences = "0.12" +GPUArraysCore = "0.1" +GPUCompiler = "0.27" Tracker = "0.2.33" -Enzyme = "0.13" diff --git a/test/deprecations.jl b/test/deprecations.jl new file mode 100644 index 0000000000..bee1822f22 --- /dev/null +++ b/test/deprecations.jl @@ -0,0 +1,4 @@ +@testset "params" begin + ps = Flux.params([2,3]) + @test length(ps) == 1 +end diff --git a/test/ext_metal/runtests.jl b/test/ext_metal/runtests.jl index 86e1068cf3..6ce7943025 100644 --- a/test/ext_metal/runtests.jl +++ b/test/ext_metal/runtests.jl @@ -3,7 +3,6 @@ using Metal using Flux using Random, Statistics using Zygote -Flux.gpu_backend!("Metal") # needs a restart @testset "data movement" begin metal_device = Flux.gpu_device() diff --git a/test/runtests.jl b/test/runtests.jl index ca1f116691..f9936fd3ae 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,10 @@ Random.seed!(0) @testset "functors" begin include("functors.jl") end + + @testset "deprecations" begin + include("deprecations.jl") + end else @info "Skipping CPU tests." end @@ -76,7 +80,6 @@ Random.seed!(0) if get(ENV, "FLUX_TEST_CUDA", "false") == "true" Pkg.add(["CUDA", "cuDNN"]) using CUDA, cuDNN - Flux.gpu_backend!("CUDA") if CUDA.functional() @testset "CUDA" begin @@ -92,7 +95,6 @@ Random.seed!(0) if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true" Pkg.add("AMDGPU") using AMDGPU - Flux.gpu_backend!("AMDGPU") if AMDGPU.functional() && AMDGPU.functional(:MIOpen) @testset "AMDGPU" begin @@ -108,7 +110,6 @@ Random.seed!(0) if get(ENV, "FLUX_TEST_METAL", "false") == "true" Pkg.add("Metal") using Metal - Flux.gpu_backend!("Metal") if Metal.functional() @testset "Metal" begin