Creates a LocalPreferences.toml file with the desired GPU backend.
If backend == "", then the gpu_backend preference is deleted. Otherwise, backend is validated to be one of the possible backends and the preference is set to backend.
If a new backend is successfully set, then the Julia session must be restarted for the change to take effect.
Selects GPU device based on the following criteria:
If gpu_backend preference is set and the backend is functional on the system, then that device is selected.
Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by supported_gpu_backends() and select the first functional backend.
If no GPU device is functional and force is false, then cpu_device() is invoked.
If nothing works, an error is thrown.
Arguments
device_id::Union{Nothing, Integer}: The device id to select. If nothing, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using CUDA.device() or AMDGPU.device() or similar. If Integer, then we select the device with the given id. Note that this is 1-indexed, in contrast to the 0-indexed CUDA.jl. For example, id = 4 corresponds to CUDA.device!(3).
Warning
device_id is only applicable for CUDA and AMDGPU backends. For Metal, oneAPI and CPU backends, device_id is ignored and a warning is printed.
Warning
gpu_device won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. Nonetheless, if cuDNN is not loaded you can still manually create a CUDADevice object and use it (e.g. dev = CUDADevice()).
Keyword Arguments
force::Bool: If true, then an error is thrown if no functional GPU device is found.
If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return nothing.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice() – denotes that the device type is unknown
See also get_device_type for a faster alternative that can be used for dispatch based on device type.
Similar to get_device but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead of get_device where ever defining dispatches based on the device type.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
Nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice – denotes that the device type is unknown
Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via MLDataDevices.loaded), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
Returns true if x is a leaf node in the data structure.
Defining MLDataDevices.isleaf(x::T) = true for custom types can be used to customize the behavior the data movement behavior when an object with nested structure containing the type is transferred to a device.
Adapt.adapt_structure(::AbstractDevice, x::T) or Adapt.adapt_structure(::AbstractDevice, x::T) will be called during data movement if isleaf(x::T) == true.
If MLDataDevices.isleaf(x::T) is not defined, then it will fall back to Functors.isleaf(x).
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
dev_or_id: Can be the device from the corresponding package. For example for CUDA it can be a CuDevice. If it is an integer, it is the device id to set. This is 1-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
rank::Integer: Local Rank of the process. This is applicable for distributed training and must be 0-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
Create a DeviceIterator that iterates through the provided iterator via iterate. Upon each iteration, the current batch is copied to the device dev, and the previous iteration is marked as freeable from GPU memory (via unsafe_free!) (no-op for a CPU device).
The conversion follows the same semantics as dev(<item from iterator>).
Similarity to CUDA.CuIterator
The design inspiration was taken from CUDA.CuIterator and was generalized to work with other backends and more complex iterators (using Functors).
MLUtils.DataLoader
Calling dev(::MLUtils.DataLoader) will automatically convert the dataloader to use the same semantics as DeviceIterator. This is generally preferred over looping over the dataloader directly and transferring the data to the device.
Examples
The following was run on a computer with an NVIDIA GPU.
julia
julia> using MLDataDevices, MLUtils
+
+julia> X = rand(Float64, 3, 33);
+
+julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
+
+julia> for (i, x) in enumerate(dataloader)
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 Matrix{Float64}")
+(i, summary(x)) = (2, "3×13 Matrix{Float64}")
+(i, summary(x)) = (3, "3×7 Matrix{Float64}")
+
+julia> for (i, x) in enumerate(CUDADevice()(dataloader))
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
LuxCore.jl defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl without having such a heavy dependency. If you are depending on Lux.jl directly, you do not need to depend on LuxCore.jl (all the functionality is exported via Lux.jl).
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the trainable parameters for the layer.
initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include BatchNorm, LSTM, GRU, etc.
Optionally:
parameterlength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
statelength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
Additionally, on calling initialparameters and initialstates, the parameters and states are not wrapped in a NamedTuple with the same name as the field.
As a convenience, we define the fallback call (::AbstractLuxWrapperLayer)(x, ps, st), which calls getfield(x, layer)(x, ps, st).
abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer
Abstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractLuxLayer.
Advanced Structure Manipulation
Advanced structure manipulation of these layers post construction is possible via Functors.fmap. For a more flexible interface, we recommend using Lux.Experimental.@layer_map.
fmap Support
fmap support needs to be explicitly enabled by loading Functors.jl and Setfield.jl.
Changes from Pre-1.0 Behavior
Previously if layers was a singleton tuple, initialparameters and initialstates would return the parameters and states for the single field layers. From v1.0.0 onwards, even for singleton tuples, the parameters/states are wrapped in a NamedTuple with the same name as the field. See AbstractLuxWrapperLayer to replicate the previous behavior of singleton tuples.
In most cases this function simply calls model(x, ps, st). However, it is still recommended to call apply instead of model(x, ps, st) directly. Some of the reasons for this include:
For certain types of inputs x, we might want to perform preprocessing before calling model. For eg, if x is an Array of ReverseDiff.TrackedReals this can cause significant regressions in model(x, ps, st) (since it won't hit any of the BLAS dispatches). In those cases, we would automatically convert x to a ReverseDiff.TrackedArray.
Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in apply and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
Tip
apply is integrated with DispatchDoctor.jl that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.
Calls apply and only returns the first argument. This function requires that model has an empty state of NamedTuple(). Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
Recursively update all occurrences of the key in the state st with the value. exclude is a function that is passed to Functors.fmap_with_path's exclude keyword.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1, it will return size(A)[1:(end - 1)]. If this behavior is undesirable, provide a custom outputsize(layer, x, rng) implementation).
Fallback Implementation
The fallback implementation of this function is defined once Lux.jl is loaded.
Changes from Pre-1.0 Behavior
Previously it was possible to override this function by defining outputsize(layer). However, this can potentially introduce a bug that is hard to bypass. See this PR for more information.
Compute σ.(x) with the best possible implementation available. On CPUs we unroll the loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use broadcasting.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
fast_activation!!(σ::F, x::AbstractArray) where {F}
Compute σ.(x) with the best possible implementation available. If it is possible to rewrite x in-place, it does so. If x is an immutable array, it falls back to the generic implementation.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
Load SLEEFPirates.jl to get faster activations
Certain activation functions are replaced with specialized implementations from SLEEFPirates.jl for FP32. This might lead to faster performance but can cause slight decrease in accuracy (in the floating point limit).
Computes the batched matrix multiplication of x and y. For more details see the NNlib documentation on NNlib.batched_mul. This function is mostly a wrapper around batched_mul but attempts to be faster on CPUs.
Load LoopVectorization.jl to get faster batched matrix multiplication
On CPUs loading LoopVectorization adds faster implementations of batched matrix multiplication.
Applies the activation function σ elementwise to the result of broadcasted addition of x and bias along the penultimate dimension. A vector x is treated as a matrix with a single last dimension.
Same as bias_activation but might update x in-place if possible. Users should not rely on x being mutated, it is recommended to use it like y = bias_activation!!(σ, x, bias). If x is updated in-place, y aliases x.
fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray,
+ b::Optional{<:AbstractVector}, cdims::ConvDims) where {F}
Computes σ.(conv(x, weight, cdims) .+ b) (b is not exactly broadcasted like this, rather it is reshaped and broadcasted to the penultimate dimension) with the best possible implementation available. This operation fuses operations into a single kernel if possible, and minimizes reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight tensor
x: Input tensor
b: Bias tensor (can be nothing)
cdims: ConvDims object
Notes on implementation
For CUDA Arrays, this uses fused CUDNN kernels when the activation is identity or relu. For other activations, it tries to fuse the operations on the Julia side.
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning.
alpha_dropout(rng::AbstractRNG, x, p, training)
+alpha_dropout(rng::AbstractRNG, x, p, training, α, A, B)
Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants for a fixed dropout probability.
Arguments
rng: Random number generator
x: Input Array
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context`
α: -1.7580993408473766. Computed at limit x tends to infinity, selu(x) = -λβ = α
A: Scaling factor for the mean
B: Scaling factor for the variance
Returns
Output Array after applying alpha dropout
Updated state for the random number generator
References
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017).
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Arguments
rng: Random number generator
x: Input Array
mask: Dropout Mask. If not used then it is constructed automatically
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
update_mask: If Val(true) or True() then the mask is generated and used. Else, the mask provided is directly used
invp: Inverse multiplied to the mask. Calculated as invp = 1 / (1 - p).
Returns
Output Array after applying dropout
Dropout Mask (if training == false, the returned value is meaningless)
Updated state for the random number generator
References
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix,
+ b::Optional{<:AbstractVector}) where {F}
Compute σ.(weight * x .+ b) with the best possible implementation available. Currently this implementation attempts to minimize reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight matrix
x: Input matrix
b: Bias vector (can be nothing)
Notes on implementation
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For CUDA Arrays, this uses a special fused implementation via cuBLASLt.
For small CPU Arrays, we use LoopVectorization.jl. On x86_64 we use Octavian for medium sized matrices. This is overridden if special BLAS implementations are loaded (currently MKL, AppleAccelerate, and BLISBLAS).
!!! tip "Load Octavian.jl
Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the
+input sizes.
Batch Normalization computes the mean and variance for each input slice and normalises the input accordingly.
Arguments
x: Input to be Normalized
scale: Scale factor () (can be nothing)
bias: Bias factor () (can be nothing)
running_mean: Running mean (can be nothing)
running_var: Running variance (can be nothing)
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
momentum: Momentum for updating running mean and variance (default: 0.1f0)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
Returns
Normalized Array of same size as x. And a Named Tuple containing the updated running mean and variance.
References
[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015.
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics.
Arguments
x: Input to be Normalized
scale: Scale factor () (can be nothing)
bias: Bias factor () (can be nothing)
groups: Number of groups
σ: Activation function (default: identity)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
Returns
The normalized array is returned.
References
[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.
Instance Normalization computes the mean and variance for each input slice and normalises the input accordingly.
Arguments
x: Input to be Normalized (must be atleast 3D)
scale: Scale factor () (can be nothing)
bias: Bias factor () (can be nothing)
running_mean: Running mean (can be nothing)
running_var: Running variance (can be nothing)
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
momentum: Momentum for updating running mean and variance (default: 0.1f0)
Returns
Normalized Array of same size as x. And a Named Tuple containing the updated running mean and variance.
References
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
and applies the activation function σ elementwise to y.
Arguments
x: Input to be Normalized
scale: Scale factor () (can be nothing)
bias: Bias factor () (can be nothing)
σ: Activation function (default: identity)
dims: Dimensions along which the mean and std of x is computed. If nothing is passed, the dims are inferred based on the dimensions of scale and bias. For example, if x is N dimensional and scale and bias are M dimensional, then the dims will be 1:(N - M).
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
Returns
Normalized Array of same size as x.
References
[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016).
Returns the internal operation mode for the given array(s). This is useful to define custom implementations using different backends like simple Julia broadcasting, Kernel Abstractions, Loop Vectorization, etc.
Currently supported modes are:
GenericBroadcastOp: This is the fallback for most types. For the following types this is the preferred mode:
Arrays with fast_scalar_indexing set to False.
Static Arrays
ReverseDiff Arrays
Tracker Arrays
ForwardDiff.Dual Arrays
GPUBroadcastOp{dev}: GPU Arrays where dev is obtained from get_device_type(xs). This option dispatches should preferably use KernelAbstractions or specialized vendor dispatches.
LoopedArrayOp: CPU arrays that can be optimized using SIMD Loops, ideally using LoopVectorization.jl or Polyester.jl.
glorot_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a normal distribution with standard deviation gain * sqrt(2 / (fan_in + fan_out)). This method is described in [1] and also known as Xavier initialization.
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
glorot_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a uniform distribution on the interval , where x = gain * sqrt(6 / (fan_in + fan_out)). This method is described in [1] and also known as Xavier initialization.
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
Constructs an array that aims to provide an identity mapping when used as parameters in most layers of a neural network. The identity mapping is scaled by the gain parameter.
Behavior
1D: Returns a Vector of zeros (useful for biases in layers where input_size == output_size).
2D: Returns an identity matrix (useful for fully connected layers with equal input and output sizes).
More than 2D: Returns a tensor where the central slice along the last two dimensions is an identity matrix, and the rest are zeros (useful for convolutional layers, simulating an identity convolution).
Caveats
Not all layers will result in an identity mapping when using this initializer. Exceptions include recurrent and normalization layers.
Layers must have input_size == output_size for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros.
For convolutional layers to achieve an identity mapping, kernel sizes must be odd, and appropriate padding must be applied to ensure the output feature maps are the same size as the input feature maps.
Arguments
rng::AbstractRNG: An optional random number generator, included for consistency with other initializers but ignored since the output is deterministic.
T::Type{<:Number}: The numeric type of the array elements.
size...: The dimensions of the array to be initialized.
gain::Number=1: A scaling factor applied to the identity mapping.
shift::Union{Integer, Tuple{Integer, Integer}}=0: An integer or a tuple specifying the circular shift applied to the output array.
Returns
AbstractArray{T}: An array initialized to represent an identity mapping, scaled by gain and optionally shifted by shift.
kaiming_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers taken from a normal distribution standard deviation gain / sqrt(fan_in)
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
kaiming_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a uniform distribution on the interval [-x, x], where x = gain * sqrt(3/fan_in).
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, using random numbers drawn from a normal distribution for the non-zero elements. This method was introduced in [1].
Note
The sparsity parameter controls the proportion of the matrix that will be zeroed. For example, a sparsity of 0.3 means that approximately 30% of the elements will be set to zero. The non-zero elements are distributed according to a normal distribution, scaled by the std parameter.
Arguments
rng::AbstractRNG: The random number generator to use.
T::Type{<:Number}: The numeric type of the elements in the returned array.
dims::Integer...: The dimensions of the weight matrix to be generated.
sparsity::Number: The proportion of elements to be zeroed. Must be between 0 and 1.
std::Number=0.01: The standard deviation of the normal distribution before applying gain.
Returns
AbstractArray{T}: A sparsely initialized weight matrix of dimensions dims and type T.
Examples
julia
julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01);
+
+julia> y isa Matrix{Float32}
+true
+
+julia> size(y) == (5, 5)
+true
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
truncated_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; mean = 0,
+ std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100)).
orthogonal([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...;
+ gain = 1) -> AbstractArray{T, length(dims)}
Return an AbstractArray{T} of the given dimensions (dims) which is a (semi) orthogonal matrix, as described in [1].
The function constructs an orthogonal or semi-orthogonal matrix depending on the specified dimensions. For two dimensions, it returns a matrix where dims = (rows, cols). For more than two dimensions, it computes an orthogonal matrix of size prod(dims[1:(end - 1)]) by dims[end] before reshaping it to the original dimensions.
Cannot construct a vector, i.e., length(dims) == 1 is forbidden.
Arguments
rng::AbstractRNG: Random number generator.
T::Type{<:Real}: The type of the elements in the array.
dims::Integer...: The dimensions of the array.
gain::Number: Scaling factor for the elements of the orthogonal matrix.
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
Compute the Jacobian-Vector Product . This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector products efficiently using mixed-mode AD.
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the JVP.
Compute the Vector-Jacobian Product . This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian products efficiently using mixed-mode AD.
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the VJP.
Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):
ndims(y) ≥ 2
size(y, ndims(y)) == size(x, ndims(x))
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
AutoZygote
Zygote.jl
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the jacobian.
x: The input to the function. Must have ndims(x) ≥ 2.
Returns
J: The Jacobian of f with respect to x. This will be a 3D Array. If the dimensions of x are (N₁, N₂, ..., Nₙ, B) and of y are (M₁, M₂, ..., Mₘ, B), then J will be a ((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B) Array.
Danger
f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.
All features listed on this page are experimental which means:
No SemVer Guarantees. We use code here to iterate fast. That said, historically we have never broken any code in this module and have always provided a deprecation period.
Expect edge-cases and report them. It will help us move these features out of experimental sooner.
Freeze the parameters with name which_params of the layer l.
Use Lux.Experimental.freeze instead
It is always recommended to use the Lux.Experimental.freeze function instead of directly using the FrozenLayer constructor.
No checks for which_params
There are no checks for which_params. For example, if the original layer has parameters named (:weight, :bias), and which_params is set to (:myweight,) then none of the parameters are frozen and no error is thrown.
Arguments
l: Lux AbstractLuxLayer.
which_params: Parameter Names to be Frozen. Can be set to nothing, in which case all parameters are frozen.
Extended Help
Parameters
Parameters of the layer l excluding which_params.
States
frozen_params: Parameters that are frozen, i.e., which_params.
states: The state of the inner layer l.
Note on Internal Layer Implementation
The inner layer should work with NamedTuple parameters. In order to support custom parameter types, users need to implement Lux.Utils.merge(::CustomParamType, ::NamedTuple) or extend Lux.Utils.named_tuple(::CustomParamType) to return a NamedTuple.
Map the function f over the model l, with the parameters ps and states st. This is different from Functors.fmap since it zips the layers, parameters, and states and invokes the function on all of them together.
KeyPath provided to the function
The KeyPath depths on the structure of the parameters and states. This is of consequence exclusively for AbstractLuxWrapperLayer where the structure of the layer doesn't match the structure of the parameters and states. In the example, provided below, the KeyPath is (:chain, :dense_1) for the first layer (following the structure in ps) while accessing the same layer in the chain is done with ( :chain, :layers, :dense_1).
Call Signature for f
Must take 4 inputs – AbstractLuxLayer, Corresponding Parameters, Corresponding States, and the Functors.KeyPath to the layer.
Must return a tuple of 3 elements – AbstractLuxLayer, new parameters and the new states.
Extended Help
Example
julia
julia> using Lux, Random
+
+julia> c = Parallel(
+ +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
+ dense_3=Dense(5 => 1));
+
+julia> rng = Random.default_rng();
+
+julia> ps, st = Lux.setup(rng, c);
+
+julia> # Makes parameters of Dense Layers inside Chain zero
+ function zero_dense_params(l, ps, st, name)
+ if l isa Dense
+ println("zeroing params of $name")
+ ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
+ end
+ return l, ps, st
+ end;
+
+julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
+zeroing params of KeyPath(:chain, :dense_1)
+zeroing params of KeyPath(:chain, :dense_2)
+zeroing params of KeyPath(:dense_3,)
+
+julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
+ ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
+ ps_new.dense_3.weight, ps_new.dense_3.bias))
+true
A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging.
Arguments
layer: The layer to be wrapped.
Extended Help
Keyword Arguments
nan_check: Whether to check for NaNs in the input, parameters, and states. Can be :both, :forward, :backward, or :none.
error_check: Whether to check for errors in the layer. If true, will throw an error if the layer fails.
location: The location of the layer. Use Lux.Experimental.@debug_mode to construct this layer to populate this value correctly.
Input / Output
Inputs and outputs are the same as the layer unless one of the nan_check or error_check criteria is met.
If nan_check is enabled and NaNs are detected then a DomainError is thrown. If error_check is enabled, then any errors in the layer are thrown with useful information to track where the error originates.
ChainRules Compatible Reverse Mode AD Tools
nan_check for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently.
Disable After Debugging
This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance.
Updates the parameters in ps with a common set of parameters new_parameters that are shared between each list in the nested list sharing. (That was kind of a mouthful, the example should make it clear).
Arguments
ps: Original parameters.
sharing: A nested list of lists of accessors of ps which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)
new_parameters: If passed the length of new_parameters must be equal to the length of sharing. For each vector in sharing the corresponding parameter in new_parameters will be used. (If not passed, the parameters corresponding to the first element of each vector in sharing will be used).
Returns
Updated Parameters having the same structure as ps.
ComponentArrays doesn't allow sharing parameters. Converting the returned parameters to a ComponentArray will silently cause the parameter sharing to be undone.
Initialize the given backend. Users can supply cuda_devices and amdgpu_devices to initialize the backend with the given devices. These can be set to missing to prevent initialization of the given device type. If set to nothing, and the backend is functional we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to initialize the backend with the given devices.
Possible values for backend are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Get the distributed backend for the given backend type. Possible values are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Danger
initialize(backend; kwargs...) must be called before calling this function.
Backend Agnostic API to broadcast the given buffer sendrecvbuf or sendbuf to all workers into recvbuf. The value at root will be broadcasted to all other workers.
data must be compatible with MLUtils interface. The returned container is compatible with MLUtils interface and is used to partition the dataset across the available processes.
Load MLUtils.jl
MLUtils.jl must be installed and loaded before using this.
This always ignores the active field of some of the Flux layers. This is almost never going to be supported.
Keyword Arguments
preserve_ps_st: Set to true to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set force_preserve to true.
force_preserve: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if force_transform is false a warning will be printed and a core Lux layer will be returned. Else, it will create a FluxLayer.
Example
julia
julia> import Flux
+
+julia> using Adapt, Lux, Random
+
+julia> m = Flux.Chain(Flux.Dense(2 => 3, relu), Flux.Dense(3 => 2));
+
+julia> m2 = adapt(FromFluxAdaptor(), m); # or FromFluxAdaptor()(m.layers)
+
+julia> x = randn(Float32, 2, 32);
+
+julia> ps, st = Lux.setup(Random.default_rng(), m2);
+
+julia> size(first(m2(x, ps, st)))
+(2, 32)
SimpleChains.jl provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux models to SimpleChains models while preserving the layer interface.
Tip
Accessing these functions require manually loading SimpleChains, i.e., using SimpleChains must be present somewhere in the code for these to be used.
Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the AbstractLuxLayer interfacem but all internal calculations are performed using SimpleChains.
Warning
There is no way to preserve trained parameters and states when converting to SimpleChains.jl.
Warning
Any kind of initialization function is not preserved when converting to SimpleChains.jl.
Arguments
input_dims: Tuple of input dimensions excluding the batch dimension. These must be of static type as SimpleChains expects.
convert_to_array: SimpleChains.jl by default outputs StrideArraysCore.StrideArray, but this might not compose well with other packages. If convert_to_array is set to true, the output will be converted to a regular Array.
Create a layer which passes an input to each path in layers, before reducing the output with connection.
Arguments
connection: An N-argument function that is called after passing the input through each layer. If connection = nothing, we return a tuple Parallel(nothing, f, g)(x, y) = (f(x), g(y))
Layers can be specified in two formats:
A list of N Lux layers
Specified as N keyword arguments.
Extended Help
Inputs
x: If x is not a tuple, then return is computed as connection([l(x) for l in layers]...). Else one is passed to each layer, thus Parallel(+, f, g)(x, y) = f(x) + g(y).
Returns
See the Inputs section for how the output is computed
Updated state of the layers
Parameters
Parameters of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
States
States of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
See also SkipConnection which is Parallel with one identity.
Create a skip connection which consists of a layer or Chain of consecutive layers and a shortcut connection linking the block's input to the output through a user-supplied 2-argument callable. The first argument to the callable will be propagated through the given layer while the second is the unchanged, "skipped" input.
The simplest "ResNet"-type connection is just SkipConnection(layer, +).
Arguments
layer: Layer or Chain of layers to be applied to the input
connection:
A 2-argument function that takes layer(input) and the input OR
An AbstractLuxLayer that takes (layer(input), input) as input
Extended Help
Inputs
x: Will be passed directly to layer
Returns
Output of connection(layer(input), input)
Updated state of layer
Parameters
Parameters of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
States
States of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
Iteratively applies model for repeats number of times. The initial input is passed into the model repeatedly if input_injection = Val(true). This layer unrolls the computation, however, semantically this is same as:
input_injection = Val(false)
julia
res = x
+for i in 1:repeats
+ res, st = model(res, ps, st)
+end
input_injection = Val(true)
julia
res = x
+for i in 1:repeats
+ res, st = model((res, x), ps, st)
+end
It is expected that repeats will be a reasonable number below 20, beyond that compile times for gradients might be unreasonably high.
Arguments
model must be an AbstractLuxLayer
Keyword Arguments
repeats: Number of times to apply the model
input_injection: If true, then the input is passed to the model along with the output
Image data should be stored in WHCN order (width, height, channels, batch). In other words, a 100 x 100 RGB image would be a 100 x 100 x 3 x 1 array, and a batch of 50 would be a 100 x 100 x 3 x 50 array. This has N = 2 spatial dimensions, and needs a kernel size like (5, 5), a 2-tuple of integers. To take convolutions along N feature dimensions, this layer expects as input an array with ndims(x) == N + 2, where size(x, N + 1) == in_chs is the number of input channels, and size(x, ndims(x)) is the number of observations in a batch.
Warning
Frameworks like Pytorch perform cross-correlation in their convolution layers. Pass cross_correlation=true to use cross-correlation instead.
Arguments
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Extended Help
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) / stride (possibly rounded) for each spatial dimension.
Periodic padding can achieved by pre-empting the layer with a WrappedFunction(x -> NNlib.circular_pad(x, N_pad; dims=pad_dims))
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform cross-correlation instead of convolution. Prior to v1, Lux used to have a CrossCor layer which performed cross-correlation. This was removed in v1 in favor of Conv with cross_correlation=true.
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution y of size (O_N, ..., O_1, C_out, N) where
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) * stride (possibly rounded) for each spatial dimension.
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform transposed cross-correlation instead of transposed convolution.
outpad: To converse Conv inversability when stride > 1, outpad can be used to increase the size of the output in the desired dimensions. Whereas pad is used to zero-pad the input, outpad only affects the output shape.
Extended Help
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution transpose y of size (O_N, ..., O_1, C_out, N) where
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. Dropout(p; dims = (3,4)) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. VariationalHiddenDropout(p; dims = 3) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
mask: Dropout mask. Initilly set to nothing. After every run, contains the mask applied in that call
update_mask: Stores whether new mask needs to be generated in the current call
Global LP Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
GPU Support
This layer is currently only supported on CPU.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Max Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Mean Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
Tuple containing
Output of shape (out_dims, batch_size)
Tuple containing new hidden state
Updated model state
Parameters
weight_ih: Concatenated Weights to map from input space .
weight_hh: Concatenated Weights to map from hidden space .
bias_ih: Concatenated Bias vector for the input space (not present if use_bias=false).
bias_hh: Concatenated Bias vector for the hidden space (not present if use_bias=false).
hidden_state: Initial hidden state vector (not present if train_state=false) .
States
rng: Controls the randomness (if any) in the initial state generation
out_dims: Output (Hidden State & Memory) Dimension
use_bias: Set to false to deactivate bias
train_state: Trainable initial hidden state can be activated by setting this to true
train_memory: Trainable initial memory can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
init_memory: Initializer for memory
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to false - Creates a hidden state using init_state, hidden memory using init_memory and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to false - Repeats hidden_state vector from the parameters to match the shape of x, creates hidden memory using init_memory and proceeds to Case 2.
Case 1c: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to true - Creates a hidden state using init_state, repeats the memory vector from parameters to match the shape of x and proceeds to Case 2.
Case 1d: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to true - Repeats the hidden state and memory vectors from the parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, c)) is provided, then the output and a tuple containing the updated hidden state and memory is returned.
Returns
Tuple Containing
Output of shape (out_dims, batch_size)
Tuple containing new hidden state and new memory
Updated model state
Parameters
weight_ih: Concatenated Weights to map from input space .
weight_hh: Concatenated Weights to map from hidden space
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Concatenated Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
memory: Initial memory vector (not present if train_memory=false)
States
rng: Controls the randomness (if any) in the initial state generation
An Elman RNNCell cell with activation (typically set to tanh or relu).
Arguments
in_dims: Input Dimension
out_dims: Output (Hidden State) Dimension
activation: Activation function
use_bias: Set to false to deactivate bias
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
Tuple containing
Output of shape (out_dims, batch_size)
Tuple containing new hidden state
Updated model state
Parameters
weight_ih: Maps the input to the hidden state.
weight_hh: Maps the hidden state to the hidden state.
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
States
rng: Controls the randomness (if any) in the initial state generation
Wraps a recurrent cell (like RNNCell, LSTMCell, GRUCell) to automatically operate over a sequence of inputs.
Relation to Flux.Recur
This is completely distinct from Flux.Recur. It doesn't make the cell stateful, rather allows operating on an entire sequence of inputs at once. See StatefulRecurrentCell for functionality similar to Flux.Recur.
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
Keyword Arguments
return_sequence: If true returns the entire sequence of outputs, else returns only the last output. Defaults to false.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Output of the cell for the entire sequence.
Update state of the cell.
Tip
Frameworks like Tensorflow have special implementation of StackedRNNCells to handle sequentially composed RNN Cells. In Lux, one can simple stack multiple Recurrence blocks in a Chain to achieve the same.
To avoid undefined behavior, once the processing of a single sequence of data is complete, update the state with Lux.update_state(st, :carry, nothing).
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
backward_cell: A optional backward recurrent cell. If backward_cell is nothing, the rnn layer instance passed as the cell argument will be used to generate the backward layer automatically. in_dims of backward_cell should be consistent with in_dims of cell
Keyword Arguments
merge_mode: Function by which outputs of the forward and backward RNNs will be combined. default value is vcat. If nothing, the outputs will not be combined.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Merged output of the cell and backward_cell for the entire sequence.
Create a fully connected layer between two inputs and an output, and otherwise similar to Dense. Its output, given vectors x & y, is another vector z with, for all i in 1:out:
z[i] = activation(x' * W[i, :, :] * y + bias[i])
If x and y are matrices, then each column of the output z = B(x, y) is of this form, with B the Bilinear layer.
Arguments
in1_dims: number of input dimensions of x
in2_dims: number of input dimensions of y
in12_dims: If specified, then in1_dims = in2_dims = in12_dims
out: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in1_dims, in2_dims)). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
A 2-Tuple containing
x must be an AbstractArray with size(x, 1) == in1_dims
y must be an AbstractArray with size(y, 1) == in2_dims
If the input is an AbstractArray, then x = y
Returns
AbstractArray with dimensions (out_dims, size(x, 2))
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in1_dims, in2_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
Create a traditional fully connected layer, whose forward pass is given by: y = activation.(weight * x .+ bias)
Arguments
in_dims: number of input dimensions
out_dims: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims)). If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an AbstractArray with size(x, 1) == in_dims
Returns
AbstractArray with dimensions (out_dims, ...) where ... are the dimensions of x
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
A lookup table that stores embeddings of dimension out_dims for a vocabulary of size in_dims. When the vocabulary is multi-dimensional, the input is expected to be a tuple of Cartesian indices.
This layer is often used to store word embeddings and retrieve them using indices.
Arguments
in_dims: number(s) of input dimensions
out_dims: number of output dimensions
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims...))
Input
Integer OR
Abstract Vector of Integers OR
Abstract Array of Integers OR
Tuple of Integers OR
Tuple of Abstract Vectors of Integers OR
Tuple of Abstract Arrays of Integers
Returns
Returns the embedding corresponding to each index in the input. For an N dimensional input, an N + 1 dimensional output is returned.
Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: y = activation.(weight .* x .+ bias)
Arguments
dims: size of the learnable scale and bias parameters.
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims))
init_bias: initializer for the bias vector (ignored if use_bias=false)
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
Returns
Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
N: Flatten the first N dimensions of the input array. If nothing, then all dimensions (except the last) are flattened. Note that the batch dimension is never flattened.
Inputs
x: AbstractArray
Returns
AbstractMatrix of size (:, size(x, ndims(x))) if N is nothing else the first N dimensions of the input array are flattened.
Empty NamedTuple()
Example
julia
julia> model = FlattenLayer()
+FlattenLayer{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = randn(rng, Float32, (2, 2, 2, 2));
+
+julia> y, st_new = model(x, ps, st);
+ size(y)
+(8, 2)
This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs.
Maxout over linear dense layers satisfies the universal approximation theorem. See [1].
Return a view of all the data of the input x where the index for dimension dim equals i. Equivalent to view(x,:,:,...,i,:,:,...) where i is in position d.
Arguments
dim: Dimension for indexing
i: Index for dimension dim
Inputs
x: AbstractArray that can be indexed with view(x,:,:,...,i,:,:,...)
Returns
view(x,:,:,...,i,:,:,...) where i is in position d
Wraps a stateless and parameter less function. Might be used when a function is added to Chain. For example, Chain(x -> relu.(x)) would not work and the right thing to do would be Chain((x, ps, st) -> (relu.(x), st)). An easier thing to do would be Chain(WrappedFunction(Base.Fix1(broadcast, relu)))
Reverse the specified dimension dims of the passed array
Arguments
dim: Dimension that need to be reversed. If nothing, for AbstractVector{T} it reverses itself (dimension 1), for other arrays, reverse the dimension ndims(x) - 1.
Inputs
x: AbstractArray.
Returns
AbstractArray with the same dimensions as the input
Empty NamedTuple()
Example
julia
julia> model = ReverseSequence()
+ReverseSequence{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = [1.0, 2.0, 3.0];
+
+julia> y, st_new = model(x, ps, st)
+([3.0, 2.0, 1.0], NamedTuple())
BatchNorm computes the mean and variance for each input slice and normalises the input accordingly.
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
groups is the number of groups along which the statistics are computed. The number of channels must be an integer multiple of the number of groups.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
epsilon: a value added to the denominator for numerical stability
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
training: Used to check if training/inference mode
Instance Normalization computes the mean and variance for each ` input slice and normalises the input accordingly.
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
Computes mean and standard deviation over the whole input array, and uses these to normalize the whole array. Optionally applies an elementwise affine transformation afterwards.
Given an input array , this layer computes
where & are trainable parameters if affine=true.
Arguments
shape: Broadcastable shape of input array excluding the batch dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
epsilon: a value added to the denominator for numerical stability.
dims: Dimensions to normalize the array over.
If affine=true, it also applies a shift and a rescale to the input through to learnable per-element bias and scale parameters.
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This updates the parameters in which_params (e.g. weight) using two parameters: one specifying the magnitude (e.g. weight_g) and one specifying the direction (e.g. weight_v).
Arguments
layer whose parameters are being reparameterized
which_params: parameter names for the parameters being reparameterized
By default, a norm over the entire array is computed. Pass dims to modify the dimension.
Inputs
x: Should be of valid type for input to layer
Returns
Output from layer
Updated model state of layer
Parameters
normalized: Parameters of layer that are being normalized
unnormalized: Parameters of layer that are not being normalized
Pixel shuffling layer with upscale factor r. Usually used for generating higher resolution images while upscaling them.
See NNlib.pixel_shuffle for more details.
Arguments
r: Upscale factor
Inputs
x: For 4D-arrays representing N images, the operation converts input size(x) == (W, H, r² x C, N) to output of size (r x W, r x H, C, N). For D-dimensional data, it expects ndims(x) == D + 2 with channel and batch dimensions, and divides the number of channels by rᴰ.
Returns
Output of size (r x W, r x H, C, N) for 4D-arrays, and (r x W, r x H, ..., C, N) for D-dimensional data, where D = ndims(x) - 2
mode: Set to :nearest, :linear, :bilinear or :trilinear
Exactly one of two keywords must be specified:
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
Alternatively, keyword size accepts a tuple, to directly specify the leading dimensions of the output.
Option 2
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
mode: Set to :nearest, :bilinear or :trilinear
Currently supported upsampling modes and corresponding NNlib's methods are:
:nearest -> NNlib.upsample_nearest
:bilinear -> NNlib.upsample_bilinear
:trilinear -> NNlib.upsample_trilinear
Extended Help
Other Keyword Arguments
align_corners: If true, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode is one of :bilinear or :trilinear.
Inputs
x: For the input dimensions look into the documentation for the corresponding NNlib function
As a rule of thumb, :nearest should work with arrays of arbitrary dimensions
:bilinear works with 4D Arrays
:trilinear works with 5D Arrays
Returns
Upsampled Input of size size or of size (I_1 x scale[1], ..., I_N x scale[N], C, N)
Helper Functions making it easier to train Lux.jl models.
Training is meant to be simple and provide extremely basic functionality. We provide basic building blocks which can be seamlessly composed to create complex training pipelines.
Compute the gradients of the objective function wrt parameters stored in ts.
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
AutoReverseDiff(; compile)
ReverseDiff.jl
AutoTracker
Tracker.jl
AutoEnzyme
Enzyme.jl
Arguments
ad: Backend (from ADTypes.jl) used to compute the gradients.
objective_function: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.
stats: Any computed statistics from the objective function.
ts: Updated Training State.
Known Limitations
AutoReverseDiff(; compile=true) is not supported for Lux models with non-empty state st. Additionally the returned stats must be empty (NamedTuple()). We catch these issues in most cases and throw an error.
Aliased Gradients
grads returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the grads from step i, the new gradients returned in step i + 1 might be aliased by the old gradients. If you want to prevent this, simply use copy(grads) or deepcopy(grads) to make a copy of the gradients.
Returned values are the same as compute_gradients. Note that despite the !, only the parameters in ts are updated inplace. Users should be using the returned ts object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like AutoReactant).
ŷ and y where ŷ is the predicted output and y is the target output.
model, ps, st, (x, y) where model is the model, ps are the parameters, st are the states and (x, y) are the input and target pair. Then it returns the loss, updated states, and an empty named tuple. This makes them compatible with the Training API.
Warning
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients wrt the inputs and drop any gradients wrt the targets.
Takes any function loss_fn that maps 2 number inputs to a single number output. Additionally, array inputs are efficiently broadcasted and aggregated using agg.
Binary Cross Entropy Loss with optional label smoothing and fused logit computation.
Returns the binary cross entropy loss computed as:
If logits is either false or Val(false):
If logits is true or Val(true):
The value of is computed using label smoothing. If label_smoothing is nothing, then no label smoothing is applied. If label_smoothing is a real number , then the value of is:
Return the cross entropy loss which is used in multi-class classification tasks. The input, , is expected to be normalized (i.e. softmax output) if logits is false or Val(false).
The loss is calculated as:
where is added for numerical stability. The value of is computed using label smoothing. If label_smoothing is nothing, then no label smoothing is applied. If label_smoothing is a real number , then the value of is calculated as:
[1] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." 2016 fourth international conference on 3D vision (3DV). Ieee, 2016.
Return the focal loss [1] which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The input, , is expected to be normalized (i.e. softmax output).
The modulating factor , controls the down-weighting strength. For this is equivalent to CrossEntropyLoss.
Return the Kullback-Leibler Divergence loss between the predicted distribution and the true distribution :
The KL divergence is a measure of how much one probability distribution is different from the other. It is always non-negative, and zero only when both the distributions are equal.
This module is a part of Lux.jl. It contains operations that are useful in DL context. Additionally certain operations here alias Base functions to behave more sensibly with GPUArrays.
Similar to fmap(f, args...) but with restricted support for the notion of "leaf" types. However, this allows for more efficient and type stable implementations of recursive operations.
How this works?
For the following types it directly defines recursion rules:
AbstractArray: If eltype is isbitstype, then f is applied to the array, else we recurse on the array.
Tuple/NamedTuple: We recurse on the values.
Number/Val/Nothing: We directly apply f.
For all other types, we recurse on the fields using Functors.fmap.
Note
In most cases, users should gravitate towards Functors.fmap if it is being used outside of hot loops. Even for other cases, it is always recommended to verify the correctness of this implementation for specific usecases.
Recursively add the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(+, x, y), but this implementation uses type stable code for common cases.
Any leaves of x that are arrays and allow in-place addition will be modified in place.
Recursively copy the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(copyto!, x, y), but this implementation uses type stable code for common cases. Note that any immutable leaf will lead to an error.
Recursively determine the element type of a nested structure x. This is equivalent to doing fmap(Lux.Utils.eltype, x), but this implementation uses type stable code for common cases.
For ambiguous inputs like nothing and Val types we return Bool as the eltype.
If unwrap_ad_types is set to Val(true) then for tracing and operator overloading based ADs (ForwardDiff, ReverseDiff, Tracker), this function will return the eltype of the unwrapped value.
Recursively create a zero value for a nested structure x. This is equivalent to doing fmap(zero, x), but this implementation uses type stable code for common cases.
By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.
Helper function to "maybe" (see below) match the element type of args... with the element type of the layer's parameters and states. This is useful for debugging purposes, to track down accidental type-promotions inside Lux layers.
Extended Help
Controlling the Behavior via Preferences
Behavior of this function is controlled via the eltype_mismatch_handling preference. The following options are supported:
"none": This is the default behavior. In this case, this function is a no-op, i.e., it simply returns args....
"warn": This option will issue a warning if the element type of args... does not match the element type of the layer's parameters and states. The warning will contain information about the layer and the element type mismatch.
"convert": This option is same as "warn", but it will also convert the element type of args... to match the element type of the layer's parameters and states (for the cases listed below).
"error": Same as "warn", but instead of issuing a warning, it will throw an error.
Warning
We print the warning for type-mismatch only once.
Element Type Conversions
For "convert" only the following conversions are done:
A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.
Usecases
Internal implementation of @compact heavily uses this layer.
In SciML codebases where propagating state might involving Boxing. For a motivating example, see the Neural ODE tutorial.
Facilitates Nested AD support in Lux. For more details on this feature, see the Nested AD Manual Page.
Static Parameters
If FT = true then the type of the state is fixed, i.e., typeof(last(model(x, ps, st))) == st.
If FT = false then type of the state might change. Note that while this works in all cases, it will introduce type instability.
Arguments
model: A Lux layer
ps: The parameters of the layer. This can be set to nothing, if the user provides the parameters on function call
st: The state of the layer
Inputs
x: The input to the layer
ps: The parameters of the layer. Optional, defaults to s.ps
@compact(kw...) do x
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(kw...) do x, p
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some parameters, in the form of keywords, and (usually as a do block) a function for the forward pass. You may think of @compact as a specialized let block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.
Defining the version with p allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.
Reserved Kwargs:
name: The name of the layer.
dispatch: The constructed layer has the type Lux.CompactLuxLayer{dispatch} which can be used for custom dispatches.
Tip
Check the Lux tutorials for more examples of using @compact.
If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer.
Special Syntax
@return: This macro doesn't really exist, but is used to return a value from the @compact block. Without the presence of this macro, we need to rely on closures which can lead to performance penalties in the reverse pass.
Having statements after the last @return macro might lead to incorrect code.
Don't do things like @return return x. This will generate non-sensical code like <new var> = return x. Essentially, @return <expr> supports any expression, that can be assigned to a variable.
Since this macro doesn't "exist", it cannot be imported as using Lux: @return. Simply use it in code, and @compact will understand it.
@init_fn: Provide a function that will be used to initialize the layer's parameters or state. See the docs of @init_fn for more details.
@non_trainable: Mark a value as non-trainable. This bypasses the regular checks and places the value into the state of the layer. See the docs of @non_trainable for more details.
Extended Help
Examples
Here is a linear model:
julia
julia> using Lux, Random
+
+julia> r = @compact(w=ones(3)) do x
+ @return w .* x
+ end
+@compact(
+ w = 3-element Vector{Float64},
+) do x
+ return w .* x
+end # Total: 3 parameters,
+ # plus 0 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> r([1, 2, 3], ps, st) # x is set to [1, 1, 1].
+([1.0, 2.0, 3.0], NamedTuple())
Here is a linear model with bias and activation:
julia
julia> d_in = 5
+5
+
+julia> d_out = 3
+3
+
+julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
+ y = W * x
+ @return act.(y .+ b)
+ end
+@compact(
+ W = 3×5 Matrix{Float64},
+ b = 3-element Vector{Float64},
+ act = relu,
+) do x
+ y = W * x
+ return act.(y .+ b)
+end # Total: 18 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), d);
+
+julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
+3×2 Matrix{Float64}:
+ 5.0 5.0
+ 5.0 5.0
+ 5.0 5.0
+
+julia> ps_dense = (; weight=ps.W, bias=ps.b);
+
+julia> first(d([1, 2, 3, 4, 5], ps, st)) ≈
+ first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
+true
Finally, here is a simple MLP. We can train this model just like any Lux model:
julia
julia> n_in = 1;
+
+julia> n_out = 1;
+
+julia> nlayers = 3;
+
+julia> model = @compact(w1=Dense(n_in, 128),
+ w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
+ embed = act.(w1(x))
+ for w in w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ @return out
+ end
+@compact(
+ w1 = Dense(1 => 128), # 256 parameters
+ w2 = NamedTuple(
+ 1 = Dense(128 => 128), # 16_512 parameters
+ 2 = Dense(128 => 128), # 16_512 parameters
+ 3 = Dense(128 => 128), # 16_512 parameters
+ ),
+ w3 = Dense(128 => 1), # 129 parameters
+ act = relu,
+) do x
+ embed = act.(w1(x))
+ for w = w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ return out
+end # Total: 49_921 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), model);
+
+julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
+(1, 32)
+
+julia> using Optimisers, Zygote
+
+julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
+
+julia> y_data = 2 .* x_data .- x_data .^ 3;
+
+julia> optim = Optimisers.setup(Adam(), ps);
+
+julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> for epoch in 1:1000
+ loss, gs = Zygote.withgradient(
+ ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
+ Optimisers.update!(optim, ps, gs[1])
+ end;
+
+julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> loss_initial > loss_final
+true
You may also specify a name for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:
julia
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
+ @return sum(w .* x)
+ end
+Linear(3 => 1) # 3 parameters
This can be useful when using @compact to hierarchically construct complex models to be used inside a Chain.
Type Stability
If your input function f is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.
Parameter Count
Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.
Create an initializer function for a parameter or state to be used for in a Compact Lux Layer created using @compact.
Arguments
fn: The function to be used for initializing the parameter or state. This only takes a single argument rng.
kind: If set to :parameter, the initializer function will be used to initialize the parameters of the layer. If set to :state, the initializer function will be used to initialize the states of the layer.
Examples
julia
julia> using Lux, Random
+
+julia> r = @compact(w=@init_fn(rng->randn32(rng, 3, 2)),
+ b=@init_fn(rng->randn32(rng, 3), :state)) do x
+ @return w * x .+ b
+ end;
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> size(ps.w)
+(3, 2)
+
+julia> size(st.b)
+(3,)
+
+julia> size(r([1, 2], ps, st)[1])
+(3,)
Set the dispatch doctor preference for LuxCore and LuxLib packages.
mode can be "disable", "warn", or "error". For details on the different modes, see the DispatchDoctor.jl documentation.
If the preferences are already set, then no action is taken. Otherwise the preference is set. For changes to take effect, the Julia session must be restarted.
This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.
Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.
Test the gradients of f with respect to args using the specified backends.
Backend
ADType
CPU
GPU
Notes
Zygote.jl
AutoZygote()
✔
✔
Tracker.jl
AutoTracker()
✔
✔
ReverseDiff.jl
AutoReverseDiff()
✔
✖
ForwardDiff.jl
AutoForwardDiff()
✔
✖
len ≤ 100
FiniteDiff.jl
AutoFiniteDiff()
✔
✖
len ≤ 100
Enzyme.jl
AutoEnzyme()
✔
✖
Only Reverse Mode
Arguments
f: The function to test the gradients of.
args: The arguments to test the gradients of. Only AbstractArrays are considered for gradient computation. Gradients wrt all other arguments are assumed to be NoTangent().
Keyword Arguments
skip_backends: A list of backends to skip.
broken_backends: A list of backends to treat as broken.
soft_fail: If true, then the test will be recorded as a soft_fail test. This overrides any broken kwargs. Alternatively, a list of backends can be passed to soft_fail to allow soft_fail tests for only those backends.
enzyme_set_runtime_activity: If true, then activate runtime activity for Enzyme.
kwargs: Additional keyword arguments to pass to check_approx.
Example
julia
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)
+
+julia> x = (; t=rand(10), x=(z=[2.0],))
+
+julia> test_gradients(f, 1.0, x, nothing)
Evaluate expr and record a test result. If expr throws an exception, the test result will be recorded as an error. If expr returns a value, and it is not a boolean, the test result will be recorded as an error.
If the test result is false then the test will be recorded as a broken test, else it will be recorded as a pass.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.js b/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.js
new file mode 100644
index 0000000000..ec0b96b22f
--- /dev/null
+++ b/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.js
@@ -0,0 +1,25 @@
+import{_ as l,c as p,j as i,a as e,G as t,a2 as n,B as d,o as h}from"./chunks/framework.DFwXuivk.js";const T=JSON.parse('{"title":"MLDataDevices","description":"","frontmatter":{},"headers":[],"relativePath":"api/Accelerator_Support/MLDataDevices.md","filePath":"api/Accelerator_Support/MLDataDevices.md","lastUpdated":null}'),o={name:"api/Accelerator_Support/MLDataDevices.md"},r={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},k={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},v={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},D={class:"jldocstring custom-block"},F={class:"jldocstring custom-block"},C={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"};function L(j,s,A,x,M,B){const a=d("Badge");return h(),p("div",null,[s[42]||(s[42]=i("h1",{id:"MLDataDevices-API",tabindex:"-1"},[e("MLDataDevices "),i("a",{class:"header-anchor",href:"#MLDataDevices-API","aria-label":'Permalink to "MLDataDevices {#MLDataDevices-API}"'},"")],-1)),s[43]||(s[43]=i("p",null,[i("code",null,"MLDataDevices.jl"),e(" is a lightweight package defining rules for transferring data across devices.")],-1)),s[44]||(s[44]=i("h2",{id:"preferences",tabindex:"-1"},[e("Preferences "),i("a",{class:"header-anchor",href:"#preferences","aria-label":'Permalink to "Preferences"'},"")],-1)),i("details",r,[i("summary",null,[s[0]||(s[0]=i("a",{id:"MLDataDevices.gpu_backend!",href:"#MLDataDevices.gpu_backend!"},[i("span",{class:"jlbinding"},"MLDataDevices.gpu_backend!")],-1)),s[1]||(s[1]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[2]||(s[2]=n(`
Creates a LocalPreferences.toml file with the desired GPU backend.
If backend == "", then the gpu_backend preference is deleted. Otherwise, backend is validated to be one of the possible backends and the preference is set to backend.
If a new backend is successfully set, then the Julia session must be restarted for the change to take effect.
`,5))]),s[45]||(s[45]=i("h2",{id:"Data-Transfer",tabindex:"-1"},[e("Data Transfer "),i("a",{class:"header-anchor",href:"#Data-Transfer","aria-label":'Permalink to "Data Transfer {#Data-Transfer}"'},"")],-1)),i("details",c,[i("summary",null,[s[3]||(s[3]=i("a",{id:"MLDataDevices.cpu_device",href:"#MLDataDevices.cpu_device"},[i("span",{class:"jlbinding"},"MLDataDevices.cpu_device")],-1)),s[4]||(s[4]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[5]||(s[5]=n('
julia
cpu_device() -> CPUDevice()
Return a CPUDevice object which can be used to transfer data to CPU.
Selects GPU device based on the following criteria:
If gpu_backend preference is set and the backend is functional on the system, then that device is selected.
Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by supported_gpu_backends() and select the first functional backend.
If no GPU device is functional and force is false, then cpu_device() is invoked.
If nothing works, an error is thrown.
Arguments
device_id::Union{Nothing, Integer}: The device id to select. If nothing, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using CUDA.device() or AMDGPU.device() or similar. If Integer, then we select the device with the given id. Note that this is 1-indexed, in contrast to the 0-indexed CUDA.jl. For example, id = 4 corresponds to CUDA.device!(3).
Warning
device_id is only applicable for CUDA and AMDGPU backends. For Metal, oneAPI and CPU backends, device_id is ignored and a warning is printed.
Warning
gpu_device won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. Nonetheless, if cuDNN is not loaded you can still manually create a CUDADevice object and use it (e.g. dev = CUDADevice()).
Keyword Arguments
force::Bool: If true, then an error is thrown if no functional GPU device is found.
',4))]),s[46]||(s[46]=i("h2",{id:"miscellaneous",tabindex:"-1"},[e("Miscellaneous "),i("a",{class:"header-anchor",href:"#miscellaneous","aria-label":'Permalink to "Miscellaneous"'},"")],-1)),i("details",u,[i("summary",null,[s[12]||(s[12]=i("a",{id:"MLDataDevices.reset_gpu_device!",href:"#MLDataDevices.reset_gpu_device!"},[i("span",{class:"jlbinding"},"MLDataDevices.reset_gpu_device!")],-1)),s[13]||(s[13]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[14]||(s[14]=n('
julia
reset_gpu_device!()
Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again.
If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return nothing.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice() – denotes that the device type is unknown
See also get_device_type for a faster alternative that can be used for dispatch based on device type.
Similar to get_device but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead of get_device where ever defining dispatches based on the device type.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
Nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice – denotes that the device type is unknown
Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via MLDataDevices.loaded), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
Returns true if x is a leaf node in the data structure.
Defining MLDataDevices.isleaf(x::T) = true for custom types can be used to customize the behavior the data movement behavior when an object with nested structure containing the type is transferred to a device.
Adapt.adapt_structure(::AbstractDevice, x::T) or Adapt.adapt_structure(::AbstractDevice, x::T) will be called during data movement if isleaf(x::T) == true.
If MLDataDevices.isleaf(x::T) is not defined, then it will fall back to Functors.isleaf(x).
',6))]),s[47]||(s[47]=i("h2",{id:"Multi-GPU-Support",tabindex:"-1"},[e("Multi-GPU Support "),i("a",{class:"header-anchor",href:"#Multi-GPU-Support","aria-label":'Permalink to "Multi-GPU Support {#Multi-GPU-Support}"'},"")],-1)),i("details",C,[i("summary",null,[s[36]||(s[36]=i("a",{id:"MLDataDevices.set_device!",href:"#MLDataDevices.set_device!"},[i("span",{class:"jlbinding"},"MLDataDevices.set_device!")],-1)),s[37]||(s[37]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[38]||(s[38]=n('
julia
set_device!(T::Type{<:AbstractDevice}, dev_or_id)
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
dev_or_id: Can be the device from the corresponding package. For example for CUDA it can be a CuDevice. If it is an integer, it is the device id to set. This is 1-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
rank::Integer: Local Rank of the process. This is applicable for distributed training and must be 0-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
',14))]),s[48]||(s[48]=i("h2",{id:"iteration",tabindex:"-1"},[e("Iteration "),i("a",{class:"header-anchor",href:"#iteration","aria-label":'Permalink to "Iteration"'},"")],-1)),i("details",m,[i("summary",null,[s[39]||(s[39]=i("a",{id:"MLDataDevices.DeviceIterator",href:"#MLDataDevices.DeviceIterator"},[i("span",{class:"jlbinding"},"MLDataDevices.DeviceIterator")],-1)),s[40]||(s[40]=e()),t(a,{type:"info",class:"jlObjectType jlType",text:"Type"})]),s[41]||(s[41]=n(`
julia
DeviceIterator(dev::AbstractDevice, iterator)
Create a DeviceIterator that iterates through the provided iterator via iterate. Upon each iteration, the current batch is copied to the device dev, and the previous iteration is marked as freeable from GPU memory (via unsafe_free!) (no-op for a CPU device).
The conversion follows the same semantics as dev(<item from iterator>).
Similarity to CUDA.CuIterator
The design inspiration was taken from CUDA.CuIterator and was generalized to work with other backends and more complex iterators (using Functors).
MLUtils.DataLoader
Calling dev(::MLUtils.DataLoader) will automatically convert the dataloader to use the same semantics as DeviceIterator. This is generally preferred over looping over the dataloader directly and transferring the data to the device.
Examples
The following was run on a computer with an NVIDIA GPU.
julia
julia> using MLDataDevices, MLUtils
+
+julia> X = rand(Float64, 3, 33);
+
+julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
+
+julia> for (i, x) in enumerate(dataloader)
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 Matrix{Float64}")
+(i, summary(x)) = (2, "3×13 Matrix{Float64}")
+(i, summary(x)) = (3, "3×7 Matrix{Float64}")
+
+julia> for (i, x) in enumerate(CUDADevice()(dataloader))
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
`,9))])])}const w=l(o,[["render",L]]);export{T as __pageData,w as default};
diff --git a/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.lean.js b/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.lean.js
new file mode 100644
index 0000000000..ec0b96b22f
--- /dev/null
+++ b/previews/PR1023/assets/api_Accelerator_Support_MLDataDevices.md.Ba-hugak.lean.js
@@ -0,0 +1,25 @@
+import{_ as l,c as p,j as i,a as e,G as t,a2 as n,B as d,o as h}from"./chunks/framework.DFwXuivk.js";const T=JSON.parse('{"title":"MLDataDevices","description":"","frontmatter":{},"headers":[],"relativePath":"api/Accelerator_Support/MLDataDevices.md","filePath":"api/Accelerator_Support/MLDataDevices.md","lastUpdated":null}'),o={name:"api/Accelerator_Support/MLDataDevices.md"},r={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},k={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},v={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},D={class:"jldocstring custom-block"},F={class:"jldocstring custom-block"},C={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"};function L(j,s,A,x,M,B){const a=d("Badge");return h(),p("div",null,[s[42]||(s[42]=i("h1",{id:"MLDataDevices-API",tabindex:"-1"},[e("MLDataDevices "),i("a",{class:"header-anchor",href:"#MLDataDevices-API","aria-label":'Permalink to "MLDataDevices {#MLDataDevices-API}"'},"")],-1)),s[43]||(s[43]=i("p",null,[i("code",null,"MLDataDevices.jl"),e(" is a lightweight package defining rules for transferring data across devices.")],-1)),s[44]||(s[44]=i("h2",{id:"preferences",tabindex:"-1"},[e("Preferences "),i("a",{class:"header-anchor",href:"#preferences","aria-label":'Permalink to "Preferences"'},"")],-1)),i("details",r,[i("summary",null,[s[0]||(s[0]=i("a",{id:"MLDataDevices.gpu_backend!",href:"#MLDataDevices.gpu_backend!"},[i("span",{class:"jlbinding"},"MLDataDevices.gpu_backend!")],-1)),s[1]||(s[1]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[2]||(s[2]=n(`
Creates a LocalPreferences.toml file with the desired GPU backend.
If backend == "", then the gpu_backend preference is deleted. Otherwise, backend is validated to be one of the possible backends and the preference is set to backend.
If a new backend is successfully set, then the Julia session must be restarted for the change to take effect.
`,5))]),s[45]||(s[45]=i("h2",{id:"Data-Transfer",tabindex:"-1"},[e("Data Transfer "),i("a",{class:"header-anchor",href:"#Data-Transfer","aria-label":'Permalink to "Data Transfer {#Data-Transfer}"'},"")],-1)),i("details",c,[i("summary",null,[s[3]||(s[3]=i("a",{id:"MLDataDevices.cpu_device",href:"#MLDataDevices.cpu_device"},[i("span",{class:"jlbinding"},"MLDataDevices.cpu_device")],-1)),s[4]||(s[4]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[5]||(s[5]=n('
julia
cpu_device() -> CPUDevice()
Return a CPUDevice object which can be used to transfer data to CPU.
Selects GPU device based on the following criteria:
If gpu_backend preference is set and the backend is functional on the system, then that device is selected.
Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by supported_gpu_backends() and select the first functional backend.
If no GPU device is functional and force is false, then cpu_device() is invoked.
If nothing works, an error is thrown.
Arguments
device_id::Union{Nothing, Integer}: The device id to select. If nothing, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using CUDA.device() or AMDGPU.device() or similar. If Integer, then we select the device with the given id. Note that this is 1-indexed, in contrast to the 0-indexed CUDA.jl. For example, id = 4 corresponds to CUDA.device!(3).
Warning
device_id is only applicable for CUDA and AMDGPU backends. For Metal, oneAPI and CPU backends, device_id is ignored and a warning is printed.
Warning
gpu_device won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. Nonetheless, if cuDNN is not loaded you can still manually create a CUDADevice object and use it (e.g. dev = CUDADevice()).
Keyword Arguments
force::Bool: If true, then an error is thrown if no functional GPU device is found.
',4))]),s[46]||(s[46]=i("h2",{id:"miscellaneous",tabindex:"-1"},[e("Miscellaneous "),i("a",{class:"header-anchor",href:"#miscellaneous","aria-label":'Permalink to "Miscellaneous"'},"")],-1)),i("details",u,[i("summary",null,[s[12]||(s[12]=i("a",{id:"MLDataDevices.reset_gpu_device!",href:"#MLDataDevices.reset_gpu_device!"},[i("span",{class:"jlbinding"},"MLDataDevices.reset_gpu_device!")],-1)),s[13]||(s[13]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[14]||(s[14]=n('
julia
reset_gpu_device!()
Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again.
If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return nothing.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice() – denotes that the device type is unknown
See also get_device_type for a faster alternative that can be used for dispatch based on device type.
Similar to get_device but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead of get_device where ever defining dispatches based on the device type.
Note
Trigger Packages must be loaded for this to return the correct device.
Special Retuened Values
Nothing – denotes that the object is device agnostic. For example, scalar, abstract range, etc.
UnknownDevice – denotes that the device type is unknown
Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via MLDataDevices.loaded), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
Returns true if x is a leaf node in the data structure.
Defining MLDataDevices.isleaf(x::T) = true for custom types can be used to customize the behavior the data movement behavior when an object with nested structure containing the type is transferred to a device.
Adapt.adapt_structure(::AbstractDevice, x::T) or Adapt.adapt_structure(::AbstractDevice, x::T) will be called during data movement if isleaf(x::T) == true.
If MLDataDevices.isleaf(x::T) is not defined, then it will fall back to Functors.isleaf(x).
',6))]),s[47]||(s[47]=i("h2",{id:"Multi-GPU-Support",tabindex:"-1"},[e("Multi-GPU Support "),i("a",{class:"header-anchor",href:"#Multi-GPU-Support","aria-label":'Permalink to "Multi-GPU Support {#Multi-GPU-Support}"'},"")],-1)),i("details",C,[i("summary",null,[s[36]||(s[36]=i("a",{id:"MLDataDevices.set_device!",href:"#MLDataDevices.set_device!"},[i("span",{class:"jlbinding"},"MLDataDevices.set_device!")],-1)),s[37]||(s[37]=e()),t(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[38]||(s[38]=n('
julia
set_device!(T::Type{<:AbstractDevice}, dev_or_id)
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
dev_or_id: Can be the device from the corresponding package. For example for CUDA it can be a CuDevice. If it is an integer, it is the device id to set. This is 1-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
Set the device for the given type. This is a no-op for CPUDevice. For CUDADevice and AMDGPUDevice, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice and oneAPIDevice don't support setting the device.
Arguments
T::Type{<:AbstractDevice}: The device type to set.
rank::Integer: Local Rank of the process. This is applicable for distributed training and must be 0-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils instead of using this function.
',14))]),s[48]||(s[48]=i("h2",{id:"iteration",tabindex:"-1"},[e("Iteration "),i("a",{class:"header-anchor",href:"#iteration","aria-label":'Permalink to "Iteration"'},"")],-1)),i("details",m,[i("summary",null,[s[39]||(s[39]=i("a",{id:"MLDataDevices.DeviceIterator",href:"#MLDataDevices.DeviceIterator"},[i("span",{class:"jlbinding"},"MLDataDevices.DeviceIterator")],-1)),s[40]||(s[40]=e()),t(a,{type:"info",class:"jlObjectType jlType",text:"Type"})]),s[41]||(s[41]=n(`
julia
DeviceIterator(dev::AbstractDevice, iterator)
Create a DeviceIterator that iterates through the provided iterator via iterate. Upon each iteration, the current batch is copied to the device dev, and the previous iteration is marked as freeable from GPU memory (via unsafe_free!) (no-op for a CPU device).
The conversion follows the same semantics as dev(<item from iterator>).
Similarity to CUDA.CuIterator
The design inspiration was taken from CUDA.CuIterator and was generalized to work with other backends and more complex iterators (using Functors).
MLUtils.DataLoader
Calling dev(::MLUtils.DataLoader) will automatically convert the dataloader to use the same semantics as DeviceIterator. This is generally preferred over looping over the dataloader directly and transferring the data to the device.
Examples
The following was run on a computer with an NVIDIA GPU.
julia
julia> using MLDataDevices, MLUtils
+
+julia> X = rand(Float64, 3, 33);
+
+julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
+
+julia> for (i, x) in enumerate(dataloader)
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 Matrix{Float64}")
+(i, summary(x)) = (2, "3×13 Matrix{Float64}")
+(i, summary(x)) = (3, "3×7 Matrix{Float64}")
+
+julia> for (i, x) in enumerate(CUDADevice()(dataloader))
+ @show i, summary(x)
+ end
+(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
+(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")
LuxCore.jl defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl without having such a heavy dependency. If you are depending on Lux.jl directly, you do not need to depend on LuxCore.jl (all the functionality is exported via Lux.jl).
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the trainable parameters for the layer.
initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include BatchNorm, LSTM, GRU, etc.
Optionally:
parameterlength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
statelength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
Additionally, on calling initialparameters and initialstates, the parameters and states are not wrapped in a NamedTuple with the same name as the field.
As a convenience, we define the fallback call (::AbstractLuxWrapperLayer)(x, ps, st), which calls getfield(x, layer)(x, ps, st).
abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer
Abstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractLuxLayer.
Advanced Structure Manipulation
Advanced structure manipulation of these layers post construction is possible via Functors.fmap. For a more flexible interface, we recommend using Lux.Experimental.@layer_map.
fmap Support
fmap support needs to be explicitly enabled by loading Functors.jl and Setfield.jl.
Changes from Pre-1.0 Behavior
Previously if layers was a singleton tuple, initialparameters and initialstates would return the parameters and states for the single field layers. From v1.0.0 onwards, even for singleton tuples, the parameters/states are wrapped in a NamedTuple with the same name as the field. See AbstractLuxWrapperLayer to replicate the previous behavior of singleton tuples.
',7))]),e[55]||(e[55]=s("h2",{id:"general",tabindex:"-1"},[t("General "),s("a",{class:"header-anchor",href:"#general","aria-label":'Permalink to "General"'},"")],-1)),s("details",k,[s("summary",null,[e[9]||(e[9]=s("a",{id:"LuxCore.apply",href:"#LuxCore.apply"},[s("span",{class:"jlbinding"},"LuxCore.apply")],-1)),e[10]||(e[10]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[11]||(e[11]=a('
julia
apply(model, x, ps, st)
In most cases this function simply calls model(x, ps, st). However, it is still recommended to call apply instead of model(x, ps, st) directly. Some of the reasons for this include:
For certain types of inputs x, we might want to perform preprocessing before calling model. For eg, if x is an Array of ReverseDiff.TrackedReals this can cause significant regressions in model(x, ps, st) (since it won't hit any of the BLAS dispatches). In those cases, we would automatically convert x to a ReverseDiff.TrackedArray.
Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in apply and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
Tip
apply is integrated with DispatchDoctor.jl that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.
Calls apply and only returns the first argument. This function requires that model has an empty state of NamedTuple(). Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
',4))]),e[56]||(e[56]=s("h2",{id:"parameters",tabindex:"-1"},[t("Parameters "),s("a",{class:"header-anchor",href:"#parameters","aria-label":'Permalink to "Parameters"'},"")],-1)),s("details",x,[s("summary",null,[e[30]||(e[30]=s("a",{id:"LuxCore.initialparameters",href:"#LuxCore.initialparameters"},[s("span",{class:"jlbinding"},"LuxCore.initialparameters")],-1)),e[31]||(e[31]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[32]||(e[32]=a('
',3))]),e[57]||(e[57]=s("h2",{id:"states",tabindex:"-1"},[t("States "),s("a",{class:"header-anchor",href:"#states","aria-label":'Permalink to "States"'},"")],-1)),s("details",j,[s("summary",null,[e[36]||(e[36]=s("a",{id:"LuxCore.initialstates",href:"#LuxCore.initialstates"},[s("span",{class:"jlbinding"},"LuxCore.initialstates")],-1)),e[37]||(e[37]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[38]||(e[38]=a('
Recursively update all occurrences of the key in the state st with the value. exclude is a function that is passed to Functors.fmap_with_path's exclude keyword.
',4))]),e[58]||(e[58]=s("h2",{id:"Layer-size",tabindex:"-1"},[t("Layer size "),s("a",{class:"header-anchor",href:"#Layer-size","aria-label":'Permalink to "Layer size {#Layer-size}"'},"")],-1)),s("details",B,[s("summary",null,[e[51]||(e[51]=s("a",{id:"LuxCore.outputsize",href:"#LuxCore.outputsize"},[s("span",{class:"jlbinding"},"LuxCore.outputsize")],-1)),e[52]||(e[52]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[53]||(e[53]=a('
julia
outputsize(layer, x, rng)
Return the output size of the layer.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1, it will return size(A)[1:(end - 1)]. If this behavior is undesirable, provide a custom outputsize(layer, x, rng) implementation).
Fallback Implementation
The fallback implementation of this function is defined once Lux.jl is loaded.
Changes from Pre-1.0 Behavior
Previously it was possible to override this function by defining outputsize(layer). However, this can potentially introduce a bug that is hard to bypass. See this PR for more information.
LuxCore.jl defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl without having such a heavy dependency. If you are depending on Lux.jl directly, you do not need to depend on LuxCore.jl (all the functionality is exported via Lux.jl).
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the trainable parameters for the layer.
initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include BatchNorm, LSTM, GRU, etc.
Optionally:
parameterlength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
statelength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.
Additionally, on calling initialparameters and initialstates, the parameters and states are not wrapped in a NamedTuple with the same name as the field.
As a convenience, we define the fallback call (::AbstractLuxWrapperLayer)(x, ps, st), which calls getfield(x, layer)(x, ps, st).
abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer
Abstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractLuxLayer.
Advanced Structure Manipulation
Advanced structure manipulation of these layers post construction is possible via Functors.fmap. For a more flexible interface, we recommend using Lux.Experimental.@layer_map.
fmap Support
fmap support needs to be explicitly enabled by loading Functors.jl and Setfield.jl.
Changes from Pre-1.0 Behavior
Previously if layers was a singleton tuple, initialparameters and initialstates would return the parameters and states for the single field layers. From v1.0.0 onwards, even for singleton tuples, the parameters/states are wrapped in a NamedTuple with the same name as the field. See AbstractLuxWrapperLayer to replicate the previous behavior of singleton tuples.
',7))]),e[55]||(e[55]=s("h2",{id:"general",tabindex:"-1"},[t("General "),s("a",{class:"header-anchor",href:"#general","aria-label":'Permalink to "General"'},"")],-1)),s("details",k,[s("summary",null,[e[9]||(e[9]=s("a",{id:"LuxCore.apply",href:"#LuxCore.apply"},[s("span",{class:"jlbinding"},"LuxCore.apply")],-1)),e[10]||(e[10]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[11]||(e[11]=a('
julia
apply(model, x, ps, st)
In most cases this function simply calls model(x, ps, st). However, it is still recommended to call apply instead of model(x, ps, st) directly. Some of the reasons for this include:
For certain types of inputs x, we might want to perform preprocessing before calling model. For eg, if x is an Array of ReverseDiff.TrackedReals this can cause significant regressions in model(x, ps, st) (since it won't hit any of the BLAS dispatches). In those cases, we would automatically convert x to a ReverseDiff.TrackedArray.
Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in apply and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
Tip
apply is integrated with DispatchDoctor.jl that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.
Calls apply and only returns the first argument. This function requires that model has an empty state of NamedTuple(). Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
',4))]),e[56]||(e[56]=s("h2",{id:"parameters",tabindex:"-1"},[t("Parameters "),s("a",{class:"header-anchor",href:"#parameters","aria-label":'Permalink to "Parameters"'},"")],-1)),s("details",x,[s("summary",null,[e[30]||(e[30]=s("a",{id:"LuxCore.initialparameters",href:"#LuxCore.initialparameters"},[s("span",{class:"jlbinding"},"LuxCore.initialparameters")],-1)),e[31]||(e[31]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[32]||(e[32]=a('
',3))]),e[57]||(e[57]=s("h2",{id:"states",tabindex:"-1"},[t("States "),s("a",{class:"header-anchor",href:"#states","aria-label":'Permalink to "States"'},"")],-1)),s("details",j,[s("summary",null,[e[36]||(e[36]=s("a",{id:"LuxCore.initialstates",href:"#LuxCore.initialstates"},[s("span",{class:"jlbinding"},"LuxCore.initialstates")],-1)),e[37]||(e[37]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[38]||(e[38]=a('
Recursively update all occurrences of the key in the state st with the value. exclude is a function that is passed to Functors.fmap_with_path's exclude keyword.
',4))]),e[58]||(e[58]=s("h2",{id:"Layer-size",tabindex:"-1"},[t("Layer size "),s("a",{class:"header-anchor",href:"#Layer-size","aria-label":'Permalink to "Layer size {#Layer-size}"'},"")],-1)),s("details",B,[s("summary",null,[e[51]||(e[51]=s("a",{id:"LuxCore.outputsize",href:"#LuxCore.outputsize"},[s("span",{class:"jlbinding"},"LuxCore.outputsize")],-1)),e[52]||(e[52]=t()),l(i,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[53]||(e[53]=a('
julia
outputsize(layer, x, rng)
Return the output size of the layer.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1, it will return size(A)[1:(end - 1)]. If this behavior is undesirable, provide a custom outputsize(layer, x, rng) implementation).
Fallback Implementation
The fallback implementation of this function is defined once Lux.jl is loaded.
Changes from Pre-1.0 Behavior
Previously it was possible to override this function by defining outputsize(layer). However, this can potentially introduce a bug that is hard to bypass. See this PR for more information.
',6))])])}const z=o(d,[["render",T]]);export{S as __pageData,z as default};
diff --git a/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.js b/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.js
new file mode 100644
index 0000000000..90493dfafe
--- /dev/null
+++ b/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.js
@@ -0,0 +1,10 @@
+import{_ as r,c as l,j as t,a as s,G as n,a2 as e,B as d,o}from"./chunks/framework.DFwXuivk.js";const st=JSON.parse('{"title":"LuxLib","description":"","frontmatter":{},"headers":[],"relativePath":"api/Building_Blocks/LuxLib.md","filePath":"api/Building_Blocks/LuxLib.md","lastUpdated":null}'),p={name:"api/Building_Blocks/LuxLib.md"},h={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},Q={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"},k={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},T={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},y={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},x={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.471ex"},xmlns:"http://www.w3.org/2000/svg",width:"25.07ex",height:"2.016ex",role:"img",focusable:"false",viewBox:"0 -683 11080.9 891","aria-hidden":"true"},L={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},E={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},v={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},w={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},F={class:"jldocstring custom-block"},C={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},A={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},j={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},H={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},D={class:"jldocstring custom-block"},M={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},B={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.471ex"},xmlns:"http://www.w3.org/2000/svg",width:"22.72ex",height:"2.016ex",role:"img",focusable:"false",viewBox:"0 -683 10042 891","aria-hidden":"true"},V={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},P={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},I={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},Z={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},z={class:"jldocstring custom-block"},N={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},R={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.025ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.294ex",height:"1.025ex",role:"img",focusable:"false",viewBox:"0 -442 572 453","aria-hidden":"true"},O={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},S={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-2.76ex"},xmlns:"http://www.w3.org/2000/svg",width:"25.034ex",height:"6.063ex",role:"img",focusable:"false",viewBox:"0 -1460 11064.9 2680","aria-hidden":"true"},G={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},U={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},J={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},q={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},X={class:"jldocstring custom-block"};function K(W,i,$,Y,_,tt){const a=d("Badge");return o(),l("div",null,[i[144]||(i[144]=t("h1",{id:"LuxLib-API",tabindex:"-1"},[s("LuxLib "),t("a",{class:"header-anchor",href:"#LuxLib-API","aria-label":'Permalink to "LuxLib {#LuxLib-API}"'},"")],-1)),i[145]||(i[145]=t("p",null,"Backend for Lux.jl",-1)),i[146]||(i[146]=t("h2",{id:"Apply-Activation",tabindex:"-1"},[s("Apply Activation "),t("a",{class:"header-anchor",href:"#Apply-Activation","aria-label":'Permalink to "Apply Activation {#Apply-Activation}"'},"")],-1)),t("details",h,[t("summary",null,[i[0]||(i[0]=t("a",{id:"LuxLib.API.fast_activation",href:"#LuxLib.API.fast_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fast_activation")],-1)),i[1]||(i[1]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[2]||(i[2]=e('
julia
fast_activation(σ::F, x::AbstractArray) where {F}
Compute σ.(x) with the best possible implementation available. On CPUs we unroll the loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use broadcasting.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
fast_activation!!(σ::F, x::AbstractArray) where {F}
Compute σ.(x) with the best possible implementation available. If it is possible to rewrite x in-place, it does so. If x is an immutable array, it falls back to the generic implementation.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
Load SLEEFPirates.jl to get faster activations
Certain activation functions are replaced with specialized implementations from SLEEFPirates.jl for FP32. This might lead to faster performance but can cause slight decrease in accuracy (in the floating point limit).
',9))]),i[147]||(i[147]=t("h2",{id:"Batched-Operations",tabindex:"-1"},[s("Batched Operations "),t("a",{class:"header-anchor",href:"#Batched-Operations","aria-label":'Permalink to "Batched Operations {#Batched-Operations}"'},"")],-1)),t("details",Q,[t("summary",null,[i[6]||(i[6]=t("a",{id:"LuxLib.API.batched_matmul",href:"#LuxLib.API.batched_matmul"},[t("span",{class:"jlbinding"},"LuxLib.API.batched_matmul")],-1)),i[7]||(i[7]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[8]||(i[8]=e('
julia
batched_matmul(x, y)
Computes the batched matrix multiplication of x and y. For more details see the NNlib documentation on NNlib.batched_mul. This function is mostly a wrapper around batched_mul but attempts to be faster on CPUs.
Load LoopVectorization.jl to get faster batched matrix multiplication
On CPUs loading LoopVectorization adds faster implementations of batched matrix multiplication.
',4))]),i[148]||(i[148]=t("h2",{id:"Bias-Activation",tabindex:"-1"},[s("Bias Activation "),t("a",{class:"header-anchor",href:"#Bias-Activation","aria-label":'Permalink to "Bias Activation {#Bias-Activation}"'},"")],-1)),t("details",m,[t("summary",null,[i[9]||(i[9]=t("a",{id:"LuxLib.API.bias_activation",href:"#LuxLib.API.bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.bias_activation")],-1)),i[10]||(i[10]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[11]||(i[11]=e('
julia
bias_activation(σ, x, bias)
Applies the activation function σ elementwise to the result of broadcasted addition of x and bias along the penultimate dimension. A vector x is treated as a matrix with a single last dimension.
Same as bias_activation but might update x in-place if possible. Users should not rely on x being mutated, it is recommended to use it like y = bias_activation!!(σ, x, bias). If x is updated in-place, y aliases x.
',4))]),i[149]||(i[149]=t("h2",{id:"Convolutional-Layers",tabindex:"-1"},[s("Convolutional Layers "),t("a",{class:"header-anchor",href:"#Convolutional-Layers","aria-label":'Permalink to "Convolutional Layers {#Convolutional-Layers}"'},"")],-1)),t("details",g,[t("summary",null,[i[15]||(i[15]=t("a",{id:"LuxLib.API.fused_conv_bias_activation",href:"#LuxLib.API.fused_conv_bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fused_conv_bias_activation")],-1)),i[16]||(i[16]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[17]||(i[17]=e(`
julia
fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray,
+ b::Optional{<:AbstractVector}, cdims::ConvDims) where {F}
Computes σ.(conv(x, weight, cdims) .+ b) (b is not exactly broadcasted like this, rather it is reshaped and broadcasted to the penultimate dimension) with the best possible implementation available. This operation fuses operations into a single kernel if possible, and minimizes reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight tensor
x: Input tensor
b: Bias tensor (can be nothing)
cdims: ConvDims object
Notes on implementation
For CUDA Arrays, this uses fused CUDNN kernels when the activation is identity or relu. For other activations, it tries to fuse the operations on the Julia side.
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning.
`,7))]),i[150]||(i[150]=t("h2",{id:"dropout",tabindex:"-1"},[s("Dropout "),t("a",{class:"header-anchor",href:"#dropout","aria-label":'Permalink to "Dropout"'},"")],-1)),t("details",c,[t("summary",null,[i[18]||(i[18]=t("a",{id:"LuxLib.API.alpha_dropout",href:"#LuxLib.API.alpha_dropout"},[t("span",{class:"jlbinding"},"LuxLib.API.alpha_dropout")],-1)),i[19]||(i[19]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[20]||(i[20]=e('
julia
alpha_dropout(rng::AbstractRNG, x, p, training)\nalpha_dropout(rng::AbstractRNG, x, p, training, α, A, B)
Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants for a fixed dropout probability.
Arguments
rng: Random number generator
x: Input Array
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context`
α: -1.7580993408473766. Computed at limit x tends to infinity, selu(x) = -λβ = α
A: Scaling factor for the mean
B: Scaling factor for the variance
Returns
Output Array after applying alpha dropout
Updated state for the random number generator
References
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017).
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Arguments
rng: Random number generator
x: Input Array
mask: Dropout Mask. If not used then it is constructed automatically
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
update_mask: If Val(true) or True() then the mask is generated and used. Else, the mask provided is directly used
invp: Inverse multiplied to the mask. Calculated as invp = 1 / (1 - p).
Returns
Output Array after applying dropout
Dropout Mask (if training == false, the returned value is meaningless)
Updated state for the random number generator
References
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
`,9))]),i[151]||(i[151]=t("h2",{id:"Fully-Connected-Layers",tabindex:"-1"},[s("Fully Connected Layers "),t("a",{class:"header-anchor",href:"#Fully-Connected-Layers","aria-label":'Permalink to "Fully Connected Layers {#Fully-Connected-Layers}"'},"")],-1)),t("details",b,[t("summary",null,[i[24]||(i[24]=t("a",{id:"LuxLib.API.fused_dense_bias_activation",href:"#LuxLib.API.fused_dense_bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fused_dense_bias_activation")],-1)),i[25]||(i[25]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[26]||(i[26]=e('
julia
fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix,\n b::Optional{<:AbstractVector}) where {F}
Compute σ.(weight * x .+ b) with the best possible implementation available. Currently this implementation attempts to minimize reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight matrix
x: Input matrix
b: Bias vector (can be nothing)
Notes on implementation
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For CUDA Arrays, this uses a special fused implementation via cuBLASLt.
For small CPU Arrays, we use LoopVectorization.jl. On x86_64 we use Octavian for medium sized matrices. This is overridden if special BLAS implementations are loaded (currently MKL, AppleAccelerate, and BLISBLAS).
!!! tip "Load Octavian.jl
Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the\ninput sizes.
',9))]),i[152]||(i[152]=t("h2",{id:"normalization",tabindex:"-1"},[s("Normalization "),t("a",{class:"header-anchor",href:"#normalization","aria-label":'Permalink to "Normalization"'},"")],-1)),t("details",f,[t("summary",null,[i[27]||(i[27]=t("a",{id:"LuxLib.API.batchnorm",href:"#LuxLib.API.batchnorm"},[t("span",{class:"jlbinding"},"LuxLib.API.batchnorm")],-1)),i[28]||(i[28]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[49]||(i[49]=e(`
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
momentum: Momentum for updating running mean and variance (default: 0.1f0)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",6))]),i[51]||(i[51]=t("p",null,[t("strong",null,"Returns")],-1)),i[52]||(i[52]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(". And a Named Tuple containing the updated running mean and variance.")],-1)),i[53]||(i[53]=t("p",null,[t("strong",null,"References")],-1)),i[54]||(i[54]=t("p",null,'[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015.',-1)),i[55]||(i[55]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/batchnorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",F,[t("summary",null,[i[56]||(i[56]=t("a",{id:"LuxLib.API.groupnorm",href:"#LuxLib.API.groupnorm"},[t("span",{class:"jlbinding"},"LuxLib.API.groupnorm")],-1)),i[57]||(i[57]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[74]||(i[74]=e(`
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics.
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",3))]),i[75]||(i[75]=t("p",null,[t("strong",null,"Returns")],-1)),i[76]||(i[76]=t("p",null,"The normalized array is returned.",-1)),i[77]||(i[77]=t("p",null,[t("strong",null,"References")],-1)),i[78]||(i[78]=t("p",null,'[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.',-1)),i[79]||(i[79]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/groupnorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",D,[t("summary",null,[i[80]||(i[80]=t("a",{id:"LuxLib.API.instancenorm",href:"#LuxLib.API.instancenorm"},[t("span",{class:"jlbinding"},"LuxLib.API.instancenorm")],-1)),i[81]||(i[81]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[102]||(i[102]=e(`
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
momentum: Momentum for updating running mean and variance (default: 0.1f0)
",6))]),i[104]||(i[104]=t("p",null,[t("strong",null,"Returns")],-1)),i[105]||(i[105]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(". And a Named Tuple containing the updated running mean and variance.")],-1)),i[106]||(i[106]=t("p",null,[t("strong",null,"References")],-1)),i[107]||(i[107]=t("p",null,'[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).',-1)),i[108]||(i[108]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/instancenorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",z,[t("summary",null,[i[109]||(i[109]=t("a",{id:"LuxLib.API.layernorm",href:"#LuxLib.API.layernorm"},[t("span",{class:"jlbinding"},"LuxLib.API.layernorm")],-1)),i[110]||(i[110]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[133]||(i[133]=e(`
dims: Dimensions along which the mean and std of x is computed. If nothing is passed, the dims are inferred based on the dimensions of scale and bias. For example, if x is N dimensional and scale and bias are M dimensional, then the dims will be 1:(N - M).
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",3))]),i[136]||(i[136]=t("p",null,[t("strong",null,"Returns")],-1)),i[137]||(i[137]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(".")],-1)),i[138]||(i[138]=t("p",null,[t("strong",null,"References")],-1)),i[139]||(i[139]=t("p",null,'[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016).',-1)),i[140]||(i[140]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/layernorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),i[153]||(i[153]=t("h2",{id:"Helper-Functions",tabindex:"-1"},[s("Helper Functions "),t("a",{class:"header-anchor",href:"#Helper-Functions","aria-label":'Permalink to "Helper Functions {#Helper-Functions}"'},"")],-1)),t("details",X,[t("summary",null,[i[141]||(i[141]=t("a",{id:"LuxLib.internal_operation_mode",href:"#LuxLib.internal_operation_mode"},[t("span",{class:"jlbinding"},"LuxLib.internal_operation_mode")],-1)),i[142]||(i[142]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[143]||(i[143]=e(`
Returns the internal operation mode for the given array(s). This is useful to define custom implementations using different backends like simple Julia broadcasting, Kernel Abstractions, Loop Vectorization, etc.
Currently supported modes are:
GenericBroadcastOp: This is the fallback for most types. For the following types this is the preferred mode:
Arrays with fast_scalar_indexing set to False.
Static Arrays
ReverseDiff Arrays
Tracker Arrays
ForwardDiff.Dual Arrays
GPUBroadcastOp{dev}: GPU Arrays where dev is obtained from get_device_type(xs). This option dispatches should preferably use KernelAbstractions or specialized vendor dispatches.
LoopedArrayOp: CPU arrays that can be optimized using SIMD Loops, ideally using LoopVectorization.jl or Polyester.jl.
`,5))])])}const et=r(p,[["render",K]]);export{st as __pageData,et as default};
diff --git a/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.lean.js b/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.lean.js
new file mode 100644
index 0000000000..90493dfafe
--- /dev/null
+++ b/previews/PR1023/assets/api_Building_Blocks_LuxLib.md.DEJmN7Bf.lean.js
@@ -0,0 +1,10 @@
+import{_ as r,c as l,j as t,a as s,G as n,a2 as e,B as d,o}from"./chunks/framework.DFwXuivk.js";const st=JSON.parse('{"title":"LuxLib","description":"","frontmatter":{},"headers":[],"relativePath":"api/Building_Blocks/LuxLib.md","filePath":"api/Building_Blocks/LuxLib.md","lastUpdated":null}'),p={name:"api/Building_Blocks/LuxLib.md"},h={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},Q={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"},k={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},T={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},y={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},x={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.471ex"},xmlns:"http://www.w3.org/2000/svg",width:"25.07ex",height:"2.016ex",role:"img",focusable:"false",viewBox:"0 -683 11080.9 891","aria-hidden":"true"},L={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},E={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},v={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},w={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},F={class:"jldocstring custom-block"},C={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},A={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},j={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},H={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},D={class:"jldocstring custom-block"},M={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},B={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.471ex"},xmlns:"http://www.w3.org/2000/svg",width:"22.72ex",height:"2.016ex",role:"img",focusable:"false",viewBox:"0 -683 10042 891","aria-hidden":"true"},V={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},P={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},I={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},Z={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},z={class:"jldocstring custom-block"},N={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},R={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.025ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.294ex",height:"1.025ex",role:"img",focusable:"false",viewBox:"0 -442 572 453","aria-hidden":"true"},O={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},S={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-2.76ex"},xmlns:"http://www.w3.org/2000/svg",width:"25.034ex",height:"6.063ex",role:"img",focusable:"false",viewBox:"0 -1460 11064.9 2680","aria-hidden":"true"},G={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},U={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.489ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.229ex",height:"1.486ex",role:"img",focusable:"false",viewBox:"0 -441 543 657","aria-hidden":"true"},J={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},q={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.439ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.281ex",height:"2.034ex",role:"img",focusable:"false",viewBox:"0 -705 566 899","aria-hidden":"true"},X={class:"jldocstring custom-block"};function K(W,i,$,Y,_,tt){const a=d("Badge");return o(),l("div",null,[i[144]||(i[144]=t("h1",{id:"LuxLib-API",tabindex:"-1"},[s("LuxLib "),t("a",{class:"header-anchor",href:"#LuxLib-API","aria-label":'Permalink to "LuxLib {#LuxLib-API}"'},"")],-1)),i[145]||(i[145]=t("p",null,"Backend for Lux.jl",-1)),i[146]||(i[146]=t("h2",{id:"Apply-Activation",tabindex:"-1"},[s("Apply Activation "),t("a",{class:"header-anchor",href:"#Apply-Activation","aria-label":'Permalink to "Apply Activation {#Apply-Activation}"'},"")],-1)),t("details",h,[t("summary",null,[i[0]||(i[0]=t("a",{id:"LuxLib.API.fast_activation",href:"#LuxLib.API.fast_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fast_activation")],-1)),i[1]||(i[1]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[2]||(i[2]=e('
julia
fast_activation(σ::F, x::AbstractArray) where {F}
Compute σ.(x) with the best possible implementation available. On CPUs we unroll the loop and use LoopVectorization.jl to vectorize the computation. On GPUs we use simply use broadcasting.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
fast_activation!!(σ::F, x::AbstractArray) where {F}
Compute σ.(x) with the best possible implementation available. If it is possible to rewrite x in-place, it does so. If x is an immutable array, it falls back to the generic implementation.
Note
This function doesn't replace σ with NNlib.fast_act(σ, ...), that needs to be done by the user if needed.
Load SLEEFPirates.jl to get faster activations
Certain activation functions are replaced with specialized implementations from SLEEFPirates.jl for FP32. This might lead to faster performance but can cause slight decrease in accuracy (in the floating point limit).
',9))]),i[147]||(i[147]=t("h2",{id:"Batched-Operations",tabindex:"-1"},[s("Batched Operations "),t("a",{class:"header-anchor",href:"#Batched-Operations","aria-label":'Permalink to "Batched Operations {#Batched-Operations}"'},"")],-1)),t("details",Q,[t("summary",null,[i[6]||(i[6]=t("a",{id:"LuxLib.API.batched_matmul",href:"#LuxLib.API.batched_matmul"},[t("span",{class:"jlbinding"},"LuxLib.API.batched_matmul")],-1)),i[7]||(i[7]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[8]||(i[8]=e('
julia
batched_matmul(x, y)
Computes the batched matrix multiplication of x and y. For more details see the NNlib documentation on NNlib.batched_mul. This function is mostly a wrapper around batched_mul but attempts to be faster on CPUs.
Load LoopVectorization.jl to get faster batched matrix multiplication
On CPUs loading LoopVectorization adds faster implementations of batched matrix multiplication.
',4))]),i[148]||(i[148]=t("h2",{id:"Bias-Activation",tabindex:"-1"},[s("Bias Activation "),t("a",{class:"header-anchor",href:"#Bias-Activation","aria-label":'Permalink to "Bias Activation {#Bias-Activation}"'},"")],-1)),t("details",m,[t("summary",null,[i[9]||(i[9]=t("a",{id:"LuxLib.API.bias_activation",href:"#LuxLib.API.bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.bias_activation")],-1)),i[10]||(i[10]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[11]||(i[11]=e('
julia
bias_activation(σ, x, bias)
Applies the activation function σ elementwise to the result of broadcasted addition of x and bias along the penultimate dimension. A vector x is treated as a matrix with a single last dimension.
Same as bias_activation but might update x in-place if possible. Users should not rely on x being mutated, it is recommended to use it like y = bias_activation!!(σ, x, bias). If x is updated in-place, y aliases x.
',4))]),i[149]||(i[149]=t("h2",{id:"Convolutional-Layers",tabindex:"-1"},[s("Convolutional Layers "),t("a",{class:"header-anchor",href:"#Convolutional-Layers","aria-label":'Permalink to "Convolutional Layers {#Convolutional-Layers}"'},"")],-1)),t("details",g,[t("summary",null,[i[15]||(i[15]=t("a",{id:"LuxLib.API.fused_conv_bias_activation",href:"#LuxLib.API.fused_conv_bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fused_conv_bias_activation")],-1)),i[16]||(i[16]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[17]||(i[17]=e(`
julia
fused_conv_bias_activation(σ::F, weight::AbstractArray, x::AbstractArray,
+ b::Optional{<:AbstractVector}, cdims::ConvDims) where {F}
Computes σ.(conv(x, weight, cdims) .+ b) (b is not exactly broadcasted like this, rather it is reshaped and broadcasted to the penultimate dimension) with the best possible implementation available. This operation fuses operations into a single kernel if possible, and minimizes reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight tensor
x: Input tensor
b: Bias tensor (can be nothing)
cdims: ConvDims object
Notes on implementation
For CUDA Arrays, this uses fused CUDNN kernels when the activation is identity or relu. For other activations, it tries to fuse the operations on the Julia side.
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For Mixed-Precision Inputs on GPU, we type promote the inputs to the highest precision, with a warning.
`,7))]),i[150]||(i[150]=t("h2",{id:"dropout",tabindex:"-1"},[s("Dropout "),t("a",{class:"header-anchor",href:"#dropout","aria-label":'Permalink to "Dropout"'},"")],-1)),t("details",c,[t("summary",null,[i[18]||(i[18]=t("a",{id:"LuxLib.API.alpha_dropout",href:"#LuxLib.API.alpha_dropout"},[t("span",{class:"jlbinding"},"LuxLib.API.alpha_dropout")],-1)),i[19]||(i[19]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[20]||(i[20]=e('
julia
alpha_dropout(rng::AbstractRNG, x, p, training)\nalpha_dropout(rng::AbstractRNG, x, p, training, α, A, B)
Alpha Dropout: Dropout ensuring that the mean and variance of the output remains same as the input. For details see [1]. Use the second call signature to avoid recomputing the constants for a fixed dropout probability.
Arguments
rng: Random number generator
x: Input Array
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context`
α: -1.7580993408473766. Computed at limit x tends to infinity, selu(x) = -λβ = α
A: Scaling factor for the mean
B: Scaling factor for the variance
Returns
Output Array after applying alpha dropout
Updated state for the random number generator
References
[1] Klambauer, Günter, et al. "Self-normalizing neural networks." Advances in neural information processing systems 30 (2017).
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Arguments
rng: Random number generator
x: Input Array
mask: Dropout Mask. If not used then it is constructed automatically
p: Probability of an element to be dropped out
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
update_mask: If Val(true) or True() then the mask is generated and used. Else, the mask provided is directly used
invp: Inverse multiplied to the mask. Calculated as invp = 1 / (1 - p).
Returns
Output Array after applying dropout
Dropout Mask (if training == false, the returned value is meaningless)
Updated state for the random number generator
References
[1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent neural networks from overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
`,9))]),i[151]||(i[151]=t("h2",{id:"Fully-Connected-Layers",tabindex:"-1"},[s("Fully Connected Layers "),t("a",{class:"header-anchor",href:"#Fully-Connected-Layers","aria-label":'Permalink to "Fully Connected Layers {#Fully-Connected-Layers}"'},"")],-1)),t("details",b,[t("summary",null,[i[24]||(i[24]=t("a",{id:"LuxLib.API.fused_dense_bias_activation",href:"#LuxLib.API.fused_dense_bias_activation"},[t("span",{class:"jlbinding"},"LuxLib.API.fused_dense_bias_activation")],-1)),i[25]||(i[25]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[26]||(i[26]=e('
julia
fused_dense_bias_activation(σ::F, weight::AbstractMatrix, x::AbstractMatrix,\n b::Optional{<:AbstractVector}) where {F}
Compute σ.(weight * x .+ b) with the best possible implementation available. Currently this implementation attempts to minimize reallocations by reusing the output buffer for multiple operations.
Arguments
σ: Activation function
weight: Weight matrix
x: Input matrix
b: Bias vector (can be nothing)
Notes on implementation
If any of the inputs, don't support setindexing (aka immutable arrays) we fallback to the generic non-mutating implementation.
Maximum memory reuse and operation fusion is guaranteed for ChainRules compatible AD backends or backends that support mutation. Backends like Tracker and ReverseDiff fallback to the generic implementation.
For CUDA Arrays, this uses a special fused implementation via cuBLASLt.
For small CPU Arrays, we use LoopVectorization.jl. On x86_64 we use Octavian for medium sized matrices. This is overridden if special BLAS implementations are loaded (currently MKL, AppleAccelerate, and BLISBLAS).
!!! tip "Load Octavian.jl
Loading `Octavian.jl` enables a polyalgorithm that uses different backends based on the\ninput sizes.
',9))]),i[152]||(i[152]=t("h2",{id:"normalization",tabindex:"-1"},[s("Normalization "),t("a",{class:"header-anchor",href:"#normalization","aria-label":'Permalink to "Normalization"'},"")],-1)),t("details",f,[t("summary",null,[i[27]||(i[27]=t("a",{id:"LuxLib.API.batchnorm",href:"#LuxLib.API.batchnorm"},[t("span",{class:"jlbinding"},"LuxLib.API.batchnorm")],-1)),i[28]||(i[28]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[49]||(i[49]=e(`
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
momentum: Momentum for updating running mean and variance (default: 0.1f0)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",6))]),i[51]||(i[51]=t("p",null,[t("strong",null,"Returns")],-1)),i[52]||(i[52]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(". And a Named Tuple containing the updated running mean and variance.")],-1)),i[53]||(i[53]=t("p",null,[t("strong",null,"References")],-1)),i[54]||(i[54]=t("p",null,'[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015.',-1)),i[55]||(i[55]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/batchnorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",F,[t("summary",null,[i[56]||(i[56]=t("a",{id:"LuxLib.API.groupnorm",href:"#LuxLib.API.groupnorm"},[t("span",{class:"jlbinding"},"LuxLib.API.groupnorm")],-1)),i[57]||(i[57]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[74]||(i[74]=e(`
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics.
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",3))]),i[75]||(i[75]=t("p",null,[t("strong",null,"Returns")],-1)),i[76]||(i[76]=t("p",null,"The normalized array is returned.",-1)),i[77]||(i[77]=t("p",null,[t("strong",null,"References")],-1)),i[78]||(i[78]=t("p",null,'[1] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.',-1)),i[79]||(i[79]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/groupnorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",D,[t("summary",null,[i[80]||(i[80]=t("a",{id:"LuxLib.API.instancenorm",href:"#LuxLib.API.instancenorm"},[t("span",{class:"jlbinding"},"LuxLib.API.instancenorm")],-1)),i[81]||(i[81]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[102]||(i[102]=e(`
training: Set to Val(true) or True() if running in training mode. Can be set to nothing to automatically determine if the function is being called within an autodiff context
σ: Activation function (default: identity)
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
momentum: Momentum for updating running mean and variance (default: 0.1f0)
",6))]),i[104]||(i[104]=t("p",null,[t("strong",null,"Returns")],-1)),i[105]||(i[105]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(". And a Named Tuple containing the updated running mean and variance.")],-1)),i[106]||(i[106]=t("p",null,[t("strong",null,"References")],-1)),i[107]||(i[107]=t("p",null,'[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).',-1)),i[108]||(i[108]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/instancenorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",z,[t("summary",null,[i[109]||(i[109]=t("a",{id:"LuxLib.API.layernorm",href:"#LuxLib.API.layernorm"},[t("span",{class:"jlbinding"},"LuxLib.API.layernorm")],-1)),i[110]||(i[110]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[133]||(i[133]=e(`
dims: Dimensions along which the mean and std of x is computed. If nothing is passed, the dims are inferred based on the dimensions of scale and bias. For example, if x is N dimensional and scale and bias are M dimensional, then the dims will be 1:(N - M).
epsilon: Value added to the denominator for numerical stability (default: eps(eltype(x)) ^ (5 / 7))
",3))]),i[136]||(i[136]=t("p",null,[t("strong",null,"Returns")],-1)),i[137]||(i[137]=t("p",null,[s("Normalized Array of same size as "),t("code",null,"x"),s(".")],-1)),i[138]||(i[138]=t("p",null,[t("strong",null,"References")],-1)),i[139]||(i[139]=t("p",null,'[1] Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton. "Layer normalization." arXiv preprint arXiv:1607.06450 (2016).',-1)),i[140]||(i[140]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/LuxLib/src/api/layernorm.jl#L1",target:"_blank",rel:"noreferrer"},"source")],-1))]),i[153]||(i[153]=t("h2",{id:"Helper-Functions",tabindex:"-1"},[s("Helper Functions "),t("a",{class:"header-anchor",href:"#Helper-Functions","aria-label":'Permalink to "Helper Functions {#Helper-Functions}"'},"")],-1)),t("details",X,[t("summary",null,[i[141]||(i[141]=t("a",{id:"LuxLib.internal_operation_mode",href:"#LuxLib.internal_operation_mode"},[t("span",{class:"jlbinding"},"LuxLib.internal_operation_mode")],-1)),i[142]||(i[142]=s()),n(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[143]||(i[143]=e(`
Returns the internal operation mode for the given array(s). This is useful to define custom implementations using different backends like simple Julia broadcasting, Kernel Abstractions, Loop Vectorization, etc.
Currently supported modes are:
GenericBroadcastOp: This is the fallback for most types. For the following types this is the preferred mode:
Arrays with fast_scalar_indexing set to False.
Static Arrays
ReverseDiff Arrays
Tracker Arrays
ForwardDiff.Dual Arrays
GPUBroadcastOp{dev}: GPU Arrays where dev is obtained from get_device_type(xs). This option dispatches should preferably use KernelAbstractions or specialized vendor dispatches.
LoopedArrayOp: CPU arrays that can be optimized using SIMD Loops, ideally using LoopVectorization.jl or Polyester.jl.
glorot_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a normal distribution with standard deviation gain * sqrt(2 / (fan_in + fan_out)). This method is described in [1] and also known as Xavier initialization.
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
glorot_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
`,1)),s("p",null,[i[7]||(i[7]=a("Return an ")),i[8]||(i[8]=s("code",null,"AbstractArray{T}",-1)),i[9]||(i[9]=a(" of the given ")),i[10]||(i[10]=s("code",null,"size",-1)),i[11]||(i[11]=a(" containing random numbers drawn from a uniform distribution on the interval ")),s("mjx-container",g,[(h(),e("svg",E,i[5]||(i[5]=[t('',1)]))),i[6]||(i[6]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mo",{stretchy:"false"},"["),s("mo",null,"−"),s("mi",null,"x"),s("mo",null,","),s("mi",null,"x"),s("mo",{stretchy:"false"},"]")])],-1))]),i[12]||(i[12]=a(", where ")),i[13]||(i[13]=s("code",null,"x = gain * sqrt(6 / (fan_in + fan_out))",-1)),i[14]||(i[14]=a(". This method is described in [1] and also known as Xavier initialization."))]),i[16]||(i[16]=s("p",null,[s("strong",null,"References")],-1)),i[17]||(i[17]=s("p",null,[a('[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." '),s("em",null,"Proceedings of the thirteenth international conference on artificial intelligence and statistics"),a(". 2010.")],-1)),i[18]||(i[18]=s("p",null,[s("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/WeightInitializers/src/initializers.jl#L13-L27",target:"_blank",rel:"noreferrer"},"source")],-1))]),s("details",y,[s("summary",null,[i[19]||(i[19]=s("a",{id:"WeightInitializers.identity_init",href:"#WeightInitializers.identity_init"},[s("span",{class:"jlbinding"},"WeightInitializers.identity_init")],-1)),i[20]||(i[20]=a()),l(n,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[21]||(i[21]=t(`
Constructs an array that aims to provide an identity mapping when used as parameters in most layers of a neural network. The identity mapping is scaled by the gain parameter.
Behavior
1D: Returns a Vector of zeros (useful for biases in layers where input_size == output_size).
2D: Returns an identity matrix (useful for fully connected layers with equal input and output sizes).
More than 2D: Returns a tensor where the central slice along the last two dimensions is an identity matrix, and the rest are zeros (useful for convolutional layers, simulating an identity convolution).
Caveats
Not all layers will result in an identity mapping when using this initializer. Exceptions include recurrent and normalization layers.
Layers must have input_size == output_size for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros.
For convolutional layers to achieve an identity mapping, kernel sizes must be odd, and appropriate padding must be applied to ensure the output feature maps are the same size as the input feature maps.
Arguments
rng::AbstractRNG: An optional random number generator, included for consistency with other initializers but ignored since the output is deterministic.
T::Type{<:Number}: The numeric type of the array elements.
size...: The dimensions of the array to be initialized.
gain::Number=1: A scaling factor applied to the identity mapping.
shift::Union{Integer, Tuple{Integer, Integer}}=0: An integer or a tuple specifying the circular shift applied to the output array.
Returns
AbstractArray{T}: An array initialized to represent an identity mapping, scaled by gain and optionally shifted by shift.
kaiming_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers taken from a normal distribution standard deviation gain / sqrt(fan_in)
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
kaiming_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a uniform distribution on the interval [-x, x], where x = gain * sqrt(3/fan_in).
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, using random numbers drawn from a normal distribution for the non-zero elements. This method was introduced in [1].
Note
The sparsity parameter controls the proportion of the matrix that will be zeroed. For example, a sparsity of 0.3 means that approximately 30% of the elements will be set to zero. The non-zero elements are distributed according to a normal distribution, scaled by the std parameter.
Arguments
rng::AbstractRNG: The random number generator to use.
T::Type{<:Number}: The numeric type of the elements in the returned array.
dims::Integer...: The dimensions of the weight matrix to be generated.
sparsity::Number: The proportion of elements to be zeroed. Must be between 0 and 1.
std::Number=0.01: The standard deviation of the normal distribution before applying gain.
Returns
AbstractArray{T}: A sparsely initialized weight matrix of dimensions dims and type T.
Examples
julia
julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01);
+
+julia> y isa Matrix{Float32}
+true
+
+julia> size(y) == (5, 5)
+true
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
truncated_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; mean = 0,
+ std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100)).
orthogonal([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...;
+ gain = 1) -> AbstractArray{T, length(dims)}
Return an AbstractArray{T} of the given dimensions (dims) which is a (semi) orthogonal matrix, as described in [1].
The function constructs an orthogonal or semi-orthogonal matrix depending on the specified dimensions. For two dimensions, it returns a matrix where dims = (rows, cols). For more than two dimensions, it computes an orthogonal matrix of size prod(dims[1:(end - 1)]) by dims[end] before reshaping it to the original dimensions.
Cannot construct a vector, i.e., length(dims) == 1 is forbidden.
Arguments
rng::AbstractRNG: Random number generator.
T::Type{<:Real}: The type of the elements in the array.
dims::Integer...: The dimensions of the array.
gain::Number: Scaling factor for the elements of the orthogonal matrix.
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
`,9))]),i[110]||(i[110]=s("h3",{id:"Other-Convenience-Functions",tabindex:"-1"},[a("Other Convenience Functions "),s("a",{class:"header-anchor",href:"#Other-Convenience-Functions","aria-label":'Permalink to "Other Convenience Functions {#Other-Convenience-Functions}"'},"")],-1)),i[111]||(i[111]=s("div",{class:"warning custom-block"},[s("p",{class:"custom-block-title"},"Beware"),s("p",null,"Unlike the other functions these ones don't take a type argument.")],-1)),s("details",C,[s("summary",null,[i[37]||(i[37]=s("a",{id:"WeightInitializers.zeros16",href:"#WeightInitializers.zeros16"},[s("span",{class:"jlbinding"},"WeightInitializers.zeros16")],-1)),i[38]||(i[38]=a()),l(n,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[39]||(i[39]=t(`
glorot_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a normal distribution with standard deviation gain * sqrt(2 / (fan_in + fan_out)). This method is described in [1] and also known as Xavier initialization.
References
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics. 2010.
glorot_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = 1) -> AbstractArray{T, length(size)}
`,1)),s("p",null,[i[7]||(i[7]=a("Return an ")),i[8]||(i[8]=s("code",null,"AbstractArray{T}",-1)),i[9]||(i[9]=a(" of the given ")),i[10]||(i[10]=s("code",null,"size",-1)),i[11]||(i[11]=a(" containing random numbers drawn from a uniform distribution on the interval ")),s("mjx-container",g,[(h(),e("svg",E,i[5]||(i[5]=[t('',1)]))),i[6]||(i[6]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mo",{stretchy:"false"},"["),s("mo",null,"−"),s("mi",null,"x"),s("mo",null,","),s("mi",null,"x"),s("mo",{stretchy:"false"},"]")])],-1))]),i[12]||(i[12]=a(", where ")),i[13]||(i[13]=s("code",null,"x = gain * sqrt(6 / (fan_in + fan_out))",-1)),i[14]||(i[14]=a(". This method is described in [1] and also known as Xavier initialization."))]),i[16]||(i[16]=s("p",null,[s("strong",null,"References")],-1)),i[17]||(i[17]=s("p",null,[a('[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." '),s("em",null,"Proceedings of the thirteenth international conference on artificial intelligence and statistics"),a(". 2010.")],-1)),i[18]||(i[18]=s("p",null,[s("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/lib/WeightInitializers/src/initializers.jl#L13-L27",target:"_blank",rel:"noreferrer"},"source")],-1))]),s("details",y,[s("summary",null,[i[19]||(i[19]=s("a",{id:"WeightInitializers.identity_init",href:"#WeightInitializers.identity_init"},[s("span",{class:"jlbinding"},"WeightInitializers.identity_init")],-1)),i[20]||(i[20]=a()),l(n,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[21]||(i[21]=t(`
Constructs an array that aims to provide an identity mapping when used as parameters in most layers of a neural network. The identity mapping is scaled by the gain parameter.
Behavior
1D: Returns a Vector of zeros (useful for biases in layers where input_size == output_size).
2D: Returns an identity matrix (useful for fully connected layers with equal input and output sizes).
More than 2D: Returns a tensor where the central slice along the last two dimensions is an identity matrix, and the rest are zeros (useful for convolutional layers, simulating an identity convolution).
Caveats
Not all layers will result in an identity mapping when using this initializer. Exceptions include recurrent and normalization layers.
Layers must have input_size == output_size for a perfect identity mapping. In cases where this condition is not met, the function pads extra dimensions with zeros.
For convolutional layers to achieve an identity mapping, kernel sizes must be odd, and appropriate padding must be applied to ensure the output feature maps are the same size as the input feature maps.
Arguments
rng::AbstractRNG: An optional random number generator, included for consistency with other initializers but ignored since the output is deterministic.
T::Type{<:Number}: The numeric type of the array elements.
size...: The dimensions of the array to be initialized.
gain::Number=1: A scaling factor applied to the identity mapping.
shift::Union{Integer, Tuple{Integer, Integer}}=0: An integer or a tuple specifying the circular shift applied to the output array.
Returns
AbstractArray{T}: An array initialized to represent an identity mapping, scaled by gain and optionally shifted by shift.
kaiming_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers taken from a normal distribution standard deviation gain / sqrt(fan_in)
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
kaiming_uniform([::AbstractRNG=Utils.default_rng()], [T=Float32], size...;
+ gain = √T(2)) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size containing random numbers drawn from a uniform distribution on the interval [-x, x], where x = gain * sqrt(3/fan_in).
References
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." Proceedings of the IEEE international conference on computer vision. 2015.
Creates a sparsely initialized weight matrix with a specified proportion of zeroed elements, using random numbers drawn from a normal distribution for the non-zero elements. This method was introduced in [1].
Note
The sparsity parameter controls the proportion of the matrix that will be zeroed. For example, a sparsity of 0.3 means that approximately 30% of the elements will be set to zero. The non-zero elements are distributed according to a normal distribution, scaled by the std parameter.
Arguments
rng::AbstractRNG: The random number generator to use.
T::Type{<:Number}: The numeric type of the elements in the returned array.
dims::Integer...: The dimensions of the weight matrix to be generated.
sparsity::Number: The proportion of elements to be zeroed. Must be between 0 and 1.
std::Number=0.01: The standard deviation of the normal distribution before applying gain.
Returns
AbstractArray{T}: A sparsely initialized weight matrix of dimensions dims and type T.
Examples
julia
julia> y = sparse_init(Xoshiro(123), Float32, 5, 5; sparsity=0.3, std=0.01);
+
+julia> y isa Matrix{Float32}
+true
+
+julia> size(y) == (5, 5)
+true
References
[1] Martens, J, "Deep learning via Hessian-free optimization" Proceedings of the 27th International Conference on International Conference on Machine Learning. 2010.
truncated_normal([::AbstractRNG=Utils.default_rng()], [T=Float32], size...; mean = 0,
+ std = 1, lo = -2, hi = 2) -> AbstractArray{T, length(size)}
Return an AbstractArray{T} of the given size where each element is drawn from a truncated normal distribution. The numbers are distributed like filter(x -> lo ≤ x ≤ hi, mean .+ std .* randn(100)).
orthogonal([::AbstractRNG=Utils.default_rng()], [T=Float32], dims::Integer...;
+ gain = 1) -> AbstractArray{T, length(dims)}
Return an AbstractArray{T} of the given dimensions (dims) which is a (semi) orthogonal matrix, as described in [1].
The function constructs an orthogonal or semi-orthogonal matrix depending on the specified dimensions. For two dimensions, it returns a matrix where dims = (rows, cols). For more than two dimensions, it computes an orthogonal matrix of size prod(dims[1:(end - 1)]) by dims[end] before reshaping it to the original dimensions.
Cannot construct a vector, i.e., length(dims) == 1 is forbidden.
Arguments
rng::AbstractRNG: Random number generator.
T::Type{<:Real}: The type of the elements in the array.
dims::Integer...: The dimensions of the array.
gain::Number: Scaling factor for the elements of the orthogonal matrix.
References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
`,9))]),i[110]||(i[110]=s("h3",{id:"Other-Convenience-Functions",tabindex:"-1"},[a("Other Convenience Functions "),s("a",{class:"header-anchor",href:"#Other-Convenience-Functions","aria-label":'Permalink to "Other Convenience Functions {#Other-Convenience-Functions}"'},"")],-1)),i[111]||(i[111]=s("div",{class:"warning custom-block"},[s("p",{class:"custom-block-title"},"Beware"),s("p",null,"Unlike the other functions these ones don't take a type argument.")],-1)),s("details",C,[s("summary",null,[i[37]||(i[37]=s("a",{id:"WeightInitializers.zeros16",href:"#WeightInitializers.zeros16"},[s("span",{class:"jlbinding"},"WeightInitializers.zeros16")],-1)),i[38]||(i[38]=a()),l(n,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[39]||(i[39]=t(`
',1)),t("p",null,[e[4]||(e[4]=a("Compute the Jacobian-Vector Product ")),t("mjx-container",c,[(n(),i("svg",T,e[2]||(e[2]=[s('',1)]))),e[3]||(e[3]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mrow",{"data-mjx-texclass":"INNER"},[t("mo",{"data-mjx-texclass":"OPEN"},"("),t("mfrac",null,[t("mrow",null,[t("mi",null,"∂"),t("mi",null,"f")]),t("mrow",null,[t("mi",null,"∂"),t("mi",null,"x")])]),t("mo",{"data-mjx-texclass":"CLOSE"},")")]),t("mi",null,"u")])],-1))]),e[5]||(e[5]=a(". This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector products efficiently using mixed-mode AD."))]),e[7]||(e[7]=s('
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the JVP.
',1)),t("p",null,[e[12]||(e[12]=a("Compute the Vector-Jacobian Product ")),t("mjx-container",u,[(n(),i("svg",h,e[10]||(e[10]=[s('',1)]))),e[11]||(e[11]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msup",null,[t("mrow",{"data-mjx-texclass":"INNER"},[t("mo",{"data-mjx-texclass":"OPEN"},"("),t("mfrac",null,[t("mrow",null,[t("mi",null,"∂"),t("mi",null,"f")]),t("mrow",null,[t("mi",null,"∂"),t("mi",null,"x")])]),t("mo",{"data-mjx-texclass":"CLOSE"},")")]),t("mi",null,"T")]),t("mi",null,"u")])],-1))]),e[13]||(e[13]=a(". This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian products efficiently using mixed-mode AD."))]),e[15]||(e[15]=s('
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the VJP.
',8))]),e[21]||(e[21]=t("h2",{id:"Batched-AD",tabindex:"-1"},[a("Batched AD "),t("a",{class:"header-anchor",href:"#Batched-AD","aria-label":'Permalink to "Batched AD {#Batched-AD}"'},"")],-1)),t("details",g,[t("summary",null,[e[16]||(e[16]=t("a",{id:"Lux.batched_jacobian",href:"#Lux.batched_jacobian"},[t("span",{class:"jlbinding"},"Lux.batched_jacobian")],-1)),e[17]||(e[17]=a()),l(o,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[18]||(e[18]=s('
Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):
ndims(y) ≥ 2
size(y, ndims(y)) == size(x, ndims(x))
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
AutoZygote
Zygote.jl
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the jacobian.
x: The input to the function. Must have ndims(x) ≥ 2.
Returns
J: The Jacobian of f with respect to x. This will be a 3D Array. If the dimensions of x are (N₁, N₂, ..., Nₙ, B) and of y are (M₁, M₂, ..., Mₘ, B), then J will be a ((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B) Array.
Danger
f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.
',11))]),e[22]||(e[22]=t("h2",{id:"Nested-2nd-Order-AD",tabindex:"-1"},[a("Nested 2nd Order AD "),t("a",{class:"header-anchor",href:"#Nested-2nd-Order-AD","aria-label":'Permalink to "Nested 2nd Order AD {#Nested-2nd-Order-AD}"'},"")],-1)),e[23]||(e[23]=t("p",null,[a("Consult the "),t("a",{href:"/previews/PR1023/manual/nested_autodiff#nested_autodiff"},"manual page on Nested AD"),a(" for information on nested automatic differentiation.")],-1))])}const L=d(Q,[["render",f]]);export{v as __pageData,L as default};
diff --git a/previews/PR1023/assets/api_Lux_autodiff.md.CpXbBbTG.lean.js b/previews/PR1023/assets/api_Lux_autodiff.md.CpXbBbTG.lean.js
new file mode 100644
index 0000000000..84ffed5cb4
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_autodiff.md.CpXbBbTG.lean.js
@@ -0,0 +1 @@
+import{_ as d,c as i,j as t,a,G as l,a2 as s,B as r,o as n}from"./chunks/framework.DFwXuivk.js";const v=JSON.parse('{"title":"Automatic Differentiation Helpers","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/autodiff.md","filePath":"api/Lux/autodiff.md","lastUpdated":null}'),Q={name:"api/Lux/autodiff.md"},p={class:"jldocstring custom-block"},c={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},T={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-1.469ex"},xmlns:"http://www.w3.org/2000/svg",width:"6.812ex",height:"4.07ex",role:"img",focusable:"false",viewBox:"0 -1149.5 3010.7 1799","aria-hidden":"true"},m={class:"jldocstring custom-block"},u={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},h={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-1.469ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.126ex",height:"4.536ex",role:"img",focusable:"false",viewBox:"0 -1355.3 3591.5 2004.8","aria-hidden":"true"},g={class:"jldocstring custom-block"};function f(b,e,x,k,y,w){const o=r("Badge");return n(),i("div",null,[e[19]||(e[19]=t("h1",{id:"autodiff-lux-helpers",tabindex:"-1"},[a("Automatic Differentiation Helpers "),t("a",{class:"header-anchor",href:"#autodiff-lux-helpers","aria-label":'Permalink to "Automatic Differentiation Helpers {#autodiff-lux-helpers}"'},"")],-1)),e[20]||(e[20]=t("h2",{id:"JVP-and-VJP-Wrappers",tabindex:"-1"},[a("JVP & VJP Wrappers "),t("a",{class:"header-anchor",href:"#JVP-and-VJP-Wrappers","aria-label":'Permalink to "JVP & VJP Wrappers {#JVP-and-VJP-Wrappers}"'},"")],-1)),t("details",p,[t("summary",null,[e[0]||(e[0]=t("a",{id:"Lux.jacobian_vector_product",href:"#Lux.jacobian_vector_product"},[t("span",{class:"jlbinding"},"Lux.jacobian_vector_product")],-1)),e[1]||(e[1]=a()),l(o,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[6]||(e[6]=s('
',1)),t("p",null,[e[4]||(e[4]=a("Compute the Jacobian-Vector Product ")),t("mjx-container",c,[(n(),i("svg",T,e[2]||(e[2]=[s('',1)]))),e[3]||(e[3]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mrow",{"data-mjx-texclass":"INNER"},[t("mo",{"data-mjx-texclass":"OPEN"},"("),t("mfrac",null,[t("mrow",null,[t("mi",null,"∂"),t("mi",null,"f")]),t("mrow",null,[t("mi",null,"∂"),t("mi",null,"x")])]),t("mo",{"data-mjx-texclass":"CLOSE"},")")]),t("mi",null,"u")])],-1))]),e[5]||(e[5]=a(". This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector products efficiently using mixed-mode AD."))]),e[7]||(e[7]=s('
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the JVP.
',1)),t("p",null,[e[12]||(e[12]=a("Compute the Vector-Jacobian Product ")),t("mjx-container",u,[(n(),i("svg",h,e[10]||(e[10]=[s('',1)]))),e[11]||(e[11]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msup",null,[t("mrow",{"data-mjx-texclass":"INNER"},[t("mo",{"data-mjx-texclass":"OPEN"},"("),t("mfrac",null,[t("mrow",null,[t("mi",null,"∂"),t("mi",null,"f")]),t("mrow",null,[t("mi",null,"∂"),t("mi",null,"x")])]),t("mo",{"data-mjx-texclass":"CLOSE"},")")]),t("mi",null,"T")]),t("mi",null,"u")])],-1))]),e[13]||(e[13]=a(". This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian products efficiently using mixed-mode AD."))]),e[15]||(e[15]=s('
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
Warning
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the VJP.
',8))]),e[21]||(e[21]=t("h2",{id:"Batched-AD",tabindex:"-1"},[a("Batched AD "),t("a",{class:"header-anchor",href:"#Batched-AD","aria-label":'Permalink to "Batched AD {#Batched-AD}"'},"")],-1)),t("details",g,[t("summary",null,[e[16]||(e[16]=t("a",{id:"Lux.batched_jacobian",href:"#Lux.batched_jacobian"},[t("span",{class:"jlbinding"},"Lux.batched_jacobian")],-1)),e[17]||(e[17]=a()),l(o,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),e[18]||(e[18]=s('
Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):
ndims(y) ≥ 2
size(y, ndims(y)) == size(x, ndims(x))
Backends & AD Packages
Supported Backends
Packages Needed
AutoForwardDiff
AutoZygote
Zygote.jl
Arguments
f: The function to compute the jacobian of.
backend: The backend to use for computing the jacobian.
x: The input to the function. Must have ndims(x) ≥ 2.
Returns
J: The Jacobian of f with respect to x. This will be a 3D Array. If the dimensions of x are (N₁, N₂, ..., Nₙ, B) and of y are (M₁, M₂, ..., Mₘ, B), then J will be a ((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B) Array.
Danger
f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.
',11))]),e[22]||(e[22]=t("h2",{id:"Nested-2nd-Order-AD",tabindex:"-1"},[a("Nested 2nd Order AD "),t("a",{class:"header-anchor",href:"#Nested-2nd-Order-AD","aria-label":'Permalink to "Nested 2nd Order AD {#Nested-2nd-Order-AD}"'},"")],-1)),e[23]||(e[23]=t("p",null,[a("Consult the "),t("a",{href:"/previews/PR1023/manual/nested_autodiff#nested_autodiff"},"manual page on Nested AD"),a(" for information on nested automatic differentiation.")],-1))])}const L=d(Q,[["render",f]]);export{v as __pageData,L as default};
diff --git a/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.js b/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.js
new file mode 100644
index 0000000000..35e29c2668
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.js
@@ -0,0 +1,55 @@
+import{_ as h,c as l,a2 as e,j as i,a,G as n,B as p,o as k}from"./chunks/framework.DFwXuivk.js";const A=JSON.parse('{"title":"Experimental Features","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/contrib.md","filePath":"api/Lux/contrib.md","lastUpdated":null}'),r={name:"api/Lux/contrib.md"},d={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"};function F(m,s,C,f,b,D){const t=p("Badge");return k(),l("div",null,[s[21]||(s[21]=e('
All features listed on this page are experimental which means:
No SemVer Guarantees. We use code here to iterate fast. That said, historically we have never broken any code in this module and have always provided a deprecation period.
Expect edge-cases and report them. It will help us move these features out of experimental sooner.
Freeze the parameters with name which_params of the layer l.
Use Lux.Experimental.freeze instead
It is always recommended to use the Lux.Experimental.freeze function instead of directly using the FrozenLayer constructor.
No checks for which_params
There are no checks for which_params. For example, if the original layer has parameters named (:weight, :bias), and which_params is set to (:myweight,) then none of the parameters are frozen and no error is thrown.
Arguments
l: Lux AbstractLuxLayer.
which_params: Parameter Names to be Frozen. Can be set to nothing, in which case all parameters are frozen.
Extended Help
Parameters
Parameters of the layer l excluding which_params.
States
frozen_params: Parameters that are frozen, i.e., which_params.
states: The state of the inner layer l.
Note on Internal Layer Implementation
The inner layer should work with NamedTuple parameters. In order to support custom parameter types, users need to implement Lux.Utils.merge(::CustomParamType, ::NamedTuple) or extend Lux.Utils.named_tuple(::CustomParamType) to return a NamedTuple.
',6))]),s[22]||(s[22]=i("p",null,[a("For detailed usage example look at the "),i("a",{href:"/previews/PR1023/manual/freezing_model_parameters#freezing-model-parameters"},"manual page"),a(".")],-1)),s[23]||(s[23]=i("h2",{id:"Map-over-Layer",tabindex:"-1"},[a("Map over Layer "),i("a",{class:"header-anchor",href:"#Map-over-Layer","aria-label":'Permalink to "Map over Layer {#Map-over-Layer}"'},"")],-1)),i("details",g,[i("summary",null,[s[9]||(s[9]=i("a",{id:"Lux.Experimental.layer_map",href:"#Lux.Experimental.layer_map"},[i("span",{class:"jlbinding"},"Lux.Experimental.layer_map")],-1)),s[10]||(s[10]=a()),n(t,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[11]||(s[11]=e(`
Map the function f over the model l, with the parameters ps and states st. This is different from Functors.fmap since it zips the layers, parameters, and states and invokes the function on all of them together.
KeyPath provided to the function
The KeyPath depths on the structure of the parameters and states. This is of consequence exclusively for AbstractLuxWrapperLayer where the structure of the layer doesn't match the structure of the parameters and states. In the example, provided below, the KeyPath is (:chain, :dense_1) for the first layer (following the structure in ps) while accessing the same layer in the chain is done with ( :chain, :layers, :dense_1).
Call Signature for f
Must take 4 inputs – AbstractLuxLayer, Corresponding Parameters, Corresponding States, and the Functors.KeyPath to the layer.
Must return a tuple of 3 elements – AbstractLuxLayer, new parameters and the new states.
Extended Help
Example
julia
julia> using Lux, Random
+
+julia> c = Parallel(
+ +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
+ dense_3=Dense(5 => 1));
+
+julia> rng = Random.default_rng();
+
+julia> ps, st = Lux.setup(rng, c);
+
+julia> # Makes parameters of Dense Layers inside Chain zero
+ function zero_dense_params(l, ps, st, name)
+ if l isa Dense
+ println("zeroing params of $name")
+ ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
+ end
+ return l, ps, st
+ end;
+
+julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
+zeroing params of KeyPath(:chain, :dense_1)
+zeroing params of KeyPath(:chain, :dense_2)
+zeroing params of KeyPath(:dense_3,)
+
+julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
+ ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
+ ps_new.dense_3.weight, ps_new.dense_3.bias))
+true
`,9))]),s[24]||(s[24]=i("h2",{id:"Debugging-Functionality",tabindex:"-1"},[a("Debugging Functionality "),i("a",{class:"header-anchor",href:"#Debugging-Functionality","aria-label":'Permalink to "Debugging Functionality {#Debugging-Functionality}"'},"")],-1)),s[25]||(s[25]=i("p",null,"Model not working properly! Here are some functionalities to help you debug you Lux model.",-1)),i("details",y,[i("summary",null,[s[12]||(s[12]=i("a",{id:"Lux.Experimental.@debug_mode",href:"#Lux.Experimental.@debug_mode"},[i("span",{class:"jlbinding"},"Lux.Experimental.@debug_mode")],-1)),s[13]||(s[13]=a()),n(t,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),s[14]||(s[14]=e('
julia
@debug_mode layer kwargs...
Recurses into the layer and replaces the inner most non Container Layers with a Lux.Experimental.DebugLayer.
A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging.
Arguments
layer: The layer to be wrapped.
Extended Help
Keyword Arguments
nan_check: Whether to check for NaNs in the input, parameters, and states. Can be :both, :forward, :backward, or :none.
error_check: Whether to check for errors in the layer. If true, will throw an error if the layer fails.
location: The location of the layer. Use Lux.Experimental.@debug_mode to construct this layer to populate this value correctly.
Input / Output
Inputs and outputs are the same as the layer unless one of the nan_check or error_check criteria is met.
If nan_check is enabled and NaNs are detected then a DomainError is thrown. If error_check is enabled, then any errors in the layer are thrown with useful information to track where the error originates.
ChainRules Compatible Reverse Mode AD Tools
nan_check for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently.
Disable After Debugging
This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance.
Updates the parameters in ps with a common set of parameters new_parameters that are shared between each list in the nested list sharing. (That was kind of a mouthful, the example should make it clear).
Arguments
ps: Original parameters.
sharing: A nested list of lists of accessors of ps which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)
new_parameters: If passed the length of new_parameters must be equal to the length of sharing. For each vector in sharing the corresponding parameter in new_parameters will be used. (If not passed, the parameters corresponding to the first element of each vector in sharing will be used).
Returns
Updated Parameters having the same structure as ps.
ComponentArrays doesn't allow sharing parameters. Converting the returned parameters to a ComponentArray will silently cause the parameter sharing to be undone.
`,10))])])}const L=h(r,[["render",F]]);export{A as __pageData,L as default};
diff --git a/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.lean.js b/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.lean.js
new file mode 100644
index 0000000000..35e29c2668
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_contrib.md.tcySUcib.lean.js
@@ -0,0 +1,55 @@
+import{_ as h,c as l,a2 as e,j as i,a,G as n,B as p,o as k}from"./chunks/framework.DFwXuivk.js";const A=JSON.parse('{"title":"Experimental Features","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/contrib.md","filePath":"api/Lux/contrib.md","lastUpdated":null}'),r={name:"api/Lux/contrib.md"},d={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"};function F(m,s,C,f,b,D){const t=p("Badge");return k(),l("div",null,[s[21]||(s[21]=e('
All features listed on this page are experimental which means:
No SemVer Guarantees. We use code here to iterate fast. That said, historically we have never broken any code in this module and have always provided a deprecation period.
Expect edge-cases and report them. It will help us move these features out of experimental sooner.
Freeze the parameters with name which_params of the layer l.
Use Lux.Experimental.freeze instead
It is always recommended to use the Lux.Experimental.freeze function instead of directly using the FrozenLayer constructor.
No checks for which_params
There are no checks for which_params. For example, if the original layer has parameters named (:weight, :bias), and which_params is set to (:myweight,) then none of the parameters are frozen and no error is thrown.
Arguments
l: Lux AbstractLuxLayer.
which_params: Parameter Names to be Frozen. Can be set to nothing, in which case all parameters are frozen.
Extended Help
Parameters
Parameters of the layer l excluding which_params.
States
frozen_params: Parameters that are frozen, i.e., which_params.
states: The state of the inner layer l.
Note on Internal Layer Implementation
The inner layer should work with NamedTuple parameters. In order to support custom parameter types, users need to implement Lux.Utils.merge(::CustomParamType, ::NamedTuple) or extend Lux.Utils.named_tuple(::CustomParamType) to return a NamedTuple.
',6))]),s[22]||(s[22]=i("p",null,[a("For detailed usage example look at the "),i("a",{href:"/previews/PR1023/manual/freezing_model_parameters#freezing-model-parameters"},"manual page"),a(".")],-1)),s[23]||(s[23]=i("h2",{id:"Map-over-Layer",tabindex:"-1"},[a("Map over Layer "),i("a",{class:"header-anchor",href:"#Map-over-Layer","aria-label":'Permalink to "Map over Layer {#Map-over-Layer}"'},"")],-1)),i("details",g,[i("summary",null,[s[9]||(s[9]=i("a",{id:"Lux.Experimental.layer_map",href:"#Lux.Experimental.layer_map"},[i("span",{class:"jlbinding"},"Lux.Experimental.layer_map")],-1)),s[10]||(s[10]=a()),n(t,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),s[11]||(s[11]=e(`
Map the function f over the model l, with the parameters ps and states st. This is different from Functors.fmap since it zips the layers, parameters, and states and invokes the function on all of them together.
KeyPath provided to the function
The KeyPath depths on the structure of the parameters and states. This is of consequence exclusively for AbstractLuxWrapperLayer where the structure of the layer doesn't match the structure of the parameters and states. In the example, provided below, the KeyPath is (:chain, :dense_1) for the first layer (following the structure in ps) while accessing the same layer in the chain is done with ( :chain, :layers, :dense_1).
Call Signature for f
Must take 4 inputs – AbstractLuxLayer, Corresponding Parameters, Corresponding States, and the Functors.KeyPath to the layer.
Must return a tuple of 3 elements – AbstractLuxLayer, new parameters and the new states.
Extended Help
Example
julia
julia> using Lux, Random
+
+julia> c = Parallel(
+ +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
+ dense_3=Dense(5 => 1));
+
+julia> rng = Random.default_rng();
+
+julia> ps, st = Lux.setup(rng, c);
+
+julia> # Makes parameters of Dense Layers inside Chain zero
+ function zero_dense_params(l, ps, st, name)
+ if l isa Dense
+ println("zeroing params of $name")
+ ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
+ end
+ return l, ps, st
+ end;
+
+julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
+zeroing params of KeyPath(:chain, :dense_1)
+zeroing params of KeyPath(:chain, :dense_2)
+zeroing params of KeyPath(:dense_3,)
+
+julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
+ ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
+ ps_new.dense_3.weight, ps_new.dense_3.bias))
+true
`,9))]),s[24]||(s[24]=i("h2",{id:"Debugging-Functionality",tabindex:"-1"},[a("Debugging Functionality "),i("a",{class:"header-anchor",href:"#Debugging-Functionality","aria-label":'Permalink to "Debugging Functionality {#Debugging-Functionality}"'},"")],-1)),s[25]||(s[25]=i("p",null,"Model not working properly! Here are some functionalities to help you debug you Lux model.",-1)),i("details",y,[i("summary",null,[s[12]||(s[12]=i("a",{id:"Lux.Experimental.@debug_mode",href:"#Lux.Experimental.@debug_mode"},[i("span",{class:"jlbinding"},"Lux.Experimental.@debug_mode")],-1)),s[13]||(s[13]=a()),n(t,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),s[14]||(s[14]=e('
julia
@debug_mode layer kwargs...
Recurses into the layer and replaces the inner most non Container Layers with a Lux.Experimental.DebugLayer.
A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging.
Arguments
layer: The layer to be wrapped.
Extended Help
Keyword Arguments
nan_check: Whether to check for NaNs in the input, parameters, and states. Can be :both, :forward, :backward, or :none.
error_check: Whether to check for errors in the layer. If true, will throw an error if the layer fails.
location: The location of the layer. Use Lux.Experimental.@debug_mode to construct this layer to populate this value correctly.
Input / Output
Inputs and outputs are the same as the layer unless one of the nan_check or error_check criteria is met.
If nan_check is enabled and NaNs are detected then a DomainError is thrown. If error_check is enabled, then any errors in the layer are thrown with useful information to track where the error originates.
ChainRules Compatible Reverse Mode AD Tools
nan_check for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently.
Disable After Debugging
This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance.
Updates the parameters in ps with a common set of parameters new_parameters that are shared between each list in the nested list sharing. (That was kind of a mouthful, the example should make it clear).
Arguments
ps: Original parameters.
sharing: A nested list of lists of accessors of ps which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)
new_parameters: If passed the length of new_parameters must be equal to the length of sharing. For each vector in sharing the corresponding parameter in new_parameters will be used. (If not passed, the parameters corresponding to the first element of each vector in sharing will be used).
Returns
Updated Parameters having the same structure as ps.
ComponentArrays doesn't allow sharing parameters. Converting the returned parameters to a ComponentArray will silently cause the parameter sharing to be undone.
`,10))])])}const L=h(r,[["render",F]]);export{A as __pageData,L as default};
diff --git a/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.js b/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.js
new file mode 100644
index 0000000000..39f3ab4f17
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.js
@@ -0,0 +1,4 @@
+import{_ as n,c as d,a2 as e,j as s,a as t,G as l,B as r,o as p}from"./chunks/framework.DFwXuivk.js";const I=JSON.parse('{"title":"Distributed Utils","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/distributed_utils.md","filePath":"api/Lux/distributed_utils.md","lastUpdated":null}'),o={name:"api/Lux/distributed_utils.md"},k={class:"jldocstring custom-block"},h={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},C={class:"jldocstring custom-block"},L={class:"jldocstring custom-block"},j={class:"jldocstring custom-block"};function F(v,i,x,D,B,U){const a=r("Badge");return p(),d("div",null,[i[39]||(i[39]=e('
',3))]),i[40]||(i[40]=s("h2",{id:"initialization",tabindex:"-1"},[t("Initialization "),s("a",{class:"header-anchor",href:"#initialization","aria-label":'Permalink to "Initialization"'},"")],-1)),s("details",u,[s("summary",null,[i[6]||(i[6]=s("a",{id:"Lux.DistributedUtils.initialize",href:"#Lux.DistributedUtils.initialize"},[s("span",{class:"jlbinding"},"Lux.DistributedUtils.initialize")],-1)),i[7]||(i[7]=t()),l(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[8]||(i[8]=e('
Initialize the given backend. Users can supply cuda_devices and amdgpu_devices to initialize the backend with the given devices. These can be set to missing to prevent initialization of the given device type. If set to nothing, and the backend is functional we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to initialize the backend with the given devices.
Possible values for backend are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Get the distributed backend for the given backend type. Possible values are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Danger
initialize(backend; kwargs...) must be called before calling this function.
Backend Agnostic API to broadcast the given buffer sendrecvbuf or sendbuf to all workers into recvbuf. The value at root will be broadcasted to all other workers.
data must be compatible with MLUtils interface. The returned container is compatible with MLUtils interface and is used to partition the dataset across the available processes.
Load MLUtils.jl
MLUtils.jl must be installed and loaded before using this.
',4))])])}const P=n(o,[["render",F]]);export{I as __pageData,P as default};
diff --git a/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.lean.js b/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.lean.js
new file mode 100644
index 0000000000..39f3ab4f17
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_distributed_utils.md.D1YsLAOX.lean.js
@@ -0,0 +1,4 @@
+import{_ as n,c as d,a2 as e,j as s,a as t,G as l,B as r,o as p}from"./chunks/framework.DFwXuivk.js";const I=JSON.parse('{"title":"Distributed Utils","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/distributed_utils.md","filePath":"api/Lux/distributed_utils.md","lastUpdated":null}'),o={name:"api/Lux/distributed_utils.md"},k={class:"jldocstring custom-block"},h={class:"jldocstring custom-block"},u={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},b={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},f={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},m={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},C={class:"jldocstring custom-block"},L={class:"jldocstring custom-block"},j={class:"jldocstring custom-block"};function F(v,i,x,D,B,U){const a=r("Badge");return p(),d("div",null,[i[39]||(i[39]=e('
',3))]),i[40]||(i[40]=s("h2",{id:"initialization",tabindex:"-1"},[t("Initialization "),s("a",{class:"header-anchor",href:"#initialization","aria-label":'Permalink to "Initialization"'},"")],-1)),s("details",u,[s("summary",null,[i[6]||(i[6]=s("a",{id:"Lux.DistributedUtils.initialize",href:"#Lux.DistributedUtils.initialize"},[s("span",{class:"jlbinding"},"Lux.DistributedUtils.initialize")],-1)),i[7]||(i[7]=t()),l(a,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[8]||(i[8]=e('
Initialize the given backend. Users can supply cuda_devices and amdgpu_devices to initialize the backend with the given devices. These can be set to missing to prevent initialization of the given device type. If set to nothing, and the backend is functional we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to initialize the backend with the given devices.
Possible values for backend are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Get the distributed backend for the given backend type. Possible values are:
MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.
NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.
Danger
initialize(backend; kwargs...) must be called before calling this function.
Backend Agnostic API to broadcast the given buffer sendrecvbuf or sendbuf to all workers into recvbuf. The value at root will be broadcasted to all other workers.
data must be compatible with MLUtils interface. The returned container is compatible with MLUtils interface and is used to partition the dataset across the available processes.
Load MLUtils.jl
MLUtils.jl must be installed and loaded before using this.
',4))])])}const P=n(o,[["render",F]]);export{I as __pageData,P as default};
diff --git a/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.js b/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.js
new file mode 100644
index 0000000000..6bd83fd777
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.js
@@ -0,0 +1,32 @@
+import{_ as n,c as p,a2 as a,j as i,a as t,G as l,B as h,o as k}from"./chunks/framework.DFwXuivk.js";const A=JSON.parse('{"title":"Interoperability between Lux and other packages","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/interop.md","filePath":"api/Lux/interop.md","lastUpdated":null}'),r={name:"api/Lux/interop.md"},d={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"};function u(F,s,C,m,f,b){const e=h("Badge");return k(),p("div",null,[s[18]||(s[18]=a('
This always ignores the active field of some of the Flux layers. This is almost never going to be supported.
Keyword Arguments
preserve_ps_st: Set to true to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set force_preserve to true.
force_preserve: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if force_transform is false a warning will be printed and a core Lux layer will be returned. Else, it will create a FluxLayer.
Example
julia
julia> import Flux
+
+julia> using Adapt, Lux, Random
+
+julia> m = Flux.Chain(Flux.Dense(2 => 3, relu), Flux.Dense(3 => 2));
+
+julia> m2 = adapt(FromFluxAdaptor(), m); # or FromFluxAdaptor()(m.layers)
+
+julia> x = randn(Float32, 2, 32);
+
+julia> ps, st = Lux.setup(Random.default_rng(), m2);
+
+julia> size(first(m2(x, ps, st)))
+(2, 32)
SimpleChains.jl provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux models to SimpleChains models while preserving the layer interface.
Tip
Accessing these functions require manually loading SimpleChains, i.e., using SimpleChains must be present somewhere in the code for these to be used.
Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the AbstractLuxLayer interfacem but all internal calculations are performed using SimpleChains.
Warning
There is no way to preserve trained parameters and states when converting to SimpleChains.jl.
Warning
Any kind of initialization function is not preserved when converting to SimpleChains.jl.
Arguments
input_dims: Tuple of input dimensions excluding the batch dimension. These must be of static type as SimpleChains expects.
convert_to_array: SimpleChains.jl by default outputs StrideArraysCore.StrideArray, but this might not compose well with other packages. If convert_to_array is set to true, the output will be converted to a regular Array.
`,6))])])}const L=n(r,[["render",u]]);export{A as __pageData,L as default};
diff --git a/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.lean.js b/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.lean.js
new file mode 100644
index 0000000000..6bd83fd777
--- /dev/null
+++ b/previews/PR1023/assets/api_Lux_interop.md.DzKzCHM4.lean.js
@@ -0,0 +1,32 @@
+import{_ as n,c as p,a2 as a,j as i,a as t,G as l,B as h,o as k}from"./chunks/framework.DFwXuivk.js";const A=JSON.parse('{"title":"Interoperability between Lux and other packages","description":"","frontmatter":{},"headers":[],"relativePath":"api/Lux/interop.md","filePath":"api/Lux/interop.md","lastUpdated":null}'),r={name:"api/Lux/interop.md"},d={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},y={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"};function u(F,s,C,m,f,b){const e=h("Badge");return k(),p("div",null,[s[18]||(s[18]=a('
This always ignores the active field of some of the Flux layers. This is almost never going to be supported.
Keyword Arguments
preserve_ps_st: Set to true to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set force_preserve to true.
force_preserve: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if force_transform is false a warning will be printed and a core Lux layer will be returned. Else, it will create a FluxLayer.
Example
julia
julia> import Flux
+
+julia> using Adapt, Lux, Random
+
+julia> m = Flux.Chain(Flux.Dense(2 => 3, relu), Flux.Dense(3 => 2));
+
+julia> m2 = adapt(FromFluxAdaptor(), m); # or FromFluxAdaptor()(m.layers)
+
+julia> x = randn(Float32, 2, 32);
+
+julia> ps, st = Lux.setup(Random.default_rng(), m2);
+
+julia> size(first(m2(x, ps, st)))
+(2, 32)
SimpleChains.jl provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux models to SimpleChains models while preserving the layer interface.
Tip
Accessing these functions require manually loading SimpleChains, i.e., using SimpleChains must be present somewhere in the code for these to be used.
Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the AbstractLuxLayer interfacem but all internal calculations are performed using SimpleChains.
Warning
There is no way to preserve trained parameters and states when converting to SimpleChains.jl.
Warning
Any kind of initialization function is not preserved when converting to SimpleChains.jl.
Arguments
input_dims: Tuple of input dimensions excluding the batch dimension. These must be of static type as SimpleChains expects.
convert_to_array: SimpleChains.jl by default outputs StrideArraysCore.StrideArray, but this might not compose well with other packages. If convert_to_array is set to true, the output will be converted to a regular Array.
Create a layer which passes an input to each path in layers, before reducing the output with connection.
Arguments
connection: An N-argument function that is called after passing the input through each layer. If connection = nothing, we return a tuple Parallel(nothing, f, g)(x, y) = (f(x), g(y))
Layers can be specified in two formats:
A list of N Lux layers
Specified as N keyword arguments.
Extended Help
Inputs
x: If x is not a tuple, then return is computed as connection([l(x) for l in layers]...). Else one is passed to each layer, thus Parallel(+, f, g)(x, y) = f(x) + g(y).
Returns
See the Inputs section for how the output is computed
Updated state of the layers
Parameters
Parameters of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
States
States of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
See also SkipConnection which is Parallel with one identity.
Create a skip connection which consists of a layer or Chain of consecutive layers and a shortcut connection linking the block's input to the output through a user-supplied 2-argument callable. The first argument to the callable will be propagated through the given layer while the second is the unchanged, "skipped" input.
The simplest "ResNet"-type connection is just SkipConnection(layer, +).
Arguments
layer: Layer or Chain of layers to be applied to the input
connection:
A 2-argument function that takes layer(input) and the input OR
An AbstractLuxLayer that takes (layer(input), input) as input
Extended Help
Inputs
x: Will be passed directly to layer
Returns
Output of connection(layer(input), input)
Updated state of layer
Parameters
Parameters of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
States
States of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
Iteratively applies model for repeats number of times. The initial input is passed into the model repeatedly if input_injection = Val(true). This layer unrolls the computation, however, semantically this is same as:
input_injection = Val(false)
julia
res = x
+for i in 1:repeats
+ res, st = model(res, ps, st)
+end
input_injection = Val(true)
julia
res = x
+for i in 1:repeats
+ res, st = model((res, x), ps, st)
+end
It is expected that repeats will be a reasonable number below 20, beyond that compile times for gradients might be unreasonably high.
Arguments
model must be an AbstractLuxLayer
Keyword Arguments
repeats: Number of times to apply the model
input_injection: If true, then the input is passed to the model along with the output
Image data should be stored in WHCN order (width, height, channels, batch). In other words, a 100 x 100 RGB image would be a 100 x 100 x 3 x 1 array, and a batch of 50 would be a 100 x 100 x 3 x 50 array. This has N = 2 spatial dimensions, and needs a kernel size like (5, 5), a 2-tuple of integers. To take convolutions along N feature dimensions, this layer expects as input an array with ndims(x) == N + 2, where size(x, N + 1) == in_chs is the number of input channels, and size(x, ndims(x)) is the number of observations in a batch.
Warning
Frameworks like Pytorch perform cross-correlation in their convolution layers. Pass cross_correlation=true to use cross-correlation instead.
Arguments
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Extended Help
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) / stride (possibly rounded) for each spatial dimension.
Periodic padding can achieved by pre-empting the layer with a WrappedFunction(x -> NNlib.circular_pad(x, N_pad; dims=pad_dims))
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform cross-correlation instead of convolution. Prior to v1, Lux used to have a CrossCor layer which performed cross-correlation. This was removed in v1 in favor of Conv with cross_correlation=true.
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution y of size (O_N, ..., O_1, C_out, N) where
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) * stride (possibly rounded) for each spatial dimension.
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform transposed cross-correlation instead of transposed convolution.
outpad: To converse Conv inversability when stride > 1, outpad can be used to increase the size of the output in the desired dimensions. Whereas pad is used to zero-pad the input, outpad only affects the output shape.
Extended Help
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution transpose y of size (O_N, ..., O_1, C_out, N) where
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. Dropout(p; dims = (3,4)) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. VariationalHiddenDropout(p; dims = 3) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
mask: Dropout mask. Initilly set to nothing. After every run, contains the mask applied in that call
update_mask: Stores whether new mask needs to be generated in the current call
Global LP Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
GPU Support
This layer is currently only supported on CPU.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Max Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Mean Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[87]||(a[87]=t("p",null,"Tuple containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[81]||(a[81]=s("Output ")),t("mjx-container",I,[(Q(),n("svg",S,a[79]||(a[79]=[i('',1)]))),a[80]||(a[80]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[82]||(a[82]=s(" of shape ")),a[83]||(a[83]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[86]||(a[86]=s("Tuple containing new hidden state ")),t("mjx-container",_,[(Q(),n("svg",G,a[84]||(a[84]=[i('',1)]))),a[85]||(a[85]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[88]||(a[88]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[122]||(a[122]=t("p",null,[t("strong",null,"Parameters")],-1)),t("ul",null,[t("li",null,[t("p",null,[a[91]||(a[91]=t("code",null,"weight_ih",-1)),a[92]||(a[92]=s(": Concatenated Weights to map from input space ")),t("mjx-container",W,[(Q(),n("svg",X,a[89]||(a[89]=[i('',1)]))),a[90]||(a[90]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[93]||(a[93]=s("."))])]),t("li",null,[t("p",null,[a[96]||(a[96]=t("code",null,"weight_hh",-1)),a[97]||(a[97]=s(": Concatenated Weights to map from hidden space ")),t("mjx-container",U,[(Q(),n("svg",q,a[94]||(a[94]=[i('',1)]))),a[95]||(a[95]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[98]||(a[98]=s("."))])]),t("li",null,[t("p",null,[a[101]||(a[101]=t("code",null,"bias_ih",-1)),a[102]||(a[102]=s(": Concatenated Bias vector for the input space ")),t("mjx-container",J,[(Q(),n("svg",K,a[99]||(a[99]=[i('',1)]))),a[100]||(a[100]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[103]||(a[103]=s(" (not present if ")),a[104]||(a[104]=t("code",null,"use_bias=false",-1)),a[105]||(a[105]=s(")."))])]),t("li",null,[t("p",null,[a[108]||(a[108]=t("code",null,"bias_hh",-1)),a[109]||(a[109]=s(": Concatenated Bias vector for the hidden space ")),t("mjx-container",$,[(Q(),n("svg",Y,a[106]||(a[106]=[i('',1)]))),a[107]||(a[107]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[110]||(a[110]=s(" (not present if ")),a[111]||(a[111]=t("code",null,"use_bias=false",-1)),a[112]||(a[112]=s(")."))])]),t("li",null,[t("p",null,[a[115]||(a[115]=t("code",null,"hidden_state",-1)),a[116]||(a[116]=s(": Initial hidden state vector (not present if ")),a[117]||(a[117]=t("code",null,"train_state=false",-1)),a[118]||(a[118]=s(") ")),t("mjx-container",t1,[(Q(),n("svg",a1,a[113]||(a[113]=[i('',1)]))),a[114]||(a[114]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[119]||(a[119]=s("."))])])]),a[123]||(a[123]=t("p",null,[t("strong",null,"States")],-1)),a[124]||(a[124]=t("ul",null,[t("li",null,[t("code",null,"rng"),s(": Controls the randomness (if any) in the initial state generation")])],-1)),a[125]||(a[125]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/src/layers/recurrent.jl#L488",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",s1,[t("summary",null,[a[126]||(a[126]=t("a",{id:"Lux.LSTMCell",href:"#Lux.LSTMCell"},[t("span",{class:"jlbinding"},"Lux.LSTMCell")],-1)),a[127]||(a[127]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[153]||(a[153]=i(`
out_dims: Output (Hidden State & Memory) Dimension
use_bias: Set to false to deactivate bias
train_state: Trainable initial hidden state can be activated by setting this to true
train_memory: Trainable initial memory can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
init_memory: Initializer for memory
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to false - Creates a hidden state using init_state, hidden memory using init_memory and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to false - Repeats hidden_state vector from the parameters to match the shape of x, creates hidden memory using init_memory and proceeds to Case 2.
Case 1c: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to true - Creates a hidden state using init_state, repeats the memory vector from parameters to match the shape of x and proceeds to Case 2.
Case 1d: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to true - Repeats the hidden state and memory vectors from the parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, c)) is provided, then the output and a tuple containing the updated hidden state and memory is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[141]||(a[141]=t("p",null,"Tuple Containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[132]||(a[132]=s("Output ")),t("mjx-container",l1,[(Q(),n("svg",n1,a[130]||(a[130]=[i('',1)]))),a[131]||(a[131]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[133]||(a[133]=s(" of shape ")),a[134]||(a[134]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[139]||(a[139]=s("Tuple containing new hidden state ")),t("mjx-container",Q1,[(Q(),n("svg",T1,a[135]||(a[135]=[i('',1)]))),a[136]||(a[136]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[140]||(a[140]=s(" and new memory ")),t("mjx-container",d1,[(Q(),n("svg",o1,a[137]||(a[137]=[i('',1)]))),a[138]||(a[138]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"c"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[142]||(a[142]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[155]||(a[155]=t("p",null,[t("strong",null,"Parameters")],-1)),t("ul",null,[t("li",null,[t("p",null,[a[145]||(a[145]=t("code",null,"weight_ih",-1)),a[146]||(a[146]=s(": Concatenated Weights to map from input space ")),t("mjx-container",r1,[(Q(),n("svg",p1,a[143]||(a[143]=[i('',1)]))),a[144]||(a[144]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"i")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"f")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"g")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"o")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[147]||(a[147]=s("."))])]),t("li",null,[t("p",null,[a[150]||(a[150]=t("code",null,"weight_hh",-1)),a[151]||(a[151]=s(": Concatenated Weights to map from hidden space ")),t("mjx-container",h1,[(Q(),n("svg",m1,a[148]||(a[148]=[i('',1)]))),a[149]||(a[149]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"i")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"f")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"g")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"o")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))])])]),a[152]||(a[152]=i("
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Concatenated Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
memory: Initial memory vector (not present if train_memory=false)
",4))]),a[156]||(a[156]=t("p",null,[t("strong",null,"States")],-1)),a[157]||(a[157]=t("ul",null,[t("li",null,[t("code",null,"rng"),s(": Controls the randomness (if any) in the initial state generation")])],-1)),a[158]||(a[158]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/src/layers/recurrent.jl#L309",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",g1,[t("summary",null,[a[159]||(a[159]=t("a",{id:"Lux.RNNCell",href:"#Lux.RNNCell"},[t("span",{class:"jlbinding"},"Lux.RNNCell")],-1)),a[160]||(a[160]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[173]||(a[173]=i(`
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[171]||(a[171]=t("p",null,"Tuple containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[165]||(a[165]=s("Output ")),t("mjx-container",u1,[(Q(),n("svg",y1,a[163]||(a[163]=[i('',1)]))),a[164]||(a[164]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[166]||(a[166]=s(" of shape ")),a[167]||(a[167]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[170]||(a[170]=s("Tuple containing new hidden state ")),t("mjx-container",E1,[(Q(),n("svg",f1,a[168]||(a[168]=[i('',1)]))),a[169]||(a[169]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[172]||(a[172]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[175]||(a[175]=i('
Parameters
weight_ih: Maps the input to the hidden state.
weight_hh: Maps the hidden state to the hidden state.
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
States
rng: Controls the randomness (if any) in the initial state generation
Wraps a recurrent cell (like RNNCell, LSTMCell, GRUCell) to automatically operate over a sequence of inputs.
Relation to Flux.Recur
This is completely distinct from Flux.Recur. It doesn't make the cell stateful, rather allows operating on an entire sequence of inputs at once. See StatefulRecurrentCell for functionality similar to Flux.Recur.
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
Keyword Arguments
return_sequence: If true returns the entire sequence of outputs, else returns only the last output. Defaults to false.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Output of the cell for the entire sequence.
Update state of the cell.
Tip
Frameworks like Tensorflow have special implementation of StackedRNNCells to handle sequentially composed RNN Cells. In Lux, one can simple stack multiple Recurrence blocks in a Chain to achieve the same.
To avoid undefined behavior, once the processing of a single sequence of data is complete, update the state with Lux.update_state(st, :carry, nothing).
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
backward_cell: A optional backward recurrent cell. If backward_cell is nothing, the rnn layer instance passed as the cell argument will be used to generate the backward layer automatically. in_dims of backward_cell should be consistent with in_dims of cell
Keyword Arguments
merge_mode: Function by which outputs of the forward and backward RNNs will be combined. default value is vcat. If nothing, the outputs will not be combined.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Merged output of the cell and backward_cell for the entire sequence.
Create a fully connected layer between two inputs and an output, and otherwise similar to Dense. Its output, given vectors x & y, is another vector z with, for all i in 1:out:
z[i] = activation(x' * W[i, :, :] * y + bias[i])
If x and y are matrices, then each column of the output z = B(x, y) is of this form, with B the Bilinear layer.
Arguments
in1_dims: number of input dimensions of x
in2_dims: number of input dimensions of y
in12_dims: If specified, then in1_dims = in2_dims = in12_dims
out: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in1_dims, in2_dims)). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
A 2-Tuple containing
x must be an AbstractArray with size(x, 1) == in1_dims
y must be an AbstractArray with size(y, 1) == in2_dims
If the input is an AbstractArray, then x = y
Returns
AbstractArray with dimensions (out_dims, size(x, 2))
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in1_dims, in2_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
Create a traditional fully connected layer, whose forward pass is given by: y = activation.(weight * x .+ bias)
Arguments
in_dims: number of input dimensions
out_dims: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims)). If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an AbstractArray with size(x, 1) == in_dims
Returns
AbstractArray with dimensions (out_dims, ...) where ... are the dimensions of x
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
A lookup table that stores embeddings of dimension out_dims for a vocabulary of size in_dims. When the vocabulary is multi-dimensional, the input is expected to be a tuple of Cartesian indices.
This layer is often used to store word embeddings and retrieve them using indices.
Arguments
in_dims: number(s) of input dimensions
out_dims: number of output dimensions
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims...))
Input
Integer OR
Abstract Vector of Integers OR
Abstract Array of Integers OR
Tuple of Integers OR
Tuple of Abstract Vectors of Integers OR
Tuple of Abstract Arrays of Integers
Returns
Returns the embedding corresponding to each index in the input. For an N dimensional input, an N + 1 dimensional output is returned.
Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: y = activation.(weight .* x .+ bias)
Arguments
dims: size of the learnable scale and bias parameters.
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims))
init_bias: initializer for the bias vector (ignored if use_bias=false)
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
Returns
Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
',13))]),a[276]||(a[276]=t("h2",{id:"Misc.-Helper-Layers",tabindex:"-1"},[s("Misc. Helper Layers "),t("a",{class:"header-anchor",href:"#Misc.-Helper-Layers","aria-label":'Permalink to "Misc. Helper Layers {#Misc.-Helper-Layers}"'},"")],-1)),t("details",D1,[t("summary",null,[a[197]||(a[197]=t("a",{id:"Lux.FlattenLayer",href:"#Lux.FlattenLayer"},[t("span",{class:"jlbinding"},"Lux.FlattenLayer")],-1)),a[198]||(a[198]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[199]||(a[199]=i(`
julia
FlattenLayer(; N = nothing)
Flattens the passed array into a matrix.
Keyword Arguments
N: Flatten the first N dimensions of the input array. If nothing, then all dimensions (except the last) are flattened. Note that the batch dimension is never flattened.
Inputs
x: AbstractArray
Returns
AbstractMatrix of size (:, size(x, ndims(x))) if N is nothing else the first N dimensions of the input array are flattened.
Empty NamedTuple()
Example
julia
julia> model = FlattenLayer()
+FlattenLayer{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = randn(rng, Float32, (2, 2, 2, 2));
+
+julia> y, st_new = model(x, ps, st);
+ size(y)
+(8, 2)
This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs.
Maxout over linear dense layers satisfies the universal approximation theorem. See [1].
Return a view of all the data of the input x where the index for dimension dim equals i. Equivalent to view(x,:,:,...,i,:,:,...) where i is in position d.
Arguments
dim: Dimension for indexing
i: Index for dimension dim
Inputs
x: AbstractArray that can be indexed with view(x,:,:,...,i,:,:,...)
Returns
view(x,:,:,...,i,:,:,...) where i is in position d
Wraps a stateless and parameter less function. Might be used when a function is added to Chain. For example, Chain(x -> relu.(x)) would not work and the right thing to do would be Chain((x, ps, st) -> (relu.(x), st)). An easier thing to do would be Chain(WrappedFunction(Base.Fix1(broadcast, relu)))
Reverse the specified dimension dims of the passed array
Arguments
dim: Dimension that need to be reversed. If nothing, for AbstractVector{T} it reverses itself (dimension 1), for other arrays, reverse the dimension ndims(x) - 1.
Inputs
x: AbstractArray.
Returns
AbstractArray with the same dimensions as the input
Empty NamedTuple()
Example
julia
julia> model = ReverseSequence()
+ReverseSequence{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = [1.0, 2.0, 3.0];
+
+julia> y, st_new = model(x, ps, st)
+([3.0, 2.0, 1.0], NamedTuple())
`,2)),t("p",null,[a[222]||(a[222]=t("code",null,"BatchNorm",-1)),a[223]||(a[223]=s(" computes the mean and variance for each ")),t("mjx-container",R1,[(Q(),n("svg",N1,a[220]||(a[220]=[i('',1)]))),a[221]||(a[221]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"D"),t("mn",null,"1")]),t("mi",null,"×"),t("mo",null,"."),t("mo",null,"."),t("mo",null,"."),t("mi",null,"×"),t("msub",null,[t("mi",null,"D"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"N"),t("mo",null,"−"),t("mn",null,"2")])]),t("mi",null,"×"),t("mn",null,"1"),t("mi",null,"×"),t("msub",null,[t("mi",null,"D"),t("mi",null,"N")])])],-1))]),a[224]||(a[224]=s(" input slice and normalises the input accordingly."))]),a[226]||(a[226]=i(`
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
groups is the number of groups along which the statistics are computed. The number of channels must be an integer multiple of the number of groups.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
epsilon: a value added to the denominator for numerical stability
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
training: Used to check if training/inference mode
`,2)),t("p",null,[a[234]||(a[234]=s("Instance Normalization computes the mean and variance for each ")),t("mjx-container",P1,[(Q(),n("svg",I1,a[232]||(a[232]=[i('',1)]))),a[233]||(a[233]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"D"),t("mn",null,"1")]),t("mo",null,"×"),t("mo",null,"."),t("mo",null,"."),t("mo",null,"."),t("mo",null,"×"),t("msub",null,[t("mi",null,"D"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"N"),t("mo",null,"−"),t("mn",null,"2")])]),t("mo",null,"×"),t("mn",null,"1"),t("mo",null,"×"),t("mn",null,"1")])],-1))]),a[235]||(a[235]=s("` input slice and normalises the input accordingly."))]),a[237]||(a[237]=i(`
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
Computes mean and standard deviation over the whole input array, and uses these to normalize the whole array. Optionally applies an elementwise affine transformation afterwards.
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This updates the parameters in which_params (e.g. weight) using two parameters: one specifying the magnitude (e.g. weight_g) and one specifying the direction (e.g. weight_v).
Arguments
layer whose parameters are being reparameterized
which_params: parameter names for the parameters being reparameterized
By default, a norm over the entire array is computed. Pass dims to modify the dimension.
Inputs
x: Should be of valid type for input to layer
Returns
Output from layer
Updated model state of layer
Parameters
normalized: Parameters of layer that are being normalized
unnormalized: Parameters of layer that are not being normalized
',12))]),a[278]||(a[278]=t("h2",{id:"upsampling",tabindex:"-1"},[s("Upsampling "),t("a",{class:"header-anchor",href:"#upsampling","aria-label":'Permalink to "Upsampling"'},"")],-1)),t("details",a2,[t("summary",null,[a[263]||(a[263]=t("a",{id:"Lux.PixelShuffle",href:"#Lux.PixelShuffle"},[t("span",{class:"jlbinding"},"Lux.PixelShuffle")],-1)),a[264]||(a[264]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[265]||(a[265]=i('
julia
PixelShuffle(r::Int)
Pixel shuffling layer with upscale factor r. Usually used for generating higher resolution images while upscaling them.
See NNlib.pixel_shuffle for more details.
Arguments
r: Upscale factor
Inputs
x: For 4D-arrays representing N images, the operation converts input size(x) == (W, H, r² x C, N) to output of size (r x W, r x H, C, N). For D-dimensional data, it expects ndims(x) == D + 2 with channel and batch dimensions, and divides the number of channels by rᴰ.
Returns
Output of size (r x W, r x H, C, N) for 4D-arrays, and (r x W, r x H, ..., C, N) for D-dimensional data, where D = ndims(x) - 2
mode: Set to :nearest, :linear, :bilinear or :trilinear
Exactly one of two keywords must be specified:
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
Alternatively, keyword size accepts a tuple, to directly specify the leading dimensions of the output.
Option 2
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
mode: Set to :nearest, :bilinear or :trilinear
Currently supported upsampling modes and corresponding NNlib's methods are:
:nearest -> NNlib.upsample_nearest
:bilinear -> NNlib.upsample_bilinear
:trilinear -> NNlib.upsample_trilinear
Extended Help
Other Keyword Arguments
align_corners: If true, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode is one of :bilinear or :trilinear.
Inputs
x: For the input dimensions look into the documentation for the corresponding NNlib function
As a rule of thumb, :nearest should work with arrays of arbitrary dimensions
:bilinear works with 4D Arrays
:trilinear works with 5D Arrays
Returns
Upsampled Input of size size or of size (I_1 x scale[1], ..., I_N x scale[N], C, N)
Create a layer which passes an input to each path in layers, before reducing the output with connection.
Arguments
connection: An N-argument function that is called after passing the input through each layer. If connection = nothing, we return a tuple Parallel(nothing, f, g)(x, y) = (f(x), g(y))
Layers can be specified in two formats:
A list of N Lux layers
Specified as N keyword arguments.
Extended Help
Inputs
x: If x is not a tuple, then return is computed as connection([l(x) for l in layers]...). Else one is passed to each layer, thus Parallel(+, f, g)(x, y) = f(x) + g(y).
Returns
See the Inputs section for how the output is computed
Updated state of the layers
Parameters
Parameters of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
States
States of each layer wrapped in a NamedTuple with fields = layer_1, layer_2, ..., layer_N (naming changes if using the kwargs API)
See also SkipConnection which is Parallel with one identity.
Create a skip connection which consists of a layer or Chain of consecutive layers and a shortcut connection linking the block's input to the output through a user-supplied 2-argument callable. The first argument to the callable will be propagated through the given layer while the second is the unchanged, "skipped" input.
The simplest "ResNet"-type connection is just SkipConnection(layer, +).
Arguments
layer: Layer or Chain of layers to be applied to the input
connection:
A 2-argument function that takes layer(input) and the input OR
An AbstractLuxLayer that takes (layer(input), input) as input
Extended Help
Inputs
x: Will be passed directly to layer
Returns
Output of connection(layer(input), input)
Updated state of layer
Parameters
Parameters of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
States
States of layer OR
If connection is an AbstractLuxLayer, then NamedTuple with fields :layers and :connection
Iteratively applies model for repeats number of times. The initial input is passed into the model repeatedly if input_injection = Val(true). This layer unrolls the computation, however, semantically this is same as:
input_injection = Val(false)
julia
res = x
+for i in 1:repeats
+ res, st = model(res, ps, st)
+end
input_injection = Val(true)
julia
res = x
+for i in 1:repeats
+ res, st = model((res, x), ps, st)
+end
It is expected that repeats will be a reasonable number below 20, beyond that compile times for gradients might be unreasonably high.
Arguments
model must be an AbstractLuxLayer
Keyword Arguments
repeats: Number of times to apply the model
input_injection: If true, then the input is passed to the model along with the output
Image data should be stored in WHCN order (width, height, channels, batch). In other words, a 100 x 100 RGB image would be a 100 x 100 x 3 x 1 array, and a batch of 50 would be a 100 x 100 x 3 x 50 array. This has N = 2 spatial dimensions, and needs a kernel size like (5, 5), a 2-tuple of integers. To take convolutions along N feature dimensions, this layer expects as input an array with ndims(x) == N + 2, where size(x, N + 1) == in_chs is the number of input channels, and size(x, ndims(x)) is the number of observations in a batch.
Warning
Frameworks like Pytorch perform cross-correlation in their convolution layers. Pass cross_correlation=true to use cross-correlation instead.
Arguments
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Extended Help
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) / stride (possibly rounded) for each spatial dimension.
Periodic padding can achieved by pre-empting the layer with a WrappedFunction(x -> NNlib.circular_pad(x, N_pad; dims=pad_dims))
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform cross-correlation instead of convolution. Prior to v1, Lux used to have a CrossCor layer which performed cross-correlation. This was removed in v1 in favor of Conv with cross_correlation=true.
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution y of size (O_N, ..., O_1, C_out, N) where
k: Tuple of integers specifying the size of the convolutional kernel. Eg, for 2D convolutions length(k) == 2
in_chs: Number of input channels
out_chs: Number of input and output channels
activation: Activation Function
Keyword Arguments
init_weight: Controls the initialization of the weight parameter. If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: Controls the initialization of the bias parameter. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(fan_in)).
stride: Should each be either single integer, or a tuple with N integers
dilation: Should each be either single integer, or a tuple with N integers
pad: Specifies the number of elements added to the borders of the data array. It can be
a single integer for equal padding all around,
a tuple of N integers, to apply the same padding at begin/end of each spatial dimension,
a tuple of 2*N integers, for asymmetric padding, or
the singleton SamePad(), to calculate padding such that size(output,d) == size(x,d) * stride (possibly rounded) for each spatial dimension.
groups: Expected to be an Int. It specifies the number of groups to divide a convolution into (set groups = in_chs for Depthwise Convolutions). in_chs and out_chs must be divisible by groups.
use_bias: Trainable bias can be disabled entirely by setting this to false.
cross_correlation: If true, perform transposed cross-correlation instead of transposed convolution.
outpad: To converse Conv inversability when stride > 1, outpad can be used to increase the size of the output in the desired dimensions. Whereas pad is used to zero-pad the input, outpad only affects the output shape.
Extended Help
Inputs
x: Data satisfying ndims(x) == N + 2 && size(x, N - 1) == in_chs, i.e. size(x) = (I_N, ..., I_1, C_in, N)
Returns
Output of the convolution transpose y of size (O_N, ..., O_1, C_out, N) where
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. Dropout(p; dims = (3,4)) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
p: Probability of Dropout (if p = 0 then NoOpLayer is returned)
Keyword Arguments
To apply dropout along certain dimension(s), specify the dims keyword. e.g. VariationalHiddenDropout(p; dims = 3) will randomly zero out entire channels on WHCN input (also called 2D dropout).
Inputs
x: Must be an AbstractArray
Returns
x with dropout mask applied if training=Val(true) else just x
State with updated rng
States
rng: Pseudo Random Number Generator
training: Used to check if training/inference mode
mask: Dropout mask. Initilly set to nothing. After every run, contains the mask applied in that call
update_mask: Stores whether new mask needs to be generated in the current call
Global LP Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
GPU Support
This layer is currently only supported on CPU.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Max Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
Global Mean Pooling layer. Transforms (w, h, c, b)-shaped input into (1, 1, c, b)-shaped output, by performing mean pooling on the complete (w, h)-shaped feature maps.
Inputs
x: Data satisfying ndims(x) > 2, i.e. size(x) = (I_N, ..., I_1, C, N)
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 3 functions. If a single value is passed, it is copied into a 3 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[87]||(a[87]=t("p",null,"Tuple containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[81]||(a[81]=s("Output ")),t("mjx-container",I,[(Q(),n("svg",S,a[79]||(a[79]=[i('',1)]))),a[80]||(a[80]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[82]||(a[82]=s(" of shape ")),a[83]||(a[83]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[86]||(a[86]=s("Tuple containing new hidden state ")),t("mjx-container",_,[(Q(),n("svg",G,a[84]||(a[84]=[i('',1)]))),a[85]||(a[85]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[88]||(a[88]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[122]||(a[122]=t("p",null,[t("strong",null,"Parameters")],-1)),t("ul",null,[t("li",null,[t("p",null,[a[91]||(a[91]=t("code",null,"weight_ih",-1)),a[92]||(a[92]=s(": Concatenated Weights to map from input space ")),t("mjx-container",W,[(Q(),n("svg",X,a[89]||(a[89]=[i('',1)]))),a[90]||(a[90]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[93]||(a[93]=s("."))])]),t("li",null,[t("p",null,[a[96]||(a[96]=t("code",null,"weight_hh",-1)),a[97]||(a[97]=s(": Concatenated Weights to map from hidden space ")),t("mjx-container",U,[(Q(),n("svg",q,a[94]||(a[94]=[i('',1)]))),a[95]||(a[95]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[98]||(a[98]=s("."))])]),t("li",null,[t("p",null,[a[101]||(a[101]=t("code",null,"bias_ih",-1)),a[102]||(a[102]=s(": Concatenated Bias vector for the input space ")),t("mjx-container",J,[(Q(),n("svg",K,a[99]||(a[99]=[i('',1)]))),a[100]||(a[100]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[103]||(a[103]=s(" (not present if ")),a[104]||(a[104]=t("code",null,"use_bias=false",-1)),a[105]||(a[105]=s(")."))])]),t("li",null,[t("p",null,[a[108]||(a[108]=t("code",null,"bias_hh",-1)),a[109]||(a[109]=s(": Concatenated Bias vector for the hidden space ")),t("mjx-container",$,[(Q(),n("svg",Y,a[106]||(a[106]=[i('',1)]))),a[107]||(a[107]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[110]||(a[110]=s(" (not present if ")),a[111]||(a[111]=t("code",null,"use_bias=false",-1)),a[112]||(a[112]=s(")."))])]),t("li",null,[t("p",null,[a[115]||(a[115]=t("code",null,"hidden_state",-1)),a[116]||(a[116]=s(": Initial hidden state vector (not present if ")),a[117]||(a[117]=t("code",null,"train_state=false",-1)),a[118]||(a[118]=s(") ")),t("mjx-container",t1,[(Q(),n("svg",a1,a[113]||(a[113]=[i('',1)]))),a[114]||(a[114]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"r")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"z")])]),t("mo",null,","),t("msub",null,[t("mi",null,"b"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"n")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[119]||(a[119]=s("."))])])]),a[123]||(a[123]=t("p",null,[t("strong",null,"States")],-1)),a[124]||(a[124]=t("ul",null,[t("li",null,[t("code",null,"rng"),s(": Controls the randomness (if any) in the initial state generation")])],-1)),a[125]||(a[125]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/src/layers/recurrent.jl#L488",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",s1,[t("summary",null,[a[126]||(a[126]=t("a",{id:"Lux.LSTMCell",href:"#Lux.LSTMCell"},[t("span",{class:"jlbinding"},"Lux.LSTMCell")],-1)),a[127]||(a[127]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[153]||(a[153]=i(`
out_dims: Output (Hidden State & Memory) Dimension
use_bias: Set to false to deactivate bias
train_state: Trainable initial hidden state can be activated by setting this to true
train_memory: Trainable initial memory can be activated by setting this to true
init_bias: Initializer for bias. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. Must be a tuple containing 4 functions. If a single value is passed, it is copied into a 4 element tuple. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
init_memory: Initializer for memory
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to false - Creates a hidden state using init_state, hidden memory using init_memory and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to false - Repeats hidden_state vector from the parameters to match the shape of x, creates hidden memory using init_memory and proceeds to Case 2.
Case 1c: Only a single input x of shape (in_dims, batch_size), train_state is set to false, train_memory is set to true - Creates a hidden state using init_state, repeats the memory vector from parameters to match the shape of x and proceeds to Case 2.
Case 1d: Only a single input x of shape (in_dims, batch_size), train_state is set to true, train_memory is set to true - Repeats the hidden state and memory vectors from the parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, c)) is provided, then the output and a tuple containing the updated hidden state and memory is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[141]||(a[141]=t("p",null,"Tuple Containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[132]||(a[132]=s("Output ")),t("mjx-container",l1,[(Q(),n("svg",n1,a[130]||(a[130]=[i('',1)]))),a[131]||(a[131]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[133]||(a[133]=s(" of shape ")),a[134]||(a[134]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[139]||(a[139]=s("Tuple containing new hidden state ")),t("mjx-container",Q1,[(Q(),n("svg",T1,a[135]||(a[135]=[i('',1)]))),a[136]||(a[136]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[140]||(a[140]=s(" and new memory ")),t("mjx-container",d1,[(Q(),n("svg",o1,a[137]||(a[137]=[i('',1)]))),a[138]||(a[138]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"c"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[142]||(a[142]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[155]||(a[155]=t("p",null,[t("strong",null,"Parameters")],-1)),t("ul",null,[t("li",null,[t("p",null,[a[145]||(a[145]=t("code",null,"weight_ih",-1)),a[146]||(a[146]=s(": Concatenated Weights to map from input space ")),t("mjx-container",r1,[(Q(),n("svg",p1,a[143]||(a[143]=[i('',1)]))),a[144]||(a[144]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"i")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"f")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"g")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"i"),t("mi",null,"o")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))]),a[147]||(a[147]=s("."))])]),t("li",null,[t("p",null,[a[150]||(a[150]=t("code",null,"weight_hh",-1)),a[151]||(a[151]=s(": Concatenated Weights to map from hidden space ")),t("mjx-container",h1,[(Q(),n("svg",m1,a[148]||(a[148]=[i('',1)]))),a[149]||(a[149]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("mo",{fence:"false",stretchy:"false"},"{"),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"i")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"f")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"g")])]),t("mo",null,","),t("msub",null,[t("mi",null,"W"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"h"),t("mi",null,"o")])]),t("mo",{fence:"false",stretchy:"false"},"}")])],-1))])])]),a[152]||(a[152]=i("
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Concatenated Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
memory: Initial memory vector (not present if train_memory=false)
",4))]),a[156]||(a[156]=t("p",null,[t("strong",null,"States")],-1)),a[157]||(a[157]=t("ul",null,[t("li",null,[t("code",null,"rng"),s(": Controls the randomness (if any) in the initial state generation")])],-1)),a[158]||(a[158]=t("p",null,[t("a",{href:"https://github.com/LuxDL/Lux.jl/blob/2ffa2f745c551ad2880316402fff5c9ff367ea40/src/layers/recurrent.jl#L309",target:"_blank",rel:"noreferrer"},"source")],-1))]),t("details",g1,[t("summary",null,[a[159]||(a[159]=t("a",{id:"Lux.RNNCell",href:"#Lux.RNNCell"},[t("span",{class:"jlbinding"},"Lux.RNNCell")],-1)),a[160]||(a[160]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[173]||(a[173]=i(`
train_state: Trainable initial hidden state can be activated by setting this to true
init_bias: Initializer for bias. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_weight: Initializer for weight. If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(out_dims)).
init_state: Initializer for hidden state
Inputs
Case 1a: Only a single input x of shape (in_dims, batch_size), train_state is set to false - Creates a hidden state using init_state and proceeds to Case 2.
Case 1b: Only a single input x of shape (in_dims, batch_size), train_state is set to true - Repeats hidden_state from parameters to match the shape of x and proceeds to Case 2.
Case 2: Tuple (x, (h, )) is provided, then the output and a tuple containing the updated hidden state is returned.
Returns
",5)),t("ul",null,[t("li",null,[a[171]||(a[171]=t("p",null,"Tuple containing",-1)),t("ul",null,[t("li",null,[t("p",null,[a[165]||(a[165]=s("Output ")),t("mjx-container",u1,[(Q(),n("svg",y1,a[163]||(a[163]=[i('',1)]))),a[164]||(a[164]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))]),a[166]||(a[166]=s(" of shape ")),a[167]||(a[167]=t("code",null,"(out_dims, batch_size)",-1))])]),t("li",null,[t("p",null,[a[170]||(a[170]=s("Tuple containing new hidden state ")),t("mjx-container",E1,[(Q(),n("svg",f1,a[168]||(a[168]=[i('',1)]))),a[169]||(a[169]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"h"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"n"),t("mi",null,"e"),t("mi",null,"w")])])])],-1))])])])])]),a[172]||(a[172]=t("li",null,[t("p",null,"Updated model state")],-1))]),a[175]||(a[175]=i('
Parameters
weight_ih: Maps the input to the hidden state.
weight_hh: Maps the hidden state to the hidden state.
bias_ih: Bias vector for the input-hidden connection (not present if use_bias=false)
bias_hh: Bias vector for the hidden-hidden connection (not present if use_bias=false)
hidden_state: Initial hidden state vector (not present if train_state=false)
States
rng: Controls the randomness (if any) in the initial state generation
Wraps a recurrent cell (like RNNCell, LSTMCell, GRUCell) to automatically operate over a sequence of inputs.
Relation to Flux.Recur
This is completely distinct from Flux.Recur. It doesn't make the cell stateful, rather allows operating on an entire sequence of inputs at once. See StatefulRecurrentCell for functionality similar to Flux.Recur.
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
Keyword Arguments
return_sequence: If true returns the entire sequence of outputs, else returns only the last output. Defaults to false.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Output of the cell for the entire sequence.
Update state of the cell.
Tip
Frameworks like Tensorflow have special implementation of StackedRNNCells to handle sequentially composed RNN Cells. In Lux, one can simple stack multiple Recurrence blocks in a Chain to achieve the same.
To avoid undefined behavior, once the processing of a single sequence of data is complete, update the state with Lux.update_state(st, :carry, nothing).
Arguments
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
cell: A recurrent cell. See RNNCell, LSTMCell, GRUCell, for how the inputs/outputs of a recurrent cell must be structured.
backward_cell: A optional backward recurrent cell. If backward_cell is nothing, the rnn layer instance passed as the cell argument will be used to generate the backward layer automatically. in_dims of backward_cell should be consistent with in_dims of cell
Keyword Arguments
merge_mode: Function by which outputs of the forward and backward RNNs will be combined. default value is vcat. If nothing, the outputs will not be combined.
ordering: The ordering of the batch and time dimensions in the input. Defaults to BatchLastIndex(). Alternatively can be set to TimeLastIndex().
Extended Help
Inputs
If x is a
Tuple or Vector: Each element is fed to the cell sequentially.
Array (except a Vector): It is spliced along the penultimate dimension and each slice is fed to the cell sequentially.
Returns
Merged output of the cell and backward_cell for the entire sequence.
Create a fully connected layer between two inputs and an output, and otherwise similar to Dense. Its output, given vectors x & y, is another vector z with, for all i in 1:out:
z[i] = activation(x' * W[i, :, :] * y + bias[i])
If x and y are matrices, then each column of the output z = B(x, y) is of this form, with B the Bilinear layer.
Arguments
in1_dims: number of input dimensions of x
in2_dims: number of input dimensions of y
in12_dims: If specified, then in1_dims = in2_dims = in12_dims
out: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in1_dims, in2_dims)). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in1_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
A 2-Tuple containing
x must be an AbstractArray with size(x, 1) == in1_dims
y must be an AbstractArray with size(y, 1) == in2_dims
If the input is an AbstractArray, then x = y
Returns
AbstractArray with dimensions (out_dims, size(x, 2))
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in1_dims, in2_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
Create a traditional fully connected layer, whose forward pass is given by: y = activation.(weight * x .+ bias)
Arguments
in_dims: number of input dimensions
out_dims: number of output dimensions
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims)). If nothing, then we use kaiming_uniform with gain computed on the basis of the activation function (taken from Pytorch nn.init.calculate_gain).
init_bias: initializer for the bias vector (ignored if use_bias=false). If nothing, then we use uniform distribution with bounds -bound and bound where bound = inv(sqrt(in_dims)).
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an AbstractArray with size(x, 1) == in_dims
Returns
AbstractArray with dimensions (out_dims, ...) where ... are the dimensions of x
Empty NamedTuple()
Parameters
weight: Weight Matrix of size (out_dims, in_dims)
bias: Bias of size (out_dims, 1) (present if use_bias=true)
A lookup table that stores embeddings of dimension out_dims for a vocabulary of size in_dims. When the vocabulary is multi-dimensional, the input is expected to be a tuple of Cartesian indices.
This layer is often used to store word embeddings and retrieve them using indices.
Arguments
in_dims: number(s) of input dimensions
out_dims: number of output dimensions
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims...))
Input
Integer OR
Abstract Vector of Integers OR
Abstract Array of Integers OR
Tuple of Integers OR
Tuple of Abstract Vectors of Integers OR
Tuple of Abstract Arrays of Integers
Returns
Returns the embedding corresponding to each index in the input. For an N dimensional input, an N + 1 dimensional output is returned.
Create a Sparsely Connected Layer with a very specific structure (only Diagonal Elements are non-zero). The forward pass is given by: y = activation.(weight .* x .+ bias)
Arguments
dims: size of the learnable scale and bias parameters.
activation: activation function
Keyword Arguments
init_weight: initializer for the weight matrix (weight = init_weight(rng, out_dims, in_dims))
init_bias: initializer for the bias vector (ignored if use_bias=false)
use_bias: Trainable bias can be disabled entirely by setting this to false
Input
x must be an Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
Returns
Array of size (dims..., B) or (dims...[0], ..., dims[k]) for k ≤ size(dims)
',13))]),a[276]||(a[276]=t("h2",{id:"Misc.-Helper-Layers",tabindex:"-1"},[s("Misc. Helper Layers "),t("a",{class:"header-anchor",href:"#Misc.-Helper-Layers","aria-label":'Permalink to "Misc. Helper Layers {#Misc.-Helper-Layers}"'},"")],-1)),t("details",D1,[t("summary",null,[a[197]||(a[197]=t("a",{id:"Lux.FlattenLayer",href:"#Lux.FlattenLayer"},[t("span",{class:"jlbinding"},"Lux.FlattenLayer")],-1)),a[198]||(a[198]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[199]||(a[199]=i(`
julia
FlattenLayer(; N = nothing)
Flattens the passed array into a matrix.
Keyword Arguments
N: Flatten the first N dimensions of the input array. If nothing, then all dimensions (except the last) are flattened. Note that the batch dimension is never flattened.
Inputs
x: AbstractArray
Returns
AbstractMatrix of size (:, size(x, ndims(x))) if N is nothing else the first N dimensions of the input array are flattened.
Empty NamedTuple()
Example
julia
julia> model = FlattenLayer()
+FlattenLayer{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = randn(rng, Float32, (2, 2, 2, 2));
+
+julia> y, st_new = model(x, ps, st);
+ size(y)
+(8, 2)
This contains a number of internal layers, each of which receives the same input. Its output is the elementwise maximum of the the internal layers' outputs.
Maxout over linear dense layers satisfies the universal approximation theorem. See [1].
Return a view of all the data of the input x where the index for dimension dim equals i. Equivalent to view(x,:,:,...,i,:,:,...) where i is in position d.
Arguments
dim: Dimension for indexing
i: Index for dimension dim
Inputs
x: AbstractArray that can be indexed with view(x,:,:,...,i,:,:,...)
Returns
view(x,:,:,...,i,:,:,...) where i is in position d
Wraps a stateless and parameter less function. Might be used when a function is added to Chain. For example, Chain(x -> relu.(x)) would not work and the right thing to do would be Chain((x, ps, st) -> (relu.(x), st)). An easier thing to do would be Chain(WrappedFunction(Base.Fix1(broadcast, relu)))
Reverse the specified dimension dims of the passed array
Arguments
dim: Dimension that need to be reversed. If nothing, for AbstractVector{T} it reverses itself (dimension 1), for other arrays, reverse the dimension ndims(x) - 1.
Inputs
x: AbstractArray.
Returns
AbstractArray with the same dimensions as the input
Empty NamedTuple()
Example
julia
julia> model = ReverseSequence()
+ReverseSequence{Nothing}(nothing)
+
+julia> rng = Random.default_rng();
+ Random.seed!(rng, 0);
+ ps, st = Lux.setup(rng, model);
+ x = [1.0, 2.0, 3.0];
+
+julia> y, st_new = model(x, ps, st)
+([3.0, 2.0, 1.0], NamedTuple())
`,2)),t("p",null,[a[222]||(a[222]=t("code",null,"BatchNorm",-1)),a[223]||(a[223]=s(" computes the mean and variance for each ")),t("mjx-container",R1,[(Q(),n("svg",N1,a[220]||(a[220]=[i('',1)]))),a[221]||(a[221]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"D"),t("mn",null,"1")]),t("mi",null,"×"),t("mo",null,"."),t("mo",null,"."),t("mo",null,"."),t("mi",null,"×"),t("msub",null,[t("mi",null,"D"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"N"),t("mo",null,"−"),t("mn",null,"2")])]),t("mi",null,"×"),t("mn",null,"1"),t("mi",null,"×"),t("msub",null,[t("mi",null,"D"),t("mi",null,"N")])])],-1))]),a[224]||(a[224]=s(" input slice and normalises the input accordingly."))]),a[226]||(a[226]=i(`
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
groups is the number of groups along which the statistics are computed. The number of channels must be an integer multiple of the number of groups.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
epsilon: a value added to the denominator for numerical stability
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
training: Used to check if training/inference mode
`,2)),t("p",null,[a[234]||(a[234]=s("Instance Normalization computes the mean and variance for each ")),t("mjx-container",P1,[(Q(),n("svg",I1,a[232]||(a[232]=[i('',1)]))),a[233]||(a[233]=t("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[t("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[t("msub",null,[t("mi",null,"D"),t("mn",null,"1")]),t("mo",null,"×"),t("mo",null,"."),t("mo",null,"."),t("mo",null,"."),t("mo",null,"×"),t("msub",null,[t("mi",null,"D"),t("mrow",{"data-mjx-texclass":"ORD"},[t("mi",null,"N"),t("mo",null,"−"),t("mn",null,"2")])]),t("mo",null,"×"),t("mn",null,"1"),t("mo",null,"×"),t("mn",null,"1")])],-1))]),a[235]||(a[235]=s("` input slice and normalises the input accordingly."))]),a[237]||(a[237]=i(`
Arguments
chs: Size of the channel dimension in your data. Given an array with N dimensions, call the N-1th the channel dimension. For a batch of feature vectors this is just the data dimension, for WHCN images it's the usual channel dimension.
activation: After normalization, elementwise activation activation is applied.
Keyword Arguments
If track_stats=true, accumulates mean and variance statistics in training phase that will be used to renormalize the input in test phase.
epsilon: a value added to the denominator for numerical stability
momentum: the value used for the running_mean and running_var computation
If affine=true, it also applies a shift and a rescale to the input through to learnable per-channel bias and scale parameters.
init_bias: Controls how the bias is initialized
init_scale: Controls how the scale is initialized
Extended Help
Inputs
x: Array where size(x, N - 1) = chs and ndims(x) > 2
Returns
y: Normalized Array
Update model state
Parameters
affine=true
bias: Bias of shape (chs,)
scale: Scale of shape (chs,)
affine=false - Empty NamedTuple()
States
Statistics if track_stats=true
running_mean: Running mean of shape (chs,)
running_var: Running variance of shape (chs,)
Statistics if track_stats=false
running_mean: nothing
running_var: nothing
training: Used to check if training/inference mode
[1] Ulyanov, Dmitry, Andrea Vedaldi, and Victor Lempitsky. "Instance normalization: The missing ingredient for fast stylization." arXiv preprint arXiv:1607.08022 (2016).
Computes mean and standard deviation over the whole input array, and uses these to normalize the whole array. Optionally applies an elementwise affine transformation afterwards.
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This updates the parameters in which_params (e.g. weight) using two parameters: one specifying the magnitude (e.g. weight_g) and one specifying the direction (e.g. weight_v).
Arguments
layer whose parameters are being reparameterized
which_params: parameter names for the parameters being reparameterized
By default, a norm over the entire array is computed. Pass dims to modify the dimension.
Inputs
x: Should be of valid type for input to layer
Returns
Output from layer
Updated model state of layer
Parameters
normalized: Parameters of layer that are being normalized
unnormalized: Parameters of layer that are not being normalized
',12))]),a[278]||(a[278]=t("h2",{id:"upsampling",tabindex:"-1"},[s("Upsampling "),t("a",{class:"header-anchor",href:"#upsampling","aria-label":'Permalink to "Upsampling"'},"")],-1)),t("details",a2,[t("summary",null,[a[263]||(a[263]=t("a",{id:"Lux.PixelShuffle",href:"#Lux.PixelShuffle"},[t("span",{class:"jlbinding"},"Lux.PixelShuffle")],-1)),a[264]||(a[264]=s()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),a[265]||(a[265]=i('
julia
PixelShuffle(r::Int)
Pixel shuffling layer with upscale factor r. Usually used for generating higher resolution images while upscaling them.
See NNlib.pixel_shuffle for more details.
Arguments
r: Upscale factor
Inputs
x: For 4D-arrays representing N images, the operation converts input size(x) == (W, H, r² x C, N) to output of size (r x W, r x H, C, N). For D-dimensional data, it expects ndims(x) == D + 2 with channel and batch dimensions, and divides the number of channels by rᴰ.
Returns
Output of size (r x W, r x H, C, N) for 4D-arrays, and (r x W, r x H, ..., C, N) for D-dimensional data, where D = ndims(x) - 2
mode: Set to :nearest, :linear, :bilinear or :trilinear
Exactly one of two keywords must be specified:
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
Alternatively, keyword size accepts a tuple, to directly specify the leading dimensions of the output.
Option 2
If scale is a number, this applies to all but the last two dimensions (channel and batch) of the input. It may also be a tuple, to control dimensions individually.
mode: Set to :nearest, :bilinear or :trilinear
Currently supported upsampling modes and corresponding NNlib's methods are:
:nearest -> NNlib.upsample_nearest
:bilinear -> NNlib.upsample_bilinear
:trilinear -> NNlib.upsample_trilinear
Extended Help
Other Keyword Arguments
align_corners: If true, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode is one of :bilinear or :trilinear.
Inputs
x: For the input dimensions look into the documentation for the corresponding NNlib function
As a rule of thumb, :nearest should work with arrays of arbitrary dimensions
:bilinear works with 4D Arrays
:trilinear works with 5D Arrays
Returns
Upsampled Input of size size or of size (I_1 x scale[1], ..., I_N x scale[N], C, N)
Compute the gradients of the objective function wrt parameters stored in ts.
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
AutoReverseDiff(; compile)
ReverseDiff.jl
AutoTracker
Tracker.jl
AutoEnzyme
Enzyme.jl
Arguments
ad: Backend (from ADTypes.jl) used to compute the gradients.
objective_function: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.
stats: Any computed statistics from the objective function.
ts: Updated Training State.
Known Limitations
AutoReverseDiff(; compile=true) is not supported for Lux models with non-empty state st. Additionally the returned stats must be empty (NamedTuple()). We catch these issues in most cases and throw an error.
Aliased Gradients
grads returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the grads from step i, the new gradients returned in step i + 1 might be aliased by the old gradients. If you want to prevent this, simply use copy(grads) or deepcopy(grads) to make a copy of the gradients.
Returned values are the same as compute_gradients. Note that despite the !, only the parameters in ts are updated inplace. Users should be using the returned ts object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like AutoReactant).
ŷ and y where ŷ is the predicted output and y is the target output.
model, ps, st, (x, y) where model is the model, ps are the parameters, st are the states and (x, y) are the input and target pair. Then it returns the loss, updated states, and an empty named tuple. This makes them compatible with the Training API.
Warning
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients wrt the inputs and drop any gradients wrt the targets.
Takes any function loss_fn that maps 2 number inputs to a single number output. Additionally, array inputs are efficiently broadcasted and aggregated using agg.
',1)),s("p",null,[i[58]||(i[58]=a("Return the binary focal loss [1]. The model input, ")),s("mjx-container",V,[(h(),n("svg",A,i[56]||(i[56]=[t('',1)]))),i[57]||(i[57]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mrow",{"data-mjx-texclass":"ORD"},[s("mover",null,[s("mi",null,"y"),s("mo",{stretchy:"false"},"^")])])])],-1))]),i[59]||(i[59]=a(", is expected to be normalized (i.e. softmax output)."))]),s("p",null,[i[62]||(i[62]=a("For ")),s("mjx-container",Z,[(h(),n("svg",O,i[60]||(i[60]=[t('',1)]))),i[61]||(i[61]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"γ"),s("mo",null,"="),s("mn",null,"0")])],-1))]),i[63]||(i[63]=a(" this is equivalent to ")),i[64]||(i[64]=s("a",{href:"/previews/PR1023/api/Lux/utilities#Lux.BinaryCrossEntropyLoss"},[s("code",null,"BinaryCrossEntropyLoss")],-1)),i[65]||(i[65]=a("."))]),i[67]||(i[67]=t(`
[1] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." 2016 fourth international conference on 3D vision (3DV). Ieee, 2016.
The KL divergence is a measure of how much one probability distribution is different from the other. It is always non-negative, and zero only when both the distributions are equal.
`,4))]),i[291]||(i[291]=s("h2",{id:"LuxOps-Module",tabindex:"-1"},[a("LuxOps Module "),s("a",{class:"header-anchor",href:"#LuxOps-Module","aria-label":'Permalink to "LuxOps Module {#LuxOps-Module}"'},"")],-1)),s("details",U1,[s("summary",null,[i[217]||(i[217]=s("a",{id:"Lux.LuxOps",href:"#Lux.LuxOps"},[s("span",{class:"jlbinding"},"Lux.LuxOps")],-1)),i[218]||(i[218]=a()),l(e,{type:"info",class:"jlObjectType jlModule",text:"Module"})]),i[219]||(i[219]=t('
julia
LuxOps
This module is a part of Lux.jl. It contains operations that are useful in DL context. Additionally certain operations here alias Base functions to behave more sensibly with GPUArrays.
',3))]),i[292]||(i[292]=s("h2",{id:"Recursive-Operations",tabindex:"-1"},[a("Recursive Operations "),s("a",{class:"header-anchor",href:"#Recursive-Operations","aria-label":'Permalink to "Recursive Operations {#Recursive-Operations}"'},"")],-1)),s("details",a2,[s("summary",null,[i[241]||(i[241]=s("a",{id:"Lux.recursive_map",href:"#Lux.recursive_map"},[s("span",{class:"jlbinding"},"Lux.recursive_map")],-1)),i[242]||(i[242]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[243]||(i[243]=t('
julia
recursive_map(f, x, args...)
Similar to fmap(f, args...) but with restricted support for the notion of "leaf" types. However, this allows for more efficient and type stable implementations of recursive operations.
How this works?
For the following types it directly defines recursion rules:
AbstractArray: If eltype is isbitstype, then f is applied to the array, else we recurse on the array.
Tuple/NamedTuple: We recurse on the values.
Number/Val/Nothing: We directly apply f.
For all other types, we recurse on the fields using Functors.fmap.
Note
In most cases, users should gravitate towards Functors.fmap if it is being used outside of hot loops. Even for other cases, it is always recommended to verify the correctness of this implementation for specific usecases.
Recursively add the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(+, x, y), but this implementation uses type stable code for common cases.
Any leaves of x that are arrays and allow in-place addition will be modified in place.
Recursively copy the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(copyto!, x, y), but this implementation uses type stable code for common cases. Note that any immutable leaf will lead to an error.
Recursively determine the element type of a nested structure x. This is equivalent to doing fmap(Lux.Utils.eltype, x), but this implementation uses type stable code for common cases.
For ambiguous inputs like nothing and Val types we return Bool as the eltype.
If unwrap_ad_types is set to Val(true) then for tracing and operator overloading based ADs (ForwardDiff, ReverseDiff, Tracker), this function will return the eltype of the unwrapped value.
Recursively create a zero value for a nested structure x. This is equivalent to doing fmap(zero, x), but this implementation uses type stable code for common cases.
',4))]),i[293]||(i[293]=s("h2",{id:"Updating-Floating-Point-Precision",tabindex:"-1"},[a("Updating Floating Point Precision "),s("a",{class:"header-anchor",href:"#Updating-Floating-Point-Precision","aria-label":'Permalink to "Updating Floating Point Precision {#Updating-Floating-Point-Precision}"'},"")],-1)),i[294]||(i[294]=s("p",null,"By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.",-1)),s("details",p2,[s("summary",null,[i[259]||(i[259]=s("a",{id:"Lux.f16",href:"#Lux.f16"},[s("span",{class:"jlbinding"},"Lux.f16")],-1)),i[260]||(i[260]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[261]||(i[261]=t('
julia
f16(m)
Converts the eltype of mfloating point values to Float16. Recurses into structs marked with Functors.@functor.
',3))]),i[295]||(i[295]=s("h2",{id:"Element-Type-Matching",tabindex:"-1"},[a("Element Type Matching "),s("a",{class:"header-anchor",href:"#Element-Type-Matching","aria-label":'Permalink to "Element Type Matching {#Element-Type-Matching}"'},"")],-1)),s("details",k2,[s("summary",null,[i[268]||(i[268]=s("a",{id:"Lux.match_eltype",href:"#Lux.match_eltype"},[s("span",{class:"jlbinding"},"Lux.match_eltype")],-1)),i[269]||(i[269]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[270]||(i[270]=t('
julia
match_eltype(layer, ps, st, args...)
Helper function to "maybe" (see below) match the element type of args... with the element type of the layer's parameters and states. This is useful for debugging purposes, to track down accidental type-promotions inside Lux layers.
Extended Help
Controlling the Behavior via Preferences
Behavior of this function is controlled via the eltype_mismatch_handling preference. The following options are supported:
"none": This is the default behavior. In this case, this function is a no-op, i.e., it simply returns args....
"warn": This option will issue a warning if the element type of args... does not match the element type of the layer's parameters and states. The warning will contain information about the layer and the element type mismatch.
"convert": This option is same as "warn", but it will also convert the element type of args... to match the element type of the layer's parameters and states (for the cases listed below).
"error": Same as "warn", but instead of issuing a warning, it will throw an error.
Warning
We print the warning for type-mismatch only once.
Element Type Conversions
For "convert" only the following conversions are done:
',11))]),i[296]||(i[296]=s("h2",{id:"Stateful-Layer",tabindex:"-1"},[a("Stateful Layer "),s("a",{class:"header-anchor",href:"#Stateful-Layer","aria-label":'Permalink to "Stateful Layer {#Stateful-Layer}"'},"")],-1)),s("details",o2,[s("summary",null,[i[271]||(i[271]=s("a",{id:"Lux.StatefulLuxLayer",href:"#Lux.StatefulLuxLayer"},[s("span",{class:"jlbinding"},"Lux.StatefulLuxLayer")],-1)),i[272]||(i[272]=a()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),i[273]||(i[273]=t('
julia
StatefulLuxLayer{FT}(model, ps, st)
Warning
This is not a Lux.AbstractLuxLayer
A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.
Usecases
Internal implementation of @compact heavily uses this layer.
In SciML codebases where propagating state might involving Boxing. For a motivating example, see the Neural ODE tutorial.
Facilitates Nested AD support in Lux. For more details on this feature, see the Nested AD Manual Page.
Static Parameters
If FT = true then the type of the state is fixed, i.e., typeof(last(model(x, ps, st))) == st.
If FT = false then type of the state might change. Note that while this works in all cases, it will introduce type instability.
Arguments
model: A Lux layer
ps: The parameters of the layer. This can be set to nothing, if the user provides the parameters on function call
st: The state of the layer
Inputs
x: The input to the layer
ps: The parameters of the layer. Optional, defaults to s.ps
',14))]),i[297]||(i[297]=s("h2",{id:"Compact-Layer",tabindex:"-1"},[a("Compact Layer "),s("a",{class:"header-anchor",href:"#Compact-Layer","aria-label":'Permalink to "Compact Layer {#Compact-Layer}"'},"")],-1)),s("details",T2,[s("summary",null,[i[274]||(i[274]=s("a",{id:"Lux.@compact",href:"#Lux.@compact"},[s("span",{class:"jlbinding"},"Lux.@compact")],-1)),i[275]||(i[275]=a()),l(e,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),i[276]||(i[276]=t(`
julia
@compact(kw...) do x
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(kw...) do x, p
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some parameters, in the form of keywords, and (usually as a do block) a function for the forward pass. You may think of @compact as a specialized let block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.
Defining the version with p allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.
Reserved Kwargs:
name: The name of the layer.
dispatch: The constructed layer has the type Lux.CompactLuxLayer{dispatch} which can be used for custom dispatches.
Tip
Check the Lux tutorials for more examples of using @compact.
If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer.
Special Syntax
@return: This macro doesn't really exist, but is used to return a value from the @compact block. Without the presence of this macro, we need to rely on closures which can lead to performance penalties in the reverse pass.
Having statements after the last @return macro might lead to incorrect code.
Don't do things like @return return x. This will generate non-sensical code like <new var> = return x. Essentially, @return <expr> supports any expression, that can be assigned to a variable.
Since this macro doesn't "exist", it cannot be imported as using Lux: @return. Simply use it in code, and @compact will understand it.
@init_fn: Provide a function that will be used to initialize the layer's parameters or state. See the docs of @init_fn for more details.
@non_trainable: Mark a value as non-trainable. This bypasses the regular checks and places the value into the state of the layer. See the docs of @non_trainable for more details.
Extended Help
Examples
Here is a linear model:
julia
julia> using Lux, Random
+
+julia> r = @compact(w=ones(3)) do x
+ @return w .* x
+ end
+@compact(
+ w = 3-element Vector{Float64},
+) do x
+ return w .* x
+end # Total: 3 parameters,
+ # plus 0 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> r([1, 2, 3], ps, st) # x is set to [1, 1, 1].
+([1.0, 2.0, 3.0], NamedTuple())
Here is a linear model with bias and activation:
julia
julia> d_in = 5
+5
+
+julia> d_out = 3
+3
+
+julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
+ y = W * x
+ @return act.(y .+ b)
+ end
+@compact(
+ W = 3×5 Matrix{Float64},
+ b = 3-element Vector{Float64},
+ act = relu,
+) do x
+ y = W * x
+ return act.(y .+ b)
+end # Total: 18 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), d);
+
+julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
+3×2 Matrix{Float64}:
+ 5.0 5.0
+ 5.0 5.0
+ 5.0 5.0
+
+julia> ps_dense = (; weight=ps.W, bias=ps.b);
+
+julia> first(d([1, 2, 3, 4, 5], ps, st)) ≈
+ first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
+true
Finally, here is a simple MLP. We can train this model just like any Lux model:
julia
julia> n_in = 1;
+
+julia> n_out = 1;
+
+julia> nlayers = 3;
+
+julia> model = @compact(w1=Dense(n_in, 128),
+ w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
+ embed = act.(w1(x))
+ for w in w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ @return out
+ end
+@compact(
+ w1 = Dense(1 => 128), # 256 parameters
+ w2 = NamedTuple(
+ 1 = Dense(128 => 128), # 16_512 parameters
+ 2 = Dense(128 => 128), # 16_512 parameters
+ 3 = Dense(128 => 128), # 16_512 parameters
+ ),
+ w3 = Dense(128 => 1), # 129 parameters
+ act = relu,
+) do x
+ embed = act.(w1(x))
+ for w = w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ return out
+end # Total: 49_921 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), model);
+
+julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
+(1, 32)
+
+julia> using Optimisers, Zygote
+
+julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
+
+julia> y_data = 2 .* x_data .- x_data .^ 3;
+
+julia> optim = Optimisers.setup(Adam(), ps);
+
+julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> for epoch in 1:1000
+ loss, gs = Zygote.withgradient(
+ ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
+ Optimisers.update!(optim, ps, gs[1])
+ end;
+
+julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> loss_initial > loss_final
+true
You may also specify a name for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:
julia
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
+ @return sum(w .* x)
+ end
+Linear(3 => 1) # 3 parameters
This can be useful when using @compact to hierarchically construct complex models to be used inside a Chain.
Type Stability
If your input function f is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.
Parameter Count
Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.
Create an initializer function for a parameter or state to be used for in a Compact Lux Layer created using @compact.
Arguments
fn: The function to be used for initializing the parameter or state. This only takes a single argument rng.
kind: If set to :parameter, the initializer function will be used to initialize the parameters of the layer. If set to :state, the initializer function will be used to initialize the states of the layer.
Examples
julia
julia> using Lux, Random
+
+julia> r = @compact(w=@init_fn(rng->randn32(rng, 3, 2)),
+ b=@init_fn(rng->randn32(rng, 3), :state)) do x
+ @return w * x .+ b
+ end;
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> size(ps.w)
+(3, 2)
+
+julia> size(st.b)
+(3,)
+
+julia> size(r([1, 2], ps, st)[1])
+(3,)
`,7))]),i[298]||(i[298]=s("h2",{id:"miscellaneous",tabindex:"-1"},[a("Miscellaneous "),s("a",{class:"header-anchor",href:"#miscellaneous","aria-label":'Permalink to "Miscellaneous"'},"")],-1)),s("details",m2,[s("summary",null,[i[283]||(i[283]=s("a",{id:"Lux.set_dispatch_doctor_preferences!",href:"#Lux.set_dispatch_doctor_preferences!"},[s("span",{class:"jlbinding"},"Lux.set_dispatch_doctor_preferences!")],-1)),i[284]||(i[284]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[285]||(i[285]=t(`
Set the dispatch doctor preference for LuxCore and LuxLib packages.
mode can be "disable", "warn", or "error". For details on the different modes, see the DispatchDoctor.jl documentation.
If the preferences are already set, then no action is taken. Otherwise the preference is set. For changes to take effect, the Julia session must be restarted.
Compute the gradients of the objective function wrt parameters stored in ts.
Backends & AD Packages
Supported Backends
Packages Needed
AutoZygote
Zygote.jl
AutoReverseDiff(; compile)
ReverseDiff.jl
AutoTracker
Tracker.jl
AutoEnzyme
Enzyme.jl
Arguments
ad: Backend (from ADTypes.jl) used to compute the gradients.
objective_function: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.
stats: Any computed statistics from the objective function.
ts: Updated Training State.
Known Limitations
AutoReverseDiff(; compile=true) is not supported for Lux models with non-empty state st. Additionally the returned stats must be empty (NamedTuple()). We catch these issues in most cases and throw an error.
Aliased Gradients
grads returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the grads from step i, the new gradients returned in step i + 1 might be aliased by the old gradients. If you want to prevent this, simply use copy(grads) or deepcopy(grads) to make a copy of the gradients.
Returned values are the same as compute_gradients. Note that despite the !, only the parameters in ts are updated inplace. Users should be using the returned ts object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like AutoReactant).
ŷ and y where ŷ is the predicted output and y is the target output.
model, ps, st, (x, y) where model is the model, ps are the parameters, st are the states and (x, y) are the input and target pair. Then it returns the loss, updated states, and an empty named tuple. This makes them compatible with the Training API.
Warning
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients wrt the inputs and drop any gradients wrt the targets.
Takes any function loss_fn that maps 2 number inputs to a single number output. Additionally, array inputs are efficiently broadcasted and aggregated using agg.
',1)),s("p",null,[i[58]||(i[58]=a("Return the binary focal loss [1]. The model input, ")),s("mjx-container",V,[(h(),n("svg",A,i[56]||(i[56]=[t('',1)]))),i[57]||(i[57]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mrow",{"data-mjx-texclass":"ORD"},[s("mover",null,[s("mi",null,"y"),s("mo",{stretchy:"false"},"^")])])])],-1))]),i[59]||(i[59]=a(", is expected to be normalized (i.e. softmax output)."))]),s("p",null,[i[62]||(i[62]=a("For ")),s("mjx-container",Z,[(h(),n("svg",O,i[60]||(i[60]=[t('',1)]))),i[61]||(i[61]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"γ"),s("mo",null,"="),s("mn",null,"0")])],-1))]),i[63]||(i[63]=a(" this is equivalent to ")),i[64]||(i[64]=s("a",{href:"/previews/PR1023/api/Lux/utilities#Lux.BinaryCrossEntropyLoss"},[s("code",null,"BinaryCrossEntropyLoss")],-1)),i[65]||(i[65]=a("."))]),i[67]||(i[67]=t(`
[1] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." 2016 fourth international conference on 3D vision (3DV). Ieee, 2016.
The KL divergence is a measure of how much one probability distribution is different from the other. It is always non-negative, and zero only when both the distributions are equal.
`,4))]),i[291]||(i[291]=s("h2",{id:"LuxOps-Module",tabindex:"-1"},[a("LuxOps Module "),s("a",{class:"header-anchor",href:"#LuxOps-Module","aria-label":'Permalink to "LuxOps Module {#LuxOps-Module}"'},"")],-1)),s("details",U1,[s("summary",null,[i[217]||(i[217]=s("a",{id:"Lux.LuxOps",href:"#Lux.LuxOps"},[s("span",{class:"jlbinding"},"Lux.LuxOps")],-1)),i[218]||(i[218]=a()),l(e,{type:"info",class:"jlObjectType jlModule",text:"Module"})]),i[219]||(i[219]=t('
julia
LuxOps
This module is a part of Lux.jl. It contains operations that are useful in DL context. Additionally certain operations here alias Base functions to behave more sensibly with GPUArrays.
',3))]),i[292]||(i[292]=s("h2",{id:"Recursive-Operations",tabindex:"-1"},[a("Recursive Operations "),s("a",{class:"header-anchor",href:"#Recursive-Operations","aria-label":'Permalink to "Recursive Operations {#Recursive-Operations}"'},"")],-1)),s("details",a2,[s("summary",null,[i[241]||(i[241]=s("a",{id:"Lux.recursive_map",href:"#Lux.recursive_map"},[s("span",{class:"jlbinding"},"Lux.recursive_map")],-1)),i[242]||(i[242]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[243]||(i[243]=t('
julia
recursive_map(f, x, args...)
Similar to fmap(f, args...) but with restricted support for the notion of "leaf" types. However, this allows for more efficient and type stable implementations of recursive operations.
How this works?
For the following types it directly defines recursion rules:
AbstractArray: If eltype is isbitstype, then f is applied to the array, else we recurse on the array.
Tuple/NamedTuple: We recurse on the values.
Number/Val/Nothing: We directly apply f.
For all other types, we recurse on the fields using Functors.fmap.
Note
In most cases, users should gravitate towards Functors.fmap if it is being used outside of hot loops. Even for other cases, it is always recommended to verify the correctness of this implementation for specific usecases.
Recursively add the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(+, x, y), but this implementation uses type stable code for common cases.
Any leaves of x that are arrays and allow in-place addition will be modified in place.
Recursively copy the leaves of two nested structures x and y. In Functor language, this is equivalent to doing fmap(copyto!, x, y), but this implementation uses type stable code for common cases. Note that any immutable leaf will lead to an error.
Recursively determine the element type of a nested structure x. This is equivalent to doing fmap(Lux.Utils.eltype, x), but this implementation uses type stable code for common cases.
For ambiguous inputs like nothing and Val types we return Bool as the eltype.
If unwrap_ad_types is set to Val(true) then for tracing and operator overloading based ADs (ForwardDiff, ReverseDiff, Tracker), this function will return the eltype of the unwrapped value.
Recursively create a zero value for a nested structure x. This is equivalent to doing fmap(zero, x), but this implementation uses type stable code for common cases.
',4))]),i[293]||(i[293]=s("h2",{id:"Updating-Floating-Point-Precision",tabindex:"-1"},[a("Updating Floating Point Precision "),s("a",{class:"header-anchor",href:"#Updating-Floating-Point-Precision","aria-label":'Permalink to "Updating Floating Point Precision {#Updating-Floating-Point-Precision}"'},"")],-1)),i[294]||(i[294]=s("p",null,"By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.",-1)),s("details",p2,[s("summary",null,[i[259]||(i[259]=s("a",{id:"Lux.f16",href:"#Lux.f16"},[s("span",{class:"jlbinding"},"Lux.f16")],-1)),i[260]||(i[260]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[261]||(i[261]=t('
julia
f16(m)
Converts the eltype of mfloating point values to Float16. Recurses into structs marked with Functors.@functor.
',3))]),i[295]||(i[295]=s("h2",{id:"Element-Type-Matching",tabindex:"-1"},[a("Element Type Matching "),s("a",{class:"header-anchor",href:"#Element-Type-Matching","aria-label":'Permalink to "Element Type Matching {#Element-Type-Matching}"'},"")],-1)),s("details",k2,[s("summary",null,[i[268]||(i[268]=s("a",{id:"Lux.match_eltype",href:"#Lux.match_eltype"},[s("span",{class:"jlbinding"},"Lux.match_eltype")],-1)),i[269]||(i[269]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[270]||(i[270]=t('
julia
match_eltype(layer, ps, st, args...)
Helper function to "maybe" (see below) match the element type of args... with the element type of the layer's parameters and states. This is useful for debugging purposes, to track down accidental type-promotions inside Lux layers.
Extended Help
Controlling the Behavior via Preferences
Behavior of this function is controlled via the eltype_mismatch_handling preference. The following options are supported:
"none": This is the default behavior. In this case, this function is a no-op, i.e., it simply returns args....
"warn": This option will issue a warning if the element type of args... does not match the element type of the layer's parameters and states. The warning will contain information about the layer and the element type mismatch.
"convert": This option is same as "warn", but it will also convert the element type of args... to match the element type of the layer's parameters and states (for the cases listed below).
"error": Same as "warn", but instead of issuing a warning, it will throw an error.
Warning
We print the warning for type-mismatch only once.
Element Type Conversions
For "convert" only the following conversions are done:
',11))]),i[296]||(i[296]=s("h2",{id:"Stateful-Layer",tabindex:"-1"},[a("Stateful Layer "),s("a",{class:"header-anchor",href:"#Stateful-Layer","aria-label":'Permalink to "Stateful Layer {#Stateful-Layer}"'},"")],-1)),s("details",o2,[s("summary",null,[i[271]||(i[271]=s("a",{id:"Lux.StatefulLuxLayer",href:"#Lux.StatefulLuxLayer"},[s("span",{class:"jlbinding"},"Lux.StatefulLuxLayer")],-1)),i[272]||(i[272]=a()),l(e,{type:"info",class:"jlObjectType jlType",text:"Type"})]),i[273]||(i[273]=t('
julia
StatefulLuxLayer{FT}(model, ps, st)
Warning
This is not a Lux.AbstractLuxLayer
A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.
Usecases
Internal implementation of @compact heavily uses this layer.
In SciML codebases where propagating state might involving Boxing. For a motivating example, see the Neural ODE tutorial.
Facilitates Nested AD support in Lux. For more details on this feature, see the Nested AD Manual Page.
Static Parameters
If FT = true then the type of the state is fixed, i.e., typeof(last(model(x, ps, st))) == st.
If FT = false then type of the state might change. Note that while this works in all cases, it will introduce type instability.
Arguments
model: A Lux layer
ps: The parameters of the layer. This can be set to nothing, if the user provides the parameters on function call
st: The state of the layer
Inputs
x: The input to the layer
ps: The parameters of the layer. Optional, defaults to s.ps
',14))]),i[297]||(i[297]=s("h2",{id:"Compact-Layer",tabindex:"-1"},[a("Compact Layer "),s("a",{class:"header-anchor",href:"#Compact-Layer","aria-label":'Permalink to "Compact Layer {#Compact-Layer}"'},"")],-1)),s("details",T2,[s("summary",null,[i[274]||(i[274]=s("a",{id:"Lux.@compact",href:"#Lux.@compact"},[s("span",{class:"jlbinding"},"Lux.@compact")],-1)),i[275]||(i[275]=a()),l(e,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),i[276]||(i[276]=t(`
julia
@compact(kw...) do x
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(kw...) do x, p
+ ...
+ @return y # optional (but recommended for best performance)
+end
+@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some parameters, in the form of keywords, and (usually as a do block) a function for the forward pass. You may think of @compact as a specialized let block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.
Defining the version with p allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.
Reserved Kwargs:
name: The name of the layer.
dispatch: The constructed layer has the type Lux.CompactLuxLayer{dispatch} which can be used for custom dispatches.
Tip
Check the Lux tutorials for more examples of using @compact.
If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer.
Special Syntax
@return: This macro doesn't really exist, but is used to return a value from the @compact block. Without the presence of this macro, we need to rely on closures which can lead to performance penalties in the reverse pass.
Having statements after the last @return macro might lead to incorrect code.
Don't do things like @return return x. This will generate non-sensical code like <new var> = return x. Essentially, @return <expr> supports any expression, that can be assigned to a variable.
Since this macro doesn't "exist", it cannot be imported as using Lux: @return. Simply use it in code, and @compact will understand it.
@init_fn: Provide a function that will be used to initialize the layer's parameters or state. See the docs of @init_fn for more details.
@non_trainable: Mark a value as non-trainable. This bypasses the regular checks and places the value into the state of the layer. See the docs of @non_trainable for more details.
Extended Help
Examples
Here is a linear model:
julia
julia> using Lux, Random
+
+julia> r = @compact(w=ones(3)) do x
+ @return w .* x
+ end
+@compact(
+ w = 3-element Vector{Float64},
+) do x
+ return w .* x
+end # Total: 3 parameters,
+ # plus 0 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> r([1, 2, 3], ps, st) # x is set to [1, 1, 1].
+([1.0, 2.0, 3.0], NamedTuple())
Here is a linear model with bias and activation:
julia
julia> d_in = 5
+5
+
+julia> d_out = 3
+3
+
+julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
+ y = W * x
+ @return act.(y .+ b)
+ end
+@compact(
+ W = 3×5 Matrix{Float64},
+ b = 3-element Vector{Float64},
+ act = relu,
+) do x
+ y = W * x
+ return act.(y .+ b)
+end # Total: 18 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), d);
+
+julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
+3×2 Matrix{Float64}:
+ 5.0 5.0
+ 5.0 5.0
+ 5.0 5.0
+
+julia> ps_dense = (; weight=ps.W, bias=ps.b);
+
+julia> first(d([1, 2, 3, 4, 5], ps, st)) ≈
+ first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
+true
Finally, here is a simple MLP. We can train this model just like any Lux model:
julia
julia> n_in = 1;
+
+julia> n_out = 1;
+
+julia> nlayers = 3;
+
+julia> model = @compact(w1=Dense(n_in, 128),
+ w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
+ embed = act.(w1(x))
+ for w in w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ @return out
+ end
+@compact(
+ w1 = Dense(1 => 128), # 256 parameters
+ w2 = NamedTuple(
+ 1 = Dense(128 => 128), # 16_512 parameters
+ 2 = Dense(128 => 128), # 16_512 parameters
+ 3 = Dense(128 => 128), # 16_512 parameters
+ ),
+ w3 = Dense(128 => 1), # 129 parameters
+ act = relu,
+) do x
+ embed = act.(w1(x))
+ for w = w2
+ embed = act.(w(embed))
+ end
+ out = w3(embed)
+ return out
+end # Total: 49_921 parameters,
+ # plus 1 states.
+
+julia> ps, st = Lux.setup(Xoshiro(0), model);
+
+julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
+(1, 32)
+
+julia> using Optimisers, Zygote
+
+julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
+
+julia> y_data = 2 .* x_data .- x_data .^ 3;
+
+julia> optim = Optimisers.setup(Adam(), ps);
+
+julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> for epoch in 1:1000
+ loss, gs = Zygote.withgradient(
+ ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
+ Optimisers.update!(optim, ps, gs[1])
+ end;
+
+julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
+
+julia> loss_initial > loss_final
+true
You may also specify a name for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:
julia
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
+ @return sum(w .* x)
+ end
+Linear(3 => 1) # 3 parameters
This can be useful when using @compact to hierarchically construct complex models to be used inside a Chain.
Type Stability
If your input function f is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.
Parameter Count
Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.
Create an initializer function for a parameter or state to be used for in a Compact Lux Layer created using @compact.
Arguments
fn: The function to be used for initializing the parameter or state. This only takes a single argument rng.
kind: If set to :parameter, the initializer function will be used to initialize the parameters of the layer. If set to :state, the initializer function will be used to initialize the states of the layer.
Examples
julia
julia> using Lux, Random
+
+julia> r = @compact(w=@init_fn(rng->randn32(rng, 3, 2)),
+ b=@init_fn(rng->randn32(rng, 3), :state)) do x
+ @return w * x .+ b
+ end;
+
+julia> ps, st = Lux.setup(Xoshiro(0), r);
+
+julia> size(ps.w)
+(3, 2)
+
+julia> size(st.b)
+(3,)
+
+julia> size(r([1, 2], ps, st)[1])
+(3,)
`,7))]),i[298]||(i[298]=s("h2",{id:"miscellaneous",tabindex:"-1"},[a("Miscellaneous "),s("a",{class:"header-anchor",href:"#miscellaneous","aria-label":'Permalink to "Miscellaneous"'},"")],-1)),s("details",m2,[s("summary",null,[i[283]||(i[283]=s("a",{id:"Lux.set_dispatch_doctor_preferences!",href:"#Lux.set_dispatch_doctor_preferences!"},[s("span",{class:"jlbinding"},"Lux.set_dispatch_doctor_preferences!")],-1)),i[284]||(i[284]=a()),l(e,{type:"info",class:"jlObjectType jlFunction",text:"Function"})]),i[285]||(i[285]=t(`
Set the dispatch doctor preference for LuxCore and LuxLib packages.
mode can be "disable", "warn", or "error". For details on the different modes, see the DispatchDoctor.jl documentation.
If the preferences are already set, then no action is taken. Otherwise the preference is set. For changes to take effect, the Julia session must be restarted.
`,5))])])}const L2=p(d,[["render",y2]]);export{C2 as __pageData,L2 as default};
diff --git a/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.js b/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.js
new file mode 100644
index 0000000000..39380266e4
--- /dev/null
+++ b/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.js
@@ -0,0 +1,12 @@
+import{_ as n,c as h,a2 as e,j as t,a as i,G as l,B as p,o as d}from"./chunks/framework.DFwXuivk.js";const C=JSON.parse('{"title":"LuxTestUtils","description":"","frontmatter":{},"headers":[],"relativePath":"api/Testing_Functionality/LuxTestUtils.md","filePath":"api/Testing_Functionality/LuxTestUtils.md","lastUpdated":null}'),k={name:"api/Testing_Functionality/LuxTestUtils.md"},r={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"};function u(y,s,f,b,x,F){const a=p("Badge");return d(),h("div",null,[s[15]||(s[15]=e('
This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.
Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.
Test the gradients of f with respect to args using the specified backends.
Backend
ADType
CPU
GPU
Notes
Zygote.jl
AutoZygote()
✔
✔
Tracker.jl
AutoTracker()
✔
✔
ReverseDiff.jl
AutoReverseDiff()
✔
✖
ForwardDiff.jl
AutoForwardDiff()
✔
✖
len ≤ 100
FiniteDiff.jl
AutoFiniteDiff()
✔
✖
len ≤ 100
Enzyme.jl
AutoEnzyme()
✔
✖
Only Reverse Mode
Arguments
f: The function to test the gradients of.
args: The arguments to test the gradients of. Only AbstractArrays are considered for gradient computation. Gradients wrt all other arguments are assumed to be NoTangent().
Keyword Arguments
skip_backends: A list of backends to skip.
broken_backends: A list of backends to treat as broken.
soft_fail: If true, then the test will be recorded as a soft_fail test. This overrides any broken kwargs. Alternatively, a list of backends can be passed to soft_fail to allow soft_fail tests for only those backends.
enzyme_set_runtime_activity: If true, then activate runtime activity for Enzyme.
kwargs: Additional keyword arguments to pass to check_approx.
Example
julia
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)
+
+julia> x = (; t=rand(10), x=(z=[2.0],))
+
+julia> test_gradients(f, 1.0, x, nothing)
',3))]),s[17]||(s[17]=t("h2",{id:"Extensions-to-@test",tabindex:"-1"},[i("Extensions to "),t("code",null,"@test"),i(),t("a",{class:"header-anchor",href:"#Extensions-to-@test","aria-label":'Permalink to "Extensions to `@test` {#Extensions-to-@test}"'},"")],-1)),t("details",E,[t("summary",null,[s[12]||(s[12]=t("a",{id:"LuxTestUtils.@test_softfail",href:"#LuxTestUtils.@test_softfail"},[t("span",{class:"jlbinding"},"LuxTestUtils.@test_softfail")],-1)),s[13]||(s[13]=i()),l(a,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),s[14]||(s[14]=e('
julia
@test_softfail expr
Evaluate expr and record a test result. If expr throws an exception, the test result will be recorded as an error. If expr returns a value, and it is not a boolean, the test result will be recorded as an error.
If the test result is false then the test will be recorded as a broken test, else it will be recorded as a pass.
',4))])])}const j=n(k,[["render",u]]);export{C as __pageData,j as default};
diff --git a/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.lean.js b/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.lean.js
new file mode 100644
index 0000000000..39380266e4
--- /dev/null
+++ b/previews/PR1023/assets/api_Testing_Functionality_LuxTestUtils.md.8-Wna7gT.lean.js
@@ -0,0 +1,12 @@
+import{_ as n,c as h,a2 as e,j as t,a as i,G as l,B as p,o as d}from"./chunks/framework.DFwXuivk.js";const C=JSON.parse('{"title":"LuxTestUtils","description":"","frontmatter":{},"headers":[],"relativePath":"api/Testing_Functionality/LuxTestUtils.md","filePath":"api/Testing_Functionality/LuxTestUtils.md","lastUpdated":null}'),k={name:"api/Testing_Functionality/LuxTestUtils.md"},r={class:"jldocstring custom-block"},o={class:"jldocstring custom-block"},g={class:"jldocstring custom-block"},c={class:"jldocstring custom-block"},E={class:"jldocstring custom-block"};function u(y,s,f,b,x,F){const a=p("Badge");return d(),h("div",null,[s[15]||(s[15]=e('
This is a testing package. Hence, we don't use features like weak dependencies to reduce load times. It is recommended that you exclusively use this package for testing and not add a dependency to it in your main package Project.toml.
Implements utilities for testing gradient correctness and dynamic dispatch of Lux.jl models.
Test the gradients of f with respect to args using the specified backends.
Backend
ADType
CPU
GPU
Notes
Zygote.jl
AutoZygote()
✔
✔
Tracker.jl
AutoTracker()
✔
✔
ReverseDiff.jl
AutoReverseDiff()
✔
✖
ForwardDiff.jl
AutoForwardDiff()
✔
✖
len ≤ 100
FiniteDiff.jl
AutoFiniteDiff()
✔
✖
len ≤ 100
Enzyme.jl
AutoEnzyme()
✔
✖
Only Reverse Mode
Arguments
f: The function to test the gradients of.
args: The arguments to test the gradients of. Only AbstractArrays are considered for gradient computation. Gradients wrt all other arguments are assumed to be NoTangent().
Keyword Arguments
skip_backends: A list of backends to skip.
broken_backends: A list of backends to treat as broken.
soft_fail: If true, then the test will be recorded as a soft_fail test. This overrides any broken kwargs. Alternatively, a list of backends can be passed to soft_fail to allow soft_fail tests for only those backends.
enzyme_set_runtime_activity: If true, then activate runtime activity for Enzyme.
kwargs: Additional keyword arguments to pass to check_approx.
Example
julia
julia> f(x, y, z) = x .+ sum(abs2, y.t) + sum(y.x.z)
+
+julia> x = (; t=rand(10), x=(z=[2.0],))
+
+julia> test_gradients(f, 1.0, x, nothing)
',3))]),s[17]||(s[17]=t("h2",{id:"Extensions-to-@test",tabindex:"-1"},[i("Extensions to "),t("code",null,"@test"),i(),t("a",{class:"header-anchor",href:"#Extensions-to-@test","aria-label":'Permalink to "Extensions to `@test` {#Extensions-to-@test}"'},"")],-1)),t("details",E,[t("summary",null,[s[12]||(s[12]=t("a",{id:"LuxTestUtils.@test_softfail",href:"#LuxTestUtils.@test_softfail"},[t("span",{class:"jlbinding"},"LuxTestUtils.@test_softfail")],-1)),s[13]||(s[13]=i()),l(a,{type:"info",class:"jlObjectType jlMacro",text:"Macro"})]),s[14]||(s[14]=e('
julia
@test_softfail expr
Evaluate expr and record a test result. If expr throws an exception, the test result will be recorded as an error. If expr returns a value, and it is not a boolean, the test result will be recorded as an error.
If the test result is false then the test will be recorded as a broken test, else it will be recorded as a pass.
',3))}const _r=k(mr,[["render",hr]]),br=b({__name:"VersionPicker",props:{screenMenu:{type:Boolean}},setup(o){const e=T([]),t=T("Versions"),s=T(!1);Le();const n=()=>typeof window<"u"&&(window.location.hostname==="localhost"||window.location.hostname==="127.0.0.1"),r=()=>{if(typeof window>"u")return"";const{origin:d,pathname:h}=window.location;if(d.includes("github.io")){const P=h.split("/").filter(Boolean),_=P.length>0?`/${P[0]}/`:"/";return`${d}${_}`}else return d},c=()=>new Promise(d=>{if(n()){d(!1);return}const h=setInterval(()=>{window.DOC_VERSIONS&&window.DOCUMENTER_CURRENT_VERSION&&(clearInterval(h),d(!0))},100);setTimeout(()=>{clearInterval(h),d(!1)},5e3)});return R(async()=>{if(!(typeof window>"u")){try{if(n()){const d=["dev"];e.value=d.map(h=>({text:h,link:"/"})),t.value="dev"}else{const d=await c(),h=y(()=>r());if(d&&window.DOC_VERSIONS&&window.DOCUMENTER_CURRENT_VERSION)e.value=window.DOC_VERSIONS.map(P=>({text:P,link:`${h.value}/${P}/`})),t.value=window.DOCUMENTER_CURRENT_VERSION;else{const P=["dev"];e.value=P.map(_=>({text:_,link:`${h.value}/${_}/`})),t.value="dev"}}}catch(d){console.warn("Error loading versions:",d);const h=["dev"],P=y(()=>r());e.value=h.map(_=>({text:_,link:`${P.value}/${_}/`})),t.value="dev"}s.value=!0}}),(d,h)=>s.value?(a(),l(M,{key:0},[!d.screenMenu&&e.value.length>0?(a(),$(Fe,{key:0,item:{text:t.value,items:e.value},class:"VPVersionPicker"},null,8,["item"])):d.screenMenu&&e.value.length>0?(a(),$(Ue,{key:1,text:t.value,items:e.value,class:"VPVersionPicker"},null,8,["text","items"])):m("",!0)],64)):m("",!0)}}),kr=k(br,[["__scopeId","data-v-d483b3a6"]]),gr=o=>{if(typeof document>"u")return{stabilizeScrollPosition:n=>async(...r)=>n(...r)};const e=document.documentElement;return{stabilizeScrollPosition:s=>async(...n)=>{const r=s(...n),c=o.value;if(!c)return r;const v=c.offsetTop-e.scrollTop;return await Ie(),e.scrollTop=c.offsetTop-v,r}}},ze="vitepress:tabSharedState",J=typeof localStorage<"u"?localStorage:null,je="vitepress:tabsSharedState",$r=()=>{const o=J==null?void 0:J.getItem(je);if(o)try{return JSON.parse(o)}catch{}return{}},yr=o=>{J&&J.setItem(je,JSON.stringify(o))},Pr=o=>{const e=ot({});D(()=>e.content,(t,s)=>{t&&s&&yr(t)},{deep:!0}),o.provide(ze,e)},Vr=(o,e)=>{const t=W(ze);if(!t)throw new Error("[vitepress-plugin-tabs] TabsSharedState should be injected");R(()=>{t.content||(t.content=$r())});const s=T(),n=y({get(){var d;const c=e.value,v=o.value;if(c){const h=(d=t.content)==null?void 0:d[c];if(h&&v.includes(h))return h}else{const h=s.value;if(h)return h}return v[0]},set(c){const v=e.value;v?t.content&&(t.content[v]=c):s.value=c}});return{selected:n,select:c=>{n.value=c}}};let Se=0;const Sr=()=>(Se++,""+Se);function Lr(){const o=Ae();return y(()=>{var s;const t=(s=o.default)==null?void 0:s.call(o);return t?t.filter(n=>typeof n.type=="object"&&"__name"in n.type&&n.type.__name==="PluginTabsTab"&&n.props).map(n=>{var r;return(r=n.props)==null?void 0:r.label}):[]})}const Ge="vitepress:tabSingleState",Tr=o=>{he(Ge,o)},wr=()=>{const o=W(Ge);if(!o)throw new Error("[vitepress-plugin-tabs] TabsSingleState should be injected");return o},Nr={class:"plugin-tabs"},Ir=["id","aria-selected","aria-controls","tabindex","onClick"],Mr=b({__name:"PluginTabs",props:{sharedStateKey:{}},setup(o){const e=o,t=Lr(),{selected:s,select:n}=Vr(t,at(e,"sharedStateKey")),r=T(),{stabilizeScrollPosition:c}=gr(r),v=c(n),d=T([]),h=_=>{var C;const V=t.value.indexOf(s.value);let S;_.key==="ArrowLeft"?S=V>=1?V-1:t.value.length-1:_.key==="ArrowRight"&&(S=V(a(),l("div",Nr,[p("div",{ref_key:"tablist",ref:r,class:"plugin-tabs--tab-list",role:"tablist",onKeydown:h},[(a(!0),l(M,null,B(i(t),S=>(a(),l("button",{id:`tab-${S}-${i(P)}`,ref_for:!0,ref_key:"buttonRefs",ref:d,key:S,role:"tab",class:"plugin-tabs--tab","aria-selected":S===i(s),"aria-controls":`panel-${S}-${i(P)}`,tabindex:S===i(s)?0:-1,onClick:()=>i(v)(S)},w(S),9,Ir))),128))],544),u(_.$slots,"default")]))}}),Cr=["id","aria-labelledby"],Ar=b({__name:"PluginTabsTab",props:{label:{}},setup(o){const{uid:e,selected:t}=wr();return(s,n)=>i(t)===s.label?(a(),l("div",{key:0,id:`panel-${s.label}-${i(e)}`,class:"plugin-tabs--content",role:"tabpanel",tabindex:"0","aria-labelledby":`tab-${s.label}-${i(e)}`},[u(s.$slots,"default",{},void 0,!0)],8,Cr)):m("",!0)}}),Br=k(Ar,[["__scopeId","data-v-9b0d03d2"]]),Er=o=>{Pr(o),o.component("PluginTabs",Mr),o.component("PluginTabsTab",Br)},Ur={extends:Ve,Layout(){return ye(Ve.Layout,null,{"aside-ads-before":()=>ye(_r)})},enhanceApp({app:o}){Er(o),o.component("VersionPicker",kr)}};export{Ur as R,Dr as V,Fr as a,Rr as b,Or as c,yo as d,L as u};
diff --git a/previews/PR1023/assets/ecosystem.md.CN8yStnf.js b/previews/PR1023/assets/ecosystem.md.CN8yStnf.js
new file mode 100644
index 0000000000..b35b946dd1
--- /dev/null
+++ b/previews/PR1023/assets/ecosystem.md.CN8yStnf.js
@@ -0,0 +1 @@
+import{V as d,a as s,b as l,c as f}from"./chunks/theme.BkveSZIg.js";import{c as k,G as t,w as a,k as e,o as L,a as n}from"./chunks/framework.DFwXuivk.js";const x=JSON.parse('{"title":"","description":"","frontmatter":{"layout":"page"},"headers":[],"relativePath":"ecosystem.md","filePath":"ecosystem.md","lastUpdated":null}'),j={name:"ecosystem.md"},J=Object.assign(j,{setup(v){const o=[{avatar:"https://github.com/SciML.png",name:"DiffEqFlux.jl",desc:"Universal neural differential equations with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods",links:[{icon:"github",link:"https://github.com/SciML/DiffEqFlux.jl"}]},{avatar:"https://github.com/SciML.png",name:"SciMLSensitivity.jl",desc:"A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.",links:[{icon:"github",link:"https://github.com/SciML/SciMLSensitivity.jl"}]},{avatar:"https://github.com/SciML.png",name:"NeuralPDE.jl",desc:"Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation",links:[{icon:"github",link:"https://github.com/SciML/NeuralPDE.jl"}]},{avatar:"https://github.com/SciML.png",name:"NeuralLyapunov.jl",desc:"A library for searching for neural Lyapunov functions in Julia",links:[{icon:"github",link:"https://github.com/SciML/NeuralLyapunov.jl"}]},{avatar:"https://github.com/SciML.png",name:"DeepEquilibriumNetworks.jl",desc:"Implicit Layer Machine Learning via Deep Equilibrium Networks, O(1) backpropagation with accelerated convergence",links:[{icon:"github",link:"https://github.com/SciML/DeepEquilibriumNetworks.jl"}]},{avatar:"https://github.com/CosmologicalEmulators.png",name:"AbstractCosmologicalEmulators.jl",desc:"Repository containing the abstract interface to the emulators used in the CosmologicalEmulators organization",links:[{icon:"github",link:"https://github.com/CosmologicalEmulators/AbstractCosmologicalEmulators.jl"}]},{avatar:"https://github.com/impICNF.png",name:"ContinuousNormalizingFlows.jl",desc:"Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia",links:[{icon:"github",link:"https://github.com/impICNF/ContinuousNormalizingFlows.jl"}]},{avatar:"https://github.com/YichengDWu.png",name:"Sophon.jl",desc:"Efficient, Accurate, and Streamlined Training of Physics-Informed Neural Networks",links:[{icon:"github",link:"https://github.com/YichengDWu/Sophon.jl"}]},{avatar:"https://github.com/SciML.png",name:"DataDrivenDiffEq.jl",desc:"Data driven modeling and automated discovery of dynamical systems for the SciML Scientific Machine Learning organization",links:[{icon:"github",link:"https://github.com/SciML/DataDrivenDiffEq.jl"}]},{avatar:"https://github.com/YichengDWu.png",name:"NeuralGraphPDE.jl",desc:"Integrating Neural Ordinary Differential Equations, the Method of Lines, and Graph Neural Networks",links:[{icon:"github",link:"https://github.com/YichengDWu/NeuralGraphPDE.jl"}]},{avatar:"https://github.com/vavrines.png",name:"Solaris.jl",desc:"Lightweight module for fusing physical and neural models",links:[{icon:"github",link:"https://github.com/vavrines/Solaris.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"Boltz.jl",desc:" Accelerate your ML research using pre-built Deep Learning Models with Lux",links:[{icon:"github",link:"https://github.com/LuxDL/Boltz.jl"}]},{avatar:"https://github.com/JuliaGNI.png",name:"GeometricMachineLearning.jl",desc:"Structure Preserving Machine Learning Models in Julia",links:[{icon:"github",link:"https://github.com/JuliaGNI/GeometricMachineLearning.jl"}]},{avatar:"https://as1.ftcdn.net/jpg/01/09/84/42/220_F_109844212_NnLGUrn3RgMHQIuqSiLGlc9d419eK2dX.jpg",name:"Want to Add Your Package?",desc:'Open a PR in LuxDL/Lux.jl'}],r=[{avatar:"https://github.com/FluxML.png",name:"Zygote.jl",desc:"Lux.jl default choice for AD",links:[{icon:"github",link:"https://github.com/FluxML/Zygote.jl"}]},{avatar:"https://github.com/FluxML.png",name:"Tracker.jl",desc:"Well tested and robust AD library (might fail on edge cases)",links:[{icon:"github",link:"https://github.com/FluxML/Tracker.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"ForwardDiff.jl",desc:"For forward mode AD support",links:[{icon:"github",link:"https://github.com/JuliaDiff/ForwardDiff.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"ReverseDiff.jl",desc:"Tape based reverse mode AD (might fail on edge cases and doesn't work on GPU)",links:[{icon:"github",link:"https://github.com/JuliaDiff/ReverseDiff.jl"}]},{avatar:"https://github.com/EnzymeAD.png",name:"Enzyme.jl",desc:"Experimental Support but will become the Future Default",links:[{icon:"github",link:"https://github.com/EnzymeAD/Enzyme.jl"}]}],u=[{avatar:"https://github.com/JuliaML.png",name:"MLUtils.jl",desc:"Utilities and abstractions for Machine Learning tasks",links:[{icon:"github",link:"https://github.com/JuliaML/MLUtils.jl"}]},{avatar:"https://github.com/JuliaML.png",name:"MLDatasets.jl",desc:"Utility package for accessing common Machine Learning datasets in Julia",links:[{icon:"github",link:"https://github.com/JuliaML/MLDatasets.jl"}]},{avatar:"https://github.com/JuliaImages.png",name:"Images.jl",desc:"An image library for Julia",links:[{icon:"github",link:"https://github.com/JuliaImages/Images.jl"}]},{avatar:"https://github.com/FluxML.png",name:"DataAugmentation.jl",desc:"Flexible data augmentation library for machine and deep learning",links:[{icon:"github",link:"https://github.com/FluxML/DataAugmentation.jl"}]}],m=[{avatar:"https://github.com/FluxML.png",name:"NNlib.jl",desc:"Neural Network primitives with multiple backends",links:[{icon:"github",link:"https://github.com/FluxML/NNlib.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"LuxLib.jl",desc:"Backend for Lux.jl",links:[{icon:"github",link:"https://github.com/LuxDL/tree/main/lib/LuxLib.jl"}]}],c=[{avatar:"https://github.com/SciML.png",name:"Optimization.jl",desc:"Unified API for Optimization in Julia",links:[{icon:"github",link:"https://github.com/SciML/Optimization.jl"}]},{avatar:"https://github.com/FluxML.png",name:"Optimisers.jl",desc:"Optimisers.jl defines many standard optimisers and utilities for learning loops",links:[{icon:"github",link:"https://github.com/FluxML/Optimisers.jl"}]},{avatar:"https://github.com/FluxML.png",name:"ParameterSchedulers.jl",desc:"Common hyperparameter scheduling for ML",links:[{icon:"github",link:"https://github.com/FluxML/ParameterSchedulers.jl"}]}],g=[{avatar:"https://github.com/FluxML.png",name:"Functors.jl",desc:"Parameterise all the things",links:[{icon:"github",link:"https://github.com/FluxML/Functors.jl"}]},{avatar:"https://github.com/jonniedie.png",name:"ComponentArrays.jl",desc:"Arrays with arbitrarily nested named components",links:[{icon:"github",link:"https://github.com/jonniedie/ComponentArrays.jl"}]}],h=[{avatar:"https://github.com/JuliaLang.png",name:"Serialization.jl",desc:"Provides serialization of Julia objects",links:[{icon:"github",link:"https://github.com/JuliaLang/julia/tree/master/stdlib/Serialization"}]},{avatar:"https://github.com/JuliaIO.png",name:"JLD2.jl",desc:"HDF5-compatible file format in pure Julia",links:[{icon:"github",link:"https://github.com/JuliaIO/JLD2.jl"}]}],p=[{avatar:"https://github.com/JuliaDiff.png",name:"FiniteDiff.jl",desc:"Fast non-allocating calculations of gradients, Jacobians, and Hessians with sparsity support",links:[{icon:"github",link:"https://github.com/JuliaDiff/FiniteDiff.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"FiniteDifferences.jl",desc:"High accuracy derivatives, estimated via numerical finite differences (formerly FDM.jl)",links:[{icon:"github",link:"https://github.com/JuliaDiff/FiniteDifferences.jl"}]},{avatar:"https://github.com/aviatesk.png",name:"JET.jl",desc:"JET employs Julia's type inference system to detect potential bugs and type instabilities",links:[{icon:"github",link:"https://github.com/aviatesk/JET.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"LuxTestUtils.jl",desc:"Collection of Functions useful for testing various packages in the Lux Ecosystem",links:[{icon:"github",link:"https://github.com/LuxDL/tree/main/lib/LuxTestUtils"}]}],b=[{avatar:"https://github.com/JuliaAI.png",name:"MLFlowClient.jl",desc:"Julia client for MLFlow",links:[{icon:"github",link:"https://github.com/JuliaAI/MLFlowClient.jl"}]},{avatar:"https://github.com/JuliaLogging.png",name:"TensorBoardLogger.jl",desc:"Easy peasy logging to TensorBoard with Julia",links:[{icon:"github",link:"https://github.com/JuliaLogging/TensorBoardLogger.jl"}]},{avatar:"https://github.com/avik-pal.png",name:"Wandb.jl",desc:"Unofficial Julia bindings for logging experiments to wandb.ai",links:[{icon:"github",link:"https://github.com/avik-pal/Wandb.jl"}]}];return(D,i)=>(L(),k("div",null,[t(e(f),null,{default:a(()=>[t(e(d),null,{title:a(()=>i[0]||(i[0]=[n("Ecosystem")])),_:1}),t(e(s),null,{title:a(()=>i[1]||(i[1]=[n("Frameworks Extending Lux.jl")])),members:a(()=>[t(e(l),{size:"small",members:o})]),_:1}),t(e(s),null,{title:a(()=>i[2]||(i[2]=[n("Automatic Differentiation")])),members:a(()=>[t(e(l),{size:"small",members:r})]),_:1}),t(e(s),null,{title:a(()=>i[3]||(i[3]=[n("Data Manipulation, Data Loading & Datasets")])),members:a(()=>[t(e(l),{size:"small",members:u})]),_:1}),t(e(s),null,{title:a(()=>i[4]||(i[4]=[n("Neural Network Primitives")])),members:a(()=>[t(e(l),{size:"small",members:m})]),_:1}),t(e(s),null,{title:a(()=>i[5]||(i[5]=[n("Optimization")])),members:a(()=>[t(e(l),{size:"small",members:c})]),_:1}),t(e(s),null,{title:a(()=>i[6]||(i[6]=[n("Parameter Manipulation")])),members:a(()=>[t(e(l),{size:"small",members:g})]),_:1}),t(e(s),null,{title:a(()=>i[7]||(i[7]=[n("Serialization")])),members:a(()=>[t(e(l),{size:"small",members:h})]),_:1}),t(e(s),null,{title:a(()=>i[8]||(i[8]=[n("Testing Utilities")])),members:a(()=>[t(e(l),{size:"small",members:p})]),_:1}),t(e(s),null,{title:a(()=>i[9]||(i[9]=[n("Training Visualization & Logging")])),members:a(()=>[t(e(l),{size:"small",members:b})]),_:1})]),_:1})]))}});export{x as __pageData,J as default};
diff --git a/previews/PR1023/assets/ecosystem.md.CN8yStnf.lean.js b/previews/PR1023/assets/ecosystem.md.CN8yStnf.lean.js
new file mode 100644
index 0000000000..b35b946dd1
--- /dev/null
+++ b/previews/PR1023/assets/ecosystem.md.CN8yStnf.lean.js
@@ -0,0 +1 @@
+import{V as d,a as s,b as l,c as f}from"./chunks/theme.BkveSZIg.js";import{c as k,G as t,w as a,k as e,o as L,a as n}from"./chunks/framework.DFwXuivk.js";const x=JSON.parse('{"title":"","description":"","frontmatter":{"layout":"page"},"headers":[],"relativePath":"ecosystem.md","filePath":"ecosystem.md","lastUpdated":null}'),j={name:"ecosystem.md"},J=Object.assign(j,{setup(v){const o=[{avatar:"https://github.com/SciML.png",name:"DiffEqFlux.jl",desc:"Universal neural differential equations with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods",links:[{icon:"github",link:"https://github.com/SciML/DiffEqFlux.jl"}]},{avatar:"https://github.com/SciML.png",name:"SciMLSensitivity.jl",desc:"A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.",links:[{icon:"github",link:"https://github.com/SciML/SciMLSensitivity.jl"}]},{avatar:"https://github.com/SciML.png",name:"NeuralPDE.jl",desc:"Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation",links:[{icon:"github",link:"https://github.com/SciML/NeuralPDE.jl"}]},{avatar:"https://github.com/SciML.png",name:"NeuralLyapunov.jl",desc:"A library for searching for neural Lyapunov functions in Julia",links:[{icon:"github",link:"https://github.com/SciML/NeuralLyapunov.jl"}]},{avatar:"https://github.com/SciML.png",name:"DeepEquilibriumNetworks.jl",desc:"Implicit Layer Machine Learning via Deep Equilibrium Networks, O(1) backpropagation with accelerated convergence",links:[{icon:"github",link:"https://github.com/SciML/DeepEquilibriumNetworks.jl"}]},{avatar:"https://github.com/CosmologicalEmulators.png",name:"AbstractCosmologicalEmulators.jl",desc:"Repository containing the abstract interface to the emulators used in the CosmologicalEmulators organization",links:[{icon:"github",link:"https://github.com/CosmologicalEmulators/AbstractCosmologicalEmulators.jl"}]},{avatar:"https://github.com/impICNF.png",name:"ContinuousNormalizingFlows.jl",desc:"Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia",links:[{icon:"github",link:"https://github.com/impICNF/ContinuousNormalizingFlows.jl"}]},{avatar:"https://github.com/YichengDWu.png",name:"Sophon.jl",desc:"Efficient, Accurate, and Streamlined Training of Physics-Informed Neural Networks",links:[{icon:"github",link:"https://github.com/YichengDWu/Sophon.jl"}]},{avatar:"https://github.com/SciML.png",name:"DataDrivenDiffEq.jl",desc:"Data driven modeling and automated discovery of dynamical systems for the SciML Scientific Machine Learning organization",links:[{icon:"github",link:"https://github.com/SciML/DataDrivenDiffEq.jl"}]},{avatar:"https://github.com/YichengDWu.png",name:"NeuralGraphPDE.jl",desc:"Integrating Neural Ordinary Differential Equations, the Method of Lines, and Graph Neural Networks",links:[{icon:"github",link:"https://github.com/YichengDWu/NeuralGraphPDE.jl"}]},{avatar:"https://github.com/vavrines.png",name:"Solaris.jl",desc:"Lightweight module for fusing physical and neural models",links:[{icon:"github",link:"https://github.com/vavrines/Solaris.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"Boltz.jl",desc:" Accelerate your ML research using pre-built Deep Learning Models with Lux",links:[{icon:"github",link:"https://github.com/LuxDL/Boltz.jl"}]},{avatar:"https://github.com/JuliaGNI.png",name:"GeometricMachineLearning.jl",desc:"Structure Preserving Machine Learning Models in Julia",links:[{icon:"github",link:"https://github.com/JuliaGNI/GeometricMachineLearning.jl"}]},{avatar:"https://as1.ftcdn.net/jpg/01/09/84/42/220_F_109844212_NnLGUrn3RgMHQIuqSiLGlc9d419eK2dX.jpg",name:"Want to Add Your Package?",desc:'Open a PR in LuxDL/Lux.jl'}],r=[{avatar:"https://github.com/FluxML.png",name:"Zygote.jl",desc:"Lux.jl default choice for AD",links:[{icon:"github",link:"https://github.com/FluxML/Zygote.jl"}]},{avatar:"https://github.com/FluxML.png",name:"Tracker.jl",desc:"Well tested and robust AD library (might fail on edge cases)",links:[{icon:"github",link:"https://github.com/FluxML/Tracker.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"ForwardDiff.jl",desc:"For forward mode AD support",links:[{icon:"github",link:"https://github.com/JuliaDiff/ForwardDiff.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"ReverseDiff.jl",desc:"Tape based reverse mode AD (might fail on edge cases and doesn't work on GPU)",links:[{icon:"github",link:"https://github.com/JuliaDiff/ReverseDiff.jl"}]},{avatar:"https://github.com/EnzymeAD.png",name:"Enzyme.jl",desc:"Experimental Support but will become the Future Default",links:[{icon:"github",link:"https://github.com/EnzymeAD/Enzyme.jl"}]}],u=[{avatar:"https://github.com/JuliaML.png",name:"MLUtils.jl",desc:"Utilities and abstractions for Machine Learning tasks",links:[{icon:"github",link:"https://github.com/JuliaML/MLUtils.jl"}]},{avatar:"https://github.com/JuliaML.png",name:"MLDatasets.jl",desc:"Utility package for accessing common Machine Learning datasets in Julia",links:[{icon:"github",link:"https://github.com/JuliaML/MLDatasets.jl"}]},{avatar:"https://github.com/JuliaImages.png",name:"Images.jl",desc:"An image library for Julia",links:[{icon:"github",link:"https://github.com/JuliaImages/Images.jl"}]},{avatar:"https://github.com/FluxML.png",name:"DataAugmentation.jl",desc:"Flexible data augmentation library for machine and deep learning",links:[{icon:"github",link:"https://github.com/FluxML/DataAugmentation.jl"}]}],m=[{avatar:"https://github.com/FluxML.png",name:"NNlib.jl",desc:"Neural Network primitives with multiple backends",links:[{icon:"github",link:"https://github.com/FluxML/NNlib.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"LuxLib.jl",desc:"Backend for Lux.jl",links:[{icon:"github",link:"https://github.com/LuxDL/tree/main/lib/LuxLib.jl"}]}],c=[{avatar:"https://github.com/SciML.png",name:"Optimization.jl",desc:"Unified API for Optimization in Julia",links:[{icon:"github",link:"https://github.com/SciML/Optimization.jl"}]},{avatar:"https://github.com/FluxML.png",name:"Optimisers.jl",desc:"Optimisers.jl defines many standard optimisers and utilities for learning loops",links:[{icon:"github",link:"https://github.com/FluxML/Optimisers.jl"}]},{avatar:"https://github.com/FluxML.png",name:"ParameterSchedulers.jl",desc:"Common hyperparameter scheduling for ML",links:[{icon:"github",link:"https://github.com/FluxML/ParameterSchedulers.jl"}]}],g=[{avatar:"https://github.com/FluxML.png",name:"Functors.jl",desc:"Parameterise all the things",links:[{icon:"github",link:"https://github.com/FluxML/Functors.jl"}]},{avatar:"https://github.com/jonniedie.png",name:"ComponentArrays.jl",desc:"Arrays with arbitrarily nested named components",links:[{icon:"github",link:"https://github.com/jonniedie/ComponentArrays.jl"}]}],h=[{avatar:"https://github.com/JuliaLang.png",name:"Serialization.jl",desc:"Provides serialization of Julia objects",links:[{icon:"github",link:"https://github.com/JuliaLang/julia/tree/master/stdlib/Serialization"}]},{avatar:"https://github.com/JuliaIO.png",name:"JLD2.jl",desc:"HDF5-compatible file format in pure Julia",links:[{icon:"github",link:"https://github.com/JuliaIO/JLD2.jl"}]}],p=[{avatar:"https://github.com/JuliaDiff.png",name:"FiniteDiff.jl",desc:"Fast non-allocating calculations of gradients, Jacobians, and Hessians with sparsity support",links:[{icon:"github",link:"https://github.com/JuliaDiff/FiniteDiff.jl"}]},{avatar:"https://github.com/JuliaDiff.png",name:"FiniteDifferences.jl",desc:"High accuracy derivatives, estimated via numerical finite differences (formerly FDM.jl)",links:[{icon:"github",link:"https://github.com/JuliaDiff/FiniteDifferences.jl"}]},{avatar:"https://github.com/aviatesk.png",name:"JET.jl",desc:"JET employs Julia's type inference system to detect potential bugs and type instabilities",links:[{icon:"github",link:"https://github.com/aviatesk/JET.jl"}]},{avatar:"https://github.com/LuxDL.png",name:"LuxTestUtils.jl",desc:"Collection of Functions useful for testing various packages in the Lux Ecosystem",links:[{icon:"github",link:"https://github.com/LuxDL/tree/main/lib/LuxTestUtils"}]}],b=[{avatar:"https://github.com/JuliaAI.png",name:"MLFlowClient.jl",desc:"Julia client for MLFlow",links:[{icon:"github",link:"https://github.com/JuliaAI/MLFlowClient.jl"}]},{avatar:"https://github.com/JuliaLogging.png",name:"TensorBoardLogger.jl",desc:"Easy peasy logging to TensorBoard with Julia",links:[{icon:"github",link:"https://github.com/JuliaLogging/TensorBoardLogger.jl"}]},{avatar:"https://github.com/avik-pal.png",name:"Wandb.jl",desc:"Unofficial Julia bindings for logging experiments to wandb.ai",links:[{icon:"github",link:"https://github.com/avik-pal/Wandb.jl"}]}];return(D,i)=>(L(),k("div",null,[t(e(f),null,{default:a(()=>[t(e(d),null,{title:a(()=>i[0]||(i[0]=[n("Ecosystem")])),_:1}),t(e(s),null,{title:a(()=>i[1]||(i[1]=[n("Frameworks Extending Lux.jl")])),members:a(()=>[t(e(l),{size:"small",members:o})]),_:1}),t(e(s),null,{title:a(()=>i[2]||(i[2]=[n("Automatic Differentiation")])),members:a(()=>[t(e(l),{size:"small",members:r})]),_:1}),t(e(s),null,{title:a(()=>i[3]||(i[3]=[n("Data Manipulation, Data Loading & Datasets")])),members:a(()=>[t(e(l),{size:"small",members:u})]),_:1}),t(e(s),null,{title:a(()=>i[4]||(i[4]=[n("Neural Network Primitives")])),members:a(()=>[t(e(l),{size:"small",members:m})]),_:1}),t(e(s),null,{title:a(()=>i[5]||(i[5]=[n("Optimization")])),members:a(()=>[t(e(l),{size:"small",members:c})]),_:1}),t(e(s),null,{title:a(()=>i[6]||(i[6]=[n("Parameter Manipulation")])),members:a(()=>[t(e(l),{size:"small",members:g})]),_:1}),t(e(s),null,{title:a(()=>i[7]||(i[7]=[n("Serialization")])),members:a(()=>[t(e(l),{size:"small",members:h})]),_:1}),t(e(s),null,{title:a(()=>i[8]||(i[8]=[n("Testing Utilities")])),members:a(()=>[t(e(l),{size:"small",members:p})]),_:1}),t(e(s),null,{title:a(()=>i[9]||(i[9]=[n("Training Visualization & Logging")])),members:a(()=>[t(e(l),{size:"small",members:b})]),_:1})]),_:1})]))}});export{x as __pageData,J as default};
diff --git a/previews/PR1023/assets/index.md.B7q9VhVT.js b/previews/PR1023/assets/index.md.B7q9VhVT.js
new file mode 100644
index 0000000000..f54e7abfc6
--- /dev/null
+++ b/previews/PR1023/assets/index.md.B7q9VhVT.js
@@ -0,0 +1,27 @@
+import{_ as i,c as a,a2 as t,o as n}from"./chunks/framework.DFwXuivk.js";const o=JSON.parse('{"title":"","description":"","frontmatter":{"layout":"home","hero":{"name":"LuxDL Docs","text":"Elegant & Performant Scientific Machine Learning in JuliaLang","tagline":"A Pure Julia Deep Learning Framework designed for Scientific Machine Learning","actions":[{"theme":"brand","text":"Tutorials","link":"/tutorials"},{"theme":"alt","text":"Ecosystem","link":"/ecosystem"},{"theme":"alt","text":"API Reference 📚","link":"/api/Lux/layers"},{"theme":"alt","text":"View on GitHub","link":"https://github.com/LuxDL/Lux.jl"}],"image":{"src":"/lux-logo.svg","alt":"Lux.jl"}},"features":[{"icon":"🚀","title":"Fast & Extendible","details":"Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal Hardware.","link":"/introduction"},{"icon":"🧑🔬","title":"SciML ❤️ Lux","details":"Lux is the default choice for many SciML packages, including DiffEqFlux.jl, NeuralPDE.jl, and more.","link":"https://sciml.ai/"},{"icon":"🧩","title":"Uniquely Composable","details":"Lux.jl natively supports Arbitrary Parameter Types, making it uniquely composable with other Julia packages (and even Non-Julia packages).","link":"/api/Lux/contrib#Training"},{"icon":"🧪","title":"Well Tested","details":"Lux.jl tests every supported Automatic Differentiation Framework with every supported hardware backend against Finite Differences to prevent sneaky 🐛 in your code.","link":"/api/Testing_Functionality/LuxTestUtils"}]},"headers":[],"relativePath":"index.md","filePath":"index.md","lastUpdated":null}'),e={name:"index.md"};function l(p,s,h,k,d,r){return n(),a("div",null,s[0]||(s[0]=[t(`
Its easy to install Lux.jl. Since Lux.jl is registered in the Julia General registry, you can simply run the following command in the Julia REPL:
julia
julia> using Pkg
+julia> Pkg.add("Lux")
If you want to use the latest unreleased version of Lux.jl, you can run the following command: (in most cases the released version will be same as the version on github)
julia
julia> using Pkg
+julia> Pkg.add(url="https://github.com/LuxDL/Lux.jl")
using Reactant, Lux
+Reactant.set_default_backend("cpu") # default
+
+const dev = xla_device()
julia
using Reactant, Lux
+Reactant.set_default_backend("gpu")
+
+const dev = xla_device()
julia
using Reactant, Lux
+Reactant.set_default_backend("tpu")
+
+const dev = xla_device()
`,15)]))}const E=i(e,[["render",l]]);export{o as __pageData,E as default};
diff --git a/previews/PR1023/assets/index.md.B7q9VhVT.lean.js b/previews/PR1023/assets/index.md.B7q9VhVT.lean.js
new file mode 100644
index 0000000000..f54e7abfc6
--- /dev/null
+++ b/previews/PR1023/assets/index.md.B7q9VhVT.lean.js
@@ -0,0 +1,27 @@
+import{_ as i,c as a,a2 as t,o as n}from"./chunks/framework.DFwXuivk.js";const o=JSON.parse('{"title":"","description":"","frontmatter":{"layout":"home","hero":{"name":"LuxDL Docs","text":"Elegant & Performant Scientific Machine Learning in JuliaLang","tagline":"A Pure Julia Deep Learning Framework designed for Scientific Machine Learning","actions":[{"theme":"brand","text":"Tutorials","link":"/tutorials"},{"theme":"alt","text":"Ecosystem","link":"/ecosystem"},{"theme":"alt","text":"API Reference 📚","link":"/api/Lux/layers"},{"theme":"alt","text":"View on GitHub","link":"https://github.com/LuxDL/Lux.jl"}],"image":{"src":"/lux-logo.svg","alt":"Lux.jl"}},"features":[{"icon":"🚀","title":"Fast & Extendible","details":"Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal Hardware.","link":"/introduction"},{"icon":"🧑🔬","title":"SciML ❤️ Lux","details":"Lux is the default choice for many SciML packages, including DiffEqFlux.jl, NeuralPDE.jl, and more.","link":"https://sciml.ai/"},{"icon":"🧩","title":"Uniquely Composable","details":"Lux.jl natively supports Arbitrary Parameter Types, making it uniquely composable with other Julia packages (and even Non-Julia packages).","link":"/api/Lux/contrib#Training"},{"icon":"🧪","title":"Well Tested","details":"Lux.jl tests every supported Automatic Differentiation Framework with every supported hardware backend against Finite Differences to prevent sneaky 🐛 in your code.","link":"/api/Testing_Functionality/LuxTestUtils"}]},"headers":[],"relativePath":"index.md","filePath":"index.md","lastUpdated":null}'),e={name:"index.md"};function l(p,s,h,k,d,r){return n(),a("div",null,s[0]||(s[0]=[t(`
Its easy to install Lux.jl. Since Lux.jl is registered in the Julia General registry, you can simply run the following command in the Julia REPL:
julia
julia> using Pkg
+julia> Pkg.add("Lux")
If you want to use the latest unreleased version of Lux.jl, you can run the following command: (in most cases the released version will be same as the version on github)
julia
julia> using Pkg
+julia> Pkg.add(url="https://github.com/LuxDL/Lux.jl")
If you found this library to be useful in academic work, then please cite:
bibtex
@software{pal2023lux,
+ author = {Pal, Avik},
+ title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
+ month = {April},
+ year = 2023,
+ note = {If you use this software, please cite it as below.},
+ publisher = {Zenodo},
+ version = {v0.5.0},
+ doi = {10.5281/zenodo.7808904},
+ url = {https://doi.org/10.5281/zenodo.7808904}
+}
bibtex
@thesis{pal2023efficient,
+ title = {{On Efficient Training \\& Inference of Neural Differential Equations}},
+ author = {Pal, Avik},
+ year = {2023},
+ school = {Massachusetts Institute of Technology}
+}
`,4)]))}const g=i(h,[["render",k]]);export{F as __pageData,g as default};
diff --git a/previews/PR1023/assets/introduction_citation.md.CC7Ayz4q.lean.js b/previews/PR1023/assets/introduction_citation.md.CC7Ayz4q.lean.js
new file mode 100644
index 0000000000..a88aabaee5
--- /dev/null
+++ b/previews/PR1023/assets/introduction_citation.md.CC7Ayz4q.lean.js
@@ -0,0 +1,16 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const F=JSON.parse('{"title":"Citation","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/citation.md","filePath":"introduction/citation.md","lastUpdated":null}'),h={name:"introduction/citation.md"};function k(l,s,p,e,E,r){return t(),a("div",null,s[0]||(s[0]=[n(`
If you found this library to be useful in academic work, then please cite:
bibtex
@software{pal2023lux,
+ author = {Pal, Avik},
+ title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
+ month = {April},
+ year = 2023,
+ note = {If you use this software, please cite it as below.},
+ publisher = {Zenodo},
+ version = {v0.5.0},
+ doi = {10.5281/zenodo.7808904},
+ url = {https://doi.org/10.5281/zenodo.7808904}
+}
bibtex
@thesis{pal2023efficient,
+ title = {{On Efficient Training \\& Inference of Neural Differential Equations}},
+ author = {Pal, Avik},
+ year = {2023},
+ school = {Massachusetts Institute of Technology}
+}
`,4)]))}const g=i(h,[["render",k]]);export{F as __pageData,g as default};
diff --git a/previews/PR1023/assets/introduction_index.md.Do6486F0.js b/previews/PR1023/assets/introduction_index.md.Do6486F0.js
new file mode 100644
index 0000000000..18d8e8c3bc
--- /dev/null
+++ b/previews/PR1023/assets/introduction_index.md.Do6486F0.js
@@ -0,0 +1,111 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"Getting Started","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/index.md","filePath":"introduction/index.md","lastUpdated":null}'),l={name:"introduction/index.md"};function e(p,s,h,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do
julia
import Pkg
+Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
julia
# Get the device determined by Lux
+dev = gpu_device()
+
+# Parameter and State Variables
+ps, st = Lux.setup(rng, model) |> dev
+
+# Dummy Input
+x = rand(rng, Float32, 128, 2) |> dev
+
+# Run the model
+y, st = Lux.apply(model, x, ps, st)
+
+# Gradients
+## First construct a TrainState
+train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
+
+## We can compute the gradients using Training.compute_gradients
+gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
+
+## Optimization
+train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no \`!\` at the end)
+
+# Both these steps can be combined into a single call
+gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
using Lux, Random, Optimisers, Zygote
+using LuxCUDA # For CUDA support
+# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
+using Printf # For pretty printing
+
+dev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 4 methods)
We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
julia
n_in = 1
+n_out = 1
+nlayers = 3
+
+model = @compact(w1=Dense(n_in => 32),
+ w2=[Dense(32 => 32) for i in 1:nlayers],
+ w3=Dense(32 => n_out),
+ act=relu) do x
+ embed = act(w1(x))
+ for w in w2
+ embed = act(w(embed))
+ end
+ out = w3(embed)
+ @return out
+end
LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>).
`,33)]))}const o=i(l,[["render",e]]);export{E as __pageData,o as default};
diff --git a/previews/PR1023/assets/introduction_index.md.Do6486F0.lean.js b/previews/PR1023/assets/introduction_index.md.Do6486F0.lean.js
new file mode 100644
index 0000000000..18d8e8c3bc
--- /dev/null
+++ b/previews/PR1023/assets/introduction_index.md.Do6486F0.lean.js
@@ -0,0 +1,111 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"Getting Started","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/index.md","filePath":"introduction/index.md","lastUpdated":null}'),l={name:"introduction/index.md"};function e(p,s,h,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do
julia
import Pkg
+Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
julia
# Get the device determined by Lux
+dev = gpu_device()
+
+# Parameter and State Variables
+ps, st = Lux.setup(rng, model) |> dev
+
+# Dummy Input
+x = rand(rng, Float32, 128, 2) |> dev
+
+# Run the model
+y, st = Lux.apply(model, x, ps, st)
+
+# Gradients
+## First construct a TrainState
+train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
+
+## We can compute the gradients using Training.compute_gradients
+gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
+
+## Optimization
+train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no \`!\` at the end)
+
+# Both these steps can be combined into a single call
+gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
using Lux, Random, Optimisers, Zygote
+using LuxCUDA # For CUDA support
+# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
+using Printf # For pretty printing
+
+dev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 4 methods)
We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
julia
n_in = 1
+n_out = 1
+nlayers = 3
+
+model = @compact(w1=Dense(n_in => 32),
+ w2=[Dense(32 => 32) for i in 1:nlayers],
+ w3=Dense(32 => n_out),
+ act=relu) do x
+ embed = act(w1(x))
+ for w in w2
+ embed = act(w(embed))
+ end
+ out = w3(embed)
+ @return out
+end
LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>).
`,33)]))}const o=i(l,[["render",e]]);export{E as __pageData,o as default};
diff --git a/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.js b/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.js
new file mode 100644
index 0000000000..c7838a2575
--- /dev/null
+++ b/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.js
@@ -0,0 +1 @@
+import{_ as t,c as r,a2 as a,o as s}from"./chunks/framework.DFwXuivk.js";const m=JSON.parse('{"title":"Why we wrote Lux?","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/overview.md","filePath":"introduction/overview.md","lastUpdated":null}'),i={name:"introduction/overview.md"};function o(n,e,l,u,d,p){return s(),r("div",null,e[0]||(e[0]=[a('
Julia already has quite a few well established Neural Network Frameworks – Flux & KNet. However, certain design elements – Coupled Model and Parameters & Internal Mutations – associated with these frameworks make them less compiler and user friendly. Making changes to address these problems in the respective frameworks would be too disruptive for users. Here comes in Lux: a neural network framework built completely using pure functions to make it both compiler and autodiff friendly.
Layers must be immutable – cannot store any parameter/state but rather store the information to construct them
Layers are pure functions
Layers return a Tuple containing the result and the updated state
Given same inputs the outputs must be same – yes this must hold true even for stochastic functions. Randomness must be controlled using rngs passed in the state.
Easily extensible
Extensive Testing – All layers and features are tested across all supported AD backends across all supported hardware backends.
Neural Networks for SciML: For SciML Applications (Neural ODEs, Deep Equilibrium Models) solvers typically expect a monolithic parameter vector. Flux enables this via its destructure mechanism, but destructure comes with various edge cases and limitations. Lux forces users to make an explicit distinction between state variables and parameter variables to avoid these issues. Also, it comes battery-included for distributed training.
Sensible display of Custom Layers – Ever wanted to see Pytorch like Network printouts or wondered how to extend the pretty printing of Flux's layers? Lux handles all of that by default.
Truly immutable models - No unexpected internal mutations since all layers are implemented as pure functions. All layers are also deterministic given the parameters and state: if a layer is supposed to be stochastic (say Dropout), the state must contain a seed which is then updated after the function call.
Easy Parameter Manipulation – By separating parameter data and layer structures, Lux makes implementing WeightNorm, SpectralNorm, etc. downright trivial. Without this separation, it is much harder to pass such parameters around without mutations which AD systems don't like.
Wider AD Support – Lux has extensive support for most AD systems in julia, while Flux is mostly tied to Zygote (with some initial support for Enzyme).
Small Neural Networks on CPU – Lux is developed for training large neural networks. For smaller architectures, we recommend using SimpleChains.jl or even better use it in conjunction with Lux via ToSimpleChainsAdaptor.
Reliability – We have learned from the mistakes of the past with Flux and everything in our core framework is extensively tested, along with downstream CI to ensure that everything works as expected.
Revising Previous Recommendation about Large Models
Previously we recommended not using Lux for very large models. But we have been making a lot of head-way with Reactant.jl and it would be worthwhile to test larger models with Lux. See compiling Lux models for more information.
',7)]))}const h=t(i,[["render",o]]);export{m as __pageData,h as default};
diff --git a/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.lean.js b/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.lean.js
new file mode 100644
index 0000000000..c7838a2575
--- /dev/null
+++ b/previews/PR1023/assets/introduction_overview.md.DDk7R0pj.lean.js
@@ -0,0 +1 @@
+import{_ as t,c as r,a2 as a,o as s}from"./chunks/framework.DFwXuivk.js";const m=JSON.parse('{"title":"Why we wrote Lux?","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/overview.md","filePath":"introduction/overview.md","lastUpdated":null}'),i={name:"introduction/overview.md"};function o(n,e,l,u,d,p){return s(),r("div",null,e[0]||(e[0]=[a('
Julia already has quite a few well established Neural Network Frameworks – Flux & KNet. However, certain design elements – Coupled Model and Parameters & Internal Mutations – associated with these frameworks make them less compiler and user friendly. Making changes to address these problems in the respective frameworks would be too disruptive for users. Here comes in Lux: a neural network framework built completely using pure functions to make it both compiler and autodiff friendly.
Layers must be immutable – cannot store any parameter/state but rather store the information to construct them
Layers are pure functions
Layers return a Tuple containing the result and the updated state
Given same inputs the outputs must be same – yes this must hold true even for stochastic functions. Randomness must be controlled using rngs passed in the state.
Easily extensible
Extensive Testing – All layers and features are tested across all supported AD backends across all supported hardware backends.
Neural Networks for SciML: For SciML Applications (Neural ODEs, Deep Equilibrium Models) solvers typically expect a monolithic parameter vector. Flux enables this via its destructure mechanism, but destructure comes with various edge cases and limitations. Lux forces users to make an explicit distinction between state variables and parameter variables to avoid these issues. Also, it comes battery-included for distributed training.
Sensible display of Custom Layers – Ever wanted to see Pytorch like Network printouts or wondered how to extend the pretty printing of Flux's layers? Lux handles all of that by default.
Truly immutable models - No unexpected internal mutations since all layers are implemented as pure functions. All layers are also deterministic given the parameters and state: if a layer is supposed to be stochastic (say Dropout), the state must contain a seed which is then updated after the function call.
Easy Parameter Manipulation – By separating parameter data and layer structures, Lux makes implementing WeightNorm, SpectralNorm, etc. downright trivial. Without this separation, it is much harder to pass such parameters around without mutations which AD systems don't like.
Wider AD Support – Lux has extensive support for most AD systems in julia, while Flux is mostly tied to Zygote (with some initial support for Enzyme).
Small Neural Networks on CPU – Lux is developed for training large neural networks. For smaller architectures, we recommend using SimpleChains.jl or even better use it in conjunction with Lux via ToSimpleChainsAdaptor.
Reliability – We have learned from the mistakes of the past with Flux and everything in our core framework is extensively tested, along with downstream CI to ensure that everything works as expected.
Revising Previous Recommendation about Large Models
Previously we recommended not using Lux for very large models. But we have been making a lot of head-way with Reactant.jl and it would be worthwhile to test larger models with Lux. See compiling Lux models for more information.
',7)]))}const h=t(i,[["render",o]]);export{m as __pageData,h as default};
diff --git a/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.js b/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.js
new file mode 100644
index 0000000000..76f92e8f9b
--- /dev/null
+++ b/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.js
@@ -0,0 +1 @@
+import{_ as t,c as r,a2 as s,o as a}from"./chunks/framework.DFwXuivk.js";const h=JSON.parse('{"title":"Resources to Get Started","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/resources.md","filePath":"introduction/resources.md","lastUpdated":null}'),o={name:"introduction/resources.md"};function i(u,e,n,l,d,c){return a(),r("div",null,e[0]||(e[0]=[s('
Go through the examples sorted based on their complexity in the documentation.
Have More Questions?
For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use Github Issues or even better send in a Pull Request.
',3)]))}const g=t(o,[["render",i]]);export{h as __pageData,g as default};
diff --git a/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.lean.js b/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.lean.js
new file mode 100644
index 0000000000..76f92e8f9b
--- /dev/null
+++ b/previews/PR1023/assets/introduction_resources.md.JKo7XfzJ.lean.js
@@ -0,0 +1 @@
+import{_ as t,c as r,a2 as s,o as a}from"./chunks/framework.DFwXuivk.js";const h=JSON.parse('{"title":"Resources to Get Started","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/resources.md","filePath":"introduction/resources.md","lastUpdated":null}'),o={name:"introduction/resources.md"};function i(u,e,n,l,d,c){return a(),r("div",null,e[0]||(e[0]=[s('
Go through the examples sorted based on their complexity in the documentation.
Have More Questions?
For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use Github Issues or even better send in a Pull Request.
',3)]))}const g=t(o,[["render",i]]);export{h as __pageData,g as default};
diff --git a/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.js b/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.js
new file mode 100644
index 0000000000..78918500b7
--- /dev/null
+++ b/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.js
@@ -0,0 +1 @@
+import{_ as a,c as o,a2 as i,o as t}from"./chunks/framework.DFwXuivk.js";const h=JSON.parse('{"title":"Updating to Lux v1","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/updating_to_v1.md","filePath":"introduction/updating_to_v1.md","lastUpdated":null}'),r={name:"introduction/updating_to_v1.md"};function n(d,e,s,l,c,u){return t(),o("div",null,e[0]||(e[0]=[i('
Lux v1 is a Major Release, mostly to signify the stability of the API. In this page, we list out a concrete set of changes that need to be made to your code to update to Lux v1. We also list out some new exciting features that were added as part of this release.
AbstractExplicitLayer has been renamed to AbstractLuxLayer.
AbstractExplicitContainerLayer behaviour
This has been renamed to AbstractLuxContainerLayer.
Previously, AbstractExplicitContainerLayer{(:a,)} (i.e. singleton containers) would produce default initial parameters and states without wrapping them in a NamedTuple{(:a,)}. This was inconsistent with non-singleton containers, and was a source of confusion. With v we return (; a = <parameters>) and (; a = <states>) by default. See AbstractLuxWrapperLayer for a replacement of this functionality.
inputsize has been removed since it was ambiguous and not used anywhere.
Changes to outputsize:
Single argument version has been removed. See LuxCore.jl Pull Request 43 for more details on the rationale behind this change.
Fallback implementation has been moved to Lux.jl. (i.e. users using Lux shouldn't see a difference, but if Lux.jl isn't loaded, this function has error.)
Internally this uses a NilArray that is able to compute sizes without actually running the computation.
Functors and Setfield have been made into optional dependencies. Certain LuxCore functionality that rely on these functions, will throw an error if these packages are not loaded.
Introduction of AbstractLuxWrapperLayer. This behaves exactly like the old singleton container. For example, the old AbstractExplicitContainerLayer{(:a,)} is equivalent to AbstractLuxWrapperLayer{:a}.
This was a major release to signify the stability of the API. There were no breaking changes. We do support a wider range of RNG types, see Supported RNG Types for more details.
This is the most aggressive change that was made. We renamed the LuxDeviceUtils.jl package to MLDataDevices.jl, to allow for non-Lux packages to use this shared device management abstraction.
Deprecation of LuxDeviceUtils.jl
This also marks the deprecation of the LuxDeviceUtils.jl package. We won't be making any updates to that package, including fixing any bugs. All users should switch to MLDataDevices.jl instead.
DeviceIterator provides a generalization of CUDA.CuIterator and works for all backends and more data types (using Functors.jl). MLUtils.DataLoader |> gdev now returns a DeviceIterator instead of being a no-op.
Direct reexport of NNlib has been removed. We reexport selected functionality from NNlib. Direactly load NNlib if you need to use the other functions.
Flattening of Chain layers has been removed, and the corresponding disable_optimizations kwarg has been removed.
Some layers overloaded Base.keys, these have been removed. These were mostly un-documented and weren't supposed to be used outside of the Lux.jl package.
disable_stacktrace_truncation! has been removed. From Julia 1.9 onwards, stacktrace truncation is enabled by default.
Certain Experimental features were present outside the Lux.Experimental module. These have been removed, use them via Lux.Experimental instead. Run Julia with with depwarn as error and Lux v0.5 to see the deprecations.
Lux.Experimental.@layer_map is not longer needed and has been removed. The name of the variable prevents writing generic functions and is no longer pre-pended to the KeyPath. See the docstring of Lux.Experimental.layer_map for more details.
allow_fast_activation kwarg has been removed completely. Pass an anonymous function as the activation to prevent internal modivations to the activation function.
Conv and ConvTranspose use an initialization based on the activation function, taken from Pytorch. Pytorch assumes the activation function is leakyrelu to compute the gain, however, we compute the gain based on the activation function passed in to the layer.
Upsample now has an align_corners keyword argument, which defaults to false. Previously this was always true.
Dense and Bilinear have updated default initializations to align with the defaults from Pytorch. See the documentation for more details.
InstanceNorm now defaults to affine=false instead of affine=true.
Embedding now defaults to init_weight=rand32 instead of init_weight=randn32.
Recurrent Cells - RNNCell, LSTMCell, and GRUCell now have different default initializations. See the documentation for more details.
',30)]))}const L=a(r,[["render",n]]);export{h as __pageData,L as default};
diff --git a/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.lean.js b/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.lean.js
new file mode 100644
index 0000000000..78918500b7
--- /dev/null
+++ b/previews/PR1023/assets/introduction_updating_to_v1.md.BI71nsoN.lean.js
@@ -0,0 +1 @@
+import{_ as a,c as o,a2 as i,o as t}from"./chunks/framework.DFwXuivk.js";const h=JSON.parse('{"title":"Updating to Lux v1","description":"","frontmatter":{},"headers":[],"relativePath":"introduction/updating_to_v1.md","filePath":"introduction/updating_to_v1.md","lastUpdated":null}'),r={name:"introduction/updating_to_v1.md"};function n(d,e,s,l,c,u){return t(),o("div",null,e[0]||(e[0]=[i('
Lux v1 is a Major Release, mostly to signify the stability of the API. In this page, we list out a concrete set of changes that need to be made to your code to update to Lux v1. We also list out some new exciting features that were added as part of this release.
AbstractExplicitLayer has been renamed to AbstractLuxLayer.
AbstractExplicitContainerLayer behaviour
This has been renamed to AbstractLuxContainerLayer.
Previously, AbstractExplicitContainerLayer{(:a,)} (i.e. singleton containers) would produce default initial parameters and states without wrapping them in a NamedTuple{(:a,)}. This was inconsistent with non-singleton containers, and was a source of confusion. With v we return (; a = <parameters>) and (; a = <states>) by default. See AbstractLuxWrapperLayer for a replacement of this functionality.
inputsize has been removed since it was ambiguous and not used anywhere.
Changes to outputsize:
Single argument version has been removed. See LuxCore.jl Pull Request 43 for more details on the rationale behind this change.
Fallback implementation has been moved to Lux.jl. (i.e. users using Lux shouldn't see a difference, but if Lux.jl isn't loaded, this function has error.)
Internally this uses a NilArray that is able to compute sizes without actually running the computation.
Functors and Setfield have been made into optional dependencies. Certain LuxCore functionality that rely on these functions, will throw an error if these packages are not loaded.
Introduction of AbstractLuxWrapperLayer. This behaves exactly like the old singleton container. For example, the old AbstractExplicitContainerLayer{(:a,)} is equivalent to AbstractLuxWrapperLayer{:a}.
This was a major release to signify the stability of the API. There were no breaking changes. We do support a wider range of RNG types, see Supported RNG Types for more details.
This is the most aggressive change that was made. We renamed the LuxDeviceUtils.jl package to MLDataDevices.jl, to allow for non-Lux packages to use this shared device management abstraction.
Deprecation of LuxDeviceUtils.jl
This also marks the deprecation of the LuxDeviceUtils.jl package. We won't be making any updates to that package, including fixing any bugs. All users should switch to MLDataDevices.jl instead.
DeviceIterator provides a generalization of CUDA.CuIterator and works for all backends and more data types (using Functors.jl). MLUtils.DataLoader |> gdev now returns a DeviceIterator instead of being a no-op.
Direct reexport of NNlib has been removed. We reexport selected functionality from NNlib. Direactly load NNlib if you need to use the other functions.
Flattening of Chain layers has been removed, and the corresponding disable_optimizations kwarg has been removed.
Some layers overloaded Base.keys, these have been removed. These were mostly un-documented and weren't supposed to be used outside of the Lux.jl package.
disable_stacktrace_truncation! has been removed. From Julia 1.9 onwards, stacktrace truncation is enabled by default.
Certain Experimental features were present outside the Lux.Experimental module. These have been removed, use them via Lux.Experimental instead. Run Julia with with depwarn as error and Lux v0.5 to see the deprecations.
Lux.Experimental.@layer_map is not longer needed and has been removed. The name of the variable prevents writing generic functions and is no longer pre-pended to the KeyPath. See the docstring of Lux.Experimental.layer_map for more details.
allow_fast_activation kwarg has been removed completely. Pass an anonymous function as the activation to prevent internal modivations to the activation function.
Conv and ConvTranspose use an initialization based on the activation function, taken from Pytorch. Pytorch assumes the activation function is leakyrelu to compute the gain, however, we compute the gain based on the activation function passed in to the layer.
Upsample now has an align_corners keyword argument, which defaults to false. Previously this was always true.
Dense and Bilinear have updated default initializations to align with the defaults from Pytorch. See the documentation for more details.
InstanceNorm now defaults to affine=false instead of affine=true.
Embedding now defaults to init_weight=rand32 instead of init_weight=randn32.
Recurrent Cells - RNNCell, LSTMCell, and GRUCell now have different default initializations. See the documentation for more details.
',30)]))}const L=a(r,[["render",n]]);export{h as __pageData,L as default};
diff --git a/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.js b/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.js
new file mode 100644
index 0000000000..d243f85d1f
--- /dev/null
+++ b/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.js
@@ -0,0 +1 @@
+import{_ as t,c as a,a2 as l,o}from"./chunks/framework.DFwXuivk.js";const p=JSON.parse('{"title":"Automatic Differentiation","description":"","frontmatter":{},"headers":[],"relativePath":"manual/autodiff.md","filePath":"manual/autodiff.md","lastUpdated":null}'),r={name:"manual/autodiff.md"};function s(i,e,f,n,d,c){return o(),a("div",null,e[0]||(e[0]=[l('
Lux is not an AD package, but it composes well with most of the AD packages available in the Julia ecosystem. This document lists the current level of support for various AD packages in Lux. Additionally, we provide some convenience functions for working with AD.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for CPU for the time-being. (We are working on faster Enzyme support for CPU)
Use Enzyme.jl, if there are mutations in the code and/or Zygote.jl fails.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for GPU for the time-being. We are working on supporting Enzyme.jl for GPU as well.
Tier I: These packages are fully supported and have been tested extensively. Often have special rules to enhance performance. Issues for these backends take the highest priority.
Tier II: These packages are supported and extensively tested but often don't have the best performance. Issues against these backends are less critical, but we fix them when possible. (Some specific edge cases, especially with AMDGPU, are known to fail here)
Tier III: We don't know if these packages currently work with Lux. We'd love to add tests for these backends, but currently these are not our priority.
Note that ChainRules.jl is not really an AD package, but we have first-class support for packages that use rrules. ↩︎
This feature is supported downstream, but we don't extensively test it to ensure that it works with Lux. ↩︎↩︎↩︎↩︎↩︎↩︎
Currently Enzyme outperforms other AD packages in terms of CPU performance. However, there are some edge cases where it might not work with Lux. We are working on improving the compatibility. Please report any issues you encounter. ↩︎
',11)]))}const u=t(r,[["render",s]]);export{p as __pageData,u as default};
diff --git a/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.lean.js b/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.lean.js
new file mode 100644
index 0000000000..d243f85d1f
--- /dev/null
+++ b/previews/PR1023/assets/manual_autodiff.md.Ba7AkbE9.lean.js
@@ -0,0 +1 @@
+import{_ as t,c as a,a2 as l,o}from"./chunks/framework.DFwXuivk.js";const p=JSON.parse('{"title":"Automatic Differentiation","description":"","frontmatter":{},"headers":[],"relativePath":"manual/autodiff.md","filePath":"manual/autodiff.md","lastUpdated":null}'),r={name:"manual/autodiff.md"};function s(i,e,f,n,d,c){return o(),a("div",null,e[0]||(e[0]=[l('
Lux is not an AD package, but it composes well with most of the AD packages available in the Julia ecosystem. This document lists the current level of support for various AD packages in Lux. Additionally, we provide some convenience functions for working with AD.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for CPU for the time-being. (We are working on faster Enzyme support for CPU)
Use Enzyme.jl, if there are mutations in the code and/or Zygote.jl fails.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for GPU for the time-being. We are working on supporting Enzyme.jl for GPU as well.
Tier I: These packages are fully supported and have been tested extensively. Often have special rules to enhance performance. Issues for these backends take the highest priority.
Tier II: These packages are supported and extensively tested but often don't have the best performance. Issues against these backends are less critical, but we fix them when possible. (Some specific edge cases, especially with AMDGPU, are known to fail here)
Tier III: We don't know if these packages currently work with Lux. We'd love to add tests for these backends, but currently these are not our priority.
Note that ChainRules.jl is not really an AD package, but we have first-class support for packages that use rrules. ↩︎
This feature is supported downstream, but we don't extensively test it to ensure that it works with Lux. ↩︎↩︎↩︎↩︎↩︎↩︎
Currently Enzyme outperforms other AD packages in terms of CPU performance. However, there are some edge cases where it might not work with Lux. We are working on improving the compatibility. Please report any issues you encounter. ↩︎
',11)]))}const u=t(r,[["render",s]]);export{p as __pageData,u as default};
diff --git a/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.js b/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.js
new file mode 100644
index 0000000000..21aa5a8301
--- /dev/null
+++ b/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.js
@@ -0,0 +1,75 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Compiling Lux Models using Reactant.jl","description":"","frontmatter":{},"headers":[],"relativePath":"manual/compiling_lux_models.md","filePath":"manual/compiling_lux_models.md","lastUpdated":null}'),e={name:"manual/compiling_lux_models.md"};function l(p,s,h,k,d,r){return t(),a("div",null,s[0]||(s[0]=[n(`
Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.
Experimental
Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.
julia
using Lux, Reactant, Enzyme, Random, Zygote
+using Functors, Optimisers, Printf
To run it using XLA we need to compile the model. We can do this using the Reactant.@compile macro. Note that the inputs need to be moved to the device using xla_device first.
Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:
Create a device using xla_device. Remember to load Reactant.jl before doing this.
Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use DeviceIterator to move the data loader to the device with an iterator.
`,40)]))}const E=i(e,[["render",l]]);export{g as __pageData,E as default};
diff --git a/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.lean.js b/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.lean.js
new file mode 100644
index 0000000000..21aa5a8301
--- /dev/null
+++ b/previews/PR1023/assets/manual_compiling_lux_models.md.CpD3KwNd.lean.js
@@ -0,0 +1,75 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Compiling Lux Models using Reactant.jl","description":"","frontmatter":{},"headers":[],"relativePath":"manual/compiling_lux_models.md","filePath":"manual/compiling_lux_models.md","lastUpdated":null}'),e={name:"manual/compiling_lux_models.md"};function l(p,s,h,k,d,r){return t(),a("div",null,s[0]||(s[0]=[n(`
Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.
Experimental
Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.
julia
using Lux, Reactant, Enzyme, Random, Zygote
+using Functors, Optimisers, Printf
To run it using XLA we need to compile the model. We can do this using the Reactant.@compile macro. Note that the inputs need to be moved to the device using xla_device first.
Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:
Create a device using xla_device. Remember to load Reactant.jl before doing this.
Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use DeviceIterator to move the data loader to the device with an iterator.
`,40)]))}const E=i(e,[["render",l]]);export{g as __pageData,E as default};
diff --git a/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.js b/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.js
new file mode 100644
index 0000000000..22e620765e
--- /dev/null
+++ b/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.js
@@ -0,0 +1,120 @@
+import{_ as a,c as i,a2 as n,o as e}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Debugging Lux Models","description":"","frontmatter":{},"headers":[],"relativePath":"manual/debugging.md","filePath":"manual/debugging.md","lastUpdated":null}'),t={name:"manual/debugging.md"};function l(p,s,h,k,r,d){return e(),i("div",null,s[0]||(s[0]=[n(`
Debugging DNNs can be very painful. Especially with the gigantic stacktraces for Lux, it is even harder to pin-point to which particular layer errored out. This page describes some useful tools that ship with Lux, that can help you debug your models.
TL;DR
Simply wrap your model with Lux.Experimental.@debug_mode!!
Don't Forget
Remember to use the non Debug mode model after you finish debugging. Debug mode models are way slower.
Let us construct a model which has an obviously incorrect dimension. In this example, you will see how easy it is to pin-point the problematic layer.
Incorrect Model Specification: Dimension Mismatch Problems
julia
using Lux, Random
+
+model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1))
+
+model_debug = Lux.Experimental.@debug_mode model
Note that we can use the parameters and states for model itself in model_debug, no need to make any changes. If you ran the original model this is the kind of error you would see:
Have you encountered those pesky little NaNs in your training? They are very hard to track down. We will create an artificially simulate NaNs in our model and see how we can track the offending layer.
We can set nan_check to :forward, :backward or :both to check for NaNs in the debug model. (or even disable it by setting it to :none)
julia
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)),
+ BatchNorm(1))
+
+ps, st = Lux.setup(rng, model)
+
+model_debug = Lux.Experimental.@debug_mode model nan_check=:both
And we have figured it out! The first NaN occurred in the parameters of model.layers.layer_2.layers.layer_2! But what if NaN occurs in the reverse pass! Let us define a custom layer and introduce a fake NaN in the backward pass.
julia
using ChainRulesCore, Zygote
+
+const CRC = ChainRulesCore
+
+offending_layer(x) = 2 .* x
Let us define a custom backward pass to introduce some NaNs:
julia
function CRC.rrule(::typeof(offending_layer), x)
+ y = offending_layer(x)
+ function ∇offending_layer(Δ)
+ Δ[1] = NaN
+ return NoTangent(), Δ
+ end
+ return y, ∇offending_layer
+end
And there you go our debug layer prints that the problem is in WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2! Once we fix the pullback of the layer, we will fix the NaNs.
In this manual section, we have discussed tracking down errors in Lux models. We have covered tracking incorrect model specifications and NaNs in forward and backward passes. However, remember that this is an Experimental feature, and there might be edge cases that don't work correctly. If you find any such cases, please open an issue on GitHub!
`,49)]))}const c=a(t,[["render",l]]);export{g as __pageData,c as default};
diff --git a/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.lean.js b/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.lean.js
new file mode 100644
index 0000000000..22e620765e
--- /dev/null
+++ b/previews/PR1023/assets/manual_debugging.md.DPbHDP8E.lean.js
@@ -0,0 +1,120 @@
+import{_ as a,c as i,a2 as n,o as e}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Debugging Lux Models","description":"","frontmatter":{},"headers":[],"relativePath":"manual/debugging.md","filePath":"manual/debugging.md","lastUpdated":null}'),t={name:"manual/debugging.md"};function l(p,s,h,k,r,d){return e(),i("div",null,s[0]||(s[0]=[n(`
Debugging DNNs can be very painful. Especially with the gigantic stacktraces for Lux, it is even harder to pin-point to which particular layer errored out. This page describes some useful tools that ship with Lux, that can help you debug your models.
TL;DR
Simply wrap your model with Lux.Experimental.@debug_mode!!
Don't Forget
Remember to use the non Debug mode model after you finish debugging. Debug mode models are way slower.
Let us construct a model which has an obviously incorrect dimension. In this example, you will see how easy it is to pin-point the problematic layer.
Incorrect Model Specification: Dimension Mismatch Problems
julia
using Lux, Random
+
+model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1))
+
+model_debug = Lux.Experimental.@debug_mode model
Note that we can use the parameters and states for model itself in model_debug, no need to make any changes. If you ran the original model this is the kind of error you would see:
Have you encountered those pesky little NaNs in your training? They are very hard to track down. We will create an artificially simulate NaNs in our model and see how we can track the offending layer.
We can set nan_check to :forward, :backward or :both to check for NaNs in the debug model. (or even disable it by setting it to :none)
julia
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)),
+ BatchNorm(1))
+
+ps, st = Lux.setup(rng, model)
+
+model_debug = Lux.Experimental.@debug_mode model nan_check=:both
And we have figured it out! The first NaN occurred in the parameters of model.layers.layer_2.layers.layer_2! But what if NaN occurs in the reverse pass! Let us define a custom layer and introduce a fake NaN in the backward pass.
julia
using ChainRulesCore, Zygote
+
+const CRC = ChainRulesCore
+
+offending_layer(x) = 2 .* x
Let us define a custom backward pass to introduce some NaNs:
julia
function CRC.rrule(::typeof(offending_layer), x)
+ y = offending_layer(x)
+ function ∇offending_layer(Δ)
+ Δ[1] = NaN
+ return NoTangent(), Δ
+ end
+ return y, ∇offending_layer
+end
And there you go our debug layer prints that the problem is in WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2! Once we fix the pullback of the layer, we will fix the NaNs.
In this manual section, we have discussed tracking down errors in Lux models. We have covered tracking incorrect model specifications and NaNs in forward and backward passes. However, remember that this is an Experimental feature, and there might be edge cases that don't work correctly. If you find any such cases, please open an issue on GitHub!
`,49)]))}const c=a(t,[["render",l]]);export{g as __pageData,c as default};
diff --git a/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.js b/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.js
new file mode 100644
index 0000000000..9627765afb
--- /dev/null
+++ b/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.js
@@ -0,0 +1,61 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Dispatching on Custom Input Types","description":"","frontmatter":{},"headers":[],"relativePath":"manual/dispatch_custom_input.md","filePath":"manual/dispatch_custom_input.md","lastUpdated":null}'),h={name:"manual/dispatch_custom_input.md"};function p(l,s,e,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
Defining a dispatch on (::Layer)(x::MyInputType, ps, st::NamedTuple) is inconvenient, since it requires the user to define a new method for every layer type.
Consider Neural ODEs. In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of Chain.
`,23)]))}const y=i(h,[["render",p]]);export{g as __pageData,y as default};
diff --git a/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.lean.js b/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.lean.js
new file mode 100644
index 0000000000..9627765afb
--- /dev/null
+++ b/previews/PR1023/assets/manual_dispatch_custom_input.md.C-fcTpBz.lean.js
@@ -0,0 +1,61 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Dispatching on Custom Input Types","description":"","frontmatter":{},"headers":[],"relativePath":"manual/dispatch_custom_input.md","filePath":"manual/dispatch_custom_input.md","lastUpdated":null}'),h={name:"manual/dispatch_custom_input.md"};function p(l,s,e,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
Defining a dispatch on (::Layer)(x::MyInputType, ps, st::NamedTuple) is inconvenient, since it requires the user to define a new method for every layer type.
Consider Neural ODEs. In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of Chain.
`,23)]))}const y=i(h,[["render",p]]);export{g as __pageData,y as default};
diff --git a/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.js b/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.js
new file mode 100644
index 0000000000..f4ec00994e
--- /dev/null
+++ b/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.js
@@ -0,0 +1,4 @@
+import{_ as e,c as t,a2 as s,o as a}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Distributed Data Parallel Training","description":"","frontmatter":{},"headers":[],"relativePath":"manual/distributed_utils.md","filePath":"manual/distributed_utils.md","lastUpdated":null}'),n={name:"manual/distributed_utils.md"};function l(r,i,d,o,h,p){return a(),t("div",null,i[0]||(i[0]=[s(`
DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code
Initialize the respective backend with DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e. NCCLBackend and not the object NCCLBackend().
It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.
Next synchronize the parameters and states of the model. This is done by calling DistributedUtils.synchronize!! with the backend and the respective input.
To split the data uniformly across the processes use DistributedUtils.DistributedDataContainer. Alternatively, one can manually split the data. For the provided container to work MLUtils.jl must be installed and loaded.
julia
data = DistributedUtils.DistributedDataContainer(backend, data)
Wrap the optimizer in DistributedUtils.DistributedOptimizer to ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.
Finally change all logging and serialization code to trigger on local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.
We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.
Older (now non-existent) Lux.gpu implementations used to "just work" with FluxMPI.jl. We expect gpu_device to continue working as expected, however, we recommend using gpu_device after calling DistributedUtils.initialize to avoid any mismatch between the device set via DistributedUtils and the device stores in CUDADevice or AMDGPUDevice.
Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.
AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.
`,26)]))}const k=e(n,[["render",l]]);export{c as __pageData,k as default};
diff --git a/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.lean.js b/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.lean.js
new file mode 100644
index 0000000000..f4ec00994e
--- /dev/null
+++ b/previews/PR1023/assets/manual_distributed_utils.md.CYDIHMTo.lean.js
@@ -0,0 +1,4 @@
+import{_ as e,c as t,a2 as s,o as a}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Distributed Data Parallel Training","description":"","frontmatter":{},"headers":[],"relativePath":"manual/distributed_utils.md","filePath":"manual/distributed_utils.md","lastUpdated":null}'),n={name:"manual/distributed_utils.md"};function l(r,i,d,o,h,p){return a(),t("div",null,i[0]||(i[0]=[s(`
DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code
Initialize the respective backend with DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e. NCCLBackend and not the object NCCLBackend().
It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.
Next synchronize the parameters and states of the model. This is done by calling DistributedUtils.synchronize!! with the backend and the respective input.
To split the data uniformly across the processes use DistributedUtils.DistributedDataContainer. Alternatively, one can manually split the data. For the provided container to work MLUtils.jl must be installed and loaded.
julia
data = DistributedUtils.DistributedDataContainer(backend, data)
Wrap the optimizer in DistributedUtils.DistributedOptimizer to ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.
Finally change all logging and serialization code to trigger on local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.
We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.
Older (now non-existent) Lux.gpu implementations used to "just work" with FluxMPI.jl. We expect gpu_device to continue working as expected, however, we recommend using gpu_device after calling DistributedUtils.initialize to avoid any mismatch between the device set via DistributedUtils and the device stores in CUDADevice or AMDGPUDevice.
Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.
AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.
`,26)]))}const k=e(n,[["render",l]]);export{c as __pageData,k as default};
diff --git a/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.js b/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.js
new file mode 100644
index 0000000000..27d88d34c9
--- /dev/null
+++ b/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.js
@@ -0,0 +1,51 @@
+import{_ as i,c as a,a2 as n,o as e}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Freezing Model Parameters","description":"","frontmatter":{},"headers":[],"relativePath":"manual/freezing_model_parameters.md","filePath":"manual/freezing_model_parameters.md","lastUpdated":null}'),h={name:"manual/freezing_model_parameters.md"};function l(t,s,p,k,r,d){return e(),a("div",null,s[0]||(s[0]=[n(`
To freeze a particular kind of layer, let's say Dense in the following example. We can use Lux.Experimental.layer_map and freeze layers if they are of type Dense.
When the function in layer_map is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this would be layer_2.layer_1.
`,16)]))}const y=i(h,[["render",l]]);export{g as __pageData,y as default};
diff --git a/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.lean.js b/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.lean.js
new file mode 100644
index 0000000000..27d88d34c9
--- /dev/null
+++ b/previews/PR1023/assets/manual_freezing_model_parameters.md.0i6pMVyg.lean.js
@@ -0,0 +1,51 @@
+import{_ as i,c as a,a2 as n,o as e}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"Freezing Model Parameters","description":"","frontmatter":{},"headers":[],"relativePath":"manual/freezing_model_parameters.md","filePath":"manual/freezing_model_parameters.md","lastUpdated":null}'),h={name:"manual/freezing_model_parameters.md"};function l(t,s,p,k,r,d){return e(),a("div",null,s[0]||(s[0]=[n(`
To freeze a particular kind of layer, let's say Dense in the following example. We can use Lux.Experimental.layer_map and freeze layers if they are of type Dense.
When the function in layer_map is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this would be layer_2.layer_1.
`,16)]))}const y=i(h,[["render",l]]);export{g as __pageData,y as default};
diff --git a/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.js b/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.js
new file mode 100644
index 0000000000..93afb6f8f3
--- /dev/null
+++ b/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.js
@@ -0,0 +1,20 @@
+import{_ as a,c as i,a2 as e,o as n}from"./chunks/framework.DFwXuivk.js";const r=JSON.parse('{"title":"GPU Management","description":"","frontmatter":{},"headers":[],"relativePath":"manual/gpu_management.md","filePath":"manual/gpu_management.md","lastUpdated":null}'),t={name:"manual/gpu_management.md"};function p(l,s,h,c,d,k){return n(),i("div",null,s[0]||(s[0]=[e(`
Starting from v0.5, Lux has transitioned to a new GPU management system. The old system using cpu and gpu functions is still in place but will be removed in v1. Using the old functions might lead to performance regressions if used inside performance critical code.
Lux.jl can handle multiple GPU backends. Currently, the following backends are supported:
julia
# Important to load trigger packages
+using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
+
+supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management is done by two simple functions: cpu_device and gpu_device.
cpu_device: This is a simple function and just returns a CPUDevice object. @example gpu_management cdev = cpu_device()@example gpu_management x_cpu = randn(Float32, 3, 2)
gpu_device: This function performs automatic GPU device selection and returns an object.
If no GPU is available, it returns a CPUDevice object.
If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use Lux.gpu_backend!(<backend_name>). (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
`,13)]))}const g=a(t,[["render",p]]);export{r as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.lean.js b/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.lean.js
new file mode 100644
index 0000000000..93afb6f8f3
--- /dev/null
+++ b/previews/PR1023/assets/manual_gpu_management.md.DNTD4_pe.lean.js
@@ -0,0 +1,20 @@
+import{_ as a,c as i,a2 as e,o as n}from"./chunks/framework.DFwXuivk.js";const r=JSON.parse('{"title":"GPU Management","description":"","frontmatter":{},"headers":[],"relativePath":"manual/gpu_management.md","filePath":"manual/gpu_management.md","lastUpdated":null}'),t={name:"manual/gpu_management.md"};function p(l,s,h,c,d,k){return n(),i("div",null,s[0]||(s[0]=[e(`
Starting from v0.5, Lux has transitioned to a new GPU management system. The old system using cpu and gpu functions is still in place but will be removed in v1. Using the old functions might lead to performance regressions if used inside performance critical code.
Lux.jl can handle multiple GPU backends. Currently, the following backends are supported:
julia
# Important to load trigger packages
+using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
+
+supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management is done by two simple functions: cpu_device and gpu_device.
cpu_device: This is a simple function and just returns a CPUDevice object. @example gpu_management cdev = cpu_device()@example gpu_management x_cpu = randn(Float32, 3, 2)
gpu_device: This function performs automatic GPU device selection and returns an object.
If no GPU is available, it returns a CPUDevice object.
If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use Lux.gpu_backend!(<backend_name>). (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
`,13)]))}const g=a(t,[["render",p]]);export{r as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_interface.md.D86uezLQ.js b/previews/PR1023/assets/manual_interface.md.D86uezLQ.js
new file mode 100644
index 0000000000..76849093d6
--- /dev/null
+++ b/previews/PR1023/assets/manual_interface.md.D86uezLQ.js
@@ -0,0 +1,91 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const o=JSON.parse('{"title":"Lux Interface","description":"","frontmatter":{},"headers":[],"relativePath":"manual/interface.md","filePath":"manual/interface.md","lastUpdated":null}'),e={name:"manual/interface.md"};function h(l,s,p,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
If you just want to define compatibility with Lux without actually using any of the other functionality provided by Lux (like layers), it is recommended to depend on LuxCore.jl instead of Lux.jl. LuxCore.jl is a significantly lighter dependency.
Following this interface provides the ability for frameworks built on top of Lux to be cross compatible. Additionally, any new functionality built into Lux, will just work for your framework.
@compact macro
While writing out a custom struct and defining dispatches manually is a good way to understand the interface, it is not the most concise way. We recommend using the Lux.@compact macro to define layers which makes handling the states and parameters downright trivial.
If the layer doesn't contain any other Lux layer, then it is a Singular Layer. This means it should optionally subtype Lux.AbstractLuxLayer but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of Dense called Linear.
First, setup the architectural details for this layer. Note, that the architecture doesn't contain any mutable structure like arrays. When in doubt, remember, once constructed a model architecture cannot change.
Tip
For people coming from Flux.jl background, this might be weird. We recommend checking out the Flux to Lux migration guide first before proceeding.
Next, we need to implement functions which return the parameters and states for the layer. In case of Linear, the parameters are weight and bias while the states are empty. States become important when defining layers like BatchNorm, WeightNorm, etc. The recommended data structure for returning parameters is a NamedTuple, though anything satisfying the Parameter Interface is valid.
You could also implement LuxCore.parameterlength and LuxCore.statelength to prevent wasteful reconstruction of the parameters and states.
julia
# This works
+println("Parameter Length: ", LuxCore.parameterlength(l), "; State Length: ",
+ LuxCore.statelength(l))
+
+# But still recommended to define these
+LuxCore.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
+
+LuxCore.statelength(::Linear) = 0
Parameter Length: 12; State Length: 0
No RNG in initialparameters and initialstates
You might notice that we don't pass in a RNG for these functions. If your parameter length and/or state length depend on a random number generator, you should think really hard about what you are trying to do and why.
Now, we need to define how the layer works. For this you make your layer a function with exactly 3 arguments – x the input, ps the parameters, and st the states. This function must return two things – y the output, and st_new the updated state.
julia
function (l::Linear)(x::AbstractMatrix, ps, st::NamedTuple)
+ y = ps.weight * x .+ ps.bias
+ return y, st
+end
Finally, let's run this layer. If you have made this far into the documentation, we don't feel you need a refresher on that.
If your layer comprises of other Lux layers, then it is a Container Layer. Note that you could treat it as a Singular Layer, and it is still fine. FWIW, if you cannot subtype your layer with LuxCore.AbstractLuxContainerLayer then you should go down the Singular Layer route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers.
Wrapper Layer
If you are defining a layer that is a wrapper around another layer, then you should subtype LuxCore.AbstractLuxWrapperLayer instead of LuxCore.AbstractLuxContainerLayer. The only difference from a container layer is that it can wrap a single layer and the parameter/state structure is exactly the same as the wrapped layer.
julia
struct ComposedLinear{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:linear_1, :linear_2)}
+ linear_1::L1
+ linear_2::L2
+end
+
+function (cl::ComposedLinear)(x::AbstractMatrix, ps, st::NamedTuple)
+ # To access the parameters and states for \`linear_1\` we do \`ps.linear_1\` and
+ # \`st.linear_1\`. Similarly for \`linear_2\`
+ y, st_l1 = cl.linear_1(x, ps.linear_1, st.linear_1)
+ y, st_l2 = cl.linear_2(y, ps.linear_2, st.linear_2)
+ # Finally, we need to return the new state which has the exact structure as \`st\`
+ return y, (linear_1 = st_l1, linear_2 = st_l2)
+end
Here, you will notice we have passed (:linear_1, :linear_2) to the supertype. It essentially informs the type that, <obj>.linear_1 and <obj>.linear_2 are Lux layers and we need to construct parameters and states for those. Let's construct these and see:
We accept any parameter type as long as we can fetch the parameters using getproperty(obj, :parameter_name). This allows us to simultaneously support NamedTuples and ComponentArrays. Let us go through a concrete example of what it means. Consider Dense which expects two parameters named weight and bias.
Automatic Differentiation
If you are defining your own parameter type, it is your responsibility to make sure that it works with the AutoDiff System you are using.
julia
using Lux, Random
+
+d = Dense(2, 3)
+rng = Random.default_rng()
+Random.seed!(rng, 0)
+
+ps_default, st = LuxCore.setup(rng, d)
+
+x = randn(rng, Float32, 2, 1)
+
+println("Result with \`NamedTuple\` parameters: ", first(d(x, ps_default, st)))
Result with \`NamedTuple\` parameters: Float32[-0.08713347; -0.4851346; -0.8490221;;]
Let, us define a custom parameter type with fields myweight and mybias but if we try to access weight we get back myweight, similar for bias.
Beware!
This is for demonstrative purposes, don't try this at home!
Result with \`DenseLayerParameters\` parameters: Float32[0.23710957; 0.1003911; -0.57671577;;]
The takeaway from this shouldn't be – lets define weird parameter types. Simply because you can do weird things like this doesn't mean you should, since it only leads to bugs.
Instead this shows the flexibility you have for how your parameters can be structured.
States are always type constrained to be NamedTuple. The structure of the input state must match that of the output state, i.e. keys(st_in) == keys(st_out). This doesn't imply that types of the input and output state match. To generate efficient code, we often do dispatch on the state, for example, Dropout, BatchNorm, etc.
`,42)]))}const g=i(e,[["render",h]]);export{o as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_interface.md.D86uezLQ.lean.js b/previews/PR1023/assets/manual_interface.md.D86uezLQ.lean.js
new file mode 100644
index 0000000000..76849093d6
--- /dev/null
+++ b/previews/PR1023/assets/manual_interface.md.D86uezLQ.lean.js
@@ -0,0 +1,91 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const o=JSON.parse('{"title":"Lux Interface","description":"","frontmatter":{},"headers":[],"relativePath":"manual/interface.md","filePath":"manual/interface.md","lastUpdated":null}'),e={name:"manual/interface.md"};function h(l,s,p,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
If you just want to define compatibility with Lux without actually using any of the other functionality provided by Lux (like layers), it is recommended to depend on LuxCore.jl instead of Lux.jl. LuxCore.jl is a significantly lighter dependency.
Following this interface provides the ability for frameworks built on top of Lux to be cross compatible. Additionally, any new functionality built into Lux, will just work for your framework.
@compact macro
While writing out a custom struct and defining dispatches manually is a good way to understand the interface, it is not the most concise way. We recommend using the Lux.@compact macro to define layers which makes handling the states and parameters downright trivial.
If the layer doesn't contain any other Lux layer, then it is a Singular Layer. This means it should optionally subtype Lux.AbstractLuxLayer but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of Dense called Linear.
First, setup the architectural details for this layer. Note, that the architecture doesn't contain any mutable structure like arrays. When in doubt, remember, once constructed a model architecture cannot change.
Tip
For people coming from Flux.jl background, this might be weird. We recommend checking out the Flux to Lux migration guide first before proceeding.
Next, we need to implement functions which return the parameters and states for the layer. In case of Linear, the parameters are weight and bias while the states are empty. States become important when defining layers like BatchNorm, WeightNorm, etc. The recommended data structure for returning parameters is a NamedTuple, though anything satisfying the Parameter Interface is valid.
You could also implement LuxCore.parameterlength and LuxCore.statelength to prevent wasteful reconstruction of the parameters and states.
julia
# This works
+println("Parameter Length: ", LuxCore.parameterlength(l), "; State Length: ",
+ LuxCore.statelength(l))
+
+# But still recommended to define these
+LuxCore.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
+
+LuxCore.statelength(::Linear) = 0
Parameter Length: 12; State Length: 0
No RNG in initialparameters and initialstates
You might notice that we don't pass in a RNG for these functions. If your parameter length and/or state length depend on a random number generator, you should think really hard about what you are trying to do and why.
Now, we need to define how the layer works. For this you make your layer a function with exactly 3 arguments – x the input, ps the parameters, and st the states. This function must return two things – y the output, and st_new the updated state.
julia
function (l::Linear)(x::AbstractMatrix, ps, st::NamedTuple)
+ y = ps.weight * x .+ ps.bias
+ return y, st
+end
Finally, let's run this layer. If you have made this far into the documentation, we don't feel you need a refresher on that.
If your layer comprises of other Lux layers, then it is a Container Layer. Note that you could treat it as a Singular Layer, and it is still fine. FWIW, if you cannot subtype your layer with LuxCore.AbstractLuxContainerLayer then you should go down the Singular Layer route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers.
Wrapper Layer
If you are defining a layer that is a wrapper around another layer, then you should subtype LuxCore.AbstractLuxWrapperLayer instead of LuxCore.AbstractLuxContainerLayer. The only difference from a container layer is that it can wrap a single layer and the parameter/state structure is exactly the same as the wrapped layer.
julia
struct ComposedLinear{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:linear_1, :linear_2)}
+ linear_1::L1
+ linear_2::L2
+end
+
+function (cl::ComposedLinear)(x::AbstractMatrix, ps, st::NamedTuple)
+ # To access the parameters and states for \`linear_1\` we do \`ps.linear_1\` and
+ # \`st.linear_1\`. Similarly for \`linear_2\`
+ y, st_l1 = cl.linear_1(x, ps.linear_1, st.linear_1)
+ y, st_l2 = cl.linear_2(y, ps.linear_2, st.linear_2)
+ # Finally, we need to return the new state which has the exact structure as \`st\`
+ return y, (linear_1 = st_l1, linear_2 = st_l2)
+end
Here, you will notice we have passed (:linear_1, :linear_2) to the supertype. It essentially informs the type that, <obj>.linear_1 and <obj>.linear_2 are Lux layers and we need to construct parameters and states for those. Let's construct these and see:
We accept any parameter type as long as we can fetch the parameters using getproperty(obj, :parameter_name). This allows us to simultaneously support NamedTuples and ComponentArrays. Let us go through a concrete example of what it means. Consider Dense which expects two parameters named weight and bias.
Automatic Differentiation
If you are defining your own parameter type, it is your responsibility to make sure that it works with the AutoDiff System you are using.
julia
using Lux, Random
+
+d = Dense(2, 3)
+rng = Random.default_rng()
+Random.seed!(rng, 0)
+
+ps_default, st = LuxCore.setup(rng, d)
+
+x = randn(rng, Float32, 2, 1)
+
+println("Result with \`NamedTuple\` parameters: ", first(d(x, ps_default, st)))
Result with \`NamedTuple\` parameters: Float32[-0.08713347; -0.4851346; -0.8490221;;]
Let, us define a custom parameter type with fields myweight and mybias but if we try to access weight we get back myweight, similar for bias.
Beware!
This is for demonstrative purposes, don't try this at home!
Result with \`DenseLayerParameters\` parameters: Float32[0.23710957; 0.1003911; -0.57671577;;]
The takeaway from this shouldn't be – lets define weird parameter types. Simply because you can do weird things like this doesn't mean you should, since it only leads to bugs.
Instead this shows the flexibility you have for how your parameters can be structured.
States are always type constrained to be NamedTuple. The structure of the input state must match that of the output state, i.e. keys(st_in) == keys(st_out). This doesn't imply that types of the input and output state match. To generate efficient code, we often do dispatch on the state, for example, Dropout, BatchNorm, etc.
`,42)]))}const g=i(e,[["render",h]]);export{o as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.js b/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.js
new file mode 100644
index 0000000000..874d628ea3
--- /dev/null
+++ b/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.js
@@ -0,0 +1,75 @@
+import{_ as e,c as a,a2 as l,j as s,a as n,o as t}from"./chunks/framework.DFwXuivk.js";const x=JSON.parse('{"title":"Migrating from Flux to Lux","description":"","frontmatter":{},"headers":[],"relativePath":"manual/migrate_from_flux.md","filePath":"manual/migrate_from_flux.md","lastUpdated":null}'),h={name:"manual/migrate_from_flux.md"},p={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},k={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.025ex"},xmlns:"http://www.w3.org/2000/svg",width:"10.24ex",height:"1.645ex",role:"img",focusable:"false",viewBox:"0 -716 4525.9 727","aria-hidden":"true"},r={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},d={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"0"},xmlns:"http://www.w3.org/2000/svg",width:"1.697ex",height:"1.62ex",role:"img",focusable:"false",viewBox:"0 -716 750 716","aria-hidden":"true"},E={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},g={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"0"},xmlns:"http://www.w3.org/2000/svg",width:"1.717ex",height:"1.545ex",role:"img",focusable:"false",viewBox:"0 -683 759 683","aria-hidden":"true"};function o(y,i,c,u,F,m){return t(),a("div",null,[i[10]||(i[10]=l(`
For the core library layers like Dense, Conv, etc. we have intentionally kept the API very similar to Flux. In most cases, replacing using Flux with using Lux should be enough to get you started. We cover the additional changes that you will have to make in the following example.
Flux and Lux operate under extremely different design philosophies regarding how layers should be implemented. A summary of the differences would be:
Flux stores everything in a single struct and relies on Functors.@functor and Flux.trainable to distinguish between trainable and non-trainable parameters.
Lux relies on the user to define Lux.initialparameters and Lux.initialstates to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, hence device transfer utilities like gpu_device, cpu_device, etc. cannot be applied on Lux layers, instead they need to be applied on the parameters and states.
using Lux, Random, NNlib, Zygote
+
+struct LuxLinear <: Lux.AbstractLuxLayer
+ init_A
+ init_B
+end
+
+function LuxLinear(A::AbstractArray, B::AbstractArray)
+ # Storing Arrays or any mutable structure inside a Lux Layer is not recommended
+ # instead we will convert this to a function to perform lazy initialization
+ return LuxLinear(() -> copy(A), () -> copy(B))
+end
+
+# \`B\` is a parameter
+Lux.initialparameters(::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)
+
+# \`A\` is a state
+Lux.initialstates(::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)
+
+(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st
julia
using Flux, Random, NNlib, Zygote, Optimisers
+
+struct FluxLinear
+ A
+ B
+end
+
+
+
+
+
+
+
+# \`A\` is not trainable
+Optimisers.trainable(f::FluxLinear) = (B=f.B,)
+
+# Needed so that both \`A\` and \`B\` can be transferred between devices
+Flux.@functor FluxLinear
+
+(l::FluxLinear)(x) = l.A * l.B * x
Flux supports a mode called :auto which automatically decides if the user is training the model or running inference. This is the default mode for Flux.BatchNorm, Flux.GroupNorm, Flux.Dropout, etc. Lux doesn't support this mode (specifically to keep code simple and do exactly what the user wants), hence our default mode is training. This can be changed using Lux.testmode.
If you have Flux loaded in your code, you can use the function FromFluxAdaptor to automatically convert your model to Lux. Note that in case a native Lux counterpart isn't available, we fallback to using Optimisers.destructure.
`,10))])}const Q=e(h,[["render",o]]);export{x as __pageData,Q as default};
diff --git a/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.lean.js b/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.lean.js
new file mode 100644
index 0000000000..874d628ea3
--- /dev/null
+++ b/previews/PR1023/assets/manual_migrate_from_flux.md.UjjC2Rl6.lean.js
@@ -0,0 +1,75 @@
+import{_ as e,c as a,a2 as l,j as s,a as n,o as t}from"./chunks/framework.DFwXuivk.js";const x=JSON.parse('{"title":"Migrating from Flux to Lux","description":"","frontmatter":{},"headers":[],"relativePath":"manual/migrate_from_flux.md","filePath":"manual/migrate_from_flux.md","lastUpdated":null}'),h={name:"manual/migrate_from_flux.md"},p={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},k={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.025ex"},xmlns:"http://www.w3.org/2000/svg",width:"10.24ex",height:"1.645ex",role:"img",focusable:"false",viewBox:"0 -716 4525.9 727","aria-hidden":"true"},r={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},d={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"0"},xmlns:"http://www.w3.org/2000/svg",width:"1.697ex",height:"1.62ex",role:"img",focusable:"false",viewBox:"0 -716 750 716","aria-hidden":"true"},E={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},g={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"0"},xmlns:"http://www.w3.org/2000/svg",width:"1.717ex",height:"1.545ex",role:"img",focusable:"false",viewBox:"0 -683 759 683","aria-hidden":"true"};function o(y,i,c,u,F,m){return t(),a("div",null,[i[10]||(i[10]=l(`
For the core library layers like Dense, Conv, etc. we have intentionally kept the API very similar to Flux. In most cases, replacing using Flux with using Lux should be enough to get you started. We cover the additional changes that you will have to make in the following example.
Flux and Lux operate under extremely different design philosophies regarding how layers should be implemented. A summary of the differences would be:
Flux stores everything in a single struct and relies on Functors.@functor and Flux.trainable to distinguish between trainable and non-trainable parameters.
Lux relies on the user to define Lux.initialparameters and Lux.initialstates to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, hence device transfer utilities like gpu_device, cpu_device, etc. cannot be applied on Lux layers, instead they need to be applied on the parameters and states.
using Lux, Random, NNlib, Zygote
+
+struct LuxLinear <: Lux.AbstractLuxLayer
+ init_A
+ init_B
+end
+
+function LuxLinear(A::AbstractArray, B::AbstractArray)
+ # Storing Arrays or any mutable structure inside a Lux Layer is not recommended
+ # instead we will convert this to a function to perform lazy initialization
+ return LuxLinear(() -> copy(A), () -> copy(B))
+end
+
+# \`B\` is a parameter
+Lux.initialparameters(::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)
+
+# \`A\` is a state
+Lux.initialstates(::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)
+
+(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st
julia
using Flux, Random, NNlib, Zygote, Optimisers
+
+struct FluxLinear
+ A
+ B
+end
+
+
+
+
+
+
+
+# \`A\` is not trainable
+Optimisers.trainable(f::FluxLinear) = (B=f.B,)
+
+# Needed so that both \`A\` and \`B\` can be transferred between devices
+Flux.@functor FluxLinear
+
+(l::FluxLinear)(x) = l.A * l.B * x
Flux supports a mode called :auto which automatically decides if the user is training the model or running inference. This is the default mode for Flux.BatchNorm, Flux.GroupNorm, Flux.Dropout, etc. Lux doesn't support this mode (specifically to keep code simple and do exactly what the user wants), hence our default mode is training. This can be changed using Lux.testmode.
If you have Flux loaded in your code, you can use the function FromFluxAdaptor to automatically convert your model to Lux. Note that in case a native Lux counterpart isn't available, we fallback to using Optimisers.destructure.
`,10))])}const Q=e(h,[["render",o]]);export{x as __pageData,Q as default};
diff --git a/previews/PR1023/assets/manual_nested_autodiff.md.BllOXIU5.js b/previews/PR1023/assets/manual_nested_autodiff.md.BllOXIU5.js
new file mode 100644
index 0000000000..87e3ab6ab9
--- /dev/null
+++ b/previews/PR1023/assets/manual_nested_autodiff.md.BllOXIU5.js
@@ -0,0 +1,119 @@
+import{_ as e,c as n,a2 as t,j as s,a,o as l}from"./chunks/framework.DFwXuivk.js";const j=JSON.parse('{"title":"Nested Automatic Differentiation","description":"","frontmatter":{},"headers":[],"relativePath":"manual/nested_autodiff.md","filePath":"manual/nested_autodiff.md","lastUpdated":null}'),h={name:"manual/nested_autodiff.md"},p={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},k={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.09ex"},xmlns:"http://www.w3.org/2000/svg",width:"10.178ex",height:"2.004ex",role:"img",focusable:"false",viewBox:"0 -846 4498.7 886","aria-hidden":"true"},d={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},r={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.09ex"},xmlns:"http://www.w3.org/2000/svg",width:"7.009ex",height:"2.004ex",role:"img",focusable:"false",viewBox:"0 -846 3098 886","aria-hidden":"true"},o={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},g={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.791ex"},xmlns:"http://www.w3.org/2000/svg",width:"11.439ex",height:"2.713ex",role:"img",focusable:"false",viewBox:"0 -849.5 5056 1199","aria-hidden":"true"},Q={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},E={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-2.819ex"},xmlns:"http://www.w3.org/2000/svg",width:"33.692ex",height:"6.74ex",role:"img",focusable:"false",viewBox:"0 -1733 14891.7 2978.9","aria-hidden":"true"},T={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},c={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.09ex"},xmlns:"http://www.w3.org/2000/svg",width:"9.913ex",height:"2.004ex",role:"img",focusable:"false",viewBox:"0 -846 4381.7 886","aria-hidden":"true"},y={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},m={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-2.819ex"},xmlns:"http://www.w3.org/2000/svg",width:"21.167ex",height:"6.74ex",role:"img",focusable:"false",viewBox:"0 -1733 9355.6 2978.9","aria-hidden":"true"},u={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},F={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.661ex"},xmlns:"http://www.w3.org/2000/svg",width:"3.843ex",height:"2.565ex",role:"img",focusable:"false",viewBox:"0 -841.7 1698.8 1133.9","aria-hidden":"true"},C={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},f={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.357ex"},xmlns:"http://www.w3.org/2000/svg",width:"3.269ex",height:"1.902ex",role:"img",focusable:"false",viewBox:"0 -683 1445 840.8","aria-hidden":"true"},b={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},x={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.357ex"},xmlns:"http://www.w3.org/2000/svg",width:"1.837ex",height:"1.359ex",role:"img",focusable:"false",viewBox:"0 -443 812 600.8","aria-hidden":"true"};function v(w,i,H,D,B,L){return l(),n("div",null,[i[32]||(i[32]=t(`
This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.
In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.
Tip
Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching Preference to false.
Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.
Let's explore this using some questions that were posted on the Julia Discourse forum.
This problem comes from @facusapienza on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.
julia
function loss_function1(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use \`Zygote.jacobian\` as well but ForwardDiff tends to be more efficient here
+ J = ForwardDiff.jacobian(smodel, x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+# Using Batchnorm to show that it is possible
+model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+loss_function1(model, x, ps, st, y)
14.883664f0
So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):
┌ Warning: \`training\` is set to \`Val{true}()\` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a \`Lux.jl\` model, set it to inference (test) mode using \`LuxCore.testmode\`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+┌ Warning: \`training\` is set to \`Val{true}()\` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a \`Lux.jl\` model, set it to inference (test) mode using \`LuxCore.testmode\`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+∞-norm(∂x - ∂x_fd): 0.00046014786
+∞-norm(∂ps - ∂ps_fd): 0.00068473816
That's pretty good, of course you will have some error from the finite differences calculation.
Notice that in this example the Jacobian J consists on the full matrix of derivatives of smodel with respect the different inputs in x. In many cases, we are interested in computing the Jacobian with respect to each input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achieved with batched_jacobian to parse the calculation of the Jacobian per each single input. Using the same example from the previous section:
julia
model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+function loss_function_batched(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use \`AutoZygote()\` as well but \`AutoForwardDiff()\` tends to be more efficient here
+ J = batched_jacobian(smodel, AutoForwardDiff(), x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+loss_function_batched(model, x, ps, st, y)
11.380777f0
Notice that in this last example we removed BatchNorm() from the neural network. This is done so outputs corresponding to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network. We can now verify again the value of the Jacobian:
In this example, it is important to remark that now batched_jacobian returns a 3D array with the Jacobian calculation for each independent input value in x.
Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative but we will use ForwardDiff.gradient instead, as we typically deal with array inputs and outputs.
Loss Function computing the Jacobian of the Parameters
The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.
We resolve these setups by using the Base.Fix1 wrapper around the stateful layer and fixing the input to the stateful layer.
julia
function loss_function3(model, x, ps, st)
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
+ return sum(abs2, J)
+end
+
+model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
+ Dense(12 => 1))
+ps, st = Lux.setup(StableRNG(0), model)
+ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
+x = rand(StableRNG(0), Float32, 1, 16)
`,51)),s("p",null,[i[6]||(i[6]=a("Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of ")),i[7]||(i[7]=s("a",{href:"https://www.nowozin.net/sebastian/blog/thoughts-on-trace-estimation-in-deep-learning.html",target:"_blank",rel:"noreferrer"},"Hutchinson 1990",-1)),i[8]||(i[8]=a(" which computes the estimated trace of a matrix ")),s("mjx-container",p,[(l(),n("svg",k,i[0]||(i[0]=[t('',1)]))),i[1]||(i[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"A"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D"),s("mo",null,"×"),s("mi",null,"D")])])])],-1))]),i[9]||(i[9]=a(" using random vectors ")),s("mjx-container",d,[(l(),n("svg",r,i[2]||(i[2]=[t('',1)]))),i[3]||(i[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"v"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D")])])])],-1))]),i[10]||(i[10]=a(" s.t. ")),s("mjx-container",o,[(l(),n("svg",g,i[4]||(i[4]=[t('',1)]))),i[5]||(i[5]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"E")]),s("mrow",{"data-mjx-texclass":"INNER"},[s("mo",{"data-mjx-texclass":"OPEN"},"["),s("mi",null,"v"),s("msup",null,[s("mi",null,"v"),s("mi",null,"T")]),s("mo",{"data-mjx-texclass":"CLOSE"},"]")]),s("mo",null,"="),s("mi",null,"I")])],-1))]),i[11]||(i[11]=a("."))]),s("mjx-container",Q,[(l(),n("svg",E,i[12]||(i[12]=[t('',1)]))),i[13]||(i[13]=s("mjx-assistive-mml",{unselectable:"on",display:"block",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",overflow:"hidden",width:"100%"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML",display:"block"},[s("mtext",null,"Tr"),s("mo",{stretchy:"false"},"("),s("mi",null,"A"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"E")]),s("mrow",{"data-mjx-texclass":"INNER"},[s("mo",{"data-mjx-texclass":"OPEN"},"["),s("msup",null,[s("mi",null,"v"),s("mi",null,"T")]),s("mi",null,"A"),s("mi",null,"v"),s("mo",{"data-mjx-texclass":"CLOSE"},"]")]),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mi",null,"V")]),s("munderover",null,[s("mo",{"data-mjx-texclass":"OP"},"∑"),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"i"),s("mo",null,"="),s("mn",null,"1")]),s("mi",null,"V")]),s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"A"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),s("p",null,[i[16]||(i[16]=a("We can use this to compute the trace of a Jacobian Matrix ")),s("mjx-container",T,[(l(),n("svg",c,i[14]||(i[14]=[t('',1)]))),i[15]||(i[15]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"J"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D"),s("mo",null,"×"),s("mi",null,"D")])])])],-1))]),i[17]||(i[17]=a(" using the following algorithm:"))]),s("mjx-container",y,[(l(),n("svg",m,i[18]||(i[18]=[t('',1)]))),i[19]||(i[19]=s("mjx-assistive-mml",{unselectable:"on",display:"block",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",overflow:"hidden",width:"100%"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML",display:"block"},[s("mtext",null,"Tr"),s("mo",{stretchy:"false"},"("),s("mi",null,"J"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mi",null,"V")]),s("munderover",null,[s("mo",{"data-mjx-texclass":"OP"},"∑"),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"i"),s("mo",null,"="),s("mn",null,"1")]),s("mi",null,"V")]),s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"J"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[33]||(i[33]=s("p",null,"Note that we can compute this using two methods:",-1)),s("ol",null,[s("li",null,[s("p",null,[i[22]||(i[22]=a("Compute ")),s("mjx-container",u,[(l(),n("svg",F,i[20]||(i[20]=[t('',1)]))),i[21]||(i[21]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"J")])],-1))]),i[23]||(i[23]=a(" using a Vector-Jacobian product and then do a matrix-vector product to get the trace."))])]),s("li",null,[s("p",null,[i[26]||(i[26]=a("Compute ")),s("mjx-container",C,[(l(),n("svg",f,i[24]||(i[24]=[t('',1)]))),i[25]||(i[25]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"J"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[27]||(i[27]=a(" using a Jacobian-Vector product and then do a matrix-vector product to get the trace."))])])]),s("p",null,[i[30]||(i[30]=a("For simplicity, we will use a single sample of ")),s("mjx-container",b,[(l(),n("svg",x,i[28]||(i[28]=[t('',1)]))),i[29]||(i[29]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[31]||(i[31]=a(" to compute the trace. Additionally, we will fix the sample to ensure that our tests against the finite difference implementation are not affected by the randomness in the sample."))]),i[34]||(i[34]=t(`
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
+tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
+tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
+println("Tr(J) using vjp: ", tr_vjp)
+println("Tr(J) using jvp: ", tr_jvp)
+println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 4.9127817
+Tr(J) using jvp: 4.912782
+Tr(J) using full jacobian: 4.912781
Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.
This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.
In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.
Tip
Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching Preference to false.
Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.
Let's explore this using some questions that were posted on the Julia Discourse forum.
This problem comes from @facusapienza on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.
julia
function loss_function1(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use \`Zygote.jacobian\` as well but ForwardDiff tends to be more efficient here
+ J = ForwardDiff.jacobian(smodel, x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+# Using Batchnorm to show that it is possible
+model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+loss_function1(model, x, ps, st, y)
14.883664f0
So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):
┌ Warning: \`training\` is set to \`Val{true}()\` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a \`Lux.jl\` model, set it to inference (test) mode using \`LuxCore.testmode\`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+┌ Warning: \`training\` is set to \`Val{true}()\` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a \`Lux.jl\` model, set it to inference (test) mode using \`LuxCore.testmode\`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+∞-norm(∂x - ∂x_fd): 0.00046014786
+∞-norm(∂ps - ∂ps_fd): 0.00068473816
That's pretty good, of course you will have some error from the finite differences calculation.
Notice that in this example the Jacobian J consists on the full matrix of derivatives of smodel with respect the different inputs in x. In many cases, we are interested in computing the Jacobian with respect to each input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achieved with batched_jacobian to parse the calculation of the Jacobian per each single input. Using the same example from the previous section:
julia
model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+function loss_function_batched(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use \`AutoZygote()\` as well but \`AutoForwardDiff()\` tends to be more efficient here
+ J = batched_jacobian(smodel, AutoForwardDiff(), x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+loss_function_batched(model, x, ps, st, y)
11.380777f0
Notice that in this last example we removed BatchNorm() from the neural network. This is done so outputs corresponding to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network. We can now verify again the value of the Jacobian:
In this example, it is important to remark that now batched_jacobian returns a 3D array with the Jacobian calculation for each independent input value in x.
Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative but we will use ForwardDiff.gradient instead, as we typically deal with array inputs and outputs.
Loss Function computing the Jacobian of the Parameters
The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.
We resolve these setups by using the Base.Fix1 wrapper around the stateful layer and fixing the input to the stateful layer.
julia
function loss_function3(model, x, ps, st)
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
+ return sum(abs2, J)
+end
+
+model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
+ Dense(12 => 1))
+ps, st = Lux.setup(StableRNG(0), model)
+ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
+x = rand(StableRNG(0), Float32, 1, 16)
`,51)),s("p",null,[i[6]||(i[6]=a("Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of ")),i[7]||(i[7]=s("a",{href:"https://www.nowozin.net/sebastian/blog/thoughts-on-trace-estimation-in-deep-learning.html",target:"_blank",rel:"noreferrer"},"Hutchinson 1990",-1)),i[8]||(i[8]=a(" which computes the estimated trace of a matrix ")),s("mjx-container",p,[(l(),n("svg",k,i[0]||(i[0]=[t('',1)]))),i[1]||(i[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"A"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D"),s("mo",null,"×"),s("mi",null,"D")])])])],-1))]),i[9]||(i[9]=a(" using random vectors ")),s("mjx-container",d,[(l(),n("svg",r,i[2]||(i[2]=[t('',1)]))),i[3]||(i[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"v"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D")])])])],-1))]),i[10]||(i[10]=a(" s.t. ")),s("mjx-container",o,[(l(),n("svg",g,i[4]||(i[4]=[t('',1)]))),i[5]||(i[5]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"E")]),s("mrow",{"data-mjx-texclass":"INNER"},[s("mo",{"data-mjx-texclass":"OPEN"},"["),s("mi",null,"v"),s("msup",null,[s("mi",null,"v"),s("mi",null,"T")]),s("mo",{"data-mjx-texclass":"CLOSE"},"]")]),s("mo",null,"="),s("mi",null,"I")])],-1))]),i[11]||(i[11]=a("."))]),s("mjx-container",Q,[(l(),n("svg",E,i[12]||(i[12]=[t('',1)]))),i[13]||(i[13]=s("mjx-assistive-mml",{unselectable:"on",display:"block",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",overflow:"hidden",width:"100%"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML",display:"block"},[s("mtext",null,"Tr"),s("mo",{stretchy:"false"},"("),s("mi",null,"A"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"E")]),s("mrow",{"data-mjx-texclass":"INNER"},[s("mo",{"data-mjx-texclass":"OPEN"},"["),s("msup",null,[s("mi",null,"v"),s("mi",null,"T")]),s("mi",null,"A"),s("mi",null,"v"),s("mo",{"data-mjx-texclass":"CLOSE"},"]")]),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mi",null,"V")]),s("munderover",null,[s("mo",{"data-mjx-texclass":"OP"},"∑"),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"i"),s("mo",null,"="),s("mn",null,"1")]),s("mi",null,"V")]),s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"A"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),s("p",null,[i[16]||(i[16]=a("We can use this to compute the trace of a Jacobian Matrix ")),s("mjx-container",T,[(l(),n("svg",c,i[14]||(i[14]=[t('',1)]))),i[15]||(i[15]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"J"),s("mo",null,"∈"),s("msup",null,[s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",{mathvariant:"double-struck"},"R")]),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"D"),s("mo",null,"×"),s("mi",null,"D")])])])],-1))]),i[17]||(i[17]=a(" using the following algorithm:"))]),s("mjx-container",y,[(l(),n("svg",m,i[18]||(i[18]=[t('',1)]))),i[19]||(i[19]=s("mjx-assistive-mml",{unselectable:"on",display:"block",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",overflow:"hidden",width:"100%"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML",display:"block"},[s("mtext",null,"Tr"),s("mo",{stretchy:"false"},"("),s("mi",null,"J"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mi",null,"V")]),s("munderover",null,[s("mo",{"data-mjx-texclass":"OP"},"∑"),s("mrow",{"data-mjx-texclass":"ORD"},[s("mi",null,"i"),s("mo",null,"="),s("mn",null,"1")]),s("mi",null,"V")]),s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"J"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[33]||(i[33]=s("p",null,"Note that we can compute this using two methods:",-1)),s("ol",null,[s("li",null,[s("p",null,[i[22]||(i[22]=a("Compute ")),s("mjx-container",u,[(l(),n("svg",F,i[20]||(i[20]=[t('',1)]))),i[21]||(i[21]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("msubsup",null,[s("mi",null,"v"),s("mi",null,"i"),s("mi",null,"T")]),s("mi",null,"J")])],-1))]),i[23]||(i[23]=a(" using a Vector-Jacobian product and then do a matrix-vector product to get the trace."))])]),s("li",null,[s("p",null,[i[26]||(i[26]=a("Compute ")),s("mjx-container",C,[(l(),n("svg",f,i[24]||(i[24]=[t('',1)]))),i[25]||(i[25]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"J"),s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[27]||(i[27]=a(" using a Jacobian-Vector product and then do a matrix-vector product to get the trace."))])])]),s("p",null,[i[30]||(i[30]=a("For simplicity, we will use a single sample of ")),s("mjx-container",b,[(l(),n("svg",x,i[28]||(i[28]=[t('',1)]))),i[29]||(i[29]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("msub",null,[s("mi",null,"v"),s("mi",null,"i")])])],-1))]),i[31]||(i[31]=a(" to compute the trace. Additionally, we will fix the sample to ensure that our tests against the finite difference implementation are not affected by the randomness in the sample."))]),i[34]||(i[34]=t(`
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
+tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
+tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
+println("Tr(J) using vjp: ", tr_vjp)
+println("Tr(J) using jvp: ", tr_jvp)
+println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 4.9127817
+Tr(J) using jvp: 4.912782
+Tr(J) using full jacobian: 4.912781
Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.
For sanity check, let's verify that the gradients are the same:
julia
println("∞-norm(∂x using vjp): ", norm(∂x_vjp .- ∂x_jvp, Inf))
+println("∞-norm(∂ps using vjp): ",
+ norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf))
+println("∞-norm(∂x using full jacobian): ", norm(∂x_full_jacobian .- ∂x_vjp, Inf))
+println("∞-norm(∂ps using full jacobian): ",
+ norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf))
∞-norm(∂x using vjp): 0.0
+∞-norm(∂ps using vjp): 0.0
+∞-norm(∂x using full jacobian): 9.536743e-7
+∞-norm(∂ps using full jacobian): 1.4305115e-6
`,20))])}const _=e(h,[["render",v]]);export{j as __pageData,_ as default};
diff --git a/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.js b/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.js
new file mode 100644
index 0000000000..2ae34d5b5f
--- /dev/null
+++ b/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.js
@@ -0,0 +1,103 @@
+import{_ as a,c as i,a2 as n,o as p}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"Neural Networks Inside GPU Kernels","description":"","frontmatter":{},"headers":[],"relativePath":"manual/nn_inside_gpu_kernels.md","filePath":"manual/nn_inside_gpu_kernels.md","lastUpdated":null}'),l={name:"manual/nn_inside_gpu_kernels.md"};function e(t,s,h,k,r,d){return p(),i("div",null,s[0]||(s[0]=[n(`
In this page, we will describe how to embed neural networks inside GPU kernels. We will use KernelAbstractions.jl to do this, making it compatible with multiple GPU backends.
Experimental Feature
This is a relatively new and experimental feature. Expect edge cases and open issues on GitHub if you find any.
Inference Only
Currently this works only for inference. We will eventually test automatic differentiation using Enzyme.jl
Batching
In most usecases, this form of batching via embedding the neural network inside a GPU kernel is not recommended and will lead to suboptimal performance. Instead, batch the input data and let Lux handle the batching internally.
julia
using Lux, LuxCUDA, Random
+using KernelAbstractions, StaticArrays
First thing to remember is that we can't use regular high-level operations inside the kernels, instead we will use Static Arrays. Leveraging Julia's multiple dispatch Lux will use specialized operations that are compatible with GPU kernels.
julia
@kernel function nn_eval_single_batch!(output, model, input, ps, st)
+ i = @index(Global, Linear)
+ y, st_ = Lux.apply(model, input[i], ps, st)
+ output[i] = y
+end
nn_eval_single_batch! (generic function with 4 methods)
We define and initialize the neural network as usual, but we need to additionally convert the Arrays into SArrays.
`,24)]))}const g=a(l,[["render",e]]);export{E as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.lean.js b/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.lean.js
new file mode 100644
index 0000000000..2ae34d5b5f
--- /dev/null
+++ b/previews/PR1023/assets/manual_nn_inside_gpu_kernels.md.egqCePx_.lean.js
@@ -0,0 +1,103 @@
+import{_ as a,c as i,a2 as n,o as p}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"Neural Networks Inside GPU Kernels","description":"","frontmatter":{},"headers":[],"relativePath":"manual/nn_inside_gpu_kernels.md","filePath":"manual/nn_inside_gpu_kernels.md","lastUpdated":null}'),l={name:"manual/nn_inside_gpu_kernels.md"};function e(t,s,h,k,r,d){return p(),i("div",null,s[0]||(s[0]=[n(`
In this page, we will describe how to embed neural networks inside GPU kernels. We will use KernelAbstractions.jl to do this, making it compatible with multiple GPU backends.
Experimental Feature
This is a relatively new and experimental feature. Expect edge cases and open issues on GitHub if you find any.
Inference Only
Currently this works only for inference. We will eventually test automatic differentiation using Enzyme.jl
Batching
In most usecases, this form of batching via embedding the neural network inside a GPU kernel is not recommended and will lead to suboptimal performance. Instead, batch the input data and let Lux handle the batching internally.
julia
using Lux, LuxCUDA, Random
+using KernelAbstractions, StaticArrays
First thing to remember is that we can't use regular high-level operations inside the kernels, instead we will use Static Arrays. Leveraging Julia's multiple dispatch Lux will use specialized operations that are compatible with GPU kernels.
julia
@kernel function nn_eval_single_batch!(output, model, input, ps, st)
+ i = @index(Global, Linear)
+ y, st_ = Lux.apply(model, input[i], ps, st)
+ output[i] = y
+end
nn_eval_single_batch! (generic function with 4 methods)
We define and initialize the neural network as usual, but we need to additionally convert the Arrays into SArrays.
`,24)]))}const g=a(l,[["render",e]]);export{E as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.js b/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.js
new file mode 100644
index 0000000000..6e3474163b
--- /dev/null
+++ b/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.js
@@ -0,0 +1,28 @@
+import{_ as s,c as a,a2 as e,o as t}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Performance Pitfalls & How to Catch Them","description":"","frontmatter":{},"headers":[],"relativePath":"manual/performance_pitfalls.md","filePath":"manual/performance_pitfalls.md","lastUpdated":null}'),n={name:"manual/performance_pitfalls.md"};function l(p,i,h,r,o,d){return t(),a("div",null,i[0]||(i[0]=[e(`
Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:
julia
using Lux, Random
+
+rng = Xoshiro(0)
+
+model = Dense(2 => 2, gelu)
+ps, st = Lux.setup(rng, model)
+Lux.recursive_eltype((ps, st))
Float32
As we can see that ps and st are structures with the highest precision being Float32. Now let's run the model using some random data:
julia
x = rand(rng, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float64
Oops our output became Float64. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x was Float64. Instead, we should have used Float32 input:
julia
x = rand(rng, Float32, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float32
This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:
Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use match_eltype and the eltype_mismatch_handling preference.
When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:
julia
using GPUArraysCore
+GPUArraysCore.allowscalar(false)
Lux.jl is integrated with DispatchDoctor.jl to catch type instabilities. You can easily enable it by setting the instability_check preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!.
For faster performance on CPUs load the following packages:
LoopVectorization.jl
Octavian.jl
If these are available, we automatically use optimized versions of the layers. Though there are cases where this might be an issue (see #980 and disabling loop vectorization).
A common pattern for loading data and transferring data to GPUs looks like this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in dataloader
+ X = X |> gdev
+ y = y |> gdev
+ # ...
+ # do some computation
+ # ...
+end
This is typically fast enough, but the data transfer to the device is happening in main process, not exploiting the parallelism in the dataloader. Instead, we can do this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in gdev(dataloader)
+ # ...
+ # do some computation
+ # ...
+end
Here, X and y are on the gpu device gdev and the data transfer happens in the worker processes. Additionally, it behaves similar to CuIterator from CUDA.jl and eagerly frees the data after every iteration (this is device agnostic and works on all supported GPU backends).
`,33)]))}const g=s(n,[["render",l]]);export{c as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.lean.js b/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.lean.js
new file mode 100644
index 0000000000..6e3474163b
--- /dev/null
+++ b/previews/PR1023/assets/manual_performance_pitfalls.md.s_yX8yPE.lean.js
@@ -0,0 +1,28 @@
+import{_ as s,c as a,a2 as e,o as t}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Performance Pitfalls & How to Catch Them","description":"","frontmatter":{},"headers":[],"relativePath":"manual/performance_pitfalls.md","filePath":"manual/performance_pitfalls.md","lastUpdated":null}'),n={name:"manual/performance_pitfalls.md"};function l(p,i,h,r,o,d){return t(),a("div",null,i[0]||(i[0]=[e(`
Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:
julia
using Lux, Random
+
+rng = Xoshiro(0)
+
+model = Dense(2 => 2, gelu)
+ps, st = Lux.setup(rng, model)
+Lux.recursive_eltype((ps, st))
Float32
As we can see that ps and st are structures with the highest precision being Float32. Now let's run the model using some random data:
julia
x = rand(rng, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float64
Oops our output became Float64. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x was Float64. Instead, we should have used Float32 input:
julia
x = rand(rng, Float32, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float32
This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:
Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use match_eltype and the eltype_mismatch_handling preference.
When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:
julia
using GPUArraysCore
+GPUArraysCore.allowscalar(false)
Lux.jl is integrated with DispatchDoctor.jl to catch type instabilities. You can easily enable it by setting the instability_check preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!.
For faster performance on CPUs load the following packages:
LoopVectorization.jl
Octavian.jl
If these are available, we automatically use optimized versions of the layers. Though there are cases where this might be an issue (see #980 and disabling loop vectorization).
A common pattern for loading data and transferring data to GPUs looks like this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in dataloader
+ X = X |> gdev
+ y = y |> gdev
+ # ...
+ # do some computation
+ # ...
+end
This is typically fast enough, but the data transfer to the device is happening in main process, not exploiting the parallelism in the dataloader. Instead, we can do this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in gdev(dataloader)
+ # ...
+ # do some computation
+ # ...
+end
Here, X and y are on the gpu device gdev and the data transfer happens in the worker processes. Additionally, it behaves similar to CuIterator from CUDA.jl and eagerly frees the data after every iteration (this is device agnostic and works on all supported GPU backends).
`,33)]))}const g=s(n,[["render",l]]);export{c as __pageData,g as default};
diff --git a/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.js b/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.js
new file mode 100644
index 0000000000..dff7287856
--- /dev/null
+++ b/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.js
@@ -0,0 +1,3 @@
+import{_ as a,c as i,a2 as t,o as s}from"./chunks/framework.DFwXuivk.js";const u=JSON.parse('{"title":"Preferences for Lux.jl","description":"","frontmatter":{},"headers":[],"relativePath":"manual/preferences.md","filePath":"manual/preferences.md","lastUpdated":null}'),o={name:"manual/preferences.md"};function n(r,e,l,c,p,d){return s(),i("div",null,e[0]||(e[0]=[t(`
automatic_nested_ad_switching - Set this to false to disable automatic switching of backends for nested automatic differentiation. See the manual section on nested automatic differentiation for more details.
gpu_backend - Set this to bypass the automatic backend selection and use a specific gpu backend. Valid options are "cuda", "rocm", "metal", and "oneapi". This preference needs to be set for MLDataDevices package. It is recommended to use MLDataDevices.gpu_backend! to set this preference.
eltype_mismatch_handling - Preference controlling what happens when layers get different eltypes as input. See the documentation on match_eltype for more details.
instability_check - Preference controlling the dispatch doctor. See the documentation on Lux.set_dispatch_doctor_preferences! for more details. The preferences need to be set for LuxCore and LuxLib packages. Both of them default to disable.
Setting the LuxCore preference sets the check at the level of LuxCore.apply. This essentially activates the dispatch doctor for all Lux layers.
Setting the LuxLib preference sets the check at the level of functional layer of Lux, for example, fused_dense_bias_activation. These functions are supposed to be type stable for common input types and can be used to guarantee type stability.
LoopVectorization.jl and Octavian.jl are optional dependencies that are used to accelerate certain CPU operations. However, these packages are tightly coupled with julia and might not work with all julia versions and systems. If these packages are loaded in any form LuxLib will use the optimized versions of the functions. But it might be desirable to disable these packages and use the default implementations instead. This can be done by setting the disable_loop_vectorization preference to true for LuxLib.
`,18)]))}const k=a(o,[["render",n]]);export{u as __pageData,k as default};
diff --git a/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.lean.js b/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.lean.js
new file mode 100644
index 0000000000..dff7287856
--- /dev/null
+++ b/previews/PR1023/assets/manual_preferences.md.CiGO9gyg.lean.js
@@ -0,0 +1,3 @@
+import{_ as a,c as i,a2 as t,o as s}from"./chunks/framework.DFwXuivk.js";const u=JSON.parse('{"title":"Preferences for Lux.jl","description":"","frontmatter":{},"headers":[],"relativePath":"manual/preferences.md","filePath":"manual/preferences.md","lastUpdated":null}'),o={name:"manual/preferences.md"};function n(r,e,l,c,p,d){return s(),i("div",null,e[0]||(e[0]=[t(`
automatic_nested_ad_switching - Set this to false to disable automatic switching of backends for nested automatic differentiation. See the manual section on nested automatic differentiation for more details.
gpu_backend - Set this to bypass the automatic backend selection and use a specific gpu backend. Valid options are "cuda", "rocm", "metal", and "oneapi". This preference needs to be set for MLDataDevices package. It is recommended to use MLDataDevices.gpu_backend! to set this preference.
eltype_mismatch_handling - Preference controlling what happens when layers get different eltypes as input. See the documentation on match_eltype for more details.
instability_check - Preference controlling the dispatch doctor. See the documentation on Lux.set_dispatch_doctor_preferences! for more details. The preferences need to be set for LuxCore and LuxLib packages. Both of them default to disable.
Setting the LuxCore preference sets the check at the level of LuxCore.apply. This essentially activates the dispatch doctor for all Lux layers.
Setting the LuxLib preference sets the check at the level of functional layer of Lux, for example, fused_dense_bias_activation. These functions are supposed to be type stable for common input types and can be used to guarantee type stability.
LoopVectorization.jl and Octavian.jl are optional dependencies that are used to accelerate certain CPU operations. However, these packages are tightly coupled with julia and might not work with all julia versions and systems. If these packages are loaded in any form LuxLib will use the optimized versions of the functions. But it might be desirable to disable these packages and use the default implementations instead. This can be done by setting the disable_loop_vectorization preference to true for LuxLib.
`,18)]))}const k=a(o,[["render",n]]);export{u as __pageData,k as default};
diff --git a/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.js b/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.js
new file mode 100644
index 0000000000..70d75a4c03
--- /dev/null
+++ b/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.js
@@ -0,0 +1,30 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Initializing Weights","description":"","frontmatter":{},"headers":[],"relativePath":"manual/weight_initializers.md","filePath":"manual/weight_initializers.md","lastUpdated":null}'),e={name:"manual/weight_initializers.md"};function l(p,s,h,k,d,g){return t(),a("div",null,s[0]||(s[0]=[n(`
The package is meant to be working with deep learning libraries such as (F)Lux. All the methods take as input the chosen rng type and the dimension for the array.
julia
weights = init(rng, dims...)
The rng is optional, if not specified a default one will be used.
julia
weights = init(dims...)
If there is the need to use keyword arguments the methods can be called with just the rng (optionally) and the keywords to get in return a function behaving like the two examples above.
`,25)]))}const o=i(e,[["render",l]]);export{c as __pageData,o as default};
diff --git a/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.lean.js b/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.lean.js
new file mode 100644
index 0000000000..70d75a4c03
--- /dev/null
+++ b/previews/PR1023/assets/manual_weight_initializers.md.kYS4Pm9l.lean.js
@@ -0,0 +1,30 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const c=JSON.parse('{"title":"Initializing Weights","description":"","frontmatter":{},"headers":[],"relativePath":"manual/weight_initializers.md","filePath":"manual/weight_initializers.md","lastUpdated":null}'),e={name:"manual/weight_initializers.md"};function l(p,s,h,k,d,g){return t(),a("div",null,s[0]||(s[0]=[n(`
The package is meant to be working with deep learning libraries such as (F)Lux. All the methods take as input the chosen rng type and the dimension for the array.
julia
weights = init(rng, dims...)
The rng is optional, if not specified a default one will be used.
julia
weights = init(dims...)
If there is the need to use keyword arguments the methods can be called with just the rng (optionally) and the keywords to get in return a function behaving like the two examples above.
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
`,7)),A("p",null,[s[6]||(s[6]=h("We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector ")),A("mjx-container",l,[(a(),i("svg",p,s[0]||(s[0]=[n('',1)]))),s[1]||(s[1]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("mi",null,"r"),A("mo",null,"="),A("msub",null,[A("mi",null,"r"),A("mn",null,"1")]),A("mo",null,"−"),A("msub",null,[A("mi",null,"r"),A("mn",null,"2")])])],-1))]),s[7]||(s[7]=h(" and use Newtonian formulas to get ")),A("mjx-container",k,[(a(),i("svg",E,s[2]||(s[2]=[n('',1)]))),s[3]||(s[3]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("msub",null,[A("mi",null,"r"),A("mn",null,"1")])])],-1))]),s[8]||(s[8]=h(", ")),A("mjx-container",r,[(a(),i("svg",d,s[4]||(s[4]=[n('',1)]))),s[5]||(s[5]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("msub",null,[A("mi",null,"r"),A("mn",null,"2")])])],-1))]),s[9]||(s[9]=h(" (e.g. Theoretical Mechanics of Particles and Continua 4.3)"))]),s[42]||(s[42]=n(`
julia
function one2two(path, m₁, m₂)
+ M = m₁ + m₂
+ r₁ = m₂ / M .* path
+ r₂ = -m₁ / M .* path
+ return r₁, r₂
+end
one2two (generic function with 1 method)
`,2)),A("p",null,[s[12]||(s[12]=h("Next we define a function to perform the change of variables: ")),A("mjx-container",Q,[(a(),i("svg",C,s[10]||(s[10]=[n('',1)]))),s[11]||(s[11]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("mo",{stretchy:"false"},"("),A("mi",null,"χ"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",null,","),A("mi",null,"ϕ"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},"↦"),A("mo",{stretchy:"false"},"("),A("mi",null,"x"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",null,","),A("mi",null,"y"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},")")])],-1))])]),s[43]||(s[43]=n(`
julia
@views function soln2orbit(soln, model_params=nothing)
+ @assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
+
+ if size(soln, 1) == 2
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+
+ @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
+ p, M, e = model_params
+ else
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+ p = soln[3, :]
+ e = soln[4, :]
+ end
+
+ r = p ./ (1 .+ e .* cos.(χ))
+ x = r .* cos.(ϕ)
+ y = r .* sin.(ϕ)
+
+ orbit = vcat(x', y')
+ return orbit
+end
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
function ODE_model(u, nn_params, t)
+ χ, ϕ = u
+ p, M, e = ode_model_params
+
+ # In this example we know that \`st\` is am empty NamedTuple hence we can safely ignore
+ # it, however, in general, we should use \`st\` to store the state of the neural network.
+ y = 1 .+ nn_model([first(u)], nn_params)
+
+ numer = (1 + e * cos(χ))^2
+ denom = M * (p^(3 / 2))
+
+ χ̇ = (numer / denom) * y[1]
+ ϕ̇ = (numer / denom) * y[2]
+
+ return [χ̇, ϕ̇]
+end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
`,7)),A("p",null,[s[6]||(s[6]=h("We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector ")),A("mjx-container",l,[(a(),i("svg",p,s[0]||(s[0]=[n('',1)]))),s[1]||(s[1]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("mi",null,"r"),A("mo",null,"="),A("msub",null,[A("mi",null,"r"),A("mn",null,"1")]),A("mo",null,"−"),A("msub",null,[A("mi",null,"r"),A("mn",null,"2")])])],-1))]),s[7]||(s[7]=h(" and use Newtonian formulas to get ")),A("mjx-container",k,[(a(),i("svg",E,s[2]||(s[2]=[n('',1)]))),s[3]||(s[3]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("msub",null,[A("mi",null,"r"),A("mn",null,"1")])])],-1))]),s[8]||(s[8]=h(", ")),A("mjx-container",r,[(a(),i("svg",d,s[4]||(s[4]=[n('',1)]))),s[5]||(s[5]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("msub",null,[A("mi",null,"r"),A("mn",null,"2")])])],-1))]),s[9]||(s[9]=h(" (e.g. Theoretical Mechanics of Particles and Continua 4.3)"))]),s[42]||(s[42]=n(`
julia
function one2two(path, m₁, m₂)
+ M = m₁ + m₂
+ r₁ = m₂ / M .* path
+ r₂ = -m₁ / M .* path
+ return r₁, r₂
+end
one2two (generic function with 1 method)
`,2)),A("p",null,[s[12]||(s[12]=h("Next we define a function to perform the change of variables: ")),A("mjx-container",Q,[(a(),i("svg",C,s[10]||(s[10]=[n('',1)]))),s[11]||(s[11]=A("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[A("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[A("mo",{stretchy:"false"},"("),A("mi",null,"χ"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",null,","),A("mi",null,"ϕ"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},"↦"),A("mo",{stretchy:"false"},"("),A("mi",null,"x"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",null,","),A("mi",null,"y"),A("mo",{stretchy:"false"},"("),A("mi",null,"t"),A("mo",{stretchy:"false"},")"),A("mo",{stretchy:"false"},")")])],-1))])]),s[43]||(s[43]=n(`
julia
@views function soln2orbit(soln, model_params=nothing)
+ @assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
+
+ if size(soln, 1) == 2
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+
+ @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
+ p, M, e = model_params
+ else
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+ p = soln[3, :]
+ e = soln[4, :]
+ end
+
+ r = p ./ (1 .+ e .* cos.(χ))
+ x = r .* cos.(ϕ)
+ y = r .* sin.(ϕ)
+
+ orbit = vcat(x', y')
+ return orbit
+end
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
function ODE_model(u, nn_params, t)
+ χ, ϕ = u
+ p, M, e = ode_model_params
+
+ # In this example we know that \`st\` is am empty NamedTuple hence we can safely ignore
+ # it, however, in general, we should use \`st\` to store the state of the neural network.
+ y = 1 .+ nn_model([first(u)], nn_params)
+
+ numer = (1 + e * cos(χ))^2
+ denom = M * (p^(3 / 2))
+
+ χ̇ = (numer / denom) * y[1]
+ ϕ̇ = (numer / denom) * y[2]
+
+ return [χ̇, ϕ̇]
+end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
It introduces basic Julia programming, as well Zygote, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl
julia
using Lux, Random
Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)
The starting point for all of our models is the Array (sometimes referred to as a Tensor in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.
julia
x = [1, 2, 3]
3-element Vector{Int64}:
+ 1
+ 2
+ 3
Here's a matrix – a square array with four elements.
julia
x = [1 2; 3 4]
2×2 Matrix{Int64}:
+ 1 2
+ 3 4
We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.
There's a few functions like this; try replacing rand with ones, zeros, or randn.
By default, Julia works stores numbers is a high-precision format called Float64. In ML we often don't need all those digits, and can ask Julia to work with Float32 instead. We can even ask for more digits using BigFloat.
CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.
You can manually add CUDA. Once CUDA is loaded you can move any array to the GPU with the cu function (or the gpu function exported by \`Lux\`\`), and it supports all of the above operations with the same syntax.
Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.
Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)
If we call any function that relies on rng and uses it via randn, rand, etc. rng will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate on PRNGs before using them.
First, let us run a random number generator 3 times with the replicated rng.
Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.
Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.
`,90)),s("p",null,[a[4]||(a[4]=e("For our first example, consider a simple function computing ")),s("mjx-container",h,[(t(),i("svg",d,a[0]||(a[0]=[n('',1)]))),a[1]||(a[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"f"),s("mo",{stretchy:"false"},"("),s("mi",null,"x"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mn",null,"2")]),s("msup",null,[s("mi",null,"x"),s("mi",null,"T")]),s("mi",null,"x")])],-1))]),a[5]||(a[5]=e(", where ")),s("mjx-container",r,[(t(),i("svg",o,a[2]||(a[2]=[n('',1)]))),a[3]||(a[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",{mathvariant:"normal"},"∇"),s("mi",null,"f"),s("mo",{stretchy:"false"},"("),s("mi",null,"x"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mi",null,"x")])],-1))])]),a[22]||(a[22]=n(`
julia
f(x) = x' * x / 2
+∇f(x) = x # \`∇\` can be typed as \`\\nabla<TAB>\`
+v = randn(rng, Float32, 4)
While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.
Compute the jvp. AutoForwardDiff specifies that we want to use ForwardDiff.jl for the Jacobian-Vector Product
x_samples = randn(rng, Float32, x_dim, n_samples)
+y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
+println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
x shape: (10, 20); y shape: (5, 20)
For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01.
julia
using Optimisers, Printf
Define the loss function
julia
lossfn = MSELoss()
+
+println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))
Loss Value with ground true parameters: 9.3742405e-5
We will train the model using our training API.
julia
function train_model!(model, ps, st, opt, nepochs::Int)
+ tstate = Training.TrainState(model, ps, st, opt)
+ for i in 1:nepochs
+ grads, loss, _, tstate = Training.single_train_step!(
+ AutoZygote(), lossfn, (x_samples, y_samples), tstate)
+ if i % 1000 == 1 || i == nepochs
+ @printf "Loss Value after %6d iterations: %.8f\\n" i loss
+ end
+ end
+ return tstate.model, tstate.parameters, tstate.states
+end
+
+model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
+
+println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
Loss Value after 1 iterations: 7.80465555
+Loss Value after 1001 iterations: 0.12477568
+Loss Value after 2001 iterations: 0.02535537
+Loss Value after 3001 iterations: 0.00914141
+Loss Value after 4001 iterations: 0.00407581
+Loss Value after 5001 iterations: 0.00198415
+Loss Value after 6001 iterations: 0.00101147
+Loss Value after 7001 iterations: 0.00053332
+Loss Value after 8001 iterations: 0.00029203
+Loss Value after 9001 iterations: 0.00016878
+Loss Value after 10000 iterations: 0.00010551
+Loss Value after training: 0.00010546855
It introduces basic Julia programming, as well Zygote, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl
julia
using Lux, Random
Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)
The starting point for all of our models is the Array (sometimes referred to as a Tensor in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.
julia
x = [1, 2, 3]
3-element Vector{Int64}:
+ 1
+ 2
+ 3
Here's a matrix – a square array with four elements.
julia
x = [1 2; 3 4]
2×2 Matrix{Int64}:
+ 1 2
+ 3 4
We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.
There's a few functions like this; try replacing rand with ones, zeros, or randn.
By default, Julia works stores numbers is a high-precision format called Float64. In ML we often don't need all those digits, and can ask Julia to work with Float32 instead. We can even ask for more digits using BigFloat.
CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.
You can manually add CUDA. Once CUDA is loaded you can move any array to the GPU with the cu function (or the gpu function exported by \`Lux\`\`), and it supports all of the above operations with the same syntax.
Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.
Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)
If we call any function that relies on rng and uses it via randn, rand, etc. rng will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate on PRNGs before using them.
First, let us run a random number generator 3 times with the replicated rng.
Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.
Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.
`,90)),s("p",null,[a[4]||(a[4]=e("For our first example, consider a simple function computing ")),s("mjx-container",h,[(t(),i("svg",d,a[0]||(a[0]=[n('',1)]))),a[1]||(a[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"f"),s("mo",{stretchy:"false"},"("),s("mi",null,"x"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mfrac",null,[s("mn",null,"1"),s("mn",null,"2")]),s("msup",null,[s("mi",null,"x"),s("mi",null,"T")]),s("mi",null,"x")])],-1))]),a[5]||(a[5]=e(", where ")),s("mjx-container",r,[(t(),i("svg",o,a[2]||(a[2]=[n('',1)]))),a[3]||(a[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",{mathvariant:"normal"},"∇"),s("mi",null,"f"),s("mo",{stretchy:"false"},"("),s("mi",null,"x"),s("mo",{stretchy:"false"},")"),s("mo",null,"="),s("mi",null,"x")])],-1))])]),a[22]||(a[22]=n(`
julia
f(x) = x' * x / 2
+∇f(x) = x # \`∇\` can be typed as \`\\nabla<TAB>\`
+v = randn(rng, Float32, 4)
While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.
Compute the jvp. AutoForwardDiff specifies that we want to use ForwardDiff.jl for the Jacobian-Vector Product
x_samples = randn(rng, Float32, x_dim, n_samples)
+y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
+println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
x shape: (10, 20); y shape: (5, 20)
For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01.
julia
using Optimisers, Printf
Define the loss function
julia
lossfn = MSELoss()
+
+println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))
Loss Value with ground true parameters: 9.3742405e-5
We will train the model using our training API.
julia
function train_model!(model, ps, st, opt, nepochs::Int)
+ tstate = Training.TrainState(model, ps, st, opt)
+ for i in 1:nepochs
+ grads, loss, _, tstate = Training.single_train_step!(
+ AutoZygote(), lossfn, (x_samples, y_samples), tstate)
+ if i % 1000 == 1 || i == nepochs
+ @printf "Loss Value after %6d iterations: %.8f\\n" i loss
+ end
+ end
+ return tstate.model, tstate.parameters, tstate.states
+end
+
+model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
+
+println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
Loss Value after 1 iterations: 7.80465555
+Loss Value after 1001 iterations: 0.12477568
+Loss Value after 2001 iterations: 0.02535537
+Loss Value after 3001 iterations: 0.00914141
+Loss Value after 4001 iterations: 0.00407581
+Loss Value after 5001 iterations: 0.00198415
+Loss Value after 6001 iterations: 0.00101147
+Loss Value after 7001 iterations: 0.00053332
+Loss Value after 8001 iterations: 0.00029203
+Loss Value after 9001 iterations: 0.00016878
+Loss Value after 10000 iterations: 0.00010551
+Loss Value after training: 0.00010546855
`,28))])}const H=l(p,[["render",F]]);export{D as __pageData,H as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.js b/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.js
new file mode 100644
index 0000000000..b98dd44765
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.js
@@ -0,0 +1,124 @@
+import{_ as l,c as a,a2 as s,j as i,a as n,o as t}from"./chunks/framework.DFwXuivk.js";const I=JSON.parse('{"title":"Fitting a Polynomial using MLP","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/2_PolynomialFitting.md","filePath":"tutorials/beginner/2_PolynomialFitting.md","lastUpdated":null}'),p={name:"tutorials/beginner/2_PolynomialFitting.md"},e={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},h={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.464ex"},xmlns:"http://www.w3.org/2000/svg",width:"11.599ex",height:"2.351ex",role:"img",focusable:"false",viewBox:"0 -833.9 5126.6 1038.9","aria-hidden":"true"};function E(d,A,k,r,v,g){return t(),a("div",null,[A[4]||(A[4]=s(`
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
`,38))])}const R=l(p,[["render",E]]);export{I as __pageData,R as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.lean.js b/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.lean.js
new file mode 100644
index 0000000000..b98dd44765
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_2_PolynomialFitting.md.sRzotb_h.lean.js
@@ -0,0 +1,124 @@
+import{_ as l,c as a,a2 as s,j as i,a as n,o as t}from"./chunks/framework.DFwXuivk.js";const I=JSON.parse('{"title":"Fitting a Polynomial using MLP","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/2_PolynomialFitting.md","filePath":"tutorials/beginner/2_PolynomialFitting.md","lastUpdated":null}'),p={name:"tutorials/beginner/2_PolynomialFitting.md"},e={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},h={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.464ex"},xmlns:"http://www.w3.org/2000/svg",width:"11.599ex",height:"2.351ex",role:"img",focusable:"false",viewBox:"0 -833.9 5126.6 1038.9","aria-hidden":"true"};function E(d,A,k,r,v,g){return t(),a("div",null,[A[4]||(A[4]=s(`
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
`,38))])}const R=l(p,[["render",E]]);export{I as __pageData,R as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.js b/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.js
new file mode 100644
index 0000000000..c94d2bb9ba
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.js
@@ -0,0 +1,611 @@
+import{_ as a,c as n,a2 as i,o as p}from"./chunks/framework.DFwXuivk.js";const r=JSON.parse('{"title":"Training a Simple LSTM","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/3_SimpleRNN.md","filePath":"tutorials/beginner/3_SimpleRNN.md","lastUpdated":null}'),l={name:"tutorials/beginner/3_SimpleRNN.md"};function e(h,s,t,c,k,o){return p(),n("div",null,s[0]||(s[0]=[i(`
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
+ # Create the spirals
+ data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
+ # Get the labels
+ labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
+ clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
+ for d in data[1:(dataset_size ÷ 2)]]
+ anticlockwise_spirals = [reshape(
+ d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
+ for d in data[((dataset_size ÷ 2) + 1):end]]
+ x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
+ # Split the dataset
+ (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
+ # Create DataLoaders
+ return (
+ # Use DataLoader to automatically minibatch and shuffle the data
+ DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
+ # Don't shuffle the validation data
+ DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
+end
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
julia
function (s::SpiralClassifier)(
+ x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
+ # First we will have to run the sequence through the LSTM Cell
+ # The first call to LSTM Cell will create the initial hidden state
+ # See that the parameters and states are automatically populated into a field called
+ # \`lstm_cell\` We use \`eachslice\` to get the elements in the sequence without copying,
+ # and \`Iterators.peel\` to split out the first element for LSTM initialization.
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
+ # Now that we have the hidden state and memory in \`carry\` we will pass the input and
+ # \`carry\` jointly
+ for x in x_rest
+ (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
+ end
+ # After running through the sequence we will pass the output through the classifier
+ y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
+ # Finally remember to create the updated state
+ st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
+ return vec(y), st
+end
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
+ lstm_cell = LSTMCell(in_dims => hidden_dims)
+ classifier = Dense(hidden_dims => out_dims, sigmoid)
+ return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ y, carry = lstm_cell(x_init)
+ for x in x_rest
+ y, carry = lstm_cell((x, carry))
+ end
+ @return vec(classifier(y))
+ end
+end
SpiralClassifierCompact (generic function with 1 method)
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
function main(model_type)
+ dev = gpu_device()
+
+ # Get the dataloaders
+ train_loader, val_loader = get_dataloaders() .|> dev
+
+ # Create the model
+ model = model_type(2, 8, 1)
+ rng = Xoshiro(0)
+ ps, st = Lux.setup(rng, model) |> dev
+
+ train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
+
+ for epoch in 1:25
+ # Train the model
+ for (x, y) in train_loader
+ (_, loss, _, train_state) = Training.single_train_step!(
+ AutoZygote(), lossfn, (x, y), train_state)
+
+ @printf "Epoch [%3d]: Loss %4.5f\\n" epoch loss
+ end
+
+ # Validate the model
+ st_ = Lux.testmode(train_state.states)
+ for (x, y) in val_loader
+ ŷ, st_ = model(x, train_state.parameters, st_)
+ loss = lossfn(ŷ, y)
+ acc = accuracy(ŷ, y)
+ @printf "Validation: Loss %4.5f Accuracy %4.5f\\n" loss acc
+ end
+ end
+
+ return (train_state.parameters, train_state.states) |> cpu_device()
+end
+
+ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.60926
+Epoch [ 1]: Loss 0.60205
+Epoch [ 1]: Loss 0.56447
+Epoch [ 1]: Loss 0.53935
+Epoch [ 1]: Loss 0.51961
+Epoch [ 1]: Loss 0.50630
+Epoch [ 1]: Loss 0.48399
+Validation: Loss 0.46956 Accuracy 1.00000
+Validation: Loss 0.47794 Accuracy 1.00000
+Epoch [ 2]: Loss 0.47301
+Epoch [ 2]: Loss 0.45405
+Epoch [ 2]: Loss 0.43968
+Epoch [ 2]: Loss 0.43054
+Epoch [ 2]: Loss 0.40202
+Epoch [ 2]: Loss 0.39666
+Epoch [ 2]: Loss 0.40138
+Validation: Loss 0.37273 Accuracy 1.00000
+Validation: Loss 0.38210 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36731
+Epoch [ 3]: Loss 0.36875
+Epoch [ 3]: Loss 0.34892
+Epoch [ 3]: Loss 0.33812
+Epoch [ 3]: Loss 0.31629
+Epoch [ 3]: Loss 0.30792
+Epoch [ 3]: Loss 0.27809
+Validation: Loss 0.28817 Accuracy 1.00000
+Validation: Loss 0.29822 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28662
+Epoch [ 4]: Loss 0.27989
+Epoch [ 4]: Loss 0.27278
+Epoch [ 4]: Loss 0.25235
+Epoch [ 4]: Loss 0.23497
+Epoch [ 4]: Loss 0.23847
+Epoch [ 4]: Loss 0.23192
+Validation: Loss 0.21844 Accuracy 1.00000
+Validation: Loss 0.22858 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21529
+Epoch [ 5]: Loss 0.21660
+Epoch [ 5]: Loss 0.21147
+Epoch [ 5]: Loss 0.18347
+Epoch [ 5]: Loss 0.18387
+Epoch [ 5]: Loss 0.16418
+Epoch [ 5]: Loss 0.18488
+Validation: Loss 0.16251 Accuracy 1.00000
+Validation: Loss 0.17173 Accuracy 1.00000
+Epoch [ 6]: Loss 0.15106
+Epoch [ 6]: Loss 0.15557
+Epoch [ 6]: Loss 0.15604
+Epoch [ 6]: Loss 0.12610
+Epoch [ 6]: Loss 0.14466
+Epoch [ 6]: Loss 0.13525
+Epoch [ 6]: Loss 0.13401
+Validation: Loss 0.11923 Accuracy 1.00000
+Validation: Loss 0.12679 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11300
+Epoch [ 7]: Loss 0.11270
+Epoch [ 7]: Loss 0.11182
+Epoch [ 7]: Loss 0.10579
+Epoch [ 7]: Loss 0.10077
+Epoch [ 7]: Loss 0.09092
+Epoch [ 7]: Loss 0.08957
+Validation: Loss 0.08530 Accuracy 1.00000
+Validation: Loss 0.09085 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08321
+Epoch [ 8]: Loss 0.07613
+Epoch [ 8]: Loss 0.07561
+Epoch [ 8]: Loss 0.07250
+Epoch [ 8]: Loss 0.06895
+Epoch [ 8]: Loss 0.07155
+Epoch [ 8]: Loss 0.06246
+Validation: Loss 0.05935 Accuracy 1.00000
+Validation: Loss 0.06304 Accuracy 1.00000
+Epoch [ 9]: Loss 0.06135
+Epoch [ 9]: Loss 0.05983
+Epoch [ 9]: Loss 0.05429
+Epoch [ 9]: Loss 0.04415
+Epoch [ 9]: Loss 0.04965
+Epoch [ 9]: Loss 0.04801
+Epoch [ 9]: Loss 0.04264
+Validation: Loss 0.04389 Accuracy 1.00000
+Validation: Loss 0.04647 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04243
+Epoch [ 10]: Loss 0.04109
+Epoch [ 10]: Loss 0.04136
+Epoch [ 10]: Loss 0.04201
+Epoch [ 10]: Loss 0.03979
+Epoch [ 10]: Loss 0.03471
+Epoch [ 10]: Loss 0.03760
+Validation: Loss 0.03546 Accuracy 1.00000
+Validation: Loss 0.03756 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03545
+Epoch [ 11]: Loss 0.03571
+Epoch [ 11]: Loss 0.03202
+Epoch [ 11]: Loss 0.03209
+Epoch [ 11]: Loss 0.03134
+Epoch [ 11]: Loss 0.03114
+Epoch [ 11]: Loss 0.03593
+Validation: Loss 0.03006 Accuracy 1.00000
+Validation: Loss 0.03189 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03210
+Epoch [ 12]: Loss 0.02768
+Epoch [ 12]: Loss 0.02955
+Epoch [ 12]: Loss 0.02631
+Epoch [ 12]: Loss 0.02720
+Epoch [ 12]: Loss 0.02667
+Epoch [ 12]: Loss 0.03031
+Validation: Loss 0.02612 Accuracy 1.00000
+Validation: Loss 0.02773 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02589
+Epoch [ 13]: Loss 0.02454
+Epoch [ 13]: Loss 0.02716
+Epoch [ 13]: Loss 0.02579
+Epoch [ 13]: Loss 0.02323
+Epoch [ 13]: Loss 0.02301
+Epoch [ 13]: Loss 0.02099
+Validation: Loss 0.02307 Accuracy 1.00000
+Validation: Loss 0.02452 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02105
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02234
+Epoch [ 14]: Loss 0.02238
+Epoch [ 14]: Loss 0.02259
+Epoch [ 14]: Loss 0.02282
+Epoch [ 14]: Loss 0.01795
+Validation: Loss 0.02066 Accuracy 1.00000
+Validation: Loss 0.02199 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02140
+Epoch [ 15]: Loss 0.02017
+Epoch [ 15]: Loss 0.01932
+Epoch [ 15]: Loss 0.02011
+Epoch [ 15]: Loss 0.01752
+Epoch [ 15]: Loss 0.02006
+Epoch [ 15]: Loss 0.01963
+Validation: Loss 0.01866 Accuracy 1.00000
+Validation: Loss 0.01988 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01796
+Epoch [ 16]: Loss 0.01636
+Epoch [ 16]: Loss 0.01900
+Epoch [ 16]: Loss 0.01740
+Epoch [ 16]: Loss 0.01782
+Epoch [ 16]: Loss 0.01824
+Epoch [ 16]: Loss 0.01976
+Validation: Loss 0.01696 Accuracy 1.00000
+Validation: Loss 0.01810 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01745
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01777
+Epoch [ 17]: Loss 0.01630
+Epoch [ 17]: Loss 0.01578
+Epoch [ 17]: Loss 0.01468
+Epoch [ 17]: Loss 0.01627
+Validation: Loss 0.01549 Accuracy 1.00000
+Validation: Loss 0.01656 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01608
+Epoch [ 18]: Loss 0.01398
+Epoch [ 18]: Loss 0.01425
+Epoch [ 18]: Loss 0.01537
+Epoch [ 18]: Loss 0.01504
+Epoch [ 18]: Loss 0.01471
+Epoch [ 18]: Loss 0.01496
+Validation: Loss 0.01423 Accuracy 1.00000
+Validation: Loss 0.01523 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01355
+Epoch [ 19]: Loss 0.01489
+Epoch [ 19]: Loss 0.01364
+Epoch [ 19]: Loss 0.01253
+Epoch [ 19]: Loss 0.01360
+Epoch [ 19]: Loss 0.01343
+Epoch [ 19]: Loss 0.01639
+Validation: Loss 0.01313 Accuracy 1.00000
+Validation: Loss 0.01405 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01377
+Epoch [ 20]: Loss 0.01183
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01292
+Epoch [ 20]: Loss 0.01361
+Epoch [ 20]: Loss 0.01227
+Validation: Loss 0.01211 Accuracy 1.00000
+Validation: Loss 0.01297 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01212
+Epoch [ 21]: Loss 0.01138
+Epoch [ 21]: Loss 0.01102
+Epoch [ 21]: Loss 0.01238
+Epoch [ 21]: Loss 0.01200
+Epoch [ 21]: Loss 0.01130
+Epoch [ 21]: Loss 0.01082
+Validation: Loss 0.01112 Accuracy 1.00000
+Validation: Loss 0.01190 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01134
+Epoch [ 22]: Loss 0.01031
+Epoch [ 22]: Loss 0.01060
+Epoch [ 22]: Loss 0.01130
+Epoch [ 22]: Loss 0.01009
+Epoch [ 22]: Loss 0.01053
+Epoch [ 22]: Loss 0.00940
+Validation: Loss 0.01002 Accuracy 1.00000
+Validation: Loss 0.01071 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00886
+Epoch [ 23]: Loss 0.01026
+Epoch [ 23]: Loss 0.01005
+Epoch [ 23]: Loss 0.00853
+Epoch [ 23]: Loss 0.01033
+Epoch [ 23]: Loss 0.00902
+Epoch [ 23]: Loss 0.00969
+Validation: Loss 0.00888 Accuracy 1.00000
+Validation: Loss 0.00947 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00903
+Epoch [ 24]: Loss 0.00856
+Epoch [ 24]: Loss 0.00866
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00830
+Epoch [ 24]: Loss 0.00781
+Epoch [ 24]: Loss 0.00662
+Validation: Loss 0.00795 Accuracy 1.00000
+Validation: Loss 0.00846 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00830
+Epoch [ 25]: Loss 0.00742
+Epoch [ 25]: Loss 0.00822
+Epoch [ 25]: Loss 0.00791
+Epoch [ 25]: Loss 0.00721
+Epoch [ 25]: Loss 0.00726
+Epoch [ 25]: Loss 0.00582
+Validation: Loss 0.00730 Accuracy 1.00000
+Validation: Loss 0.00775 Accuracy 1.00000
We can also train the compact model with the exact same code!
Epoch [ 1]: Loss 0.62249
+Epoch [ 1]: Loss 0.58988
+Epoch [ 1]: Loss 0.57122
+Epoch [ 1]: Loss 0.54145
+Epoch [ 1]: Loss 0.51676
+Epoch [ 1]: Loss 0.49941
+Epoch [ 1]: Loss 0.48712
+Validation: Loss 0.46707 Accuracy 1.00000
+Validation: Loss 0.46650 Accuracy 1.00000
+Epoch [ 2]: Loss 0.46435
+Epoch [ 2]: Loss 0.45555
+Epoch [ 2]: Loss 0.45454
+Epoch [ 2]: Loss 0.42345
+Epoch [ 2]: Loss 0.41436
+Epoch [ 2]: Loss 0.38527
+Epoch [ 2]: Loss 0.37442
+Validation: Loss 0.36940 Accuracy 1.00000
+Validation: Loss 0.36858 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36752
+Epoch [ 3]: Loss 0.36360
+Epoch [ 3]: Loss 0.34430
+Epoch [ 3]: Loss 0.32734
+Epoch [ 3]: Loss 0.31783
+Epoch [ 3]: Loss 0.31825
+Epoch [ 3]: Loss 0.28565
+Validation: Loss 0.28440 Accuracy 1.00000
+Validation: Loss 0.28337 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28307
+Epoch [ 4]: Loss 0.27199
+Epoch [ 4]: Loss 0.26836
+Epoch [ 4]: Loss 0.26051
+Epoch [ 4]: Loss 0.24528
+Epoch [ 4]: Loss 0.23063
+Epoch [ 4]: Loss 0.22536
+Validation: Loss 0.21475 Accuracy 1.00000
+Validation: Loss 0.21368 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21305
+Epoch [ 5]: Loss 0.21531
+Epoch [ 5]: Loss 0.19616
+Epoch [ 5]: Loss 0.18414
+Epoch [ 5]: Loss 0.18294
+Epoch [ 5]: Loss 0.17875
+Epoch [ 5]: Loss 0.17815
+Validation: Loss 0.15941 Accuracy 1.00000
+Validation: Loss 0.15850 Accuracy 1.00000
+Epoch [ 6]: Loss 0.16464
+Epoch [ 6]: Loss 0.14669
+Epoch [ 6]: Loss 0.14234
+Epoch [ 6]: Loss 0.14785
+Epoch [ 6]: Loss 0.13936
+Epoch [ 6]: Loss 0.13121
+Epoch [ 6]: Loss 0.11054
+Validation: Loss 0.11688 Accuracy 1.00000
+Validation: Loss 0.11621 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11895
+Epoch [ 7]: Loss 0.11755
+Epoch [ 7]: Loss 0.11153
+Epoch [ 7]: Loss 0.10806
+Epoch [ 7]: Loss 0.08931
+Epoch [ 7]: Loss 0.08989
+Epoch [ 7]: Loss 0.08885
+Validation: Loss 0.08377 Accuracy 1.00000
+Validation: Loss 0.08332 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08392
+Epoch [ 8]: Loss 0.07975
+Epoch [ 8]: Loss 0.07711
+Epoch [ 8]: Loss 0.07462
+Epoch [ 8]: Loss 0.06929
+Epoch [ 8]: Loss 0.06475
+Epoch [ 8]: Loss 0.06222
+Validation: Loss 0.05835 Accuracy 1.00000
+Validation: Loss 0.05808 Accuracy 1.00000
+Epoch [ 9]: Loss 0.05835
+Epoch [ 9]: Loss 0.05645
+Epoch [ 9]: Loss 0.05303
+Epoch [ 9]: Loss 0.04974
+Epoch [ 9]: Loss 0.04989
+Epoch [ 9]: Loss 0.04836
+Epoch [ 9]: Loss 0.04374
+Validation: Loss 0.04304 Accuracy 1.00000
+Validation: Loss 0.04283 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04373
+Epoch [ 10]: Loss 0.03963
+Epoch [ 10]: Loss 0.04024
+Epoch [ 10]: Loss 0.03893
+Epoch [ 10]: Loss 0.04085
+Epoch [ 10]: Loss 0.03933
+Epoch [ 10]: Loss 0.02782
+Validation: Loss 0.03470 Accuracy 1.00000
+Validation: Loss 0.03451 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03413
+Epoch [ 11]: Loss 0.03603
+Epoch [ 11]: Loss 0.03246
+Epoch [ 11]: Loss 0.03142
+Epoch [ 11]: Loss 0.03040
+Epoch [ 11]: Loss 0.03279
+Epoch [ 11]: Loss 0.03336
+Validation: Loss 0.02942 Accuracy 1.00000
+Validation: Loss 0.02924 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03113
+Epoch [ 12]: Loss 0.02712
+Epoch [ 12]: Loss 0.02845
+Epoch [ 12]: Loss 0.02904
+Epoch [ 12]: Loss 0.02709
+Epoch [ 12]: Loss 0.02722
+Epoch [ 12]: Loss 0.02449
+Validation: Loss 0.02555 Accuracy 1.00000
+Validation: Loss 0.02540 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02730
+Epoch [ 13]: Loss 0.02638
+Epoch [ 13]: Loss 0.02358
+Epoch [ 13]: Loss 0.02337
+Epoch [ 13]: Loss 0.02417
+Epoch [ 13]: Loss 0.02397
+Epoch [ 13]: Loss 0.02159
+Validation: Loss 0.02258 Accuracy 1.00000
+Validation: Loss 0.02243 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02377
+Epoch [ 14]: Loss 0.02260
+Epoch [ 14]: Loss 0.02070
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02060
+Epoch [ 14]: Loss 0.02212
+Epoch [ 14]: Loss 0.02141
+Validation: Loss 0.02019 Accuracy 1.00000
+Validation: Loss 0.02006 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02146
+Epoch [ 15]: Loss 0.01937
+Epoch [ 15]: Loss 0.02047
+Epoch [ 15]: Loss 0.01826
+Epoch [ 15]: Loss 0.01953
+Epoch [ 15]: Loss 0.01824
+Epoch [ 15]: Loss 0.02201
+Validation: Loss 0.01821 Accuracy 1.00000
+Validation: Loss 0.01809 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01872
+Epoch [ 16]: Loss 0.01647
+Epoch [ 16]: Loss 0.01868
+Epoch [ 16]: Loss 0.01763
+Epoch [ 16]: Loss 0.01802
+Epoch [ 16]: Loss 0.01730
+Epoch [ 16]: Loss 0.01691
+Validation: Loss 0.01653 Accuracy 1.00000
+Validation: Loss 0.01642 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01638
+Epoch [ 17]: Loss 0.01693
+Epoch [ 17]: Loss 0.01747
+Epoch [ 17]: Loss 0.01530
+Epoch [ 17]: Loss 0.01570
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01431
+Validation: Loss 0.01511 Accuracy 1.00000
+Validation: Loss 0.01501 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01395
+Epoch [ 18]: Loss 0.01493
+Epoch [ 18]: Loss 0.01631
+Epoch [ 18]: Loss 0.01388
+Epoch [ 18]: Loss 0.01496
+Epoch [ 18]: Loss 0.01520
+Epoch [ 18]: Loss 0.01366
+Validation: Loss 0.01390 Accuracy 1.00000
+Validation: Loss 0.01381 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01337
+Epoch [ 19]: Loss 0.01481
+Epoch [ 19]: Loss 0.01359
+Epoch [ 19]: Loss 0.01293
+Epoch [ 19]: Loss 0.01317
+Epoch [ 19]: Loss 0.01404
+Epoch [ 19]: Loss 0.01416
+Validation: Loss 0.01286 Accuracy 1.00000
+Validation: Loss 0.01277 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01286
+Epoch [ 20]: Loss 0.01335
+Epoch [ 20]: Loss 0.01259
+Epoch [ 20]: Loss 0.01343
+Epoch [ 20]: Loss 0.01294
+Epoch [ 20]: Loss 0.01124
+Epoch [ 20]: Loss 0.01124
+Validation: Loss 0.01194 Accuracy 1.00000
+Validation: Loss 0.01186 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01229
+Epoch [ 21]: Loss 0.01273
+Epoch [ 21]: Loss 0.01021
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01191
+Epoch [ 21]: Loss 0.01311
+Validation: Loss 0.01111 Accuracy 1.00000
+Validation: Loss 0.01104 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01112
+Epoch [ 22]: Loss 0.01155
+Epoch [ 22]: Loss 0.01068
+Epoch [ 22]: Loss 0.01120
+Epoch [ 22]: Loss 0.00993
+Epoch [ 22]: Loss 0.01129
+Epoch [ 22]: Loss 0.01098
+Validation: Loss 0.01033 Accuracy 1.00000
+Validation: Loss 0.01026 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00950
+Epoch [ 23]: Loss 0.01102
+Epoch [ 23]: Loss 0.01060
+Epoch [ 23]: Loss 0.01058
+Epoch [ 23]: Loss 0.00987
+Epoch [ 23]: Loss 0.01006
+Epoch [ 23]: Loss 0.00747
+Validation: Loss 0.00952 Accuracy 1.00000
+Validation: Loss 0.00945 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00960
+Epoch [ 24]: Loss 0.00995
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00888
+Epoch [ 24]: Loss 0.00955
+Epoch [ 24]: Loss 0.00915
+Epoch [ 24]: Loss 0.00884
+Validation: Loss 0.00861 Accuracy 1.00000
+Validation: Loss 0.00856 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00958
+Epoch [ 25]: Loss 0.00920
+Epoch [ 25]: Loss 0.00803
+Epoch [ 25]: Loss 0.00769
+Epoch [ 25]: Loss 0.00804
+Epoch [ 25]: Loss 0.00784
+Epoch [ 25]: Loss 0.00760
+Validation: Loss 0.00766 Accuracy 1.00000
+Validation: Loss 0.00762 Accuracy 1.00000
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
`,45)]))}const d=a(l,[["render",e]]);export{r as __pageData,d as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.lean.js b/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.lean.js
new file mode 100644
index 0000000000..c94d2bb9ba
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_3_SimpleRNN.md.CxWJLH3E.lean.js
@@ -0,0 +1,611 @@
+import{_ as a,c as n,a2 as i,o as p}from"./chunks/framework.DFwXuivk.js";const r=JSON.parse('{"title":"Training a Simple LSTM","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/3_SimpleRNN.md","filePath":"tutorials/beginner/3_SimpleRNN.md","lastUpdated":null}'),l={name:"tutorials/beginner/3_SimpleRNN.md"};function e(h,s,t,c,k,o){return p(),n("div",null,s[0]||(s[0]=[i(`
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
+ # Create the spirals
+ data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
+ # Get the labels
+ labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
+ clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
+ for d in data[1:(dataset_size ÷ 2)]]
+ anticlockwise_spirals = [reshape(
+ d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
+ for d in data[((dataset_size ÷ 2) + 1):end]]
+ x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
+ # Split the dataset
+ (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
+ # Create DataLoaders
+ return (
+ # Use DataLoader to automatically minibatch and shuffle the data
+ DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
+ # Don't shuffle the validation data
+ DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
+end
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
julia
function (s::SpiralClassifier)(
+ x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
+ # First we will have to run the sequence through the LSTM Cell
+ # The first call to LSTM Cell will create the initial hidden state
+ # See that the parameters and states are automatically populated into a field called
+ # \`lstm_cell\` We use \`eachslice\` to get the elements in the sequence without copying,
+ # and \`Iterators.peel\` to split out the first element for LSTM initialization.
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
+ # Now that we have the hidden state and memory in \`carry\` we will pass the input and
+ # \`carry\` jointly
+ for x in x_rest
+ (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
+ end
+ # After running through the sequence we will pass the output through the classifier
+ y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
+ # Finally remember to create the updated state
+ st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
+ return vec(y), st
+end
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
+ lstm_cell = LSTMCell(in_dims => hidden_dims)
+ classifier = Dense(hidden_dims => out_dims, sigmoid)
+ return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ y, carry = lstm_cell(x_init)
+ for x in x_rest
+ y, carry = lstm_cell((x, carry))
+ end
+ @return vec(classifier(y))
+ end
+end
SpiralClassifierCompact (generic function with 1 method)
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
function main(model_type)
+ dev = gpu_device()
+
+ # Get the dataloaders
+ train_loader, val_loader = get_dataloaders() .|> dev
+
+ # Create the model
+ model = model_type(2, 8, 1)
+ rng = Xoshiro(0)
+ ps, st = Lux.setup(rng, model) |> dev
+
+ train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
+
+ for epoch in 1:25
+ # Train the model
+ for (x, y) in train_loader
+ (_, loss, _, train_state) = Training.single_train_step!(
+ AutoZygote(), lossfn, (x, y), train_state)
+
+ @printf "Epoch [%3d]: Loss %4.5f\\n" epoch loss
+ end
+
+ # Validate the model
+ st_ = Lux.testmode(train_state.states)
+ for (x, y) in val_loader
+ ŷ, st_ = model(x, train_state.parameters, st_)
+ loss = lossfn(ŷ, y)
+ acc = accuracy(ŷ, y)
+ @printf "Validation: Loss %4.5f Accuracy %4.5f\\n" loss acc
+ end
+ end
+
+ return (train_state.parameters, train_state.states) |> cpu_device()
+end
+
+ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.60926
+Epoch [ 1]: Loss 0.60205
+Epoch [ 1]: Loss 0.56447
+Epoch [ 1]: Loss 0.53935
+Epoch [ 1]: Loss 0.51961
+Epoch [ 1]: Loss 0.50630
+Epoch [ 1]: Loss 0.48399
+Validation: Loss 0.46956 Accuracy 1.00000
+Validation: Loss 0.47794 Accuracy 1.00000
+Epoch [ 2]: Loss 0.47301
+Epoch [ 2]: Loss 0.45405
+Epoch [ 2]: Loss 0.43968
+Epoch [ 2]: Loss 0.43054
+Epoch [ 2]: Loss 0.40202
+Epoch [ 2]: Loss 0.39666
+Epoch [ 2]: Loss 0.40138
+Validation: Loss 0.37273 Accuracy 1.00000
+Validation: Loss 0.38210 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36731
+Epoch [ 3]: Loss 0.36875
+Epoch [ 3]: Loss 0.34892
+Epoch [ 3]: Loss 0.33812
+Epoch [ 3]: Loss 0.31629
+Epoch [ 3]: Loss 0.30792
+Epoch [ 3]: Loss 0.27809
+Validation: Loss 0.28817 Accuracy 1.00000
+Validation: Loss 0.29822 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28662
+Epoch [ 4]: Loss 0.27989
+Epoch [ 4]: Loss 0.27278
+Epoch [ 4]: Loss 0.25235
+Epoch [ 4]: Loss 0.23497
+Epoch [ 4]: Loss 0.23847
+Epoch [ 4]: Loss 0.23192
+Validation: Loss 0.21844 Accuracy 1.00000
+Validation: Loss 0.22858 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21529
+Epoch [ 5]: Loss 0.21660
+Epoch [ 5]: Loss 0.21147
+Epoch [ 5]: Loss 0.18347
+Epoch [ 5]: Loss 0.18387
+Epoch [ 5]: Loss 0.16418
+Epoch [ 5]: Loss 0.18488
+Validation: Loss 0.16251 Accuracy 1.00000
+Validation: Loss 0.17173 Accuracy 1.00000
+Epoch [ 6]: Loss 0.15106
+Epoch [ 6]: Loss 0.15557
+Epoch [ 6]: Loss 0.15604
+Epoch [ 6]: Loss 0.12610
+Epoch [ 6]: Loss 0.14466
+Epoch [ 6]: Loss 0.13525
+Epoch [ 6]: Loss 0.13401
+Validation: Loss 0.11923 Accuracy 1.00000
+Validation: Loss 0.12679 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11300
+Epoch [ 7]: Loss 0.11270
+Epoch [ 7]: Loss 0.11182
+Epoch [ 7]: Loss 0.10579
+Epoch [ 7]: Loss 0.10077
+Epoch [ 7]: Loss 0.09092
+Epoch [ 7]: Loss 0.08957
+Validation: Loss 0.08530 Accuracy 1.00000
+Validation: Loss 0.09085 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08321
+Epoch [ 8]: Loss 0.07613
+Epoch [ 8]: Loss 0.07561
+Epoch [ 8]: Loss 0.07250
+Epoch [ 8]: Loss 0.06895
+Epoch [ 8]: Loss 0.07155
+Epoch [ 8]: Loss 0.06246
+Validation: Loss 0.05935 Accuracy 1.00000
+Validation: Loss 0.06304 Accuracy 1.00000
+Epoch [ 9]: Loss 0.06135
+Epoch [ 9]: Loss 0.05983
+Epoch [ 9]: Loss 0.05429
+Epoch [ 9]: Loss 0.04415
+Epoch [ 9]: Loss 0.04965
+Epoch [ 9]: Loss 0.04801
+Epoch [ 9]: Loss 0.04264
+Validation: Loss 0.04389 Accuracy 1.00000
+Validation: Loss 0.04647 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04243
+Epoch [ 10]: Loss 0.04109
+Epoch [ 10]: Loss 0.04136
+Epoch [ 10]: Loss 0.04201
+Epoch [ 10]: Loss 0.03979
+Epoch [ 10]: Loss 0.03471
+Epoch [ 10]: Loss 0.03760
+Validation: Loss 0.03546 Accuracy 1.00000
+Validation: Loss 0.03756 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03545
+Epoch [ 11]: Loss 0.03571
+Epoch [ 11]: Loss 0.03202
+Epoch [ 11]: Loss 0.03209
+Epoch [ 11]: Loss 0.03134
+Epoch [ 11]: Loss 0.03114
+Epoch [ 11]: Loss 0.03593
+Validation: Loss 0.03006 Accuracy 1.00000
+Validation: Loss 0.03189 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03210
+Epoch [ 12]: Loss 0.02768
+Epoch [ 12]: Loss 0.02955
+Epoch [ 12]: Loss 0.02631
+Epoch [ 12]: Loss 0.02720
+Epoch [ 12]: Loss 0.02667
+Epoch [ 12]: Loss 0.03031
+Validation: Loss 0.02612 Accuracy 1.00000
+Validation: Loss 0.02773 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02589
+Epoch [ 13]: Loss 0.02454
+Epoch [ 13]: Loss 0.02716
+Epoch [ 13]: Loss 0.02579
+Epoch [ 13]: Loss 0.02323
+Epoch [ 13]: Loss 0.02301
+Epoch [ 13]: Loss 0.02099
+Validation: Loss 0.02307 Accuracy 1.00000
+Validation: Loss 0.02452 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02105
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02234
+Epoch [ 14]: Loss 0.02238
+Epoch [ 14]: Loss 0.02259
+Epoch [ 14]: Loss 0.02282
+Epoch [ 14]: Loss 0.01795
+Validation: Loss 0.02066 Accuracy 1.00000
+Validation: Loss 0.02199 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02140
+Epoch [ 15]: Loss 0.02017
+Epoch [ 15]: Loss 0.01932
+Epoch [ 15]: Loss 0.02011
+Epoch [ 15]: Loss 0.01752
+Epoch [ 15]: Loss 0.02006
+Epoch [ 15]: Loss 0.01963
+Validation: Loss 0.01866 Accuracy 1.00000
+Validation: Loss 0.01988 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01796
+Epoch [ 16]: Loss 0.01636
+Epoch [ 16]: Loss 0.01900
+Epoch [ 16]: Loss 0.01740
+Epoch [ 16]: Loss 0.01782
+Epoch [ 16]: Loss 0.01824
+Epoch [ 16]: Loss 0.01976
+Validation: Loss 0.01696 Accuracy 1.00000
+Validation: Loss 0.01810 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01745
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01777
+Epoch [ 17]: Loss 0.01630
+Epoch [ 17]: Loss 0.01578
+Epoch [ 17]: Loss 0.01468
+Epoch [ 17]: Loss 0.01627
+Validation: Loss 0.01549 Accuracy 1.00000
+Validation: Loss 0.01656 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01608
+Epoch [ 18]: Loss 0.01398
+Epoch [ 18]: Loss 0.01425
+Epoch [ 18]: Loss 0.01537
+Epoch [ 18]: Loss 0.01504
+Epoch [ 18]: Loss 0.01471
+Epoch [ 18]: Loss 0.01496
+Validation: Loss 0.01423 Accuracy 1.00000
+Validation: Loss 0.01523 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01355
+Epoch [ 19]: Loss 0.01489
+Epoch [ 19]: Loss 0.01364
+Epoch [ 19]: Loss 0.01253
+Epoch [ 19]: Loss 0.01360
+Epoch [ 19]: Loss 0.01343
+Epoch [ 19]: Loss 0.01639
+Validation: Loss 0.01313 Accuracy 1.00000
+Validation: Loss 0.01405 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01377
+Epoch [ 20]: Loss 0.01183
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01292
+Epoch [ 20]: Loss 0.01361
+Epoch [ 20]: Loss 0.01227
+Validation: Loss 0.01211 Accuracy 1.00000
+Validation: Loss 0.01297 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01212
+Epoch [ 21]: Loss 0.01138
+Epoch [ 21]: Loss 0.01102
+Epoch [ 21]: Loss 0.01238
+Epoch [ 21]: Loss 0.01200
+Epoch [ 21]: Loss 0.01130
+Epoch [ 21]: Loss 0.01082
+Validation: Loss 0.01112 Accuracy 1.00000
+Validation: Loss 0.01190 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01134
+Epoch [ 22]: Loss 0.01031
+Epoch [ 22]: Loss 0.01060
+Epoch [ 22]: Loss 0.01130
+Epoch [ 22]: Loss 0.01009
+Epoch [ 22]: Loss 0.01053
+Epoch [ 22]: Loss 0.00940
+Validation: Loss 0.01002 Accuracy 1.00000
+Validation: Loss 0.01071 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00886
+Epoch [ 23]: Loss 0.01026
+Epoch [ 23]: Loss 0.01005
+Epoch [ 23]: Loss 0.00853
+Epoch [ 23]: Loss 0.01033
+Epoch [ 23]: Loss 0.00902
+Epoch [ 23]: Loss 0.00969
+Validation: Loss 0.00888 Accuracy 1.00000
+Validation: Loss 0.00947 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00903
+Epoch [ 24]: Loss 0.00856
+Epoch [ 24]: Loss 0.00866
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00830
+Epoch [ 24]: Loss 0.00781
+Epoch [ 24]: Loss 0.00662
+Validation: Loss 0.00795 Accuracy 1.00000
+Validation: Loss 0.00846 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00830
+Epoch [ 25]: Loss 0.00742
+Epoch [ 25]: Loss 0.00822
+Epoch [ 25]: Loss 0.00791
+Epoch [ 25]: Loss 0.00721
+Epoch [ 25]: Loss 0.00726
+Epoch [ 25]: Loss 0.00582
+Validation: Loss 0.00730 Accuracy 1.00000
+Validation: Loss 0.00775 Accuracy 1.00000
We can also train the compact model with the exact same code!
Epoch [ 1]: Loss 0.62249
+Epoch [ 1]: Loss 0.58988
+Epoch [ 1]: Loss 0.57122
+Epoch [ 1]: Loss 0.54145
+Epoch [ 1]: Loss 0.51676
+Epoch [ 1]: Loss 0.49941
+Epoch [ 1]: Loss 0.48712
+Validation: Loss 0.46707 Accuracy 1.00000
+Validation: Loss 0.46650 Accuracy 1.00000
+Epoch [ 2]: Loss 0.46435
+Epoch [ 2]: Loss 0.45555
+Epoch [ 2]: Loss 0.45454
+Epoch [ 2]: Loss 0.42345
+Epoch [ 2]: Loss 0.41436
+Epoch [ 2]: Loss 0.38527
+Epoch [ 2]: Loss 0.37442
+Validation: Loss 0.36940 Accuracy 1.00000
+Validation: Loss 0.36858 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36752
+Epoch [ 3]: Loss 0.36360
+Epoch [ 3]: Loss 0.34430
+Epoch [ 3]: Loss 0.32734
+Epoch [ 3]: Loss 0.31783
+Epoch [ 3]: Loss 0.31825
+Epoch [ 3]: Loss 0.28565
+Validation: Loss 0.28440 Accuracy 1.00000
+Validation: Loss 0.28337 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28307
+Epoch [ 4]: Loss 0.27199
+Epoch [ 4]: Loss 0.26836
+Epoch [ 4]: Loss 0.26051
+Epoch [ 4]: Loss 0.24528
+Epoch [ 4]: Loss 0.23063
+Epoch [ 4]: Loss 0.22536
+Validation: Loss 0.21475 Accuracy 1.00000
+Validation: Loss 0.21368 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21305
+Epoch [ 5]: Loss 0.21531
+Epoch [ 5]: Loss 0.19616
+Epoch [ 5]: Loss 0.18414
+Epoch [ 5]: Loss 0.18294
+Epoch [ 5]: Loss 0.17875
+Epoch [ 5]: Loss 0.17815
+Validation: Loss 0.15941 Accuracy 1.00000
+Validation: Loss 0.15850 Accuracy 1.00000
+Epoch [ 6]: Loss 0.16464
+Epoch [ 6]: Loss 0.14669
+Epoch [ 6]: Loss 0.14234
+Epoch [ 6]: Loss 0.14785
+Epoch [ 6]: Loss 0.13936
+Epoch [ 6]: Loss 0.13121
+Epoch [ 6]: Loss 0.11054
+Validation: Loss 0.11688 Accuracy 1.00000
+Validation: Loss 0.11621 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11895
+Epoch [ 7]: Loss 0.11755
+Epoch [ 7]: Loss 0.11153
+Epoch [ 7]: Loss 0.10806
+Epoch [ 7]: Loss 0.08931
+Epoch [ 7]: Loss 0.08989
+Epoch [ 7]: Loss 0.08885
+Validation: Loss 0.08377 Accuracy 1.00000
+Validation: Loss 0.08332 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08392
+Epoch [ 8]: Loss 0.07975
+Epoch [ 8]: Loss 0.07711
+Epoch [ 8]: Loss 0.07462
+Epoch [ 8]: Loss 0.06929
+Epoch [ 8]: Loss 0.06475
+Epoch [ 8]: Loss 0.06222
+Validation: Loss 0.05835 Accuracy 1.00000
+Validation: Loss 0.05808 Accuracy 1.00000
+Epoch [ 9]: Loss 0.05835
+Epoch [ 9]: Loss 0.05645
+Epoch [ 9]: Loss 0.05303
+Epoch [ 9]: Loss 0.04974
+Epoch [ 9]: Loss 0.04989
+Epoch [ 9]: Loss 0.04836
+Epoch [ 9]: Loss 0.04374
+Validation: Loss 0.04304 Accuracy 1.00000
+Validation: Loss 0.04283 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04373
+Epoch [ 10]: Loss 0.03963
+Epoch [ 10]: Loss 0.04024
+Epoch [ 10]: Loss 0.03893
+Epoch [ 10]: Loss 0.04085
+Epoch [ 10]: Loss 0.03933
+Epoch [ 10]: Loss 0.02782
+Validation: Loss 0.03470 Accuracy 1.00000
+Validation: Loss 0.03451 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03413
+Epoch [ 11]: Loss 0.03603
+Epoch [ 11]: Loss 0.03246
+Epoch [ 11]: Loss 0.03142
+Epoch [ 11]: Loss 0.03040
+Epoch [ 11]: Loss 0.03279
+Epoch [ 11]: Loss 0.03336
+Validation: Loss 0.02942 Accuracy 1.00000
+Validation: Loss 0.02924 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03113
+Epoch [ 12]: Loss 0.02712
+Epoch [ 12]: Loss 0.02845
+Epoch [ 12]: Loss 0.02904
+Epoch [ 12]: Loss 0.02709
+Epoch [ 12]: Loss 0.02722
+Epoch [ 12]: Loss 0.02449
+Validation: Loss 0.02555 Accuracy 1.00000
+Validation: Loss 0.02540 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02730
+Epoch [ 13]: Loss 0.02638
+Epoch [ 13]: Loss 0.02358
+Epoch [ 13]: Loss 0.02337
+Epoch [ 13]: Loss 0.02417
+Epoch [ 13]: Loss 0.02397
+Epoch [ 13]: Loss 0.02159
+Validation: Loss 0.02258 Accuracy 1.00000
+Validation: Loss 0.02243 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02377
+Epoch [ 14]: Loss 0.02260
+Epoch [ 14]: Loss 0.02070
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02060
+Epoch [ 14]: Loss 0.02212
+Epoch [ 14]: Loss 0.02141
+Validation: Loss 0.02019 Accuracy 1.00000
+Validation: Loss 0.02006 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02146
+Epoch [ 15]: Loss 0.01937
+Epoch [ 15]: Loss 0.02047
+Epoch [ 15]: Loss 0.01826
+Epoch [ 15]: Loss 0.01953
+Epoch [ 15]: Loss 0.01824
+Epoch [ 15]: Loss 0.02201
+Validation: Loss 0.01821 Accuracy 1.00000
+Validation: Loss 0.01809 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01872
+Epoch [ 16]: Loss 0.01647
+Epoch [ 16]: Loss 0.01868
+Epoch [ 16]: Loss 0.01763
+Epoch [ 16]: Loss 0.01802
+Epoch [ 16]: Loss 0.01730
+Epoch [ 16]: Loss 0.01691
+Validation: Loss 0.01653 Accuracy 1.00000
+Validation: Loss 0.01642 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01638
+Epoch [ 17]: Loss 0.01693
+Epoch [ 17]: Loss 0.01747
+Epoch [ 17]: Loss 0.01530
+Epoch [ 17]: Loss 0.01570
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01431
+Validation: Loss 0.01511 Accuracy 1.00000
+Validation: Loss 0.01501 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01395
+Epoch [ 18]: Loss 0.01493
+Epoch [ 18]: Loss 0.01631
+Epoch [ 18]: Loss 0.01388
+Epoch [ 18]: Loss 0.01496
+Epoch [ 18]: Loss 0.01520
+Epoch [ 18]: Loss 0.01366
+Validation: Loss 0.01390 Accuracy 1.00000
+Validation: Loss 0.01381 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01337
+Epoch [ 19]: Loss 0.01481
+Epoch [ 19]: Loss 0.01359
+Epoch [ 19]: Loss 0.01293
+Epoch [ 19]: Loss 0.01317
+Epoch [ 19]: Loss 0.01404
+Epoch [ 19]: Loss 0.01416
+Validation: Loss 0.01286 Accuracy 1.00000
+Validation: Loss 0.01277 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01286
+Epoch [ 20]: Loss 0.01335
+Epoch [ 20]: Loss 0.01259
+Epoch [ 20]: Loss 0.01343
+Epoch [ 20]: Loss 0.01294
+Epoch [ 20]: Loss 0.01124
+Epoch [ 20]: Loss 0.01124
+Validation: Loss 0.01194 Accuracy 1.00000
+Validation: Loss 0.01186 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01229
+Epoch [ 21]: Loss 0.01273
+Epoch [ 21]: Loss 0.01021
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01191
+Epoch [ 21]: Loss 0.01311
+Validation: Loss 0.01111 Accuracy 1.00000
+Validation: Loss 0.01104 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01112
+Epoch [ 22]: Loss 0.01155
+Epoch [ 22]: Loss 0.01068
+Epoch [ 22]: Loss 0.01120
+Epoch [ 22]: Loss 0.00993
+Epoch [ 22]: Loss 0.01129
+Epoch [ 22]: Loss 0.01098
+Validation: Loss 0.01033 Accuracy 1.00000
+Validation: Loss 0.01026 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00950
+Epoch [ 23]: Loss 0.01102
+Epoch [ 23]: Loss 0.01060
+Epoch [ 23]: Loss 0.01058
+Epoch [ 23]: Loss 0.00987
+Epoch [ 23]: Loss 0.01006
+Epoch [ 23]: Loss 0.00747
+Validation: Loss 0.00952 Accuracy 1.00000
+Validation: Loss 0.00945 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00960
+Epoch [ 24]: Loss 0.00995
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00888
+Epoch [ 24]: Loss 0.00955
+Epoch [ 24]: Loss 0.00915
+Epoch [ 24]: Loss 0.00884
+Validation: Loss 0.00861 Accuracy 1.00000
+Validation: Loss 0.00856 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00958
+Epoch [ 25]: Loss 0.00920
+Epoch [ 25]: Loss 0.00803
+Epoch [ 25]: Loss 0.00769
+Epoch [ 25]: Loss 0.00804
+Epoch [ 25]: Loss 0.00784
+Epoch [ 25]: Loss 0.00760
+Validation: Loss 0.00766 Accuracy 1.00000
+Validation: Loss 0.00762 Accuracy 1.00000
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
`,45)]))}const d=a(l,[["render",e]]);export{r as __pageData,d as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.js b/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.js
new file mode 100644
index 0000000000..ed2337fcda
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.js
@@ -0,0 +1,143 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"MNIST Classification with SimpleChains","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/4_SimpleChains.md","filePath":"tutorials/beginner/4_SimpleChains.md","lastUpdated":null}'),p={name:"tutorials/beginner/4_SimpleChains.md"};function l(h,s,e,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
[ 1/10] Time 106.58s Training Accuracy: 22.44% Test Accuracy: 19.50%
+[ 2/10] Time 106.61s Training Accuracy: 47.06% Test Accuracy: 45.50%
+[ 3/10] Time 112.24s Training Accuracy: 61.50% Test Accuracy: 61.00%
+[ 4/10] Time 115.93s Training Accuracy: 69.89% Test Accuracy: 65.00%
+[ 5/10] Time 118.22s Training Accuracy: 75.22% Test Accuracy: 74.00%
+[ 6/10] Time 112.80s Training Accuracy: 78.44% Test Accuracy: 77.50%
+[ 7/10] Time 108.41s Training Accuracy: 81.22% Test Accuracy: 81.00%
+[ 8/10] Time 112.49s Training Accuracy: 83.94% Test Accuracy: 80.50%
+[ 9/10] Time 113.54s Training Accuracy: 85.89% Test Accuracy: 84.50%
+[10/10] Time 113.99s Training Accuracy: 87.11% Test Accuracy: 84.50%
Now we will train the SimpleChains model
julia
train(simple_chains_model)
[ 1/10] Time 18.70s Training Accuracy: 29.06% Test Accuracy: 23.50%
+[ 2/10] Time 17.64s Training Accuracy: 45.83% Test Accuracy: 43.00%
+[ 3/10] Time 17.64s Training Accuracy: 62.72% Test Accuracy: 57.50%
+[ 4/10] Time 17.64s Training Accuracy: 65.67% Test Accuracy: 61.50%
+[ 5/10] Time 17.65s Training Accuracy: 74.72% Test Accuracy: 68.50%
+[ 6/10] Time 17.63s Training Accuracy: 79.61% Test Accuracy: 77.00%
+[ 7/10] Time 17.64s Training Accuracy: 81.83% Test Accuracy: 77.00%
+[ 8/10] Time 17.63s Training Accuracy: 83.94% Test Accuracy: 79.50%
+[ 9/10] Time 17.65s Training Accuracy: 84.50% Test Accuracy: 84.50%
+[10/10] Time 17.63s Training Accuracy: 87.78% Test Accuracy: 83.50%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
`,32)]))}const c=i(p,[["render",l]]);export{g as __pageData,c as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.lean.js b/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.lean.js
new file mode 100644
index 0000000000..ed2337fcda
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_4_SimpleChains.md.BJAmU5gZ.lean.js
@@ -0,0 +1,143 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const g=JSON.parse('{"title":"MNIST Classification with SimpleChains","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/4_SimpleChains.md","filePath":"tutorials/beginner/4_SimpleChains.md","lastUpdated":null}'),p={name:"tutorials/beginner/4_SimpleChains.md"};function l(h,s,e,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
[ 1/10] Time 106.58s Training Accuracy: 22.44% Test Accuracy: 19.50%
+[ 2/10] Time 106.61s Training Accuracy: 47.06% Test Accuracy: 45.50%
+[ 3/10] Time 112.24s Training Accuracy: 61.50% Test Accuracy: 61.00%
+[ 4/10] Time 115.93s Training Accuracy: 69.89% Test Accuracy: 65.00%
+[ 5/10] Time 118.22s Training Accuracy: 75.22% Test Accuracy: 74.00%
+[ 6/10] Time 112.80s Training Accuracy: 78.44% Test Accuracy: 77.50%
+[ 7/10] Time 108.41s Training Accuracy: 81.22% Test Accuracy: 81.00%
+[ 8/10] Time 112.49s Training Accuracy: 83.94% Test Accuracy: 80.50%
+[ 9/10] Time 113.54s Training Accuracy: 85.89% Test Accuracy: 84.50%
+[10/10] Time 113.99s Training Accuracy: 87.11% Test Accuracy: 84.50%
Now we will train the SimpleChains model
julia
train(simple_chains_model)
[ 1/10] Time 18.70s Training Accuracy: 29.06% Test Accuracy: 23.50%
+[ 2/10] Time 17.64s Training Accuracy: 45.83% Test Accuracy: 43.00%
+[ 3/10] Time 17.64s Training Accuracy: 62.72% Test Accuracy: 57.50%
+[ 4/10] Time 17.64s Training Accuracy: 65.67% Test Accuracy: 61.50%
+[ 5/10] Time 17.65s Training Accuracy: 74.72% Test Accuracy: 68.50%
+[ 6/10] Time 17.63s Training Accuracy: 79.61% Test Accuracy: 77.00%
+[ 7/10] Time 17.64s Training Accuracy: 81.83% Test Accuracy: 77.00%
+[ 8/10] Time 17.63s Training Accuracy: 83.94% Test Accuracy: 79.50%
+[ 9/10] Time 17.65s Training Accuracy: 84.50% Test Accuracy: 84.50%
+[10/10] Time 17.63s Training Accuracy: 87.78% Test Accuracy: 83.50%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
`,32)]))}const c=i(p,[["render",l]]);export{g as __pageData,c as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.js b/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.js
new file mode 100644
index 0000000000..2590724477
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.js
@@ -0,0 +1,181 @@
+import{_ as s,c as i,a2 as a,o as n}from"./chunks/framework.DFwXuivk.js";const d=JSON.parse('{"title":"Training Lux Models using Optimization.jl","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/5_OptimizationIntegration.md","filePath":"tutorials/beginner/5_OptimizationIntegration.md","lastUpdated":null}'),t={name:"tutorials/beginner/5_OptimizationIntegration.md"};function p(l,A,h,e,k,E){return n(),i("div",null,A[0]||(A[0]=[a(`
Lux's native Training.TrainState is a great API for gradient-based learning of neural networks, however, it is geared towards using Optimisers.jl as the backend. However, often times we want to train the neural networks with other optimization methods like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with Optimization.jl that provides a simple unified interface to various optimization methods.
We will base our tutorial on the minibatching tutorial from the official Optimization.jl docs.
Neural ODE
This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. Please refer to the Neural ODE tutorial for more information.
We will define the DataLoader to batch over the data, additionally we will pipe it through the gdev device to move the data to the GPU on each iteration.
By default gdev will move all objects to the GPU. But we don't want to move the time vector to the GPU. So we will wrap it in a struct.
Here we are using different optimization methods for demonstration purposes. This problem is trivial enough to not require this.
Optimization.jl requires an abstract array as the parameters, hence we will construct a ComponentArray to store the parameters.
Parameter Estimation vs State Estimation
Optimization.jl performs state estimation, which effectively means for a function f(u, p), it is trying to compute the optimal u for a given p. This terminology might be confusing to ML practitioners, since in the ML world, we usually do parameter estimation. This effectively means that the u in Optimization.jl corresponds to our model parameters that is being optimized.
julia
function train_model(dataloader)
+ model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
+ ps, st = Lux.setup(Random.default_rng(), model)
+
+ ps_ca = ComponentArray(ps) |> gdev
+ st = st |> gdev
+
+ function callback(state, l)
+ state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\\n" state.iter l
+ return l < 1e-8 ## Terminate if loss is small
+ end
+
+ smodel = StatefulLuxLayer{true}(model, nothing, st)
+
+ function loss_adjoint(θ, (u_batch, t_batch))
+ t_batch = t_batch.t
+ u0 = u_batch[:, 1]
+ dudt(u, p, t) = smodel(u, p)
+ prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
+ pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch))
+ return MSELoss()(pred, u_batch)
+ end
+
+ # Define the Optimization Function that takes in the optimization state (our parameters)
+ # and optimization parameters (nothing in our case) and data from the dataloader and
+ # returns the loss.
+ opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
+ opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
+
+ epochs = 25
+ res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
+
+ # Let's finetune a bit with L-BFGS
+ opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))
+ res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)
+
+ # Now that we have a good fit, let's train it on the entire dataset without
+ # Minibatching. We need to do this since ODE solves can lead to accumulated errors if
+ # the model was trained on individual parts (without a data-shooting approach).
+ opt_prob = remake(opt_prob; u0=res_lbfgs.u)
+ res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)
+
+ return StatefulLuxLayer{true}(model, res.u, smodel.st)
+end
+
+trained_model = train_model(dataloader)
`,29)]))}const g=s(t,[["render",p]]);export{d as __pageData,g as default};
diff --git a/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.lean.js b/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.lean.js
new file mode 100644
index 0000000000..2590724477
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_beginner_5_OptimizationIntegration.md.B9JkcgjS.lean.js
@@ -0,0 +1,181 @@
+import{_ as s,c as i,a2 as a,o as n}from"./chunks/framework.DFwXuivk.js";const d=JSON.parse('{"title":"Training Lux Models using Optimization.jl","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/beginner/5_OptimizationIntegration.md","filePath":"tutorials/beginner/5_OptimizationIntegration.md","lastUpdated":null}'),t={name:"tutorials/beginner/5_OptimizationIntegration.md"};function p(l,A,h,e,k,E){return n(),i("div",null,A[0]||(A[0]=[a(`
Lux's native Training.TrainState is a great API for gradient-based learning of neural networks, however, it is geared towards using Optimisers.jl as the backend. However, often times we want to train the neural networks with other optimization methods like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with Optimization.jl that provides a simple unified interface to various optimization methods.
We will base our tutorial on the minibatching tutorial from the official Optimization.jl docs.
Neural ODE
This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. Please refer to the Neural ODE tutorial for more information.
We will define the DataLoader to batch over the data, additionally we will pipe it through the gdev device to move the data to the GPU on each iteration.
By default gdev will move all objects to the GPU. But we don't want to move the time vector to the GPU. So we will wrap it in a struct.
Here we are using different optimization methods for demonstration purposes. This problem is trivial enough to not require this.
Optimization.jl requires an abstract array as the parameters, hence we will construct a ComponentArray to store the parameters.
Parameter Estimation vs State Estimation
Optimization.jl performs state estimation, which effectively means for a function f(u, p), it is trying to compute the optimal u for a given p. This terminology might be confusing to ML practitioners, since in the ML world, we usually do parameter estimation. This effectively means that the u in Optimization.jl corresponds to our model parameters that is being optimized.
julia
function train_model(dataloader)
+ model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
+ ps, st = Lux.setup(Random.default_rng(), model)
+
+ ps_ca = ComponentArray(ps) |> gdev
+ st = st |> gdev
+
+ function callback(state, l)
+ state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\\n" state.iter l
+ return l < 1e-8 ## Terminate if loss is small
+ end
+
+ smodel = StatefulLuxLayer{true}(model, nothing, st)
+
+ function loss_adjoint(θ, (u_batch, t_batch))
+ t_batch = t_batch.t
+ u0 = u_batch[:, 1]
+ dudt(u, p, t) = smodel(u, p)
+ prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
+ pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch))
+ return MSELoss()(pred, u_batch)
+ end
+
+ # Define the Optimization Function that takes in the optimization state (our parameters)
+ # and optimization parameters (nothing in our case) and data from the dataloader and
+ # returns the loss.
+ opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
+ opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
+
+ epochs = 25
+ res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
+
+ # Let's finetune a bit with L-BFGS
+ opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))
+ res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)
+
+ # Now that we have a good fit, let's train it on the entire dataset without
+ # Minibatching. We need to do this since ODE solves can lead to accumulated errors if
+ # the model was trained on individual parts (without a data-shooting approach).
+ opt_prob = remake(opt_prob; u0=res_lbfgs.u)
+ res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)
+
+ return StatefulLuxLayer{true}(model, res.u, smodel.st)
+end
+
+trained_model = train_model(dataloader)
`,29)]))}const g=s(t,[["render",p]]);export{d as __pageData,g as default};
diff --git a/previews/PR1023/assets/tutorials_index.md.amQ7-phS.js b/previews/PR1023/assets/tutorials_index.md.amQ7-phS.js
new file mode 100644
index 0000000000..f358cfd60e
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_index.md.amQ7-phS.js
@@ -0,0 +1 @@
+import{d,o as r,c as n,j as e,k as f,g as b,t as p,_ as m,F as _,C as w,b as v,K as x,a,G as s}from"./chunks/framework.DFwXuivk.js";const y={class:"img-box"},N=["href"],D=["src"],L={class:"transparent-box1"},P={class:"caption"},T={class:"transparent-box2"},I={class:"subcaption"},k={class:"opacity-low"},C=d({__name:"GalleryImage",props:{href:{},src:{},caption:{},desc:{}},setup(u){return(i,l)=>(r(),n("div",y,[e("a",{href:i.href},[e("img",{src:f(b)(i.src),height:"150px",alt:""},null,8,D),e("div",L,[e("div",P,[e("h2",null,p(i.caption),1)])]),e("div",T,[e("div",I,[e("p",k,p(i.desc),1)])])],8,N)]))}}),j=m(C,[["__scopeId","data-v-06a0366f"]]),S={class:"gallery-image"},E=d({__name:"Gallery",props:{images:{}},setup(u){return(i,l)=>(r(),n("div",S,[(r(!0),n(_,null,w(i.images,c=>(r(),v(j,x({ref_for:!0},c),null,16))),256))]))}}),o=m(E,[["__scopeId","data-v-578d61bc"]]),F=JSON.parse('{"title":"Tutorials","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/index.md","filePath":"tutorials/index.md","lastUpdated":null}'),M={name:"tutorials/index.md"},O=d({...M,setup(u){const i=[{href:"beginner/1_Basics",src:"https://picsum.photos/350/250?image=444",caption:"Julia & Lux for the Uninitiated",desc:"How to get started with Julia and Lux for those who have never used Julia before."},{href:"beginner/2_PolynomialFitting",src:"../mlp.webp",caption:"Fitting a Polynomial using MLP",desc:"Learn the Basics of Lux by fitting a Multi-Layer Perceptron to a Polynomial."},{href:"beginner/3_SimpleRNN",src:"../lstm-illustrative.webp",caption:"Training a Simple LSTM",desc:"Learn how to define custom layers and train an RNN on time-series data."},{href:"beginner/4_SimpleChains",src:"../blas_optimizations.jpg",caption:"Use SimpleChains.jl as a Backend",desc:"Learn how to train small neural networks really fast on CPU."},{href:"beginner/5_OptimizationIntegration",src:"../optimization_integration.png",caption:"Fitting with Optimization.jl",desc:"Learn how to use Optimization.jl with Lux (on GPUs)."},{href:"https://luxdl.github.io/Boltz.jl/stable/tutorials/1_GettingStarted",src:"https://production-media.paperswithcode.com/datasets/ImageNet-0000000008-f2e87edd_Y0fT5zg.jpg",caption:"Pre-Built Deep Learning Models",desc:"Use Boltz.jl to load pre-built DL and SciML models."}],l=[{href:"intermediate/1_NeuralODE",src:"../mnist.jpg",caption:"MNIST Classification using Neural ODE",desc:"Train a Neural Ordinary Differential Equations to classify MNIST Images."},{href:"intermediate/2_BayesianNN",src:"https://github.com/TuringLang.png",caption:"Bayesian Neural Networks",desc:"Figure out how to use Probabilistic Programming Frameworks like Turing with Lux."},{href:"intermediate/3_HyperNet",src:"../hypernet.jpg",caption:"Training a HyperNetwork",desc:"Train a hypernetwork to work on multiple datasets by predicting NN parameters."},{href:"intermediate/4_PINN2DPDE",src:"../pinn_nested_ad.gif",caption:"Training a PINN",desc:"Train a PINN to solve 2D PDEs (using Nested AD)."}],c=[{href:"advanced/1_GravitationalWaveForm",src:"../gravitational_waveform.png",caption:"Neural ODE to Model Gravitational Waveforms",desc:"Training a Neural ODE to fit simulated data of gravitational waveforms."},{href:"https://luxdl.github.io/Boltz.jl/stable/tutorials/2_SymbolicOptimalControl",src:"../symbolic_optimal_control.png",caption:"Optimal Control with Symbolic UDE",desc:"Train a UDE and replace a part of it with Symbolic Regression."}],h=[{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet",src:"https://production-media.paperswithcode.com/datasets/ImageNet-0000000008-f2e87edd_Y0fT5zg.jpg",caption:"ImageNet Classification",desc:"Train Large Image Classifiers using Lux (on Distributed GPUs)."},{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM",src:"https://raw.githubusercontent.com/LuxDL/Lux.jl/main/examples/DDIM/assets/flowers_generated.png",caption:"Denoising Diffusion Implicit Model (DDIM)",desc:"Train a Diffusion Model to generate images from Gaussian noises."},{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",src:"https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp",caption:"ConvMixer on CIFAR-10",desc:"Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes."}],g=[{href:"https://docs.sciml.ai/Overview/stable/showcase/pinngpu/",src:"../pinn.gif",caption:"GPU-Accelerated Physics-Informed Neural Networks",desc:"Use Machine Learning (PINNs) to solve the Heat Equation PDE on a GPU."},{href:"https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode_weather_forecast/",src:"../weather-neural-ode.gif",caption:"Weather Forecasting with Neural ODEs",desc:"Train a neural ODEs to a multidimensional weather dataset and use it for weather forecasting."},{href:"https://docs.sciml.ai/SciMLSensitivity/stable/examples/sde/SDE_control/",src:"../neural-sde.png",caption:"Controlling Stochastic Differential Equations",desc:"Control the time evolution of a continuously monitored qubit described by an SDE with multiplicative scalar noise."},{href:"https://github.com/Dale-Black/ComputerVisionTutorials.jl/",src:"https://raw.githubusercontent.com/Dale-Black/ComputerVisionTutorials.jl/main/assets/image-seg-green.jpeg",caption:"Medical Image Segmentation",desc:"Explore various aspects of deep learning for medical imaging and a comprehensive overview of Julia packages."},{href:"https://github.com/agdestein/NeuralClosureTutorials",src:"https://raw.githubusercontent.com/agdestein/NeuralClosureTutorials/main/assets/navier_stokes.gif",caption:"Neural PDE closures",desc:"Learn an unknown term in a PDE using convolutional neural networks and Fourier neural operators."}];return(B,t)=>(r(),n("div",null,[t[0]||(t[0]=e("h1",{id:"tutorials",tabindex:"-1"},[a("Tutorials "),e("a",{class:"header-anchor",href:"#tutorials","aria-label":'Permalink to "Tutorials"'},"")],-1)),t[1]||(t[1]=e("h2",{id:"beginner-tutorials",tabindex:"-1"},[a("Beginner Tutorials "),e("a",{class:"header-anchor",href:"#beginner-tutorials","aria-label":'Permalink to "Beginner Tutorials"'},"")],-1)),s(o,{images:i}),t[2]||(t[2]=e("h2",{id:"intermediate-tutorials",tabindex:"-1"},[a("Intermediate Tutorials "),e("a",{class:"header-anchor",href:"#intermediate-tutorials","aria-label":'Permalink to "Intermediate Tutorials"'},"")],-1)),s(o,{images:l}),t[3]||(t[3]=e("h2",{id:"advanced-tutorials",tabindex:"-1"},[a("Advanced Tutorials "),e("a",{class:"header-anchor",href:"#advanced-tutorials","aria-label":'Permalink to "Advanced Tutorials"'},"")],-1)),s(o,{images:c}),t[4]||(t[4]=e("h2",{id:"larger-models",tabindex:"-1"},[a("Larger Models "),e("a",{class:"header-anchor",href:"#larger-models","aria-label":'Permalink to "Larger Models"'},"")],-1)),t[5]||(t[5]=e("div",{class:"warning custom-block"},[e("p",{class:"custom-block-title"},"WARNING"),e("p",null,"These models are part of the Lux examples, however, these are larger model that cannot be run on CI and aren't frequently tested. If you find a bug in one of these models, please open an issue or PR to fix it.")],-1)),s(o,{images:h}),t[6]||(t[6]=e("h2",{id:"selected-3rd-party-tutorials",tabindex:"-1"},[a("Selected 3rd Party Tutorials "),e("a",{class:"header-anchor",href:"#selected-3rd-party-tutorials","aria-label":'Permalink to "Selected 3rd Party Tutorials"'},"")],-1)),t[7]||(t[7]=e("div",{class:"warning custom-block"},[e("p",{class:"custom-block-title"},"WARNING"),e("p",null,[a("These tutorials are developed by the community and may not be up-to-date with the latest version of "),e("code",null,"Lux.jl"),a(". Please refer to the official documentation for the most up-to-date information.")]),e("p",null,[a("Please open an issue (ideally both at "),e("code",null,"Lux.jl"),a(" and at the downstream linked package) if any of them are non-functional and we will try to get them updated.")])],-1)),s(o,{images:g}),t[8]||(t[8]=e("div",{class:"tip custom-block"},[e("p",{class:"custom-block-title"},"TIP"),e("p",null,[a("If you found an amazing tutorial showcasing "),e("code",null,"Lux.jl"),a(" online, or wrote one yourself, please open an issue or PR to add it to the list!")])],-1))]))}});export{F as __pageData,O as default};
diff --git a/previews/PR1023/assets/tutorials_index.md.amQ7-phS.lean.js b/previews/PR1023/assets/tutorials_index.md.amQ7-phS.lean.js
new file mode 100644
index 0000000000..f358cfd60e
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_index.md.amQ7-phS.lean.js
@@ -0,0 +1 @@
+import{d,o as r,c as n,j as e,k as f,g as b,t as p,_ as m,F as _,C as w,b as v,K as x,a,G as s}from"./chunks/framework.DFwXuivk.js";const y={class:"img-box"},N=["href"],D=["src"],L={class:"transparent-box1"},P={class:"caption"},T={class:"transparent-box2"},I={class:"subcaption"},k={class:"opacity-low"},C=d({__name:"GalleryImage",props:{href:{},src:{},caption:{},desc:{}},setup(u){return(i,l)=>(r(),n("div",y,[e("a",{href:i.href},[e("img",{src:f(b)(i.src),height:"150px",alt:""},null,8,D),e("div",L,[e("div",P,[e("h2",null,p(i.caption),1)])]),e("div",T,[e("div",I,[e("p",k,p(i.desc),1)])])],8,N)]))}}),j=m(C,[["__scopeId","data-v-06a0366f"]]),S={class:"gallery-image"},E=d({__name:"Gallery",props:{images:{}},setup(u){return(i,l)=>(r(),n("div",S,[(r(!0),n(_,null,w(i.images,c=>(r(),v(j,x({ref_for:!0},c),null,16))),256))]))}}),o=m(E,[["__scopeId","data-v-578d61bc"]]),F=JSON.parse('{"title":"Tutorials","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/index.md","filePath":"tutorials/index.md","lastUpdated":null}'),M={name:"tutorials/index.md"},O=d({...M,setup(u){const i=[{href:"beginner/1_Basics",src:"https://picsum.photos/350/250?image=444",caption:"Julia & Lux for the Uninitiated",desc:"How to get started with Julia and Lux for those who have never used Julia before."},{href:"beginner/2_PolynomialFitting",src:"../mlp.webp",caption:"Fitting a Polynomial using MLP",desc:"Learn the Basics of Lux by fitting a Multi-Layer Perceptron to a Polynomial."},{href:"beginner/3_SimpleRNN",src:"../lstm-illustrative.webp",caption:"Training a Simple LSTM",desc:"Learn how to define custom layers and train an RNN on time-series data."},{href:"beginner/4_SimpleChains",src:"../blas_optimizations.jpg",caption:"Use SimpleChains.jl as a Backend",desc:"Learn how to train small neural networks really fast on CPU."},{href:"beginner/5_OptimizationIntegration",src:"../optimization_integration.png",caption:"Fitting with Optimization.jl",desc:"Learn how to use Optimization.jl with Lux (on GPUs)."},{href:"https://luxdl.github.io/Boltz.jl/stable/tutorials/1_GettingStarted",src:"https://production-media.paperswithcode.com/datasets/ImageNet-0000000008-f2e87edd_Y0fT5zg.jpg",caption:"Pre-Built Deep Learning Models",desc:"Use Boltz.jl to load pre-built DL and SciML models."}],l=[{href:"intermediate/1_NeuralODE",src:"../mnist.jpg",caption:"MNIST Classification using Neural ODE",desc:"Train a Neural Ordinary Differential Equations to classify MNIST Images."},{href:"intermediate/2_BayesianNN",src:"https://github.com/TuringLang.png",caption:"Bayesian Neural Networks",desc:"Figure out how to use Probabilistic Programming Frameworks like Turing with Lux."},{href:"intermediate/3_HyperNet",src:"../hypernet.jpg",caption:"Training a HyperNetwork",desc:"Train a hypernetwork to work on multiple datasets by predicting NN parameters."},{href:"intermediate/4_PINN2DPDE",src:"../pinn_nested_ad.gif",caption:"Training a PINN",desc:"Train a PINN to solve 2D PDEs (using Nested AD)."}],c=[{href:"advanced/1_GravitationalWaveForm",src:"../gravitational_waveform.png",caption:"Neural ODE to Model Gravitational Waveforms",desc:"Training a Neural ODE to fit simulated data of gravitational waveforms."},{href:"https://luxdl.github.io/Boltz.jl/stable/tutorials/2_SymbolicOptimalControl",src:"../symbolic_optimal_control.png",caption:"Optimal Control with Symbolic UDE",desc:"Train a UDE and replace a part of it with Symbolic Regression."}],h=[{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/ImageNet",src:"https://production-media.paperswithcode.com/datasets/ImageNet-0000000008-f2e87edd_Y0fT5zg.jpg",caption:"ImageNet Classification",desc:"Train Large Image Classifiers using Lux (on Distributed GPUs)."},{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM",src:"https://raw.githubusercontent.com/LuxDL/Lux.jl/main/examples/DDIM/assets/flowers_generated.png",caption:"Denoising Diffusion Implicit Model (DDIM)",desc:"Train a Diffusion Model to generate images from Gaussian noises."},{href:"https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",src:"https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp",caption:"ConvMixer on CIFAR-10",desc:"Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes."}],g=[{href:"https://docs.sciml.ai/Overview/stable/showcase/pinngpu/",src:"../pinn.gif",caption:"GPU-Accelerated Physics-Informed Neural Networks",desc:"Use Machine Learning (PINNs) to solve the Heat Equation PDE on a GPU."},{href:"https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode_weather_forecast/",src:"../weather-neural-ode.gif",caption:"Weather Forecasting with Neural ODEs",desc:"Train a neural ODEs to a multidimensional weather dataset and use it for weather forecasting."},{href:"https://docs.sciml.ai/SciMLSensitivity/stable/examples/sde/SDE_control/",src:"../neural-sde.png",caption:"Controlling Stochastic Differential Equations",desc:"Control the time evolution of a continuously monitored qubit described by an SDE with multiplicative scalar noise."},{href:"https://github.com/Dale-Black/ComputerVisionTutorials.jl/",src:"https://raw.githubusercontent.com/Dale-Black/ComputerVisionTutorials.jl/main/assets/image-seg-green.jpeg",caption:"Medical Image Segmentation",desc:"Explore various aspects of deep learning for medical imaging and a comprehensive overview of Julia packages."},{href:"https://github.com/agdestein/NeuralClosureTutorials",src:"https://raw.githubusercontent.com/agdestein/NeuralClosureTutorials/main/assets/navier_stokes.gif",caption:"Neural PDE closures",desc:"Learn an unknown term in a PDE using convolutional neural networks and Fourier neural operators."}];return(B,t)=>(r(),n("div",null,[t[0]||(t[0]=e("h1",{id:"tutorials",tabindex:"-1"},[a("Tutorials "),e("a",{class:"header-anchor",href:"#tutorials","aria-label":'Permalink to "Tutorials"'},"")],-1)),t[1]||(t[1]=e("h2",{id:"beginner-tutorials",tabindex:"-1"},[a("Beginner Tutorials "),e("a",{class:"header-anchor",href:"#beginner-tutorials","aria-label":'Permalink to "Beginner Tutorials"'},"")],-1)),s(o,{images:i}),t[2]||(t[2]=e("h2",{id:"intermediate-tutorials",tabindex:"-1"},[a("Intermediate Tutorials "),e("a",{class:"header-anchor",href:"#intermediate-tutorials","aria-label":'Permalink to "Intermediate Tutorials"'},"")],-1)),s(o,{images:l}),t[3]||(t[3]=e("h2",{id:"advanced-tutorials",tabindex:"-1"},[a("Advanced Tutorials "),e("a",{class:"header-anchor",href:"#advanced-tutorials","aria-label":'Permalink to "Advanced Tutorials"'},"")],-1)),s(o,{images:c}),t[4]||(t[4]=e("h2",{id:"larger-models",tabindex:"-1"},[a("Larger Models "),e("a",{class:"header-anchor",href:"#larger-models","aria-label":'Permalink to "Larger Models"'},"")],-1)),t[5]||(t[5]=e("div",{class:"warning custom-block"},[e("p",{class:"custom-block-title"},"WARNING"),e("p",null,"These models are part of the Lux examples, however, these are larger model that cannot be run on CI and aren't frequently tested. If you find a bug in one of these models, please open an issue or PR to fix it.")],-1)),s(o,{images:h}),t[6]||(t[6]=e("h2",{id:"selected-3rd-party-tutorials",tabindex:"-1"},[a("Selected 3rd Party Tutorials "),e("a",{class:"header-anchor",href:"#selected-3rd-party-tutorials","aria-label":'Permalink to "Selected 3rd Party Tutorials"'},"")],-1)),t[7]||(t[7]=e("div",{class:"warning custom-block"},[e("p",{class:"custom-block-title"},"WARNING"),e("p",null,[a("These tutorials are developed by the community and may not be up-to-date with the latest version of "),e("code",null,"Lux.jl"),a(". Please refer to the official documentation for the most up-to-date information.")]),e("p",null,[a("Please open an issue (ideally both at "),e("code",null,"Lux.jl"),a(" and at the downstream linked package) if any of them are non-functional and we will try to get them updated.")])],-1)),s(o,{images:g}),t[8]||(t[8]=e("div",{class:"tip custom-block"},[e("p",{class:"custom-block-title"},"TIP"),e("p",null,[a("If you found an amazing tutorial showcasing "),e("code",null,"Lux.jl"),a(" online, or wrote one yourself, please open an issue or PR to add it to the list!")])],-1))]))}});export{F as __pageData,O as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.js b/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.js
new file mode 100644
index 0000000000..9ac270633b
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.js
@@ -0,0 +1,272 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"MNIST Classification using Neural ODEs","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/1_NeuralODE.md","filePath":"tutorials/intermediate/1_NeuralODE.md","lastUpdated":null}'),e={name:"tutorials/intermediate/1_NeuralODE.md"};function l(p,s,h,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
First we will use the @compact macro to define the Neural ODE Layer.
julia
function NeuralODECompact(
+ model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
+ return @compact(; model, solver, tspan, kwargs...) do x, p
+ dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
+ # Note the \`p.model\` here
+ prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
+ @return solve(prob, solver; kwargs...)
+ end
+end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function train(model_function; cpu::Bool=false, kwargs...)
+ dev = cpu ? cpu_device() : gpu_device()
+ model, ps, st = create_model(model_function; dev, kwargs...)
+
+ # Training
+ train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
+
+ tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
+
+ ### Lets train the model
+ nepochs = 9
+ for epoch in 1:nepochs
+ stime = time()
+ for (x, y) in train_dataloader
+ _, _, _, tstate = Training.single_train_step!(
+ AutoZygote(), logitcrossentropy, (x, y), tstate)
+ end
+ ttime = time() - stime
+
+ tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
+ te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
+ @printf "[%d/%d]\\tTime %.4fs\\tTraining Accuracy: %.5f%%\\tTest \\
+ Accuracy: %.5f%%\\n" epoch nepochs ttime tr_acc te_acc
+ end
+end
+
+train(NeuralODECompact)
[1/9] Time 119.4158s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4958s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
+[3/9] Time 0.6961s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
+[4/9] Time 0.4869s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
+[5/9] Time 0.5064s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
+[6/9] Time 0.7482s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
+[7/9] Time 0.4736s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
+[8/9] Time 0.4883s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
+[9/9] Time 0.7453s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9] Time 36.4249s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4793s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
+[3/9] Time 0.6545s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
+[4/9] Time 0.4797s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
+[5/9] Time 0.4833s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
+[6/9] Time 0.7233s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
+[7/9] Time 0.4913s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
+[8/9] Time 0.4843s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
+[9/9] Time 0.7256s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).
[1/9] Time 42.6019s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.5487s Training Accuracy: 57.55556% Test Accuracy: 54.00000%
+[3/9] Time 0.4660s Training Accuracy: 69.85185% Test Accuracy: 69.33333%
+[4/9] Time 0.4833s Training Accuracy: 72.51852% Test Accuracy: 74.00000%
+[5/9] Time 0.4743s Training Accuracy: 75.33333% Test Accuracy: 76.00000%
+[6/9] Time 0.4944s Training Accuracy: 78.88889% Test Accuracy: 79.33333%
+[7/9] Time 0.6809s Training Accuracy: 81.03704% Test Accuracy: 80.00000%
+[8/9] Time 0.4987s Training Accuracy: 83.77778% Test Accuracy: 81.33333%
+[9/9] Time 0.5045s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
[1/9] Time 96.0630s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 14.0172s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
+[3/9] Time 13.5410s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
+[4/9] Time 13.6407s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
+[5/9] Time 13.4329s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
+[6/9] Time 12.0878s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 14.5981s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
+[8/9] Time 13.6945s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
+[9/9] Time 10.3098s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
[1/9] Time 49.7652s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 21.6687s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
+[3/9] Time 21.5681s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
+[4/9] Time 21.3427s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
+[5/9] Time 23.9941s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
+[6/9] Time 22.0233s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 22.4246s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
+[8/9] Time 23.1968s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
+[9/9] Time 24.0997s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
[1/9] Time 38.2440s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4759s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
+[3/9] Time 0.4745s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
+[4/9] Time 0.4670s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
+[5/9] Time 0.5117s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
+[6/9] Time 0.4779s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
+[7/9] Time 0.4705s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
+[8/9] Time 0.4590s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
+[9/9] Time 0.4555s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
`,62)]))}const o=i(e,[["render",l]]);export{E as __pageData,o as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.lean.js b/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.lean.js
new file mode 100644
index 0000000000..9ac270633b
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_1_NeuralODE.md.C9riTwge.lean.js
@@ -0,0 +1,272 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const E=JSON.parse('{"title":"MNIST Classification using Neural ODEs","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/1_NeuralODE.md","filePath":"tutorials/intermediate/1_NeuralODE.md","lastUpdated":null}'),e={name:"tutorials/intermediate/1_NeuralODE.md"};function l(p,s,h,k,r,d){return t(),a("div",null,s[0]||(s[0]=[n(`
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
First we will use the @compact macro to define the Neural ODE Layer.
julia
function NeuralODECompact(
+ model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
+ return @compact(; model, solver, tspan, kwargs...) do x, p
+ dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
+ # Note the \`p.model\` here
+ prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
+ @return solve(prob, solver; kwargs...)
+ end
+end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function train(model_function; cpu::Bool=false, kwargs...)
+ dev = cpu ? cpu_device() : gpu_device()
+ model, ps, st = create_model(model_function; dev, kwargs...)
+
+ # Training
+ train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
+
+ tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
+
+ ### Lets train the model
+ nepochs = 9
+ for epoch in 1:nepochs
+ stime = time()
+ for (x, y) in train_dataloader
+ _, _, _, tstate = Training.single_train_step!(
+ AutoZygote(), logitcrossentropy, (x, y), tstate)
+ end
+ ttime = time() - stime
+
+ tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
+ te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
+ @printf "[%d/%d]\\tTime %.4fs\\tTraining Accuracy: %.5f%%\\tTest \\
+ Accuracy: %.5f%%\\n" epoch nepochs ttime tr_acc te_acc
+ end
+end
+
+train(NeuralODECompact)
[1/9] Time 119.4158s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4958s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
+[3/9] Time 0.6961s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
+[4/9] Time 0.4869s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
+[5/9] Time 0.5064s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
+[6/9] Time 0.7482s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
+[7/9] Time 0.4736s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
+[8/9] Time 0.4883s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
+[9/9] Time 0.7453s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9] Time 36.4249s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4793s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
+[3/9] Time 0.6545s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
+[4/9] Time 0.4797s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
+[5/9] Time 0.4833s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
+[6/9] Time 0.7233s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
+[7/9] Time 0.4913s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
+[8/9] Time 0.4843s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
+[9/9] Time 0.7256s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).
[1/9] Time 42.6019s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.5487s Training Accuracy: 57.55556% Test Accuracy: 54.00000%
+[3/9] Time 0.4660s Training Accuracy: 69.85185% Test Accuracy: 69.33333%
+[4/9] Time 0.4833s Training Accuracy: 72.51852% Test Accuracy: 74.00000%
+[5/9] Time 0.4743s Training Accuracy: 75.33333% Test Accuracy: 76.00000%
+[6/9] Time 0.4944s Training Accuracy: 78.88889% Test Accuracy: 79.33333%
+[7/9] Time 0.6809s Training Accuracy: 81.03704% Test Accuracy: 80.00000%
+[8/9] Time 0.4987s Training Accuracy: 83.77778% Test Accuracy: 81.33333%
+[9/9] Time 0.5045s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
[1/9] Time 96.0630s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 14.0172s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
+[3/9] Time 13.5410s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
+[4/9] Time 13.6407s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
+[5/9] Time 13.4329s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
+[6/9] Time 12.0878s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 14.5981s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
+[8/9] Time 13.6945s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
+[9/9] Time 10.3098s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
[1/9] Time 49.7652s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 21.6687s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
+[3/9] Time 21.5681s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
+[4/9] Time 21.3427s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
+[5/9] Time 23.9941s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
+[6/9] Time 22.0233s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 22.4246s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
+[8/9] Time 23.1968s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
+[9/9] Time 24.0997s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
[1/9] Time 38.2440s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4759s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
+[3/9] Time 0.4745s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
+[4/9] Time 0.4670s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
+[5/9] Time 0.5117s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
+[6/9] Time 0.4779s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
+[7/9] Time 0.4705s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
+[8/9] Time 0.4590s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
+[9/9] Time 0.4555s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
`,62)]))}const o=i(e,[["render",l]]);export{E as __pageData,o as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.js b/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.js
new file mode 100644
index 0000000000..e94e3f350d
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.js
@@ -0,0 +1,206 @@
+import{_ as t,c as a,a2 as i,j as A,o as n}from"./chunks/framework.DFwXuivk.js";const E="/previews/PR1023/assets/results.Dm2mgseg.gif",B=JSON.parse('{"title":"Bayesian Neural Network","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/2_BayesianNN.md","filePath":"tutorials/intermediate/2_BayesianNN.md","lastUpdated":null}'),p={name:"tutorials/intermediate/2_BayesianNN.md"},h={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},e={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-3.222ex"},xmlns:"http://www.w3.org/2000/svg",width:"46.264ex",height:"6.301ex",role:"img",focusable:"false",viewBox:"0 -1361 20448.8 2785.1","aria-hidden":"true"};function l(k,s,g,r,d,Q){return n(),a("div",null,[s[2]||(s[2]=i(`
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
julia
# Number of points to generate
+N = 80
+M = round(Int, N / 4)
+rng = Random.default_rng()
+Random.seed!(rng, 1234)
+
+# Generate artificial data
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
+
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
+
+# Store all the data for later
+xs = [xt1s; xt0s]
+ts = [ones(2 * M); zeros(2 * M)]
+
+# Plot data points
+
+function plot_data()
+ x1 = first.(xt1s)
+ y1 = last.(xt1s)
+ x2 = first.(xt0s)
+ y2 = last.(xt0s)
+
+ fig = Figure()
+ ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
+
+ scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
+ scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
+
+ return fig
+end
+
+plot_data()
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.
The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
julia
# Construct a neural network using Lux
+nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
+
+# Initialize the model weights and state
+ps, st = Lux.setup(rng, nn)
+
+Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
julia
# Create a regularization term and a Gaussian prior variance term.
+alpha = 0.09
+sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
+ @assert length(ps_new) == Lux.parameterlength(ps)
+ i = 1
+ function get_ps(x)
+ z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
+ i += length(x)
+ return z
+ end
+ return fmap(get_ps, ps)
+end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
julia
const model = StatefulLuxLayer{true}(nn, nothing, st)
+
+# Specify the probabilistic model.
+@model function bayes_nn(xs, ts)
+ # Sample the parameters
+ nparameters = Lux.parameterlength(nn)
+ parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
+
+ # Forward NN to make predictions
+ preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
+
+ # Observe each prediction.
+ for i in eachindex(ts)
+ ts[i] ~ Bernoulli(preds[i])
+ end
+end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
julia
# Extract all weight and bias parameters.
+θ = MCMCChains.group(ch, :parameters).value;
# A helper to run the nn through data \`x\` using parameters \`θ\`
+nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
+
+# Plot the data we have.
+fig = plot_data()
+
+# Find the index that provided the highest log posterior in the chain.
+_, i = findmax(ch[:lp])
+
+# Extract the max row value from i.
+i = i.I[1]
+
+# Plot the posterior distribution with a contour plot
+x1_range = collect(range(-6; stop=6, length=25))
+x2_range = collect(range(-6; stop=6, length=25))
+Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
+contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
julia
# Return the average predicted value across multiple weights.
+nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
julia
fig = plot_data()
+Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
+c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+record(fig, "results.gif", 1:250:size(θ, 1)) do i
+ fig.current_axis[].title = "Iteration: $i"
+ Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
+ c[3] = Z
+ return fig
+end
`,16))])}const I=t(p,[["render",l]]);export{B as __pageData,I as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.lean.js b/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.lean.js
new file mode 100644
index 0000000000..e94e3f350d
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_2_BayesianNN.md.fxCpWm1u.lean.js
@@ -0,0 +1,206 @@
+import{_ as t,c as a,a2 as i,j as A,o as n}from"./chunks/framework.DFwXuivk.js";const E="/previews/PR1023/assets/results.Dm2mgseg.gif",B=JSON.parse('{"title":"Bayesian Neural Network","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/2_BayesianNN.md","filePath":"tutorials/intermediate/2_BayesianNN.md","lastUpdated":null}'),p={name:"tutorials/intermediate/2_BayesianNN.md"},h={class:"MathJax",jax:"SVG",display:"true",style:{direction:"ltr",display:"block","text-align":"center",margin:"1em 0",position:"relative"}},e={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-3.222ex"},xmlns:"http://www.w3.org/2000/svg",width:"46.264ex",height:"6.301ex",role:"img",focusable:"false",viewBox:"0 -1361 20448.8 2785.1","aria-hidden":"true"};function l(k,s,g,r,d,Q){return n(),a("div",null,[s[2]||(s[2]=i(`
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
julia
# Number of points to generate
+N = 80
+M = round(Int, N / 4)
+rng = Random.default_rng()
+Random.seed!(rng, 1234)
+
+# Generate artificial data
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
+
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
+
+# Store all the data for later
+xs = [xt1s; xt0s]
+ts = [ones(2 * M); zeros(2 * M)]
+
+# Plot data points
+
+function plot_data()
+ x1 = first.(xt1s)
+ y1 = last.(xt1s)
+ x2 = first.(xt0s)
+ y2 = last.(xt0s)
+
+ fig = Figure()
+ ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
+
+ scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
+ scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
+
+ return fig
+end
+
+plot_data()
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.
The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
julia
# Construct a neural network using Lux
+nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
+
+# Initialize the model weights and state
+ps, st = Lux.setup(rng, nn)
+
+Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
julia
# Create a regularization term and a Gaussian prior variance term.
+alpha = 0.09
+sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
+ @assert length(ps_new) == Lux.parameterlength(ps)
+ i = 1
+ function get_ps(x)
+ z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
+ i += length(x)
+ return z
+ end
+ return fmap(get_ps, ps)
+end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
julia
const model = StatefulLuxLayer{true}(nn, nothing, st)
+
+# Specify the probabilistic model.
+@model function bayes_nn(xs, ts)
+ # Sample the parameters
+ nparameters = Lux.parameterlength(nn)
+ parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
+
+ # Forward NN to make predictions
+ preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
+
+ # Observe each prediction.
+ for i in eachindex(ts)
+ ts[i] ~ Bernoulli(preds[i])
+ end
+end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
julia
# Extract all weight and bias parameters.
+θ = MCMCChains.group(ch, :parameters).value;
# A helper to run the nn through data \`x\` using parameters \`θ\`
+nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
+
+# Plot the data we have.
+fig = plot_data()
+
+# Find the index that provided the highest log posterior in the chain.
+_, i = findmax(ch[:lp])
+
+# Extract the max row value from i.
+i = i.I[1]
+
+# Plot the posterior distribution with a contour plot
+x1_range = collect(range(-6; stop=6, length=25))
+x2_range = collect(range(-6; stop=6, length=25))
+Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
+contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
julia
# Return the average predicted value across multiple weights.
+nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
julia
fig = plot_data()
+Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
+c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+record(fig, "results.gif", 1:250:size(θ, 1)) do i
+ fig.current_axis[].title = "Iteration: $i"
+ Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
+ c[3] = Z
+ return fig
+end
`,16))])}const I=t(p,[["render",l]]);export{B as __pageData,I as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.js b/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.js
new file mode 100644
index 0000000000..a4b0ee6dc5
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.js
@@ -0,0 +1,272 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const d=JSON.parse('{"title":"Training a HyperNetwork on MNIST and FashionMNIST","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/3_HyperNet.md","filePath":"tutorials/intermediate/3_HyperNet.md","lastUpdated":null}'),p={name:"tutorials/intermediate/3_HyperNet.md"};function l(h,s,e,k,r,c){return t(),a("div",null,s[0]||(s[0]=[n(`
Training a HyperNetwork on MNIST and FashionMNIST
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
+ return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
+end
function create_model()
+ # Doesn't need to be a MLP can have any Lux Layer
+ core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
+ weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
+ Dense(64, Lux.parameterlength(core_network)))
+
+ model = HyperNet(weight_generator, core_network)
+ return model
+end
[ 1/ 50] MNIST Time 70.85048s Training Accuracy: 61.23% Test Accuracy: 56.25%
+[ 1/ 50] FashionMNIST Time 0.02819s Training Accuracy: 43.26% Test Accuracy: 37.50%
+[ 2/ 50] MNIST Time 0.02797s Training Accuracy: 71.19% Test Accuracy: 68.75%
+[ 2/ 50] FashionMNIST Time 0.02918s Training Accuracy: 56.25% Test Accuracy: 46.88%
+[ 3/ 50] MNIST Time 0.02907s Training Accuracy: 79.39% Test Accuracy: 71.88%
+[ 3/ 50] FashionMNIST Time 0.02807s Training Accuracy: 59.67% Test Accuracy: 53.12%
+[ 4/ 50] MNIST Time 0.02442s Training Accuracy: 78.71% Test Accuracy: 68.75%
+[ 4/ 50] FashionMNIST Time 0.02106s Training Accuracy: 68.36% Test Accuracy: 65.62%
+[ 5/ 50] MNIST Time 0.02221s Training Accuracy: 83.79% Test Accuracy: 75.00%
+[ 5/ 50] FashionMNIST Time 0.02173s Training Accuracy: 71.78% Test Accuracy: 62.50%
+[ 6/ 50] MNIST Time 0.02186s Training Accuracy: 88.67% Test Accuracy: 75.00%
+[ 6/ 50] FashionMNIST Time 0.02362s Training Accuracy: 72.95% Test Accuracy: 56.25%
+[ 7/ 50] MNIST Time 0.02382s Training Accuracy: 90.92% Test Accuracy: 78.12%
+[ 7/ 50] FashionMNIST Time 0.02322s Training Accuracy: 80.27% Test Accuracy: 68.75%
+[ 8/ 50] MNIST Time 0.03652s Training Accuracy: 90.82% Test Accuracy: 78.12%
+[ 8/ 50] FashionMNIST Time 0.02083s Training Accuracy: 76.46% Test Accuracy: 68.75%
+[ 9/ 50] MNIST Time 0.02124s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 9/ 50] FashionMNIST Time 0.02080s Training Accuracy: 74.71% Test Accuracy: 65.62%
+[ 10/ 50] MNIST Time 0.02075s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 10/ 50] FashionMNIST Time 0.02080s Training Accuracy: 77.34% Test Accuracy: 62.50%
+[ 11/ 50] MNIST Time 0.02030s Training Accuracy: 96.29% Test Accuracy: 78.12%
+[ 11/ 50] FashionMNIST Time 0.02048s Training Accuracy: 82.13% Test Accuracy: 78.12%
+[ 12/ 50] MNIST Time 0.02080s Training Accuracy: 97.95% Test Accuracy: 78.12%
+[ 12/ 50] FashionMNIST Time 0.02626s Training Accuracy: 81.84% Test Accuracy: 78.12%
+[ 13/ 50] MNIST Time 0.02091s Training Accuracy: 98.44% Test Accuracy: 84.38%
+[ 13/ 50] FashionMNIST Time 0.02084s Training Accuracy: 84.08% Test Accuracy: 71.88%
+[ 14/ 50] MNIST Time 0.02098s Training Accuracy: 98.93% Test Accuracy: 81.25%
+[ 14/ 50] FashionMNIST Time 0.02068s Training Accuracy: 85.55% Test Accuracy: 65.62%
+[ 15/ 50] MNIST Time 0.02067s Training Accuracy: 99.22% Test Accuracy: 84.38%
+[ 15/ 50] FashionMNIST Time 0.02068s Training Accuracy: 86.13% Test Accuracy: 68.75%
+[ 16/ 50] MNIST Time 0.02060s Training Accuracy: 99.51% Test Accuracy: 81.25%
+[ 16/ 50] FashionMNIST Time 0.02051s Training Accuracy: 86.13% Test Accuracy: 65.62%
+[ 17/ 50] MNIST Time 0.02531s Training Accuracy: 99.61% Test Accuracy: 81.25%
+[ 17/ 50] FashionMNIST Time 0.02054s Training Accuracy: 87.11% Test Accuracy: 71.88%
+[ 18/ 50] MNIST Time 0.02092s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 18/ 50] FashionMNIST Time 0.02098s Training Accuracy: 88.28% Test Accuracy: 75.00%
+[ 19/ 50] MNIST Time 0.02228s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 19/ 50] FashionMNIST Time 0.02067s Training Accuracy: 89.16% Test Accuracy: 71.88%
+[ 20/ 50] MNIST Time 0.02038s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 20/ 50] FashionMNIST Time 0.02079s Training Accuracy: 89.26% Test Accuracy: 75.00%
+[ 21/ 50] MNIST Time 0.02039s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 21/ 50] FashionMNIST Time 0.02023s Training Accuracy: 89.65% Test Accuracy: 75.00%
+[ 22/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 22/ 50] FashionMNIST Time 0.02039s Training Accuracy: 89.94% Test Accuracy: 75.00%
+[ 23/ 50] MNIST Time 0.02139s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 23/ 50] FashionMNIST Time 0.02072s Training Accuracy: 90.43% Test Accuracy: 71.88%
+[ 24/ 50] MNIST Time 0.02055s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 24/ 50] FashionMNIST Time 0.02085s Training Accuracy: 90.72% Test Accuracy: 71.88%
+[ 25/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 25/ 50] FashionMNIST Time 0.02870s Training Accuracy: 92.29% Test Accuracy: 75.00%
+[ 26/ 50] MNIST Time 0.02078s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 26/ 50] FashionMNIST Time 0.02083s Training Accuracy: 92.38% Test Accuracy: 71.88%
+[ 27/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 27/ 50] FashionMNIST Time 0.02037s Training Accuracy: 91.80% Test Accuracy: 75.00%
+[ 28/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 28/ 50] FashionMNIST Time 0.02035s Training Accuracy: 92.97% Test Accuracy: 68.75%
+[ 29/ 50] MNIST Time 0.02075s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 29/ 50] FashionMNIST Time 0.02075s Training Accuracy: 93.16% Test Accuracy: 71.88%
+[ 30/ 50] MNIST Time 0.02654s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 30/ 50] FashionMNIST Time 0.02034s Training Accuracy: 92.09% Test Accuracy: 71.88%
+[ 31/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 31/ 50] FashionMNIST Time 0.02075s Training Accuracy: 94.24% Test Accuracy: 71.88%
+[ 32/ 50] MNIST Time 0.02297s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 32/ 50] FashionMNIST Time 0.02142s Training Accuracy: 93.65% Test Accuracy: 71.88%
+[ 33/ 50] MNIST Time 0.02200s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 33/ 50] FashionMNIST Time 0.02105s Training Accuracy: 94.34% Test Accuracy: 75.00%
+[ 34/ 50] MNIST Time 0.02155s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 34/ 50] FashionMNIST Time 0.02781s Training Accuracy: 93.65% Test Accuracy: 68.75%
+[ 35/ 50] MNIST Time 0.02128s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 35/ 50] FashionMNIST Time 0.02310s Training Accuracy: 95.12% Test Accuracy: 71.88%
+[ 36/ 50] MNIST Time 0.02250s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 36/ 50] FashionMNIST Time 0.02097s Training Accuracy: 95.90% Test Accuracy: 71.88%
+[ 37/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 37/ 50] FashionMNIST Time 0.02062s Training Accuracy: 95.80% Test Accuracy: 75.00%
+[ 38/ 50] MNIST Time 0.02122s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 38/ 50] FashionMNIST Time 0.02084s Training Accuracy: 95.70% Test Accuracy: 71.88%
+[ 39/ 50] MNIST Time 0.01987s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 39/ 50] FashionMNIST Time 0.02035s Training Accuracy: 96.88% Test Accuracy: 71.88%
+[ 40/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 40/ 50] FashionMNIST Time 0.02133s Training Accuracy: 96.68% Test Accuracy: 71.88%
+[ 41/ 50] MNIST Time 0.02054s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 41/ 50] FashionMNIST Time 0.02079s Training Accuracy: 97.07% Test Accuracy: 71.88%
+[ 42/ 50] MNIST Time 0.02094s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 42/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 43/ 50] MNIST Time 0.02632s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 43/ 50] FashionMNIST Time 0.02029s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 44/ 50] MNIST Time 0.02053s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 44/ 50] FashionMNIST Time 0.02080s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 45/ 50] MNIST Time 0.02082s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 45/ 50] FashionMNIST Time 0.02060s Training Accuracy: 97.85% Test Accuracy: 75.00%
+[ 46/ 50] MNIST Time 0.02029s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 46/ 50] FashionMNIST Time 0.02048s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 47/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 47/ 50] FashionMNIST Time 0.02595s Training Accuracy: 97.66% Test Accuracy: 75.00%
+[ 48/ 50] MNIST Time 0.02109s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 48/ 50] FashionMNIST Time 0.02037s Training Accuracy: 96.97% Test Accuracy: 68.75%
+[ 49/ 50] MNIST Time 0.02034s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 49/ 50] FashionMNIST Time 0.02065s Training Accuracy: 97.36% Test Accuracy: 75.00%
+[ 50/ 50] MNIST Time 0.02088s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 50/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.75% Test Accuracy: 68.75%
+
+[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
+[FINAL] FashionMNIST Training Accuracy: 97.75% Test Accuracy: 68.75%
`,25)]))}const y=i(p,[["render",l]]);export{d as __pageData,y as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.lean.js b/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.lean.js
new file mode 100644
index 0000000000..a4b0ee6dc5
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_3_HyperNet.md.D4Wex7F9.lean.js
@@ -0,0 +1,272 @@
+import{_ as i,c as a,a2 as n,o as t}from"./chunks/framework.DFwXuivk.js";const d=JSON.parse('{"title":"Training a HyperNetwork on MNIST and FashionMNIST","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/3_HyperNet.md","filePath":"tutorials/intermediate/3_HyperNet.md","lastUpdated":null}'),p={name:"tutorials/intermediate/3_HyperNet.md"};function l(h,s,e,k,r,c){return t(),a("div",null,s[0]||(s[0]=[n(`
Training a HyperNetwork on MNIST and FashionMNIST
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
+ return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
+end
function create_model()
+ # Doesn't need to be a MLP can have any Lux Layer
+ core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
+ weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
+ Dense(64, Lux.parameterlength(core_network)))
+
+ model = HyperNet(weight_generator, core_network)
+ return model
+end
[ 1/ 50] MNIST Time 70.85048s Training Accuracy: 61.23% Test Accuracy: 56.25%
+[ 1/ 50] FashionMNIST Time 0.02819s Training Accuracy: 43.26% Test Accuracy: 37.50%
+[ 2/ 50] MNIST Time 0.02797s Training Accuracy: 71.19% Test Accuracy: 68.75%
+[ 2/ 50] FashionMNIST Time 0.02918s Training Accuracy: 56.25% Test Accuracy: 46.88%
+[ 3/ 50] MNIST Time 0.02907s Training Accuracy: 79.39% Test Accuracy: 71.88%
+[ 3/ 50] FashionMNIST Time 0.02807s Training Accuracy: 59.67% Test Accuracy: 53.12%
+[ 4/ 50] MNIST Time 0.02442s Training Accuracy: 78.71% Test Accuracy: 68.75%
+[ 4/ 50] FashionMNIST Time 0.02106s Training Accuracy: 68.36% Test Accuracy: 65.62%
+[ 5/ 50] MNIST Time 0.02221s Training Accuracy: 83.79% Test Accuracy: 75.00%
+[ 5/ 50] FashionMNIST Time 0.02173s Training Accuracy: 71.78% Test Accuracy: 62.50%
+[ 6/ 50] MNIST Time 0.02186s Training Accuracy: 88.67% Test Accuracy: 75.00%
+[ 6/ 50] FashionMNIST Time 0.02362s Training Accuracy: 72.95% Test Accuracy: 56.25%
+[ 7/ 50] MNIST Time 0.02382s Training Accuracy: 90.92% Test Accuracy: 78.12%
+[ 7/ 50] FashionMNIST Time 0.02322s Training Accuracy: 80.27% Test Accuracy: 68.75%
+[ 8/ 50] MNIST Time 0.03652s Training Accuracy: 90.82% Test Accuracy: 78.12%
+[ 8/ 50] FashionMNIST Time 0.02083s Training Accuracy: 76.46% Test Accuracy: 68.75%
+[ 9/ 50] MNIST Time 0.02124s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 9/ 50] FashionMNIST Time 0.02080s Training Accuracy: 74.71% Test Accuracy: 65.62%
+[ 10/ 50] MNIST Time 0.02075s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 10/ 50] FashionMNIST Time 0.02080s Training Accuracy: 77.34% Test Accuracy: 62.50%
+[ 11/ 50] MNIST Time 0.02030s Training Accuracy: 96.29% Test Accuracy: 78.12%
+[ 11/ 50] FashionMNIST Time 0.02048s Training Accuracy: 82.13% Test Accuracy: 78.12%
+[ 12/ 50] MNIST Time 0.02080s Training Accuracy: 97.95% Test Accuracy: 78.12%
+[ 12/ 50] FashionMNIST Time 0.02626s Training Accuracy: 81.84% Test Accuracy: 78.12%
+[ 13/ 50] MNIST Time 0.02091s Training Accuracy: 98.44% Test Accuracy: 84.38%
+[ 13/ 50] FashionMNIST Time 0.02084s Training Accuracy: 84.08% Test Accuracy: 71.88%
+[ 14/ 50] MNIST Time 0.02098s Training Accuracy: 98.93% Test Accuracy: 81.25%
+[ 14/ 50] FashionMNIST Time 0.02068s Training Accuracy: 85.55% Test Accuracy: 65.62%
+[ 15/ 50] MNIST Time 0.02067s Training Accuracy: 99.22% Test Accuracy: 84.38%
+[ 15/ 50] FashionMNIST Time 0.02068s Training Accuracy: 86.13% Test Accuracy: 68.75%
+[ 16/ 50] MNIST Time 0.02060s Training Accuracy: 99.51% Test Accuracy: 81.25%
+[ 16/ 50] FashionMNIST Time 0.02051s Training Accuracy: 86.13% Test Accuracy: 65.62%
+[ 17/ 50] MNIST Time 0.02531s Training Accuracy: 99.61% Test Accuracy: 81.25%
+[ 17/ 50] FashionMNIST Time 0.02054s Training Accuracy: 87.11% Test Accuracy: 71.88%
+[ 18/ 50] MNIST Time 0.02092s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 18/ 50] FashionMNIST Time 0.02098s Training Accuracy: 88.28% Test Accuracy: 75.00%
+[ 19/ 50] MNIST Time 0.02228s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 19/ 50] FashionMNIST Time 0.02067s Training Accuracy: 89.16% Test Accuracy: 71.88%
+[ 20/ 50] MNIST Time 0.02038s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 20/ 50] FashionMNIST Time 0.02079s Training Accuracy: 89.26% Test Accuracy: 75.00%
+[ 21/ 50] MNIST Time 0.02039s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 21/ 50] FashionMNIST Time 0.02023s Training Accuracy: 89.65% Test Accuracy: 75.00%
+[ 22/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 22/ 50] FashionMNIST Time 0.02039s Training Accuracy: 89.94% Test Accuracy: 75.00%
+[ 23/ 50] MNIST Time 0.02139s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 23/ 50] FashionMNIST Time 0.02072s Training Accuracy: 90.43% Test Accuracy: 71.88%
+[ 24/ 50] MNIST Time 0.02055s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 24/ 50] FashionMNIST Time 0.02085s Training Accuracy: 90.72% Test Accuracy: 71.88%
+[ 25/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 25/ 50] FashionMNIST Time 0.02870s Training Accuracy: 92.29% Test Accuracy: 75.00%
+[ 26/ 50] MNIST Time 0.02078s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 26/ 50] FashionMNIST Time 0.02083s Training Accuracy: 92.38% Test Accuracy: 71.88%
+[ 27/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 27/ 50] FashionMNIST Time 0.02037s Training Accuracy: 91.80% Test Accuracy: 75.00%
+[ 28/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 28/ 50] FashionMNIST Time 0.02035s Training Accuracy: 92.97% Test Accuracy: 68.75%
+[ 29/ 50] MNIST Time 0.02075s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 29/ 50] FashionMNIST Time 0.02075s Training Accuracy: 93.16% Test Accuracy: 71.88%
+[ 30/ 50] MNIST Time 0.02654s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 30/ 50] FashionMNIST Time 0.02034s Training Accuracy: 92.09% Test Accuracy: 71.88%
+[ 31/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 31/ 50] FashionMNIST Time 0.02075s Training Accuracy: 94.24% Test Accuracy: 71.88%
+[ 32/ 50] MNIST Time 0.02297s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 32/ 50] FashionMNIST Time 0.02142s Training Accuracy: 93.65% Test Accuracy: 71.88%
+[ 33/ 50] MNIST Time 0.02200s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 33/ 50] FashionMNIST Time 0.02105s Training Accuracy: 94.34% Test Accuracy: 75.00%
+[ 34/ 50] MNIST Time 0.02155s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 34/ 50] FashionMNIST Time 0.02781s Training Accuracy: 93.65% Test Accuracy: 68.75%
+[ 35/ 50] MNIST Time 0.02128s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 35/ 50] FashionMNIST Time 0.02310s Training Accuracy: 95.12% Test Accuracy: 71.88%
+[ 36/ 50] MNIST Time 0.02250s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 36/ 50] FashionMNIST Time 0.02097s Training Accuracy: 95.90% Test Accuracy: 71.88%
+[ 37/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 37/ 50] FashionMNIST Time 0.02062s Training Accuracy: 95.80% Test Accuracy: 75.00%
+[ 38/ 50] MNIST Time 0.02122s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 38/ 50] FashionMNIST Time 0.02084s Training Accuracy: 95.70% Test Accuracy: 71.88%
+[ 39/ 50] MNIST Time 0.01987s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 39/ 50] FashionMNIST Time 0.02035s Training Accuracy: 96.88% Test Accuracy: 71.88%
+[ 40/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 40/ 50] FashionMNIST Time 0.02133s Training Accuracy: 96.68% Test Accuracy: 71.88%
+[ 41/ 50] MNIST Time 0.02054s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 41/ 50] FashionMNIST Time 0.02079s Training Accuracy: 97.07% Test Accuracy: 71.88%
+[ 42/ 50] MNIST Time 0.02094s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 42/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 43/ 50] MNIST Time 0.02632s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 43/ 50] FashionMNIST Time 0.02029s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 44/ 50] MNIST Time 0.02053s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 44/ 50] FashionMNIST Time 0.02080s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 45/ 50] MNIST Time 0.02082s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 45/ 50] FashionMNIST Time 0.02060s Training Accuracy: 97.85% Test Accuracy: 75.00%
+[ 46/ 50] MNIST Time 0.02029s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 46/ 50] FashionMNIST Time 0.02048s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 47/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 47/ 50] FashionMNIST Time 0.02595s Training Accuracy: 97.66% Test Accuracy: 75.00%
+[ 48/ 50] MNIST Time 0.02109s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 48/ 50] FashionMNIST Time 0.02037s Training Accuracy: 96.97% Test Accuracy: 68.75%
+[ 49/ 50] MNIST Time 0.02034s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 49/ 50] FashionMNIST Time 0.02065s Training Accuracy: 97.36% Test Accuracy: 75.00%
+[ 50/ 50] MNIST Time 0.02088s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 50/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.75% Test Accuracy: 68.75%
+
+[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
+[FINAL] FashionMNIST Training Accuracy: 97.75% Test Accuracy: 68.75%
`,25)]))}const y=i(p,[["render",l]]);export{d as __pageData,y as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.js b/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.js
new file mode 100644
index 0000000000..c6cdda4e4d
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.js
@@ -0,0 +1,343 @@
+import{_ as p,c as n,a2 as a,j as s,a as t,o as l}from"./chunks/framework.DFwXuivk.js";const h="/previews/PR1023/assets/pinn_nested_ad.B__JnolW.gif",m=JSON.parse('{"title":"Training a PINN on 2D PDE","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/4_PINN2DPDE.md","filePath":"tutorials/intermediate/4_PINN2DPDE.md","lastUpdated":null}'),k={name:"tutorials/intermediate/4_PINN2DPDE.md"},e={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},r={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.586ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3795.2 1000","aria-hidden":"true"},E={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},d={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.401ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3713.2 1000","aria-hidden":"true"},g={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},y={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.109ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3584.2 1000","aria-hidden":"true"};function o(c,i,F,C,D,L){return l(),n("div",null,[i[10]||(i[10]=a(`
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.
`,20)),s("p",null,[i[6]||(i[6]=t("We will generate some random data to train the model on. We will take data on a square spatial and temporal domain ")),s("mjx-container",e,[(l(),n("svg",r,i[0]||(i[0]=[a('',1)]))),i[1]||(i[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"x"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[7]||(i[7]=t(", ")),s("mjx-container",E,[(l(),n("svg",d,i[2]||(i[2]=[a('',1)]))),i[3]||(i[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"y"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[8]||(i[8]=t(", and ")),s("mjx-container",g,[(l(),n("svg",y,i[4]||(i[4]=[a('',1)]))),i[5]||(i[5]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"t"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[9]||(i[9]=t(". Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that."))]),i[11]||(i[11]=a(`
`,12))])}const B=p(k,[["render",o]]);export{m as __pageData,B as default};
diff --git a/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.lean.js b/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.lean.js
new file mode 100644
index 0000000000..c6cdda4e4d
--- /dev/null
+++ b/previews/PR1023/assets/tutorials_intermediate_4_PINN2DPDE.md.lubVAoSP.lean.js
@@ -0,0 +1,343 @@
+import{_ as p,c as n,a2 as a,j as s,a as t,o as l}from"./chunks/framework.DFwXuivk.js";const h="/previews/PR1023/assets/pinn_nested_ad.B__JnolW.gif",m=JSON.parse('{"title":"Training a PINN on 2D PDE","description":"","frontmatter":{},"headers":[],"relativePath":"tutorials/intermediate/4_PINN2DPDE.md","filePath":"tutorials/intermediate/4_PINN2DPDE.md","lastUpdated":null}'),k={name:"tutorials/intermediate/4_PINN2DPDE.md"},e={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},r={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.586ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3795.2 1000","aria-hidden":"true"},E={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},d={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.401ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3713.2 1000","aria-hidden":"true"},g={class:"MathJax",jax:"SVG",style:{direction:"ltr",position:"relative"}},y={style:{overflow:"visible","min-height":"1px","min-width":"1px","vertical-align":"-0.566ex"},xmlns:"http://www.w3.org/2000/svg",width:"8.109ex",height:"2.262ex",role:"img",focusable:"false",viewBox:"0 -750 3584.2 1000","aria-hidden":"true"};function o(c,i,F,C,D,L){return l(),n("div",null,[i[10]||(i[10]=a(`
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.
`,20)),s("p",null,[i[6]||(i[6]=t("We will generate some random data to train the model on. We will take data on a square spatial and temporal domain ")),s("mjx-container",e,[(l(),n("svg",r,i[0]||(i[0]=[a('',1)]))),i[1]||(i[1]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"x"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[7]||(i[7]=t(", ")),s("mjx-container",E,[(l(),n("svg",d,i[2]||(i[2]=[a('',1)]))),i[3]||(i[3]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"y"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[8]||(i[8]=t(", and ")),s("mjx-container",g,[(l(),n("svg",y,i[4]||(i[4]=[a('',1)]))),i[5]||(i[5]=s("mjx-assistive-mml",{unselectable:"on",display:"inline",style:{top:"0px",left:"0px",clip:"rect(1px, 1px, 1px, 1px)","-webkit-touch-callout":"none","-webkit-user-select":"none","-khtml-user-select":"none","-moz-user-select":"none","-ms-user-select":"none","user-select":"none",position:"absolute",padding:"1px 0px 0px 0px",border:"0px",display:"block",width:"auto",overflow:"hidden"}},[s("math",{xmlns:"http://www.w3.org/1998/Math/MathML"},[s("mi",null,"t"),s("mo",null,"∈"),s("mo",{stretchy:"false"},"["),s("mn",null,"0"),s("mo",null,","),s("mn",null,"2"),s("mo",{stretchy:"false"},"]")])],-1))]),i[9]||(i[9]=t(". Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that."))]),i[11]||(i[11]=a(`
Universal neural differential equations with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
SciMLSensitivity.jl
A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
NeuralPDE.jl
Physics-Informed Neural Networks (PINN) and Deep BSDE Solvers of Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
NeuralLyapunov.jl
A library for searching for neural Lyapunov functions in Julia
DeepEquilibriumNetworks.jl
Implicit Layer Machine Learning via Deep Equilibrium Networks, O(1) backpropagation with accelerated convergence
AbstractCosmologicalEmulators.jl
Repository containing the abstract interface to the emulators used in the CosmologicalEmulators organization
ContinuousNormalizingFlows.jl
Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia
Sophon.jl
Efficient, Accurate, and Streamlined Training of Physics-Informed Neural Networks
DataDrivenDiffEq.jl
Data driven modeling and automated discovery of dynamical systems for the SciML Scientific Machine Learning organization
NeuralGraphPDE.jl
Integrating Neural Ordinary Differential Equations, the Method of Lines, and Graph Neural Networks
Solaris.jl
Lightweight module for fusing physical and neural models
Boltz.jl
Accelerate your ML research using pre-built Deep Learning Models with Lux
GeometricMachineLearning.jl
Structure Preserving Machine Learning Models in Julia
Its easy to install Lux.jl. Since Lux.jl is registered in the Julia General registry, you can simply run the following command in the Julia REPL:
julia
julia> using Pkg
+julia> Pkg.add("Lux")
If you want to use the latest unreleased version of Lux.jl, you can run the following command: (in most cases the released version will be same as the version on github)
julia
julia> using Pkg
+julia> Pkg.add(url="https://github.com/LuxDL/Lux.jl")
If you found this library to be useful in academic work, then please cite:
bibtex
@software{pal2023lux,
+ author = {Pal, Avik},
+ title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
+ month = {April},
+ year = 2023,
+ note = {If you use this software, please cite it as below.},
+ publisher = {Zenodo},
+ version = {v0.5.0},
+ doi = {10.5281/zenodo.7808904},
+ url = {https://doi.org/10.5281/zenodo.7808904}
+}
bibtex
@thesis{pal2023efficient,
+ title = {{On Efficient Training \& Inference of Neural Differential Equations}},
+ author = {Pal, Avik},
+ year = {2023},
+ school = {Massachusetts Institute of Technology}
+}
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/introduction/index.html b/previews/PR1023/introduction/index.html
new file mode 100644
index 0000000000..0edc04bca5
--- /dev/null
+++ b/previews/PR1023/introduction/index.html
@@ -0,0 +1,144 @@
+
+
+
+
+
+ Getting Started | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do
julia
import Pkg
+Pkg.add("Lux")
Update to v1
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.
julia
# Get the device determined by Lux
+dev = gpu_device()
+
+# Parameter and State Variables
+ps, st = Lux.setup(rng, model) |> dev
+
+# Dummy Input
+x = rand(rng, Float32, 128, 2) |> dev
+
+# Run the model
+y, st = Lux.apply(model, x, ps, st)
+
+# Gradients
+## First construct a TrainState
+train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
+
+## We can compute the gradients using Training.compute_gradients
+gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
+
+## Optimization
+train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
+
+# Both these steps can be combined into a single call
+gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
+ (x, dev(rand(rng, Float32, 10, 2))), train_state)
using Lux, Random, Optimisers, Zygote
+using LuxCUDA # For CUDA support
+# using AMDGPU, Metal, oneAPI # Other pptional packages for GPU support
+using Printf # For pretty printing
+
+dev = gpu_device()
(::CUDADevice{Nothing}) (generic function with 4 methods)
We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.
julia
n_in = 1
+n_out = 1
+nlayers = 3
+
+model = @compact(w1=Dense(n_in => 32),
+ w2=[Dense(32 => 32) for i in 1:nlayers],
+ w3=Dense(32 => n_out),
+ act=relu) do x
+ embed = act(w1(x))
+ for w in w2
+ embed = act(w(embed))
+ end
+ out = w3(embed)
+ @return out
+end
LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.
You can install all those packages via import Pkg; Pkg.add(<package name>).
Julia already has quite a few well established Neural Network Frameworks – Flux & KNet. However, certain design elements – Coupled Model and Parameters & Internal Mutations – associated with these frameworks make them less compiler and user friendly. Making changes to address these problems in the respective frameworks would be too disruptive for users. Here comes in Lux: a neural network framework built completely using pure functions to make it both compiler and autodiff friendly.
Layers must be immutable – cannot store any parameter/state but rather store the information to construct them
Layers are pure functions
Layers return a Tuple containing the result and the updated state
Given same inputs the outputs must be same – yes this must hold true even for stochastic functions. Randomness must be controlled using rngs passed in the state.
Easily extensible
Extensive Testing – All layers and features are tested across all supported AD backends across all supported hardware backends.
Neural Networks for SciML: For SciML Applications (Neural ODEs, Deep Equilibrium Models) solvers typically expect a monolithic parameter vector. Flux enables this via its destructure mechanism, but destructure comes with various edge cases and limitations. Lux forces users to make an explicit distinction between state variables and parameter variables to avoid these issues. Also, it comes battery-included for distributed training.
Sensible display of Custom Layers – Ever wanted to see Pytorch like Network printouts or wondered how to extend the pretty printing of Flux's layers? Lux handles all of that by default.
Truly immutable models - No unexpected internal mutations since all layers are implemented as pure functions. All layers are also deterministic given the parameters and state: if a layer is supposed to be stochastic (say Dropout), the state must contain a seed which is then updated after the function call.
Easy Parameter Manipulation – By separating parameter data and layer structures, Lux makes implementing WeightNorm, SpectralNorm, etc. downright trivial. Without this separation, it is much harder to pass such parameters around without mutations which AD systems don't like.
Wider AD Support – Lux has extensive support for most AD systems in julia, while Flux is mostly tied to Zygote (with some initial support for Enzyme).
Small Neural Networks on CPU – Lux is developed for training large neural networks. For smaller architectures, we recommend using SimpleChains.jl or even better use it in conjunction with Lux via ToSimpleChainsAdaptor.
Reliability – We have learned from the mistakes of the past with Flux and everything in our core framework is extensively tested, along with downstream CI to ensure that everything works as expected.
Revising Previous Recommendation about Large Models
Previously we recommended not using Lux for very large models. But we have been making a lot of head-way with Reactant.jl and it would be worthwhile to test larger models with Lux. See compiling Lux models for more information.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/introduction/resources.html b/previews/PR1023/introduction/resources.html
new file mode 100644
index 0000000000..8bbfd66fd1
--- /dev/null
+++ b/previews/PR1023/introduction/resources.html
@@ -0,0 +1,34 @@
+
+
+
+
+
+ Resources to Get Started | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Go through the examples sorted based on their complexity in the documentation.
Have More Questions?
For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use Github Issues or even better send in a Pull Request.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/introduction/updating_to_v1.html b/previews/PR1023/introduction/updating_to_v1.html
new file mode 100644
index 0000000000..a6ef228763
--- /dev/null
+++ b/previews/PR1023/introduction/updating_to_v1.html
@@ -0,0 +1,34 @@
+
+
+
+
+
+ Updating to Lux v1 | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Lux v1 is a Major Release, mostly to signify the stability of the API. In this page, we list out a concrete set of changes that need to be made to your code to update to Lux v1. We also list out some new exciting features that were added as part of this release.
AbstractExplicitLayer has been renamed to AbstractLuxLayer.
AbstractExplicitContainerLayer behaviour
This has been renamed to AbstractLuxContainerLayer.
Previously, AbstractExplicitContainerLayer{(:a,)} (i.e. singleton containers) would produce default initial parameters and states without wrapping them in a NamedTuple{(:a,)}. This was inconsistent with non-singleton containers, and was a source of confusion. With v we return (; a = <parameters>) and (; a = <states>) by default. See AbstractLuxWrapperLayer for a replacement of this functionality.
inputsize has been removed since it was ambiguous and not used anywhere.
Changes to outputsize:
Single argument version has been removed. See LuxCore.jl Pull Request 43 for more details on the rationale behind this change.
Fallback implementation has been moved to Lux.jl. (i.e. users using Lux shouldn't see a difference, but if Lux.jl isn't loaded, this function has error.)
Internally this uses a NilArray that is able to compute sizes without actually running the computation.
Functors and Setfield have been made into optional dependencies. Certain LuxCore functionality that rely on these functions, will throw an error if these packages are not loaded.
Introduction of AbstractLuxWrapperLayer. This behaves exactly like the old singleton container. For example, the old AbstractExplicitContainerLayer{(:a,)} is equivalent to AbstractLuxWrapperLayer{:a}.
This was a major release to signify the stability of the API. There were no breaking changes. We do support a wider range of RNG types, see Supported RNG Types for more details.
This is the most aggressive change that was made. We renamed the LuxDeviceUtils.jl package to MLDataDevices.jl, to allow for non-Lux packages to use this shared device management abstraction.
Deprecation of LuxDeviceUtils.jl
This also marks the deprecation of the LuxDeviceUtils.jl package. We won't be making any updates to that package, including fixing any bugs. All users should switch to MLDataDevices.jl instead.
DeviceIterator provides a generalization of CUDA.CuIterator and works for all backends and more data types (using Functors.jl). MLUtils.DataLoader |> gdev now returns a DeviceIterator instead of being a no-op.
Direct reexport of NNlib has been removed. We reexport selected functionality from NNlib. Direactly load NNlib if you need to use the other functions.
Flattening of Chain layers has been removed, and the corresponding disable_optimizations kwarg has been removed.
Some layers overloaded Base.keys, these have been removed. These were mostly un-documented and weren't supposed to be used outside of the Lux.jl package.
disable_stacktrace_truncation! has been removed. From Julia 1.9 onwards, stacktrace truncation is enabled by default.
Certain Experimental features were present outside the Lux.Experimental module. These have been removed, use them via Lux.Experimental instead. Run Julia with with depwarn as error and Lux v0.5 to see the deprecations.
Lux.Experimental.@layer_map is not longer needed and has been removed. The name of the variable prevents writing generic functions and is no longer pre-pended to the KeyPath. See the docstring of Lux.Experimental.layer_map for more details.
allow_fast_activation kwarg has been removed completely. Pass an anonymous function as the activation to prevent internal modivations to the activation function.
Conv and ConvTranspose use an initialization based on the activation function, taken from Pytorch. Pytorch assumes the activation function is leakyrelu to compute the gain, however, we compute the gain based on the activation function passed in to the layer.
Upsample now has an align_corners keyword argument, which defaults to false. Previously this was always true.
Dense and Bilinear have updated default initializations to align with the defaults from Pytorch. See the documentation for more details.
InstanceNorm now defaults to affine=false instead of affine=true.
Embedding now defaults to init_weight=rand32 instead of init_weight=randn32.
Recurrent Cells - RNNCell, LSTMCell, and GRUCell now have different default initializations. See the documentation for more details.
Lux is not an AD package, but it composes well with most of the AD packages available in the Julia ecosystem. This document lists the current level of support for various AD packages in Lux. Additionally, we provide some convenience functions for working with AD.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for CPU for the time-being. (We are working on faster Enzyme support for CPU)
Use Enzyme.jl, if there are mutations in the code and/or Zygote.jl fails.
Use Zygote.jl for the best performance. This is the most reliable and fastest option for GPU for the time-being. We are working on supporting Enzyme.jl for GPU as well.
Tier I: These packages are fully supported and have been tested extensively. Often have special rules to enhance performance. Issues for these backends take the highest priority.
Tier II: These packages are supported and extensively tested but often don't have the best performance. Issues against these backends are less critical, but we fix them when possible. (Some specific edge cases, especially with AMDGPU, are known to fail here)
Tier III: We don't know if these packages currently work with Lux. We'd love to add tests for these backends, but currently these are not our priority.
Note that ChainRules.jl is not really an AD package, but we have first-class support for packages that use rrules. ↩︎
This feature is supported downstream, but we don't extensively test it to ensure that it works with Lux. ↩︎↩︎↩︎↩︎↩︎↩︎
Currently Enzyme outperforms other AD packages in terms of CPU performance. However, there are some edge cases where it might not work with Lux. We are working on improving the compatibility. Please report any issues you encounter. ↩︎
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/manual/compiling_lux_models.html b/previews/PR1023/manual/compiling_lux_models.html
new file mode 100644
index 0000000000..c9ca6ac424
--- /dev/null
+++ b/previews/PR1023/manual/compiling_lux_models.html
@@ -0,0 +1,108 @@
+
+
+
+
+
+ Compiling Lux Models using Reactant.jl | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Reactant takes Julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for automatic differentiation, and create relevant executables for CPU/GPU/TPU via XLA. It presently operates as a tracing system. Compiled functions will assume the same control flow pattern as was original taken by objects used at compile time, and control flow (e.g. if, for) as well as any type instabilities will be removed. The benefits of this approach is immediately making all such code available for advanced optimization with little developer effort.
Experimental
Reactant compilation is a very new feature and is currently experimental. Certain models might not be compilable yet, but we are actively working on it. Open an issue if you encounter any problems.
julia
using Lux, Reactant, Enzyme, Random, Zygote
+using Functors, Optimisers, Printf
To run it using XLA we need to compile the model. We can do this using the Reactant.@compile macro. Note that the inputs need to be moved to the device using xla_device first.
Now that we saw the low-level API let's see how to train the model without any of this boilerplate. Simply follow the following steps:
Create a device using xla_device. Remember to load Reactant.jl before doing this.
Similar to other device functions move the model, parameters, states and data to the device. Note that you might want to use DeviceIterator to move the data loader to the device with an iterator.
Debugging DNNs can be very painful. Especially with the gigantic stacktraces for Lux, it is even harder to pin-point to which particular layer errored out. This page describes some useful tools that ship with Lux, that can help you debug your models.
TL;DR
Simply wrap your model with Lux.Experimental.@debug_mode!!
Don't Forget
Remember to use the non Debug mode model after you finish debugging. Debug mode models are way slower.
Let us construct a model which has an obviously incorrect dimension. In this example, you will see how easy it is to pin-point the problematic layer.
Incorrect Model Specification: Dimension Mismatch Problems
julia
using Lux, Random
+
+model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1))
+
+model_debug = Lux.Experimental.@debug_mode model
Note that we can use the parameters and states for model itself in model_debug, no need to make any changes. If you ran the original model this is the kind of error you would see:
Have you encountered those pesky little NaNs in your training? They are very hard to track down. We will create an artificially simulate NaNs in our model and see how we can track the offending layer.
We can set nan_check to :forward, :backward or :both to check for NaNs in the debug model. (or even disable it by setting it to :none)
julia
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)),
+ BatchNorm(1))
+
+ps, st = Lux.setup(rng, model)
+
+model_debug = Lux.Experimental.@debug_mode model nan_check=:both
And we have figured it out! The first NaN occurred in the parameters of model.layers.layer_2.layers.layer_2! But what if NaN occurs in the reverse pass! Let us define a custom layer and introduce a fake NaN in the backward pass.
julia
using ChainRulesCore, Zygote
+
+const CRC = ChainRulesCore
+
+offending_layer(x) = 2 .* x
Let us define a custom backward pass to introduce some NaNs:
julia
function CRC.rrule(::typeof(offending_layer), x)
+ y = offending_layer(x)
+ function ∇offending_layer(Δ)
+ Δ[1] = NaN
+ return NoTangent(), Δ
+ end
+ return y, ∇offending_layer
+end
And there you go our debug layer prints that the problem is in WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2! Once we fix the pullback of the layer, we will fix the NaNs.
In this manual section, we have discussed tracking down errors in Lux models. We have covered tracking incorrect model specifications and NaNs in forward and backward passes. However, remember that this is an Experimental feature, and there might be edge cases that don't work correctly. If you find any such cases, please open an issue on GitHub!
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/manual/dispatch_custom_input.html b/previews/PR1023/manual/dispatch_custom_input.html
new file mode 100644
index 0000000000..e6c974aa5b
--- /dev/null
+++ b/previews/PR1023/manual/dispatch_custom_input.html
@@ -0,0 +1,94 @@
+
+
+
+
+
+ Dispatching on Custom Input Types | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Defining a dispatch on (::Layer)(x::MyInputType, ps, st::NamedTuple) is inconvenient, since it requires the user to define a new method for every layer type.
Consider Neural ODEs. In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of Chain.
DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code
Initialize the respective backend with DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e. NCCLBackend and not the object NCCLBackend().
It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.
Next synchronize the parameters and states of the model. This is done by calling DistributedUtils.synchronize!! with the backend and the respective input.
To split the data uniformly across the processes use DistributedUtils.DistributedDataContainer. Alternatively, one can manually split the data. For the provided container to work MLUtils.jl must be installed and loaded.
julia
data = DistributedUtils.DistributedDataContainer(backend, data)
Wrap the optimizer in DistributedUtils.DistributedOptimizer to ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.
Finally change all logging and serialization code to trigger on local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.
We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.
Older (now non-existent) Lux.gpu implementations used to "just work" with FluxMPI.jl. We expect gpu_device to continue working as expected, however, we recommend using gpu_device after calling DistributedUtils.initialize to avoid any mismatch between the device set via DistributedUtils and the device stores in CUDADevice or AMDGPUDevice.
Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.
AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/manual/freezing_model_parameters.html b/previews/PR1023/manual/freezing_model_parameters.html
new file mode 100644
index 0000000000..34ae92985c
--- /dev/null
+++ b/previews/PR1023/manual/freezing_model_parameters.html
@@ -0,0 +1,84 @@
+
+
+
+
+
+ Freezing Model Parameters | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
To freeze a particular kind of layer, let's say Dense in the following example. We can use Lux.Experimental.layer_map and freeze layers if they are of type Dense.
When the function in layer_map is called, the 4th argument is the name of the layer. For example, if you want to freeze the 1st layer inside the inner Chain. The name for this would be layer_2.layer_1.
Starting from v0.5, Lux has transitioned to a new GPU management system. The old system using cpu and gpu functions is still in place but will be removed in v1. Using the old functions might lead to performance regressions if used inside performance critical code.
Lux.jl can handle multiple GPU backends. Currently, the following backends are supported:
julia
# Important to load trigger packages
+using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
+
+supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management is done by two simple functions: cpu_device and gpu_device.
cpu_device: This is a simple function and just returns a CPUDevice object. @example gpu_management cdev = cpu_device()@example gpu_management x_cpu = randn(Float32, 3, 2)
gpu_device: This function performs automatic GPU device selection and returns an object.
If no GPU is available, it returns a CPUDevice object.
If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use Lux.gpu_backend!(<backend_name>). (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
If you just want to define compatibility with Lux without actually using any of the other functionality provided by Lux (like layers), it is recommended to depend on LuxCore.jl instead of Lux.jl. LuxCore.jl is a significantly lighter dependency.
Following this interface provides the ability for frameworks built on top of Lux to be cross compatible. Additionally, any new functionality built into Lux, will just work for your framework.
@compact macro
While writing out a custom struct and defining dispatches manually is a good way to understand the interface, it is not the most concise way. We recommend using the Lux.@compact macro to define layers which makes handling the states and parameters downright trivial.
If the layer doesn't contain any other Lux layer, then it is a Singular Layer. This means it should optionally subtype Lux.AbstractLuxLayer but mandatorily define all the necessary functions mentioned in the docstrings. Consider a simplified version of Dense called Linear.
First, setup the architectural details for this layer. Note, that the architecture doesn't contain any mutable structure like arrays. When in doubt, remember, once constructed a model architecture cannot change.
Tip
For people coming from Flux.jl background, this might be weird. We recommend checking out the Flux to Lux migration guide first before proceeding.
Next, we need to implement functions which return the parameters and states for the layer. In case of Linear, the parameters are weight and bias while the states are empty. States become important when defining layers like BatchNorm, WeightNorm, etc. The recommended data structure for returning parameters is a NamedTuple, though anything satisfying the Parameter Interface is valid.
You could also implement LuxCore.parameterlength and LuxCore.statelength to prevent wasteful reconstruction of the parameters and states.
julia
# This works
+println("Parameter Length: ", LuxCore.parameterlength(l), "; State Length: ",
+ LuxCore.statelength(l))
+
+# But still recommended to define these
+LuxCore.parameterlength(l::Linear) = l.out_dims * l.in_dims + l.out_dims
+
+LuxCore.statelength(::Linear) = 0
Parameter Length: 12; State Length: 0
No RNG in initialparameters and initialstates
You might notice that we don't pass in a RNG for these functions. If your parameter length and/or state length depend on a random number generator, you should think really hard about what you are trying to do and why.
Now, we need to define how the layer works. For this you make your layer a function with exactly 3 arguments – x the input, ps the parameters, and st the states. This function must return two things – y the output, and st_new the updated state.
julia
function (l::Linear)(x::AbstractMatrix, ps, st::NamedTuple)
+ y = ps.weight * x .+ ps.bias
+ return y, st
+end
Finally, let's run this layer. If you have made this far into the documentation, we don't feel you need a refresher on that.
If your layer comprises of other Lux layers, then it is a Container Layer. Note that you could treat it as a Singular Layer, and it is still fine. FWIW, if you cannot subtype your layer with LuxCore.AbstractLuxContainerLayer then you should go down the Singular Layer route. But subtyping allows us to bypass some of these common definitions. Let us now define a layer, which is basically a composition of two linear layers.
Wrapper Layer
If you are defining a layer that is a wrapper around another layer, then you should subtype LuxCore.AbstractLuxWrapperLayer instead of LuxCore.AbstractLuxContainerLayer. The only difference from a container layer is that it can wrap a single layer and the parameter/state structure is exactly the same as the wrapped layer.
julia
struct ComposedLinear{L1, L2} <: LuxCore.AbstractLuxContainerLayer{(:linear_1, :linear_2)}
+ linear_1::L1
+ linear_2::L2
+end
+
+function (cl::ComposedLinear)(x::AbstractMatrix, ps, st::NamedTuple)
+ # To access the parameters and states for `linear_1` we do `ps.linear_1` and
+ # `st.linear_1`. Similarly for `linear_2`
+ y, st_l1 = cl.linear_1(x, ps.linear_1, st.linear_1)
+ y, st_l2 = cl.linear_2(y, ps.linear_2, st.linear_2)
+ # Finally, we need to return the new state which has the exact structure as `st`
+ return y, (linear_1 = st_l1, linear_2 = st_l2)
+end
Here, you will notice we have passed (:linear_1, :linear_2) to the supertype. It essentially informs the type that, <obj>.linear_1 and <obj>.linear_2 are Lux layers and we need to construct parameters and states for those. Let's construct these and see:
We accept any parameter type as long as we can fetch the parameters using getproperty(obj, :parameter_name). This allows us to simultaneously support NamedTuples and ComponentArrays. Let us go through a concrete example of what it means. Consider Dense which expects two parameters named weight and bias.
Automatic Differentiation
If you are defining your own parameter type, it is your responsibility to make sure that it works with the AutoDiff System you are using.
julia
using Lux, Random
+
+d = Dense(2, 3)
+rng = Random.default_rng()
+Random.seed!(rng, 0)
+
+ps_default, st = LuxCore.setup(rng, d)
+
+x = randn(rng, Float32, 2, 1)
+
+println("Result with `NamedTuple` parameters: ", first(d(x, ps_default, st)))
Result with `NamedTuple` parameters: Float32[-0.08713347; -0.4851346; -0.8490221;;]
Let, us define a custom parameter type with fields myweight and mybias but if we try to access weight we get back myweight, similar for bias.
Beware!
This is for demonstrative purposes, don't try this at home!
Result with `DenseLayerParameters` parameters: Float32[0.23710957; 0.1003911; -0.57671577;;]
The takeaway from this shouldn't be – lets define weird parameter types. Simply because you can do weird things like this doesn't mean you should, since it only leads to bugs.
Instead this shows the flexibility you have for how your parameters can be structured.
States are always type constrained to be NamedTuple. The structure of the input state must match that of the output state, i.e. keys(st_in) == keys(st_out). This doesn't imply that types of the input and output state match. To generate efficient code, we often do dispatch on the state, for example, Dropout, BatchNorm, etc.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/manual/migrate_from_flux.html b/previews/PR1023/manual/migrate_from_flux.html
new file mode 100644
index 0000000000..ccd9652ed8
--- /dev/null
+++ b/previews/PR1023/manual/migrate_from_flux.html
@@ -0,0 +1,108 @@
+
+
+
+
+
+ Migrating from Flux to Lux | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
For the core library layers like Dense, Conv, etc. we have intentionally kept the API very similar to Flux. In most cases, replacing using Flux with using Lux should be enough to get you started. We cover the additional changes that you will have to make in the following example.
Flux and Lux operate under extremely different design philosophies regarding how layers should be implemented. A summary of the differences would be:
Flux stores everything in a single struct and relies on Functors.@functor and Flux.trainable to distinguish between trainable and non-trainable parameters.
Lux relies on the user to define Lux.initialparameters and Lux.initialstates to distinguish between trainable parameters (called "parameters") and non-trainable parameters (called "states"). Additionally, Lux layers define the model architecture, hence device transfer utilities like gpu_device, cpu_device, etc. cannot be applied on Lux layers, instead they need to be applied on the parameters and states.
Let's work through a concrete example to demonstrate this. We will implement a very simple layer that computes where is not trainable and is trainable.
julia
using Lux, Random, NNlib, Zygote
+
+struct LuxLinear <: Lux.AbstractLuxLayer
+ init_A
+ init_B
+end
+
+function LuxLinear(A::AbstractArray, B::AbstractArray)
+ # Storing Arrays or any mutable structure inside a Lux Layer is not recommended
+ # instead we will convert this to a function to perform lazy initialization
+ return LuxLinear(() -> copy(A), () -> copy(B))
+end
+
+# `B` is a parameter
+Lux.initialparameters(::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)
+
+# `A` is a state
+Lux.initialstates(::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)
+
+(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st
julia
using Flux, Random, NNlib, Zygote, Optimisers
+
+struct FluxLinear
+ A
+ B
+end
+
+
+
+
+
+
+
+# `A` is not trainable
+Optimisers.trainable(f::FluxLinear) = (B=f.B,)
+
+# Needed so that both `A` and `B` can be transferred between devices
+Flux.@functor FluxLinear
+
+(l::FluxLinear)(x) = l.A * l.B * x
Flux supports a mode called :auto which automatically decides if the user is training the model or running inference. This is the default mode for Flux.BatchNorm, Flux.GroupNorm, Flux.Dropout, etc. Lux doesn't support this mode (specifically to keep code simple and do exactly what the user wants), hence our default mode is training. This can be changed using Lux.testmode.
If you have Flux loaded in your code, you can use the function FromFluxAdaptor to automatically convert your model to Lux. Note that in case a native Lux counterpart isn't available, we fallback to using Optimisers.destructure.
This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.
In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.
Tip
Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching Preference to false.
Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.
Let's explore this using some questions that were posted on the Julia Discourse forum.
This problem comes from @facusapienza on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.
julia
function loss_function1(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
+ J = ForwardDiff.jacobian(smodel, x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+# Using Batchnorm to show that it is possible
+model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+loss_function1(model, x, ps, st, y)
14.883664f0
So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
+└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-5/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:309
+∞-norm(∂x - ∂x_fd): 0.00046014786
+∞-norm(∂ps - ∂ps_fd): 0.00068473816
That's pretty good, of course you will have some error from the finite differences calculation.
Notice that in this example the Jacobian J consists on the full matrix of derivatives of smodel with respect the different inputs in x. In many cases, we are interested in computing the Jacobian with respect to each input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achieved with batched_jacobian to parse the calculation of the Jacobian per each single input. Using the same example from the previous section:
julia
model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
+ps, st = Lux.setup(StableRNG(0), model)
+x = randn(StableRNG(0), Float32, 2, 10)
+y = randn(StableRNG(11), Float32, 2, 10)
+
+function loss_function_batched(model, x, ps, st, y)
+ # Make it a stateful layer
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ ŷ = smodel(x)
+ loss_emp = sum(abs2, ŷ .- y)
+ # You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here
+ J = batched_jacobian(smodel, AutoForwardDiff(), x)
+ loss_reg = abs2(norm(J .* 0.01f0))
+ return loss_emp + loss_reg
+end
+
+loss_function_batched(model, x, ps, st, y)
11.380777f0
Notice that in this last example we removed BatchNorm() from the neural network. This is done so outputs corresponding to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network. We can now verify again the value of the Jacobian:
In this example, it is important to remark that now batched_jacobian returns a 3D array with the Jacobian calculation for each independent input value in x.
Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative but we will use ForwardDiff.gradient instead, as we typically deal with array inputs and outputs.
Loss Function computing the Jacobian of the Parameters
The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.
We resolve these setups by using the Base.Fix1 wrapper around the stateful layer and fixing the input to the stateful layer.
julia
function loss_function3(model, x, ps, st)
+ smodel = StatefulLuxLayer{true}(model, ps, st)
+ J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
+ return sum(abs2, J)
+end
+
+model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
+ Dense(12 => 1))
+ps, st = Lux.setup(StableRNG(0), model)
+ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
+x = rand(StableRNG(0), Float32, 1, 16)
Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of Hutchinson 1990 which computes the estimated trace of a matrix using random vectors s.t. .
We can use this to compute the trace of a Jacobian Matrix using the following algorithm:
Note that we can compute this using two methods:
Compute using a Vector-Jacobian product and then do a matrix-vector product to get the trace.
Compute using a Jacobian-Vector product and then do a matrix-vector product to get the trace.
For simplicity, we will use a single sample of to compute the trace. Additionally, we will fix the sample to ensure that our tests against the finite difference implementation are not affected by the randomness in the sample.
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
+tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
+tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
+println("Tr(J) using vjp: ", tr_vjp)
+println("Tr(J) using jvp: ", tr_jvp)
+println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 4.9127817
+Tr(J) using jvp: 4.912782
+Tr(J) using full jacobian: 4.912781
Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.
In this page, we will describe how to embed neural networks inside GPU kernels. We will use KernelAbstractions.jl to do this, making it compatible with multiple GPU backends.
Experimental Feature
This is a relatively new and experimental feature. Expect edge cases and open issues on GitHub if you find any.
Inference Only
Currently this works only for inference. We will eventually test automatic differentiation using Enzyme.jl
Batching
In most usecases, this form of batching via embedding the neural network inside a GPU kernel is not recommended and will lead to suboptimal performance. Instead, batch the input data and let Lux handle the batching internally.
julia
using Lux, LuxCUDA, Random
+using KernelAbstractions, StaticArrays
First thing to remember is that we can't use regular high-level operations inside the kernels, instead we will use Static Arrays. Leveraging Julia's multiple dispatch Lux will use specialized operations that are compatible with GPU kernels.
julia
@kernel function nn_eval_single_batch!(output, model, input, ps, st)
+ i = @index(Global, Linear)
+ y, st_ = Lux.apply(model, input[i], ps, st)
+ output[i] = y
+end
nn_eval_single_batch! (generic function with 4 methods)
We define and initialize the neural network as usual, but we need to additionally convert the Arrays into SArrays.
Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:
julia
using Lux, Random
+
+rng = Xoshiro(0)
+
+model = Dense(2 => 2, gelu)
+ps, st = Lux.setup(rng, model)
+Lux.recursive_eltype((ps, st))
Float32
As we can see that ps and st are structures with the highest precision being Float32. Now let's run the model using some random data:
julia
x = rand(rng, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float64
Oops our output became Float64. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x was Float64. Instead, we should have used Float32 input:
julia
x = rand(rng, Float32, 2, 4)
+
+eltype(first(model(x, ps, st)))
Float32
This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:
Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use match_eltype and the eltype_mismatch_handling preference.
When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:
julia
using GPUArraysCore
+GPUArraysCore.allowscalar(false)
Lux.jl is integrated with DispatchDoctor.jl to catch type instabilities. You can easily enable it by setting the instability_check preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!.
For faster performance on CPUs load the following packages:
LoopVectorization.jl
Octavian.jl
If these are available, we automatically use optimized versions of the layers. Though there are cases where this might be an issue (see #980 and disabling loop vectorization).
A common pattern for loading data and transferring data to GPUs looks like this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in dataloader
+ X = X |> gdev
+ y = y |> gdev
+ # ...
+ # do some computation
+ # ...
+end
This is typically fast enough, but the data transfer to the device is happening in main process, not exploiting the parallelism in the dataloader. Instead, we can do this:
julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
+gdev = gpu_device()
+
+for (X, y) in gdev(dataloader)
+ # ...
+ # do some computation
+ # ...
+end
Here, X and y are on the gpu device gdev and the data transfer happens in the worker processes. Additionally, it behaves similar to CuIterator from CUDA.jl and eagerly frees the data after every iteration (this is device agnostic and works on all supported GPU backends).
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/manual/preferences.html b/previews/PR1023/manual/preferences.html
new file mode 100644
index 0000000000..92dbf8a0ca
--- /dev/null
+++ b/previews/PR1023/manual/preferences.html
@@ -0,0 +1,36 @@
+
+
+
+
+
+ Preferences for Lux.jl | Lux.jl Docs
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
automatic_nested_ad_switching - Set this to false to disable automatic switching of backends for nested automatic differentiation. See the manual section on nested automatic differentiation for more details.
gpu_backend - Set this to bypass the automatic backend selection and use a specific gpu backend. Valid options are "cuda", "rocm", "metal", and "oneapi". This preference needs to be set for MLDataDevices package. It is recommended to use MLDataDevices.gpu_backend! to set this preference.
eltype_mismatch_handling - Preference controlling what happens when layers get different eltypes as input. See the documentation on match_eltype for more details.
instability_check - Preference controlling the dispatch doctor. See the documentation on Lux.set_dispatch_doctor_preferences! for more details. The preferences need to be set for LuxCore and LuxLib packages. Both of them default to disable.
Setting the LuxCore preference sets the check at the level of LuxCore.apply. This essentially activates the dispatch doctor for all Lux layers.
Setting the LuxLib preference sets the check at the level of functional layer of Lux, for example, fused_dense_bias_activation. These functions are supposed to be type stable for common input types and can be used to guarantee type stability.
LoopVectorization.jl and Octavian.jl are optional dependencies that are used to accelerate certain CPU operations. However, these packages are tightly coupled with julia and might not work with all julia versions and systems. If these packages are loaded in any form LuxLib will use the optimized versions of the functions. But it might be desirable to disable these packages and use the default implementations instead. This can be done by setting the disable_loop_vectorization preference to true for LuxLib.
The package is meant to be working with deep learning libraries such as (F)Lux. All the methods take as input the chosen rng type and the dimension for the array.
julia
weights = init(rng, dims...)
The rng is optional, if not specified a default one will be used.
julia
weights = init(dims...)
If there is the need to use keyword arguments the methods can be called with just the rng (optionally) and the keywords to get in return a function behaving like the two examples above.
This section can be skipped. It defines functions to simulate the model, however, from a scientific machine learning perspective, isn't super relevant.
We need a very crude 2-body path. Assume the 1-body motion is a newtonian 2-body position vector and use Newtonian formulas to get , (e.g. Theoretical Mechanics of Particles and Continua 4.3)
julia
function one2two(path, m₁, m₂)
+ M = m₁ + m₂
+ r₁ = m₂ / M .* path
+ r₂ = -m₁ / M .* path
+ return r₁, r₂
+end
one2two (generic function with 1 method)
Next we define a function to perform the change of variables:
julia
@views function soln2orbit(soln, model_params=nothing)
+ @assert size(soln, 1) ∈ [2, 4] "size(soln,1) must be either 2 or 4"
+
+ if size(soln, 1) == 2
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+
+ @assert length(model_params)==3 "model_params must have length 3 when size(soln,2) = 2"
+ p, M, e = model_params
+ else
+ χ = soln[1, :]
+ ϕ = soln[2, :]
+ p = soln[3, :]
+ e = soln[4, :]
+ end
+
+ r = p ./ (1 .+ e .* cos.(χ))
+ x = r .* cos.(ϕ)
+ y = r .* sin.(ϕ)
+
+ orbit = vcat(x', y')
+ return orbit
+end
Next, we define the neural network model that takes 1 input (time) and has two outputs. We'll make a function ODE_model that takes the initial conditions, neural network parameters and a time as inputs and returns the derivatives.
It is typically never recommended to use globals but incase you do use them, make sure to mark them as const.
We will deviate from the standard Neural Network initialization and use WeightInitializers.jl,
Now we define a system of odes which describes motion of point like particle with Newtonian physics, uses
where, , , and are constants
julia
function ODE_model(u, nn_params, t)
+ χ, ϕ = u
+ p, M, e = ode_model_params
+
+ # In this example we know that `st` is am empty NamedTuple hence we can safely ignore
+ # it, however, in general, we should use `st` to store the state of the neural network.
+ y = 1 .+ nn_model([first(u)], nn_params)
+
+ numer = (1 + e * cos(χ))^2
+ denom = M * (p^(3 / 2))
+
+ χ̇ = (numer / denom) * y[1]
+ ϕ̇ = (numer / denom) * y[2]
+
+ return [χ̇, ϕ̇]
+end
ODE_model (generic function with 1 method)
Let us now simulate the neural network model and plot the results. We'll use the untrained neural network parameters to simulate the model.
It introduces basic Julia programming, as well Zygote, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl
julia
using Lux, Random
Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)
The starting point for all of our models is the Array (sometimes referred to as a Tensor in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.
julia
x = [1, 2, 3]
3-element Vector{Int64}:
+ 1
+ 2
+ 3
Here's a matrix – a square array with four elements.
julia
x = [1 2; 3 4]
2×2 Matrix{Int64}:
+ 1 2
+ 3 4
We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.
There's a few functions like this; try replacing rand with ones, zeros, or randn.
By default, Julia works stores numbers is a high-precision format called Float64. In ML we often don't need all those digits, and can ask Julia to work with Float32 instead. We can even ask for more digits using BigFloat.
CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.
You can manually add CUDA. Once CUDA is loaded you can move any array to the GPU with the cu function (or the gpu function exported by `Lux``), and it supports all of the above operations with the same syntax.
Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.
Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)
If we call any function that relies on rng and uses it via randn, rand, etc. rng will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate on PRNGs before using them.
First, let us run a random number generator 3 times with the replicated rng.
Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.
Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.
While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.
Compute the jvp. AutoForwardDiff specifies that we want to use ForwardDiff.jl for the Jacobian-Vector Product
Finally, now let us consider a linear regression problem. From a set of data-points , we try to find a set of parameters and , s.t. , which minimizes the mean squared error:
We can write f from scratch, but to demonstrate Lux, let us use the Dense layer.
x_samples = randn(rng, Float32, x_dim, n_samples)
+y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
+println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
x shape: (10, 20); y shape: (5, 20)
For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01.
julia
using Optimisers, Printf
Define the loss function
julia
lossfn = MSELoss()
+
+println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))
Loss Value with ground true parameters: 9.3742405e-5
We will train the model using our training API.
julia
function train_model!(model, ps, st, opt, nepochs::Int)
+ tstate = Training.TrainState(model, ps, st, opt)
+ for i in 1:nepochs
+ grads, loss, _, tstate = Training.single_train_step!(
+ AutoZygote(), lossfn, (x_samples, y_samples), tstate)
+ if i % 1000 == 1 || i == nepochs
+ @printf "Loss Value after %6d iterations: %.8f\n" i loss
+ end
+ end
+ return tstate.model, tstate.parameters, tstate.states
+end
+
+model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
+
+println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
Loss Value after 1 iterations: 7.80465555
+Loss Value after 1001 iterations: 0.12477568
+Loss Value after 2001 iterations: 0.02535537
+Loss Value after 3001 iterations: 0.00914141
+Loss Value after 4001 iterations: 0.00407581
+Loss Value after 5001 iterations: 0.00198415
+Loss Value after 6001 iterations: 0.00101147
+Loss Value after 7001 iterations: 0.00053332
+Loss Value after 8001 iterations: 0.00029203
+Loss Value after 9001 iterations: 0.00016878
+Loss Value after 10000 iterations: 0.00010551
+Loss Value after training: 0.00010546855
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
+ # Create the spirals
+ data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
+ # Get the labels
+ labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
+ clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
+ for d in data[1:(dataset_size ÷ 2)]]
+ anticlockwise_spirals = [reshape(
+ d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
+ for d in data[((dataset_size ÷ 2) + 1):end]]
+ x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
+ # Split the dataset
+ (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
+ # Create DataLoaders
+ return (
+ # Use DataLoader to automatically minibatch and shuffle the data
+ DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
+ # Don't shuffle the validation data
+ DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
+end
We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.
To understand more about container layers, please look at Container Layer.
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
julia
function (s::SpiralClassifier)(
+ x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
+ # First we will have to run the sequence through the LSTM Cell
+ # The first call to LSTM Cell will create the initial hidden state
+ # See that the parameters and states are automatically populated into a field called
+ # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
+ # and `Iterators.peel` to split out the first element for LSTM initialization.
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
+ # Now that we have the hidden state and memory in `carry` we will pass the input and
+ # `carry` jointly
+ for x in x_rest
+ (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
+ end
+ # After running through the sequence we will pass the output through the classifier
+ y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
+ # Finally remember to create the updated state
+ st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
+ return vec(y), st
+end
We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
+ lstm_cell = LSTMCell(in_dims => hidden_dims)
+ classifier = Dense(hidden_dims => out_dims, sigmoid)
+ return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
+ x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
+ y, carry = lstm_cell(x_init)
+ for x in x_rest
+ y, carry = lstm_cell((x, carry))
+ end
+ @return vec(classifier(y))
+ end
+end
SpiralClassifierCompact (generic function with 1 method)
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.
function main(model_type)
+ dev = gpu_device()
+
+ # Get the dataloaders
+ train_loader, val_loader = get_dataloaders() .|> dev
+
+ # Create the model
+ model = model_type(2, 8, 1)
+ rng = Xoshiro(0)
+ ps, st = Lux.setup(rng, model) |> dev
+
+ train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
+
+ for epoch in 1:25
+ # Train the model
+ for (x, y) in train_loader
+ (_, loss, _, train_state) = Training.single_train_step!(
+ AutoZygote(), lossfn, (x, y), train_state)
+
+ @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
+ end
+
+ # Validate the model
+ st_ = Lux.testmode(train_state.states)
+ for (x, y) in val_loader
+ ŷ, st_ = model(x, train_state.parameters, st_)
+ loss = lossfn(ŷ, y)
+ acc = accuracy(ŷ, y)
+ @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
+ end
+ end
+
+ return (train_state.parameters, train_state.states) |> cpu_device()
+end
+
+ps_trained, st_trained = main(SpiralClassifier)
Epoch [ 1]: Loss 0.60926
+Epoch [ 1]: Loss 0.60205
+Epoch [ 1]: Loss 0.56447
+Epoch [ 1]: Loss 0.53935
+Epoch [ 1]: Loss 0.51961
+Epoch [ 1]: Loss 0.50630
+Epoch [ 1]: Loss 0.48399
+Validation: Loss 0.46956 Accuracy 1.00000
+Validation: Loss 0.47794 Accuracy 1.00000
+Epoch [ 2]: Loss 0.47301
+Epoch [ 2]: Loss 0.45405
+Epoch [ 2]: Loss 0.43968
+Epoch [ 2]: Loss 0.43054
+Epoch [ 2]: Loss 0.40202
+Epoch [ 2]: Loss 0.39666
+Epoch [ 2]: Loss 0.40138
+Validation: Loss 0.37273 Accuracy 1.00000
+Validation: Loss 0.38210 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36731
+Epoch [ 3]: Loss 0.36875
+Epoch [ 3]: Loss 0.34892
+Epoch [ 3]: Loss 0.33812
+Epoch [ 3]: Loss 0.31629
+Epoch [ 3]: Loss 0.30792
+Epoch [ 3]: Loss 0.27809
+Validation: Loss 0.28817 Accuracy 1.00000
+Validation: Loss 0.29822 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28662
+Epoch [ 4]: Loss 0.27989
+Epoch [ 4]: Loss 0.27278
+Epoch [ 4]: Loss 0.25235
+Epoch [ 4]: Loss 0.23497
+Epoch [ 4]: Loss 0.23847
+Epoch [ 4]: Loss 0.23192
+Validation: Loss 0.21844 Accuracy 1.00000
+Validation: Loss 0.22858 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21529
+Epoch [ 5]: Loss 0.21660
+Epoch [ 5]: Loss 0.21147
+Epoch [ 5]: Loss 0.18347
+Epoch [ 5]: Loss 0.18387
+Epoch [ 5]: Loss 0.16418
+Epoch [ 5]: Loss 0.18488
+Validation: Loss 0.16251 Accuracy 1.00000
+Validation: Loss 0.17173 Accuracy 1.00000
+Epoch [ 6]: Loss 0.15106
+Epoch [ 6]: Loss 0.15557
+Epoch [ 6]: Loss 0.15604
+Epoch [ 6]: Loss 0.12610
+Epoch [ 6]: Loss 0.14466
+Epoch [ 6]: Loss 0.13525
+Epoch [ 6]: Loss 0.13401
+Validation: Loss 0.11923 Accuracy 1.00000
+Validation: Loss 0.12679 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11300
+Epoch [ 7]: Loss 0.11270
+Epoch [ 7]: Loss 0.11182
+Epoch [ 7]: Loss 0.10579
+Epoch [ 7]: Loss 0.10077
+Epoch [ 7]: Loss 0.09092
+Epoch [ 7]: Loss 0.08957
+Validation: Loss 0.08530 Accuracy 1.00000
+Validation: Loss 0.09085 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08321
+Epoch [ 8]: Loss 0.07613
+Epoch [ 8]: Loss 0.07561
+Epoch [ 8]: Loss 0.07250
+Epoch [ 8]: Loss 0.06895
+Epoch [ 8]: Loss 0.07155
+Epoch [ 8]: Loss 0.06246
+Validation: Loss 0.05935 Accuracy 1.00000
+Validation: Loss 0.06304 Accuracy 1.00000
+Epoch [ 9]: Loss 0.06135
+Epoch [ 9]: Loss 0.05983
+Epoch [ 9]: Loss 0.05429
+Epoch [ 9]: Loss 0.04415
+Epoch [ 9]: Loss 0.04965
+Epoch [ 9]: Loss 0.04801
+Epoch [ 9]: Loss 0.04264
+Validation: Loss 0.04389 Accuracy 1.00000
+Validation: Loss 0.04647 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04243
+Epoch [ 10]: Loss 0.04109
+Epoch [ 10]: Loss 0.04136
+Epoch [ 10]: Loss 0.04201
+Epoch [ 10]: Loss 0.03979
+Epoch [ 10]: Loss 0.03471
+Epoch [ 10]: Loss 0.03760
+Validation: Loss 0.03546 Accuracy 1.00000
+Validation: Loss 0.03756 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03545
+Epoch [ 11]: Loss 0.03571
+Epoch [ 11]: Loss 0.03202
+Epoch [ 11]: Loss 0.03209
+Epoch [ 11]: Loss 0.03134
+Epoch [ 11]: Loss 0.03114
+Epoch [ 11]: Loss 0.03593
+Validation: Loss 0.03006 Accuracy 1.00000
+Validation: Loss 0.03189 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03210
+Epoch [ 12]: Loss 0.02768
+Epoch [ 12]: Loss 0.02955
+Epoch [ 12]: Loss 0.02631
+Epoch [ 12]: Loss 0.02720
+Epoch [ 12]: Loss 0.02667
+Epoch [ 12]: Loss 0.03031
+Validation: Loss 0.02612 Accuracy 1.00000
+Validation: Loss 0.02773 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02589
+Epoch [ 13]: Loss 0.02454
+Epoch [ 13]: Loss 0.02716
+Epoch [ 13]: Loss 0.02579
+Epoch [ 13]: Loss 0.02323
+Epoch [ 13]: Loss 0.02301
+Epoch [ 13]: Loss 0.02099
+Validation: Loss 0.02307 Accuracy 1.00000
+Validation: Loss 0.02452 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02105
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02234
+Epoch [ 14]: Loss 0.02238
+Epoch [ 14]: Loss 0.02259
+Epoch [ 14]: Loss 0.02282
+Epoch [ 14]: Loss 0.01795
+Validation: Loss 0.02066 Accuracy 1.00000
+Validation: Loss 0.02199 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02140
+Epoch [ 15]: Loss 0.02017
+Epoch [ 15]: Loss 0.01932
+Epoch [ 15]: Loss 0.02011
+Epoch [ 15]: Loss 0.01752
+Epoch [ 15]: Loss 0.02006
+Epoch [ 15]: Loss 0.01963
+Validation: Loss 0.01866 Accuracy 1.00000
+Validation: Loss 0.01988 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01796
+Epoch [ 16]: Loss 0.01636
+Epoch [ 16]: Loss 0.01900
+Epoch [ 16]: Loss 0.01740
+Epoch [ 16]: Loss 0.01782
+Epoch [ 16]: Loss 0.01824
+Epoch [ 16]: Loss 0.01976
+Validation: Loss 0.01696 Accuracy 1.00000
+Validation: Loss 0.01810 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01745
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01777
+Epoch [ 17]: Loss 0.01630
+Epoch [ 17]: Loss 0.01578
+Epoch [ 17]: Loss 0.01468
+Epoch [ 17]: Loss 0.01627
+Validation: Loss 0.01549 Accuracy 1.00000
+Validation: Loss 0.01656 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01608
+Epoch [ 18]: Loss 0.01398
+Epoch [ 18]: Loss 0.01425
+Epoch [ 18]: Loss 0.01537
+Epoch [ 18]: Loss 0.01504
+Epoch [ 18]: Loss 0.01471
+Epoch [ 18]: Loss 0.01496
+Validation: Loss 0.01423 Accuracy 1.00000
+Validation: Loss 0.01523 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01355
+Epoch [ 19]: Loss 0.01489
+Epoch [ 19]: Loss 0.01364
+Epoch [ 19]: Loss 0.01253
+Epoch [ 19]: Loss 0.01360
+Epoch [ 19]: Loss 0.01343
+Epoch [ 19]: Loss 0.01639
+Validation: Loss 0.01313 Accuracy 1.00000
+Validation: Loss 0.01405 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01377
+Epoch [ 20]: Loss 0.01183
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01194
+Epoch [ 20]: Loss 0.01292
+Epoch [ 20]: Loss 0.01361
+Epoch [ 20]: Loss 0.01227
+Validation: Loss 0.01211 Accuracy 1.00000
+Validation: Loss 0.01297 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01212
+Epoch [ 21]: Loss 0.01138
+Epoch [ 21]: Loss 0.01102
+Epoch [ 21]: Loss 0.01238
+Epoch [ 21]: Loss 0.01200
+Epoch [ 21]: Loss 0.01130
+Epoch [ 21]: Loss 0.01082
+Validation: Loss 0.01112 Accuracy 1.00000
+Validation: Loss 0.01190 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01134
+Epoch [ 22]: Loss 0.01031
+Epoch [ 22]: Loss 0.01060
+Epoch [ 22]: Loss 0.01130
+Epoch [ 22]: Loss 0.01009
+Epoch [ 22]: Loss 0.01053
+Epoch [ 22]: Loss 0.00940
+Validation: Loss 0.01002 Accuracy 1.00000
+Validation: Loss 0.01071 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00886
+Epoch [ 23]: Loss 0.01026
+Epoch [ 23]: Loss 0.01005
+Epoch [ 23]: Loss 0.00853
+Epoch [ 23]: Loss 0.01033
+Epoch [ 23]: Loss 0.00902
+Epoch [ 23]: Loss 0.00969
+Validation: Loss 0.00888 Accuracy 1.00000
+Validation: Loss 0.00947 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00903
+Epoch [ 24]: Loss 0.00856
+Epoch [ 24]: Loss 0.00866
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00830
+Epoch [ 24]: Loss 0.00781
+Epoch [ 24]: Loss 0.00662
+Validation: Loss 0.00795 Accuracy 1.00000
+Validation: Loss 0.00846 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00830
+Epoch [ 25]: Loss 0.00742
+Epoch [ 25]: Loss 0.00822
+Epoch [ 25]: Loss 0.00791
+Epoch [ 25]: Loss 0.00721
+Epoch [ 25]: Loss 0.00726
+Epoch [ 25]: Loss 0.00582
+Validation: Loss 0.00730 Accuracy 1.00000
+Validation: Loss 0.00775 Accuracy 1.00000
We can also train the compact model with the exact same code!
Epoch [ 1]: Loss 0.62249
+Epoch [ 1]: Loss 0.58988
+Epoch [ 1]: Loss 0.57122
+Epoch [ 1]: Loss 0.54145
+Epoch [ 1]: Loss 0.51676
+Epoch [ 1]: Loss 0.49941
+Epoch [ 1]: Loss 0.48712
+Validation: Loss 0.46707 Accuracy 1.00000
+Validation: Loss 0.46650 Accuracy 1.00000
+Epoch [ 2]: Loss 0.46435
+Epoch [ 2]: Loss 0.45555
+Epoch [ 2]: Loss 0.45454
+Epoch [ 2]: Loss 0.42345
+Epoch [ 2]: Loss 0.41436
+Epoch [ 2]: Loss 0.38527
+Epoch [ 2]: Loss 0.37442
+Validation: Loss 0.36940 Accuracy 1.00000
+Validation: Loss 0.36858 Accuracy 1.00000
+Epoch [ 3]: Loss 0.36752
+Epoch [ 3]: Loss 0.36360
+Epoch [ 3]: Loss 0.34430
+Epoch [ 3]: Loss 0.32734
+Epoch [ 3]: Loss 0.31783
+Epoch [ 3]: Loss 0.31825
+Epoch [ 3]: Loss 0.28565
+Validation: Loss 0.28440 Accuracy 1.00000
+Validation: Loss 0.28337 Accuracy 1.00000
+Epoch [ 4]: Loss 0.28307
+Epoch [ 4]: Loss 0.27199
+Epoch [ 4]: Loss 0.26836
+Epoch [ 4]: Loss 0.26051
+Epoch [ 4]: Loss 0.24528
+Epoch [ 4]: Loss 0.23063
+Epoch [ 4]: Loss 0.22536
+Validation: Loss 0.21475 Accuracy 1.00000
+Validation: Loss 0.21368 Accuracy 1.00000
+Epoch [ 5]: Loss 0.21305
+Epoch [ 5]: Loss 0.21531
+Epoch [ 5]: Loss 0.19616
+Epoch [ 5]: Loss 0.18414
+Epoch [ 5]: Loss 0.18294
+Epoch [ 5]: Loss 0.17875
+Epoch [ 5]: Loss 0.17815
+Validation: Loss 0.15941 Accuracy 1.00000
+Validation: Loss 0.15850 Accuracy 1.00000
+Epoch [ 6]: Loss 0.16464
+Epoch [ 6]: Loss 0.14669
+Epoch [ 6]: Loss 0.14234
+Epoch [ 6]: Loss 0.14785
+Epoch [ 6]: Loss 0.13936
+Epoch [ 6]: Loss 0.13121
+Epoch [ 6]: Loss 0.11054
+Validation: Loss 0.11688 Accuracy 1.00000
+Validation: Loss 0.11621 Accuracy 1.00000
+Epoch [ 7]: Loss 0.11895
+Epoch [ 7]: Loss 0.11755
+Epoch [ 7]: Loss 0.11153
+Epoch [ 7]: Loss 0.10806
+Epoch [ 7]: Loss 0.08931
+Epoch [ 7]: Loss 0.08989
+Epoch [ 7]: Loss 0.08885
+Validation: Loss 0.08377 Accuracy 1.00000
+Validation: Loss 0.08332 Accuracy 1.00000
+Epoch [ 8]: Loss 0.08392
+Epoch [ 8]: Loss 0.07975
+Epoch [ 8]: Loss 0.07711
+Epoch [ 8]: Loss 0.07462
+Epoch [ 8]: Loss 0.06929
+Epoch [ 8]: Loss 0.06475
+Epoch [ 8]: Loss 0.06222
+Validation: Loss 0.05835 Accuracy 1.00000
+Validation: Loss 0.05808 Accuracy 1.00000
+Epoch [ 9]: Loss 0.05835
+Epoch [ 9]: Loss 0.05645
+Epoch [ 9]: Loss 0.05303
+Epoch [ 9]: Loss 0.04974
+Epoch [ 9]: Loss 0.04989
+Epoch [ 9]: Loss 0.04836
+Epoch [ 9]: Loss 0.04374
+Validation: Loss 0.04304 Accuracy 1.00000
+Validation: Loss 0.04283 Accuracy 1.00000
+Epoch [ 10]: Loss 0.04373
+Epoch [ 10]: Loss 0.03963
+Epoch [ 10]: Loss 0.04024
+Epoch [ 10]: Loss 0.03893
+Epoch [ 10]: Loss 0.04085
+Epoch [ 10]: Loss 0.03933
+Epoch [ 10]: Loss 0.02782
+Validation: Loss 0.03470 Accuracy 1.00000
+Validation: Loss 0.03451 Accuracy 1.00000
+Epoch [ 11]: Loss 0.03413
+Epoch [ 11]: Loss 0.03603
+Epoch [ 11]: Loss 0.03246
+Epoch [ 11]: Loss 0.03142
+Epoch [ 11]: Loss 0.03040
+Epoch [ 11]: Loss 0.03279
+Epoch [ 11]: Loss 0.03336
+Validation: Loss 0.02942 Accuracy 1.00000
+Validation: Loss 0.02924 Accuracy 1.00000
+Epoch [ 12]: Loss 0.03113
+Epoch [ 12]: Loss 0.02712
+Epoch [ 12]: Loss 0.02845
+Epoch [ 12]: Loss 0.02904
+Epoch [ 12]: Loss 0.02709
+Epoch [ 12]: Loss 0.02722
+Epoch [ 12]: Loss 0.02449
+Validation: Loss 0.02555 Accuracy 1.00000
+Validation: Loss 0.02540 Accuracy 1.00000
+Epoch [ 13]: Loss 0.02730
+Epoch [ 13]: Loss 0.02638
+Epoch [ 13]: Loss 0.02358
+Epoch [ 13]: Loss 0.02337
+Epoch [ 13]: Loss 0.02417
+Epoch [ 13]: Loss 0.02397
+Epoch [ 13]: Loss 0.02159
+Validation: Loss 0.02258 Accuracy 1.00000
+Validation: Loss 0.02243 Accuracy 1.00000
+Epoch [ 14]: Loss 0.02377
+Epoch [ 14]: Loss 0.02260
+Epoch [ 14]: Loss 0.02070
+Epoch [ 14]: Loss 0.02170
+Epoch [ 14]: Loss 0.02060
+Epoch [ 14]: Loss 0.02212
+Epoch [ 14]: Loss 0.02141
+Validation: Loss 0.02019 Accuracy 1.00000
+Validation: Loss 0.02006 Accuracy 1.00000
+Epoch [ 15]: Loss 0.02146
+Epoch [ 15]: Loss 0.01937
+Epoch [ 15]: Loss 0.02047
+Epoch [ 15]: Loss 0.01826
+Epoch [ 15]: Loss 0.01953
+Epoch [ 15]: Loss 0.01824
+Epoch [ 15]: Loss 0.02201
+Validation: Loss 0.01821 Accuracy 1.00000
+Validation: Loss 0.01809 Accuracy 1.00000
+Epoch [ 16]: Loss 0.01872
+Epoch [ 16]: Loss 0.01647
+Epoch [ 16]: Loss 0.01868
+Epoch [ 16]: Loss 0.01763
+Epoch [ 16]: Loss 0.01802
+Epoch [ 16]: Loss 0.01730
+Epoch [ 16]: Loss 0.01691
+Validation: Loss 0.01653 Accuracy 1.00000
+Validation: Loss 0.01642 Accuracy 1.00000
+Epoch [ 17]: Loss 0.01638
+Epoch [ 17]: Loss 0.01693
+Epoch [ 17]: Loss 0.01747
+Epoch [ 17]: Loss 0.01530
+Epoch [ 17]: Loss 0.01570
+Epoch [ 17]: Loss 0.01579
+Epoch [ 17]: Loss 0.01431
+Validation: Loss 0.01511 Accuracy 1.00000
+Validation: Loss 0.01501 Accuracy 1.00000
+Epoch [ 18]: Loss 0.01395
+Epoch [ 18]: Loss 0.01493
+Epoch [ 18]: Loss 0.01631
+Epoch [ 18]: Loss 0.01388
+Epoch [ 18]: Loss 0.01496
+Epoch [ 18]: Loss 0.01520
+Epoch [ 18]: Loss 0.01366
+Validation: Loss 0.01390 Accuracy 1.00000
+Validation: Loss 0.01381 Accuracy 1.00000
+Epoch [ 19]: Loss 0.01337
+Epoch [ 19]: Loss 0.01481
+Epoch [ 19]: Loss 0.01359
+Epoch [ 19]: Loss 0.01293
+Epoch [ 19]: Loss 0.01317
+Epoch [ 19]: Loss 0.01404
+Epoch [ 19]: Loss 0.01416
+Validation: Loss 0.01286 Accuracy 1.00000
+Validation: Loss 0.01277 Accuracy 1.00000
+Epoch [ 20]: Loss 0.01286
+Epoch [ 20]: Loss 0.01335
+Epoch [ 20]: Loss 0.01259
+Epoch [ 20]: Loss 0.01343
+Epoch [ 20]: Loss 0.01294
+Epoch [ 20]: Loss 0.01124
+Epoch [ 20]: Loss 0.01124
+Validation: Loss 0.01194 Accuracy 1.00000
+Validation: Loss 0.01186 Accuracy 1.00000
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01229
+Epoch [ 21]: Loss 0.01273
+Epoch [ 21]: Loss 0.01021
+Epoch [ 21]: Loss 0.01159
+Epoch [ 21]: Loss 0.01191
+Epoch [ 21]: Loss 0.01311
+Validation: Loss 0.01111 Accuracy 1.00000
+Validation: Loss 0.01104 Accuracy 1.00000
+Epoch [ 22]: Loss 0.01112
+Epoch [ 22]: Loss 0.01155
+Epoch [ 22]: Loss 0.01068
+Epoch [ 22]: Loss 0.01120
+Epoch [ 22]: Loss 0.00993
+Epoch [ 22]: Loss 0.01129
+Epoch [ 22]: Loss 0.01098
+Validation: Loss 0.01033 Accuracy 1.00000
+Validation: Loss 0.01026 Accuracy 1.00000
+Epoch [ 23]: Loss 0.00950
+Epoch [ 23]: Loss 0.01102
+Epoch [ 23]: Loss 0.01060
+Epoch [ 23]: Loss 0.01058
+Epoch [ 23]: Loss 0.00987
+Epoch [ 23]: Loss 0.01006
+Epoch [ 23]: Loss 0.00747
+Validation: Loss 0.00952 Accuracy 1.00000
+Validation: Loss 0.00945 Accuracy 1.00000
+Epoch [ 24]: Loss 0.00960
+Epoch [ 24]: Loss 0.00995
+Epoch [ 24]: Loss 0.00883
+Epoch [ 24]: Loss 0.00888
+Epoch [ 24]: Loss 0.00955
+Epoch [ 24]: Loss 0.00915
+Epoch [ 24]: Loss 0.00884
+Validation: Loss 0.00861 Accuracy 1.00000
+Validation: Loss 0.00856 Accuracy 1.00000
+Epoch [ 25]: Loss 0.00958
+Epoch [ 25]: Loss 0.00920
+Epoch [ 25]: Loss 0.00803
+Epoch [ 25]: Loss 0.00769
+Epoch [ 25]: Loss 0.00804
+Epoch [ 25]: Loss 0.00784
+Epoch [ 25]: Loss 0.00760
+Validation: Loss 0.00766 Accuracy 1.00000
+Validation: Loss 0.00762 Accuracy 1.00000
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
SimpleChains.jl is an excellent framework for training small neural networks. In this tutorial we will demonstrate how to use the same API as Lux.jl to train a model using SimpleChains.jl. We will use the tutorial from SimpleChains.jl as a reference.
[ 1/10] Time 106.58s Training Accuracy: 22.44% Test Accuracy: 19.50%
+[ 2/10] Time 106.61s Training Accuracy: 47.06% Test Accuracy: 45.50%
+[ 3/10] Time 112.24s Training Accuracy: 61.50% Test Accuracy: 61.00%
+[ 4/10] Time 115.93s Training Accuracy: 69.89% Test Accuracy: 65.00%
+[ 5/10] Time 118.22s Training Accuracy: 75.22% Test Accuracy: 74.00%
+[ 6/10] Time 112.80s Training Accuracy: 78.44% Test Accuracy: 77.50%
+[ 7/10] Time 108.41s Training Accuracy: 81.22% Test Accuracy: 81.00%
+[ 8/10] Time 112.49s Training Accuracy: 83.94% Test Accuracy: 80.50%
+[ 9/10] Time 113.54s Training Accuracy: 85.89% Test Accuracy: 84.50%
+[10/10] Time 113.99s Training Accuracy: 87.11% Test Accuracy: 84.50%
Now we will train the SimpleChains model
julia
train(simple_chains_model)
[ 1/10] Time 18.70s Training Accuracy: 29.06% Test Accuracy: 23.50%
+[ 2/10] Time 17.64s Training Accuracy: 45.83% Test Accuracy: 43.00%
+[ 3/10] Time 17.64s Training Accuracy: 62.72% Test Accuracy: 57.50%
+[ 4/10] Time 17.64s Training Accuracy: 65.67% Test Accuracy: 61.50%
+[ 5/10] Time 17.65s Training Accuracy: 74.72% Test Accuracy: 68.50%
+[ 6/10] Time 17.63s Training Accuracy: 79.61% Test Accuracy: 77.00%
+[ 7/10] Time 17.64s Training Accuracy: 81.83% Test Accuracy: 77.00%
+[ 8/10] Time 17.63s Training Accuracy: 83.94% Test Accuracy: 79.50%
+[ 9/10] Time 17.65s Training Accuracy: 84.50% Test Accuracy: 84.50%
+[10/10] Time 17.63s Training Accuracy: 87.78% Test Accuracy: 83.50%
On my local machine we see a 3-4x speedup when using SimpleChains.jl. The conditions of the server this documentation is being built on is not ideal for CPU benchmarking hence, the speedup may not be as significant and even there might be regressions.
Lux's native Training.TrainState is a great API for gradient-based learning of neural networks, however, it is geared towards using Optimisers.jl as the backend. However, often times we want to train the neural networks with other optimization methods like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with Optimization.jl that provides a simple unified interface to various optimization methods.
We will base our tutorial on the minibatching tutorial from the official Optimization.jl docs.
Neural ODE
This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. Please refer to the Neural ODE tutorial for more information.
We will define the DataLoader to batch over the data, additionally we will pipe it through the gdev device to move the data to the GPU on each iteration.
By default gdev will move all objects to the GPU. But we don't want to move the time vector to the GPU. So we will wrap it in a struct.
Here we are using different optimization methods for demonstration purposes. This problem is trivial enough to not require this.
Optimization.jl requires an abstract array as the parameters, hence we will construct a ComponentArray to store the parameters.
Parameter Estimation vs State Estimation
Optimization.jl performs state estimation, which effectively means for a function f(u, p), it is trying to compute the optimal u for a given p. This terminology might be confusing to ML practitioners, since in the ML world, we usually do parameter estimation. This effectively means that the u in Optimization.jl corresponds to our model parameters that is being optimized.
julia
function train_model(dataloader)
+ model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
+ ps, st = Lux.setup(Random.default_rng(), model)
+
+ ps_ca = ComponentArray(ps) |> gdev
+ st = st |> gdev
+
+ function callback(state, l)
+ state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\n" state.iter l
+ return l < 1e-8 ## Terminate if loss is small
+ end
+
+ smodel = StatefulLuxLayer{true}(model, nothing, st)
+
+ function loss_adjoint(θ, (u_batch, t_batch))
+ t_batch = t_batch.t
+ u0 = u_batch[:, 1]
+ dudt(u, p, t) = smodel(u, p)
+ prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
+ pred = convert(AbstractArray, solve(prob, Tsit5(); saveat=t_batch))
+ return MSELoss()(pred, u_batch)
+ end
+
+ # Define the Optimization Function that takes in the optimization state (our parameters)
+ # and optimization parameters (nothing in our case) and data from the dataloader and
+ # returns the loss.
+ opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
+ opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
+
+ epochs = 25
+ res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
+
+ # Let's finetune a bit with L-BFGS
+ opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))
+ res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)
+
+ # Now that we have a good fit, let's train it on the entire dataset without
+ # Minibatching. We need to do this since ODE solves can lead to accumulated errors if
+ # the model was trained on individual parts (without a data-shooting approach).
+ opt_prob = remake(opt_prob; u0=res_lbfgs.u)
+ res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)
+
+ return StatefulLuxLayer{true}(model, res.u, smodel.st)
+end
+
+trained_model = train_model(dataloader)
These models are part of the Lux examples, however, these are larger model that cannot be run on CI and aren't frequently tested. If you find a bug in one of these models, please open an issue or PR to fix it.
These tutorials are developed by the community and may not be up-to-date with the latest version of Lux.jl. Please refer to the official documentation for the most up-to-date information.
Please open an issue (ideally both at Lux.jl and at the downstream linked package) if any of them are non-functional and we will try to get them updated.
To understand Neural ODEs, users should look up these lecture notes. We recommend users to directly use DiffEqFlux.jl, instead of implementing Neural ODEs from scratch.
First we will use the @compact macro to define the Neural ODE Layer.
julia
function NeuralODECompact(
+ model::Lux.AbstractLuxLayer; solver=Tsit5(), tspan=(0.0f0, 1.0f0), kwargs...)
+ return @compact(; model, solver, tspan, kwargs...) do x, p
+ dudt(u, p, t) = vec(model(reshape(u, size(x)), p))
+ # Note the `p.model` here
+ prob = ODEProblem(ODEFunction{false}(dudt), vec(x), tspan, p.model)
+ @return solve(prob, solver; kwargs...)
+ end
+end
NeuralODECompact (generic function with 1 method)
We recommend using the compact macro for creating custom layers. The below implementation exists mostly for historical reasons when @compact was not part of the stable API. Also, it helps users understand how the layer interface of Lux works.
The NeuralODE is a ContainerLayer, which stores a model. The parameters and states of the NeuralODE are same as those of the underlying model.
OrdinaryDiffEq.jl can deal with non-Vector Inputs! However, certain discrete sensitivities like ReverseDiffAdjoint can't handle non-Vector inputs. Hence, we need to convert the input and output of the ODE solver to a Vector.
function train(model_function; cpu::Bool=false, kwargs...)
+ dev = cpu ? cpu_device() : gpu_device()
+ model, ps, st = create_model(model_function; dev, kwargs...)
+
+ # Training
+ train_dataloader, test_dataloader = loadmnist(128, 0.9) |> dev
+
+ tstate = Training.TrainState(model, ps, st, Adam(0.001f0))
+
+ ### Lets train the model
+ nepochs = 9
+ for epoch in 1:nepochs
+ stime = time()
+ for (x, y) in train_dataloader
+ _, _, _, tstate = Training.single_train_step!(
+ AutoZygote(), logitcrossentropy, (x, y), tstate)
+ end
+ ttime = time() - stime
+
+ tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100
+ te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100
+ @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \
+ Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc
+ end
+end
+
+train(NeuralODECompact)
[1/9] Time 119.4158s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4958s Training Accuracy: 58.22222% Test Accuracy: 57.33333%
+[3/9] Time 0.6961s Training Accuracy: 67.85185% Test Accuracy: 70.66667%
+[4/9] Time 0.4869s Training Accuracy: 74.29630% Test Accuracy: 74.66667%
+[5/9] Time 0.5064s Training Accuracy: 76.29630% Test Accuracy: 76.00000%
+[6/9] Time 0.7482s Training Accuracy: 78.74074% Test Accuracy: 80.00000%
+[7/9] Time 0.4736s Training Accuracy: 82.22222% Test Accuracy: 81.33333%
+[8/9] Time 0.4883s Training Accuracy: 83.62963% Test Accuracy: 83.33333%
+[9/9] Time 0.7453s Training Accuracy: 85.18519% Test Accuracy: 82.66667%
julia
train(NeuralODE)
[1/9] Time 36.4249s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4793s Training Accuracy: 57.18519% Test Accuracy: 57.33333%
+[3/9] Time 0.6545s Training Accuracy: 68.37037% Test Accuracy: 68.00000%
+[4/9] Time 0.4797s Training Accuracy: 73.77778% Test Accuracy: 75.33333%
+[5/9] Time 0.4833s Training Accuracy: 76.14815% Test Accuracy: 77.33333%
+[6/9] Time 0.7233s Training Accuracy: 79.48148% Test Accuracy: 80.66667%
+[7/9] Time 0.4913s Training Accuracy: 81.25926% Test Accuracy: 80.66667%
+[8/9] Time 0.4843s Training Accuracy: 83.40741% Test Accuracy: 82.66667%
+[9/9] Time 0.7256s Training Accuracy: 84.81481% Test Accuracy: 82.00000%
We can also change the sensealg and train the model! GaussAdjoint allows you to use any arbitrary parameter structure and not just a flat vector (ComponentArray).
[1/9] Time 42.6019s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.5487s Training Accuracy: 57.55556% Test Accuracy: 54.00000%
+[3/9] Time 0.4660s Training Accuracy: 69.85185% Test Accuracy: 69.33333%
+[4/9] Time 0.4833s Training Accuracy: 72.51852% Test Accuracy: 74.00000%
+[5/9] Time 0.4743s Training Accuracy: 75.33333% Test Accuracy: 76.00000%
+[6/9] Time 0.4944s Training Accuracy: 78.88889% Test Accuracy: 79.33333%
+[7/9] Time 0.6809s Training Accuracy: 81.03704% Test Accuracy: 80.00000%
+[8/9] Time 0.4987s Training Accuracy: 83.77778% Test Accuracy: 81.33333%
+[9/9] Time 0.5045s Training Accuracy: 85.25926% Test Accuracy: 82.66667%
But remember some AD backends like ReverseDiff is not GPU compatible. For a model this size, you will notice that training time is significantly lower for training on CPU than on GPU.
[1/9] Time 96.0630s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 14.0172s Training Accuracy: 58.74074% Test Accuracy: 56.66667%
+[3/9] Time 13.5410s Training Accuracy: 69.92593% Test Accuracy: 71.33333%
+[4/9] Time 13.6407s Training Accuracy: 72.81481% Test Accuracy: 74.00000%
+[5/9] Time 13.4329s Training Accuracy: 76.37037% Test Accuracy: 78.66667%
+[6/9] Time 12.0878s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 14.5981s Training Accuracy: 81.62963% Test Accuracy: 80.66667%
+[8/9] Time 13.6945s Training Accuracy: 83.33333% Test Accuracy: 80.00000%
+[9/9] Time 10.3098s Training Accuracy: 85.40741% Test Accuracy: 82.00000%
For completeness, let's also test out discrete sensitivities!
[1/9] Time 49.7652s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 21.6687s Training Accuracy: 58.66667% Test Accuracy: 57.33333%
+[3/9] Time 21.5681s Training Accuracy: 69.70370% Test Accuracy: 71.33333%
+[4/9] Time 21.3427s Training Accuracy: 72.74074% Test Accuracy: 74.00000%
+[5/9] Time 23.9941s Training Accuracy: 76.14815% Test Accuracy: 78.66667%
+[6/9] Time 22.0233s Training Accuracy: 79.03704% Test Accuracy: 80.66667%
+[7/9] Time 22.4246s Training Accuracy: 81.55556% Test Accuracy: 80.66667%
+[8/9] Time 23.1968s Training Accuracy: 83.40741% Test Accuracy: 80.00000%
+[9/9] Time 24.0997s Training Accuracy: 85.25926% Test Accuracy: 81.33333%
[1/9] Time 38.2440s Training Accuracy: 37.48148% Test Accuracy: 40.00000%
+[2/9] Time 0.4759s Training Accuracy: 58.22222% Test Accuracy: 55.33333%
+[3/9] Time 0.4745s Training Accuracy: 68.29630% Test Accuracy: 68.66667%
+[4/9] Time 0.4670s Training Accuracy: 73.11111% Test Accuracy: 76.00000%
+[5/9] Time 0.5117s Training Accuracy: 75.92593% Test Accuracy: 76.66667%
+[6/9] Time 0.4779s Training Accuracy: 78.96296% Test Accuracy: 80.66667%
+[7/9] Time 0.4705s Training Accuracy: 80.81481% Test Accuracy: 81.33333%
+[8/9] Time 0.4590s Training Accuracy: 83.25926% Test Accuracy: 82.66667%
+[9/9] Time 0.4555s Training Accuracy: 84.59259% Test Accuracy: 82.00000%
We might not see a significant difference in the training time, but let us investigate the type stabilities of the layers.
We borrow this tutorial from the official Turing Docs. We will show how the explicit parameterization of Lux enables first-class composability with packages which expect flattened out parameter vectors.
Note: The tutorial in the official Turing docs is now using Lux instead of Flux.
We will use Turing.jl with Lux.jl to implement implementing a classification algorithm. Lets start by importing the relevant libraries.
Our goal here is to use a Bayesian neural network to classify points in an artificial dataset. The code below generates data points arranged in a box-like pattern and displays a graph of the dataset we'll be working with.
julia
# Number of points to generate
+N = 80
+M = round(Int, N / 4)
+rng = Random.default_rng()
+Random.seed!(rng, 1234)
+
+# Generate artificial data
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt1s = Array([[x1s[i] + 0.5f0; x2s[i] + 0.5f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt1s, Array([[x1s[i] - 5.0f0; x2s[i] - 5.0f0] for i in 1:M]))
+
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+xt0s = Array([[x1s[i] + 0.5f0; x2s[i] - 5.0f0] for i in 1:M])
+x1s = rand(rng, Float32, M) * 4.5f0;
+x2s = rand(rng, Float32, M) * 4.5f0;
+append!(xt0s, Array([[x1s[i] - 5.0f0; x2s[i] + 0.5f0] for i in 1:M]))
+
+# Store all the data for later
+xs = [xt1s; xt0s]
+ts = [ones(2 * M); zeros(2 * M)]
+
+# Plot data points
+
+function plot_data()
+ x1 = first.(xt1s)
+ y1 = last.(xt1s)
+ x2 = first.(xt0s)
+ y2 = last.(xt0s)
+
+ fig = Figure()
+ ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
+
+ scatter!(ax, x1, y1; markersize=16, color=:red, strokecolor=:black, strokewidth=2)
+ scatter!(ax, x2, y2; markersize=16, color=:blue, strokecolor=:black, strokewidth=2)
+
+ return fig
+end
+
+plot_data()
The next step is to define a feedforward neural network where we express our parameters as distributions, and not single points as with traditional neural networks. For this we will use Dense to define liner layers and compose them via Chain, both are neural network primitives from Lux. The network nn we will create will have two hidden layers with tanh activations and one output layer with sigmoid activation, as shown below.
The nn is an instance that acts as a function and can take data, parameters and current state as inputs and output predictions. We will define distributions on the neural network parameters.
julia
# Construct a neural network using Lux
+nn = Chain(Dense(2 => 3, tanh), Dense(3 => 2, tanh), Dense(2 => 1, sigmoid))
+
+# Initialize the model weights and state
+ps, st = Lux.setup(rng, nn)
+
+Lux.parameterlength(nn) # number of parameters in NN
20
The probabilistic model specification below creates a parameters variable, which has IID normal variables. The parameters represents all parameters of our neural net (weights and biases).
julia
# Create a regularization term and a Gaussian prior variance term.
+alpha = 0.09
+sig = sqrt(1.0 / alpha)
3.3333333333333335
Construct named tuple from a sampled parameter vector. We could also use ComponentArrays here and simply broadcast to avoid doing this. But let's do it this way to avoid dependencies.
julia
function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
+ @assert length(ps_new) == Lux.parameterlength(ps)
+ i = 1
+ function get_ps(x)
+ z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x))
+ i += length(x)
+ return z
+ end
+ return fmap(get_ps, ps)
+end
vector_to_parameters (generic function with 1 method)
To interface with external libraries it is often desirable to use the StatefulLuxLayer to automatically handle the neural network states.
julia
const model = StatefulLuxLayer{true}(nn, nothing, st)
+
+# Specify the probabilistic model.
+@model function bayes_nn(xs, ts)
+ # Sample the parameters
+ nparameters = Lux.parameterlength(nn)
+ parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
+
+ # Forward NN to make predictions
+ preds = Lux.apply(model, xs, vector_to_parameters(parameters, ps))
+
+ # Observe each prediction.
+ for i in eachindex(ts)
+ ts[i] ~ Bernoulli(preds[i])
+ end
+end
bayes_nn (generic function with 2 methods)
Inference can now be performed by calling sample. We use the HMC sampler here.
Now we extract the parameter samples from the sampled chain as θ (this is of size 5000 x 20 where 5000 is the number of iterations and 20 is the number of parameters). We'll use these primarily to determine how good our model's classifier is.
julia
# Extract all weight and bias parameters.
+θ = MCMCChains.group(ch, :parameters).value;
# A helper to run the nn through data `x` using parameters `θ`
+nn_forward(x, θ) = model(x, vector_to_parameters(θ, ps))
+
+# Plot the data we have.
+fig = plot_data()
+
+# Find the index that provided the highest log posterior in the chain.
+_, i = findmax(ch[:lp])
+
+# Extract the max row value from i.
+i = i.I[1]
+
+# Plot the posterior distribution with a contour plot
+x1_range = collect(range(-6; stop=6, length=25))
+x2_range = collect(range(-6; stop=6, length=25))
+Z = [nn_forward([x1, x2], θ[i, :])[1] for x1 in x1_range, x2 in x2_range]
+contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+fig
The contour plot above shows that the MAP method is not too bad at classifying our data. Now we can visualize our predictions.
The nn_predict function takes the average predicted value from a network parameterized by weights drawn from the MCMC chain.
julia
# Return the average predicted value across multiple weights.
+nn_predict(x, θ, num) = mean([first(nn_forward(x, view(θ, i, :))) for i in 1:10:num])
nn_predict (generic function with 1 method)
Next, we use the nn_predict function to predict the value at a sample of points where the x1 and x2 coordinates range between -6 and 6. As we can see below, we still have a satisfactory fit to our data, and more importantly, we can also see where the neural network is uncertain about its predictions much easier–-those regions between cluster boundaries.
Suppose we are interested in how the predictive power of our Bayesian neural network evolved between samples. In that case, the following graph displays an animation of the contour plot generated from the network weights in samples 1 to 5,000.
julia
fig = plot_data()
+Z = [first(nn_forward([x1, x2], θ[1, :])) for x1 in x1_range, x2 in x2_range]
+c = contour!(x1_range, x2_range, Z; linewidth=3, colormap=:seaborn_bright)
+record(fig, "results.gif", 1:250:size(θ, 1)) do i
+ fig.current_axis[].title = "Iteration: $i"
+ Z = [first(nn_forward([x1, x2], θ[i, :])) for x1 in x1_range, x2 in x2_range]
+ c[3] = Z
+ return fig
+end
Defining functions on the CompactLuxLayer requires some understanding of how the layer is structured, as such we don't recommend doing it unless you are familiar with the internals. In this case, we simply write it to ignore the initialization of the core_network parameters.
julia
function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
+ return (; weight_generator=Lux.initialparameters(rng, hn.layers.weight_generator),)
+end
function create_model()
+ # Doesn't need to be a MLP can have any Lux Layer
+ core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
+ weight_generator = Chain(Embedding(2 => 32), Dense(32, 64, relu),
+ Dense(64, Lux.parameterlength(core_network)))
+
+ model = HyperNet(weight_generator, core_network)
+ return model
+end
[ 1/ 50] MNIST Time 70.85048s Training Accuracy: 61.23% Test Accuracy: 56.25%
+[ 1/ 50] FashionMNIST Time 0.02819s Training Accuracy: 43.26% Test Accuracy: 37.50%
+[ 2/ 50] MNIST Time 0.02797s Training Accuracy: 71.19% Test Accuracy: 68.75%
+[ 2/ 50] FashionMNIST Time 0.02918s Training Accuracy: 56.25% Test Accuracy: 46.88%
+[ 3/ 50] MNIST Time 0.02907s Training Accuracy: 79.39% Test Accuracy: 71.88%
+[ 3/ 50] FashionMNIST Time 0.02807s Training Accuracy: 59.67% Test Accuracy: 53.12%
+[ 4/ 50] MNIST Time 0.02442s Training Accuracy: 78.71% Test Accuracy: 68.75%
+[ 4/ 50] FashionMNIST Time 0.02106s Training Accuracy: 68.36% Test Accuracy: 65.62%
+[ 5/ 50] MNIST Time 0.02221s Training Accuracy: 83.79% Test Accuracy: 75.00%
+[ 5/ 50] FashionMNIST Time 0.02173s Training Accuracy: 71.78% Test Accuracy: 62.50%
+[ 6/ 50] MNIST Time 0.02186s Training Accuracy: 88.67% Test Accuracy: 75.00%
+[ 6/ 50] FashionMNIST Time 0.02362s Training Accuracy: 72.95% Test Accuracy: 56.25%
+[ 7/ 50] MNIST Time 0.02382s Training Accuracy: 90.92% Test Accuracy: 78.12%
+[ 7/ 50] FashionMNIST Time 0.02322s Training Accuracy: 80.27% Test Accuracy: 68.75%
+[ 8/ 50] MNIST Time 0.03652s Training Accuracy: 90.82% Test Accuracy: 78.12%
+[ 8/ 50] FashionMNIST Time 0.02083s Training Accuracy: 76.46% Test Accuracy: 68.75%
+[ 9/ 50] MNIST Time 0.02124s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 9/ 50] FashionMNIST Time 0.02080s Training Accuracy: 74.71% Test Accuracy: 65.62%
+[ 10/ 50] MNIST Time 0.02075s Training Accuracy: 94.63% Test Accuracy: 81.25%
+[ 10/ 50] FashionMNIST Time 0.02080s Training Accuracy: 77.34% Test Accuracy: 62.50%
+[ 11/ 50] MNIST Time 0.02030s Training Accuracy: 96.29% Test Accuracy: 78.12%
+[ 11/ 50] FashionMNIST Time 0.02048s Training Accuracy: 82.13% Test Accuracy: 78.12%
+[ 12/ 50] MNIST Time 0.02080s Training Accuracy: 97.95% Test Accuracy: 78.12%
+[ 12/ 50] FashionMNIST Time 0.02626s Training Accuracy: 81.84% Test Accuracy: 78.12%
+[ 13/ 50] MNIST Time 0.02091s Training Accuracy: 98.44% Test Accuracy: 84.38%
+[ 13/ 50] FashionMNIST Time 0.02084s Training Accuracy: 84.08% Test Accuracy: 71.88%
+[ 14/ 50] MNIST Time 0.02098s Training Accuracy: 98.93% Test Accuracy: 81.25%
+[ 14/ 50] FashionMNIST Time 0.02068s Training Accuracy: 85.55% Test Accuracy: 65.62%
+[ 15/ 50] MNIST Time 0.02067s Training Accuracy: 99.22% Test Accuracy: 84.38%
+[ 15/ 50] FashionMNIST Time 0.02068s Training Accuracy: 86.13% Test Accuracy: 68.75%
+[ 16/ 50] MNIST Time 0.02060s Training Accuracy: 99.51% Test Accuracy: 81.25%
+[ 16/ 50] FashionMNIST Time 0.02051s Training Accuracy: 86.13% Test Accuracy: 65.62%
+[ 17/ 50] MNIST Time 0.02531s Training Accuracy: 99.61% Test Accuracy: 81.25%
+[ 17/ 50] FashionMNIST Time 0.02054s Training Accuracy: 87.11% Test Accuracy: 71.88%
+[ 18/ 50] MNIST Time 0.02092s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 18/ 50] FashionMNIST Time 0.02098s Training Accuracy: 88.28% Test Accuracy: 75.00%
+[ 19/ 50] MNIST Time 0.02228s Training Accuracy: 99.80% Test Accuracy: 81.25%
+[ 19/ 50] FashionMNIST Time 0.02067s Training Accuracy: 89.16% Test Accuracy: 71.88%
+[ 20/ 50] MNIST Time 0.02038s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 20/ 50] FashionMNIST Time 0.02079s Training Accuracy: 89.26% Test Accuracy: 75.00%
+[ 21/ 50] MNIST Time 0.02039s Training Accuracy: 99.90% Test Accuracy: 81.25%
+[ 21/ 50] FashionMNIST Time 0.02023s Training Accuracy: 89.65% Test Accuracy: 75.00%
+[ 22/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 22/ 50] FashionMNIST Time 0.02039s Training Accuracy: 89.94% Test Accuracy: 75.00%
+[ 23/ 50] MNIST Time 0.02139s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 23/ 50] FashionMNIST Time 0.02072s Training Accuracy: 90.43% Test Accuracy: 71.88%
+[ 24/ 50] MNIST Time 0.02055s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 24/ 50] FashionMNIST Time 0.02085s Training Accuracy: 90.72% Test Accuracy: 71.88%
+[ 25/ 50] MNIST Time 0.02080s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 25/ 50] FashionMNIST Time 0.02870s Training Accuracy: 92.29% Test Accuracy: 75.00%
+[ 26/ 50] MNIST Time 0.02078s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 26/ 50] FashionMNIST Time 0.02083s Training Accuracy: 92.38% Test Accuracy: 71.88%
+[ 27/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 27/ 50] FashionMNIST Time 0.02037s Training Accuracy: 91.80% Test Accuracy: 75.00%
+[ 28/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 28/ 50] FashionMNIST Time 0.02035s Training Accuracy: 92.97% Test Accuracy: 68.75%
+[ 29/ 50] MNIST Time 0.02075s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 29/ 50] FashionMNIST Time 0.02075s Training Accuracy: 93.16% Test Accuracy: 71.88%
+[ 30/ 50] MNIST Time 0.02654s Training Accuracy: 100.00% Test Accuracy: 81.25%
+[ 30/ 50] FashionMNIST Time 0.02034s Training Accuracy: 92.09% Test Accuracy: 71.88%
+[ 31/ 50] MNIST Time 0.02107s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 31/ 50] FashionMNIST Time 0.02075s Training Accuracy: 94.24% Test Accuracy: 71.88%
+[ 32/ 50] MNIST Time 0.02297s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 32/ 50] FashionMNIST Time 0.02142s Training Accuracy: 93.65% Test Accuracy: 71.88%
+[ 33/ 50] MNIST Time 0.02200s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 33/ 50] FashionMNIST Time 0.02105s Training Accuracy: 94.34% Test Accuracy: 75.00%
+[ 34/ 50] MNIST Time 0.02155s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 34/ 50] FashionMNIST Time 0.02781s Training Accuracy: 93.65% Test Accuracy: 68.75%
+[ 35/ 50] MNIST Time 0.02128s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 35/ 50] FashionMNIST Time 0.02310s Training Accuracy: 95.12% Test Accuracy: 71.88%
+[ 36/ 50] MNIST Time 0.02250s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 36/ 50] FashionMNIST Time 0.02097s Training Accuracy: 95.90% Test Accuracy: 71.88%
+[ 37/ 50] MNIST Time 0.02084s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 37/ 50] FashionMNIST Time 0.02062s Training Accuracy: 95.80% Test Accuracy: 75.00%
+[ 38/ 50] MNIST Time 0.02122s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 38/ 50] FashionMNIST Time 0.02084s Training Accuracy: 95.70% Test Accuracy: 71.88%
+[ 39/ 50] MNIST Time 0.01987s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 39/ 50] FashionMNIST Time 0.02035s Training Accuracy: 96.88% Test Accuracy: 71.88%
+[ 40/ 50] MNIST Time 0.02083s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 40/ 50] FashionMNIST Time 0.02133s Training Accuracy: 96.68% Test Accuracy: 71.88%
+[ 41/ 50] MNIST Time 0.02054s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 41/ 50] FashionMNIST Time 0.02079s Training Accuracy: 97.07% Test Accuracy: 71.88%
+[ 42/ 50] MNIST Time 0.02094s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 42/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 43/ 50] MNIST Time 0.02632s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 43/ 50] FashionMNIST Time 0.02029s Training Accuracy: 97.36% Test Accuracy: 71.88%
+[ 44/ 50] MNIST Time 0.02053s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 44/ 50] FashionMNIST Time 0.02080s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 45/ 50] MNIST Time 0.02082s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 45/ 50] FashionMNIST Time 0.02060s Training Accuracy: 97.85% Test Accuracy: 75.00%
+[ 46/ 50] MNIST Time 0.02029s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 46/ 50] FashionMNIST Time 0.02048s Training Accuracy: 97.75% Test Accuracy: 71.88%
+[ 47/ 50] MNIST Time 0.02093s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 47/ 50] FashionMNIST Time 0.02595s Training Accuracy: 97.66% Test Accuracy: 75.00%
+[ 48/ 50] MNIST Time 0.02109s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 48/ 50] FashionMNIST Time 0.02037s Training Accuracy: 96.97% Test Accuracy: 68.75%
+[ 49/ 50] MNIST Time 0.02034s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 49/ 50] FashionMNIST Time 0.02065s Training Accuracy: 97.36% Test Accuracy: 75.00%
+[ 50/ 50] MNIST Time 0.02088s Training Accuracy: 100.00% Test Accuracy: 84.38%
+[ 50/ 50] FashionMNIST Time 0.02084s Training Accuracy: 97.75% Test Accuracy: 68.75%
+
+[FINAL] MNIST Training Accuracy: 100.00% Test Accuracy: 84.38%
+[FINAL] FashionMNIST Training Accuracy: 97.75% Test Accuracy: 68.75%
In this tutorial we will go over using a PINN to solve 2D PDEs. We will be using the system from NeuralPDE Tutorials. However, we will be using our custom loss function and use nested AD capabilities of Lux.jl.
This is a demonstration of Lux.jl. For serious usecases of PINNs, please refer to the package: NeuralPDE.jl.
Since Lux supports efficient nested AD upto 2nd order, we will rewrite the problem with first order derivatives, so that we can compute the gradients of the loss using 2nd order AD.
All the networks take 3 input variables and output a scalar value. Here, we will define a a wrapper over the 3 networks, so that we can train them using Training.TrainState.
We will generate some random data to train the model on. We will take data on a square spatial and temporal domain , , and . Typically, you want to be smarter about the sampling process, but for the sake of simplicity, we will skip that.
+
+
+
+
\ No newline at end of file
diff --git a/previews/PR1023/weather-neural-ode.gif b/previews/PR1023/weather-neural-ode.gif
new file mode 100644
index 0000000000..34c46737d1
Binary files /dev/null and b/previews/PR1023/weather-neural-ode.gif differ