From 3713531dee3dc2be1448d2c2d23c97dc1368268e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 18 Oct 2024 11:39:14 -0400 Subject: [PATCH] docs: update deps for the NeuralODE tutorial --- examples/NeuralODE/Project.toml | 14 +++++++------- examples/NeuralODE/main.jl | 10 +++++----- test/qa_tests.jl | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/NeuralODE/Project.toml b/examples/NeuralODE/Project.toml index 6ee61e610..cc869ad93 100644 --- a/examples/NeuralODE/Project.toml +++ b/examples/NeuralODE/Project.toml @@ -7,7 +7,7 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" @@ -17,12 +17,12 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ComponentArrays = "0.15" Lux = "1" -LuxCUDA = "0.2, 0.3" -MLDatasets = "0.5, 0.7" -MLUtils = "0.2, 0.3, 0.4" -OneHotArrays = "0.1, 0.2" -Optimisers = "0.2, 0.3" -OrdinaryDiffEq = "6" +LuxCUDA = "0.3" +MLDatasets = "0.7" +MLUtils = "0.4" +OneHotArrays = "0.2" +Optimisers = "0.3" +OrdinaryDiffEqTsit5 = "1" SciMLSensitivity = "7.63" Statistics = "1" Zygote = "0.6" diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index ae63407c2..9640b6a2c 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -7,8 +7,8 @@ # ## Package Imports -using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEq, Random, - Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf +using Lux, ComponentArrays, SciMLSensitivity, LuxCUDA, Optimisers, OrdinaryDiffEqTsit5, + Random, Statistics, Zygote, OneHotArrays, InteractiveUtils, Printf using MLDatasets: MNIST using MLUtils: DataLoader, splitobs @@ -139,9 +139,9 @@ function train(model_function; cpu::Bool=false, kwargs...) end ttime = time() - stime - tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) - te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) - @printf "[%d/%d] \t Time %.2fs \t Training Accuracy: %.5f%% \t Test \ + tr_acc = accuracy(model, tstate.parameters, tstate.states, train_dataloader) * 100 + te_acc = accuracy(model, tstate.parameters, tstate.states, test_dataloader) * 100 + @printf "[%d/%d]\tTime %.4fs\tTraining Accuracy: %.5f%%\tTest \ Accuracy: %.5f%%\n" epoch nepochs ttime tr_acc te_acc end end diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 074f464b0..49977d1f6 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -10,7 +10,7 @@ Aqua.test_piracies(Lux; treat_as_own=[Lux.outputsize]) end -@testitem "Explicit Imports: Quality Assurance" setup=[SharedTestSetup] tags=[:others] begin +@testitem "Explicit Imports: Quality Assurance" tags=[:others] begin # Load all trigger packages import Lux, ComponentArrays, ReverseDiff, SimpleChains, Tracker, Zygote, Enzyme using ExplicitImports