Skip to content

Commit

Permalink
start a tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent 76b8f2d commit 807a623
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ y = rand(rng, Float32, 2, 3) |> gdev

model(x, ps, st)

gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
gs = only(Zygote.gradient(p -> sum(abs2, first(model(x, p, st)) .- y), ps))
```

## Citation
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
DeepEquilibriumNetworks = "2"
Documenter = "1"
DocumenterCitations = "0.2, 1"
DocumenterCitations = "1"
45 changes: 43 additions & 2 deletions docs/src/tutorials/basic_mnist_deq.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,44 @@
# Training a Simple MNIST Classifier with DEQ
# Training a Simple MNIST Classifier using Deep Equilibrium Models

This Tutorial is currently under preparation. Check back soon.
We will train a simple Deep Equilibrium Model on MNIST. First we load a few packages.

```@example basic_mnist_deq
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
Statistics, Random, Optimization, OptimizationOptimisers
using LuxCUDA
using MLDatasets: MNIST
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs
CUDA.allowscalar(false)
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
```

Setup device functions from Lux. See
[GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management) for more details.

```@example basic_mnist_deq
const cdev = cpu_device()
const gdev = gpu_device()
```

We can now construct our dataloader.

```@example basic_mnist_deq
function onehot(labels_raw)
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
end
function loadmnist(batchsize)
# Load MNIST
mnist = MNIST(; split = :train)
imgs, labels_raw = mnist.features, mnist.targets
# Process images into (H,W,C,BS) batches
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
gdev
x_train = batchview(x_train, batchsize)
# Onehot and batch the labels
y_train = onehot(labels_raw) |> gdev
y_train = batchview(y_train, batchsize)
return x_train, y_train
end
```
8 changes: 0 additions & 8 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ const DEQs = DeepEquilibriumNetworks
include("layers.jl")
include("utils.jl")

## FIXME: Remove once Manifest is removed
using SciMLBase, SciMLSensitivity

@inline __default_sensealg(::SteadyStateProblem) = SteadyStateAdjoint(;
autojacvec=ZygoteVJP(), linsolve_kwargs=(; maxiters=10, abstol=1e-3, reltol=1e-3))
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
## FIXME: Remove once Manifest is removed

# Exports
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,
Expand Down

0 comments on commit 807a623

Please sign in to comment.