import Pkg
Pkg.add("Lux")
Tip
If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.
Packages | Stable Version | Monthly Downloads | Total Downloads | Build Status |
---|---|---|---|---|
π¦ Lux.jl | ||||
β π¦ LuxLib.jl | ||||
β π¦ LuxCore.jl | ||||
β π¦ MLDataDevices.jl | ||||
β π¦ WeightInitializers.jl | ||||
β π¦ LuxTestUtils.jl | ||||
β π¦ LuxCUDA.jl |
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
# Get the device determined by Lux
dev = gpu_device()
# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev
# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))
## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(AutoZygote(), MSELoss(),
(x, dev(rand(rng, Float32, 10, 2))), train_state)
Look in the examples directory for self-contained usage examples. The documentation has examples sorted into proper categories.
For usage related questions, please use Github Discussions which allows questions and answers to be indexed. To report bugs use github issues or even better send in a pull request.
If you found this library to be useful in academic work, then please cite:
@software{pal2023lux,
author = {Pal, Avik},
title = {{Lux: Explicit Parameterization of Deep Neural Networks in Julia}},
month = apr,
year = 2023,
note = {If you use this software, please cite it as below.},
publisher = {Zenodo},
version = {v1.4.2},
doi = {10.5281/zenodo.7808903},
url = {https://doi.org/10.5281/zenodo.7808903},
swhid = {swh:1:dir:1a304ec3243961314a1cc7c1481a31c4386c4a34;origin=https://doi.org/10.5281/zenodo.7808903;visit=swh:1:snp:e2bbe43b14bde47c4ddf7e637eb7fc7bd10db8c7;anchor=swh:1:rel:2c0c0ff927e7bfe8fc8bc43fd553ab392a6eb403;path=/}
}
@thesis{pal2023efficient,
title = {{On Efficient Training \& Inference of Neural Differential Equations}},
author = {Pal, Avik},
year = {2023},
school = {Massachusetts Institute of Technology}
}
Also consider starring our github repo.
This section is somewhat incomplete. You can contribute by contributing to finishing this section π.
The full test of Lux.jl
takes a long time, here's how to test a portion of the code.
For each @testitem
, there are corresponding tags
, for example:
@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers]
For example, let's consider the tests for SkipConnection
:
@testitem "SkipConnection" setup=[SharedTestSetup] tags=[:core_layers] begin
...
end
We can test the group to which SkipConnection
belongs by testing core_layers
.
To do so set the LUX_TEST_GROUP
environment variable, or rename the tag to
further narrow the test scope:
export LUX_TEST_GROUP="core_layers"
Or directly modify the default test tag in runtests.jl
:
# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "core_layers"))
But be sure to restore the default value "all" before submitting the code.
Furthermore if you want to run a specific test based on the name of the testset, you can use TestEnv.jl as follows. Start with activating the Lux environment and then run the following:
using TestEnv; TestEnv.activate(); using ReTestItems;
# Assuming you are in the main directory of Lux
ReTestItems.runtests("tests/"; name = "NAME OF THE TEST")
For the SkipConnection
tests that would be:
ReTestItems.runtests("tests/"; name = "SkipConnection")