diff --git a/bench/Project.toml b/bench/Project.toml index 9a067a3265..32a5421326 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -4,6 +4,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/bench/helpers.jl b/bench/helpers.jl index 5aa79988c5..efe3521fed 100644 --- a/bench/helpers.jl +++ b/bench/helpers.jl @@ -1,41 +1,44 @@ # TODO: Special Handling for GPU Arrays with @sync -function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::NamedTuple, - st; simple_chains=nothing) +function benchmark_forward_pass( + tag::String, end_tag::String, model, x_dims; simple_chains=nothing, + flux_model=nothing) SUITE[tag]["cpu"]["forward"]["NamedTuple"][end_tag] = @benchmarkable Lux.apply( - $model, $x, $ps_nt, $st) + $model, x, ps_nt, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims)) - ps_ca = ComponentArray(ps_nt) SUITE[tag]["cpu"]["forward"]["ComponentArray"][end_tag] = @benchmarkable Lux.apply( - $model, $x, $ps_ca, $st) + $model, x, ps_ca, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims); ps_ca = ComponentArray(ps_nt)) if simple_chains !== nothing simple_chains_model = simple_chains(model) - ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing) SUITE[tag]["cpu"]["forward"]["SimpleChains"][end_tag] = @benchmarkable Lux.apply( - $simple_chains_model, $x, $ps_simple_chains, $st_simple_chains) + $simple_chains_model, x, ps_simple_chains, st_simple_chains) setup=((x, ps_simple_chains, st_simple_chains) = general_setup( + $simple_chains_model, $x_dims)) + end + + if flux_model !== nothing + SUITE[tag]["cpu"]["forward"]["Flux"][end_tag] = @benchmarkable fmodel(x) setup=(x = randn( + StableRNG(0), Float32, $x_dims); + fmodel = $(flux_model())) end return end function benchmark_reverse_pass( - tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st; - simple_chains=nothing) - # Not everyone can handle NamedTuples so convert to ComponentArray - __f = @closure ps -> sum(abs2, first(Lux.apply(model, x, ps, st))) - ps_ca = ComponentArray(ps_nt) - + tag::String, end_tag::String, backends, model, x_dims; + simple_chains=nothing, flux_model=nothing) for backend in backends - __benchmark_reverse_pass(tag, end_tag, backend, __f, ps_ca) + __benchmark_reverse_pass(tag, end_tag, backend, model, x_dims) end if simple_chains !== nothing simple_chains_model = simple_chains(model) - ps_simple_chains, st_simple_chains = general_setup(simple_chains_model, nothing) - __f = @closure ps -> sum( - abs2, first(Lux.apply(simple_chains_model, x, ps, st_simple_chains))) __benchmark_reverse_pass_simple_chains( - tag, end_tag, AutoZygote(), __f, ps_simple_chains) + tag, end_tag, AutoZygote(), simple_chains_model, x_dims) + end + + if flux_model !== nothing + __benchmark_reverse_pass_flux(tag, end_tag, AutoZygote(), flux_model, x_dims) end return @@ -51,41 +54,78 @@ end # TODO: Remove these once DifferentiationInterface has been released function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoEnzyme, f::F, x; kwargs...) where {F} + tag::String, end_tag::String, ::AutoEnzyme, model, x_dims) # TODO: Enable this. But enzyme doesn't handle closures well it seems... # SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient( # $Enzyme.Reverse, $f, $x) return error("Enzyme backend hasn't been implemented yet.") end function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoTapir, f::F, x; kwargs...) where {F} + tag::String, end_tag::String, ::AutoTapir, model, x_dims) + SUITE[tag]["cpu"]["reverse"]["Tapir"][end_tag] = @benchmarkable Tapir.value_and_pullback!!( + trrule, 1.0f0, f, ps_ca) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + ps_ca = ComponentArray(ps) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + trrule = Tapir.build_rrule(f, ps_ca) + end + return end function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoTracker, f::F, x; kwargs...) where {F} + tag::String, end_tag::String, ::AutoTracker, model, x_dims) SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient( - $f, $x) + f, ps_ca) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + ps_ca = ComponentArray(ps) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + end return end function __benchmark_reverse_pass( - tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x; kwargs...) where {F} + tag::String, end_tag::String, ad::AutoReverseDiff, model, x_dims) if ad.compile SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!( - ∂x, tape, $x) setup=(∂x = similar($x); - tape = ReverseDiff.compile(ReverseDiff.GradientTape($f, $x))) + ∂x, tape, ps_ca) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + ∂x = similar(x) + ps_ca = ComponentArray(ps) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, ps_ca)) + end else SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable ReverseDiff.gradient( - $f, $x) + f, ps_ca) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + ps_ca = ComponentArray(ps) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + end end end -function __benchmark_reverse_pass( - tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F} +function __benchmark_reverse_pass(tag::String, end_tag::String, ::AutoZygote, model, x_dims) SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient( - $f, $x) + f, ps_ca) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + ps_ca = ComponentArray(ps) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + end return end function __benchmark_reverse_pass_simple_chains( - tag::String, end_tag::String, ::AutoZygote, f::F, x; kwargs...) where {F} + tag::String, end_tag::String, ::AutoZygote, model, x_dims) SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient( - $f, $x) + f, ps) setup=begin + (x, ps, st) = general_setup($model, $x_dims) + f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) + end + return +end +function __benchmark_reverse_pass_flux( + tag::String, end_tag::String, ::AutoZygote, model, x_dims) + SUITE[tag]["cpu"]["reverse"]["Flux"][end_tag] = @benchmarkable Zygote.gradient( + f, m) setup=begin + x = randn(StableRNG(0), Float32, $x_dims) + m = $(model)() + f = @closure(m->sum(abs2, m(x))) + end return end diff --git a/bench/layers.jl b/bench/layers.jl index c1d54c8df8..9dee9b40b2 100644 --- a/bench/layers.jl +++ b/bench/layers.jl @@ -1,15 +1,14 @@ function add_dense_benchmarks!() for n in (2, 20, 200, 2000) layer = Dense(n => n) - x, ps, st = general_setup(layer, (n, 128)) - simple_chains = Lux.ToSimpleChainsAdaptor((static(n),)) + simple_chains = n ≤ 200 ? Lux.ToSimpleChainsAdaptor((static(n),)) : nothing + flux_model = () -> Flux.Dense(n => n) benchmark_forward_pass( - "Dense($n => $n)", "($n, 128)", layer, x, ps, st; simple_chains) + "Dense($n => $n)", "($n, 128)", layer, (n, 128); simple_chains, flux_model) benchmark_reverse_pass( "Dense($n => $n)", "($n, 128)", - (AutoTapir(), AutoTracker(), AutoReverseDiff(), - AutoReverseDiff(true), AutoZygote()), - layer, x, ps, st; simple_chains) + (AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()), + layer, (n, 128); simple_chains, flux_model) end return @@ -18,13 +17,15 @@ end function add_conv_benchmarks!() for ch in (1, 3, 16, 64) layer = Conv((3, 3), ch => ch) - x, ps, st = general_setup(layer, (64, 64, ch, 128)) - simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch))) + simple_chains = ch ≤ 16 ? + Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch))) : + nothing + flux_model = () -> Flux.Conv((3, 3), ch => ch) benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", - layer, x, ps, st; simple_chains) + layer, (64, 64, ch, 128); simple_chains, flux_model) benchmark_reverse_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", - (AutoTapir(), AutoTracker(), AutoReverseDiff(), - AutoReverseDiff(true), AutoZygote()), layer, x, ps, st; simple_chains) + (AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()), + layer, (64, 64, ch, 128); simple_chains, flux_model) end end diff --git a/bench/runbenchmarks.jl b/bench/runbenchmarks.jl index e6f8ae5a21..af45a24bc2 100644 --- a/bench/runbenchmarks.jl +++ b/bench/runbenchmarks.jl @@ -3,6 +3,7 @@ using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @btime, @benchmarkable using ComponentArrays: ComponentArray using InteractiveUtils: versioninfo using FastClosures: @closure +using Flux: Flux using Lux: Lux, BatchNorm, Chain, Conv, Dense, Dropout, FlattenLayer, MaxPool using NNlib: relu using SimpleChains: SimpleChains, static @@ -27,10 +28,11 @@ struct AutoTapir <: ADTypes.AbstractReverseMode end const SUITE = BenchmarkGroup() include("helpers.jl") -# include("vgg.jl") +include("vgg.jl") include("layers.jl") BenchmarkTools.tune!(SUITE; verbose=true) results = BenchmarkTools.run(SUITE; verbose=true) +display(median(results)) BenchmarkTools.save(joinpath(@__DIR__, "benchmark_results.json"), median(results)) diff --git a/bench/vgg.jl b/bench/vgg.jl index eaabffc1ee..80393de753 100644 --- a/bench/vgg.jl +++ b/bench/vgg.jl @@ -17,12 +17,33 @@ function add_vgg_benchmarks!() BatchNorm(512), MaxPool((2, 2)), FlattenLayer(), Dense(512, 4096, relu), Dropout(0.5), Dense(4096, 4096, relu), Dropout(0.5), Dense(4096, 10)) + flux_model = () -> Flux.Chain( + Flux.Conv((3, 3), 3 => 64, relu; pad=(1, 1), stride=(1, 1)), + Flux.BatchNorm(64), Flux.Conv((3, 3), 64 => 64, relu; pad=(1, 1), stride=(1, 1)), + Flux.BatchNorm(64), Flux.MaxPool((2, 2)), + Flux.Conv((3, 3), 64 => 128, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(128), + Flux.Conv((3, 3), 128 => 128, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(128), + Flux.MaxPool((2, 2)), + Flux.Conv((3, 3), 128 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256), + Flux.Conv((3, 3), 256 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256), + Flux.Conv((3, 3), 256 => 256, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(256), + Flux.MaxPool((2, 2)), + Flux.Conv((3, 3), 256 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.MaxPool((2, 2)), + Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.Conv((3, 3), 512 => 512, relu; pad=(1, 1), stride=(1, 1)), Flux.BatchNorm(512), + Flux.MaxPool((2, 2)), Flux.flatten, Flux.Dense(512, 4096, relu), Flux.Dropout(0.5), + Flux.Dense(4096, 4096, relu), Flux.Dropout(0.5), Flux.Dense(4096, 10)) + for bsize in (1, 16, 64) - x, ps, st = general_setup(vgg16, (32, 32, 3, bsize)) - benchmark_forward_pass("vgg16", "(32, 32, 3, $bsize)", vgg16, x, ps, st) + benchmark_forward_pass( + "vgg16", "(32, 32, 3, $bsize)", vgg16, (32, 32, 3, bsize); flux_model) benchmark_reverse_pass( - "vgg16", "(32, 32, 3, $bsize)", - (AutoTapir(), AutoTracker(), AutoZygote()), vgg16, x, ps, st) + "vgg16", "(32, 32, 3, $bsize)", (AutoTracker(), AutoZygote()), + vgg16, (32, 32, 3, bsize); flux_model) end return diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 17784060b3..45ca04fed5 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -62,9 +62,9 @@ export default defineConfig({ }, nav: [ { text: 'Home', link: '/' }, - { text: 'Getting Started', link: '/introduction/index' }, + { text: 'Getting Started', link: '/introduction' }, { text: 'Benchmarks', link: 'https://lux.csail.mit.edu/benchmarks/' }, - { text: 'Tutorials', link: '/tutorials/index' }, + { text: 'Tutorials', link: '/tutorials' }, { text: 'Manual', link: '/manual/interface' }, { text: 'API', items: [ @@ -104,22 +104,22 @@ export default defineConfig({ }, { text: 'Versions', items: [ - { text: 'Stable', link: 'https://lux.csail.mit.edu/stable/' }, - { text: 'Dev', link: 'https://lux.csail.mit.edu/dev/' } + { text: 'Stable', link: 'https://lux.csail.mit.edu/stable' }, + { text: 'Dev', link: 'https://lux.csail.mit.edu/dev' } ] } ], sidebar: { "/introduction/": { text: 'Getting Started', collapsed: false, items: [ - { text: 'Introduction', link: '/introduction/index' }, + { text: 'Introduction', link: '/introduction' }, { text: 'Overview', link: '/introduction/overview' }, { text: 'Resources', link: '/introduction/resources' }, { text: 'Citation', link: '/introduction/citation' }] }, "/tutorials/": { text: 'Tutorials', collapsed: false, items: [ - { text: 'Overview', link: '/tutorials/index' }, + { text: 'Overview', link: '/tutorials' }, { text: 'Beginner', collapsed: false, items: [ { text: 'Julia & Lux for the Uninitiated', link: '/tutorials/beginner/1_Basics' }, diff --git a/docs/src/index.md b/docs/src/index.md index 6403b2fe58..ce2a2a339a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -10,7 +10,7 @@ hero: actions: - theme: brand text: Tutorials - link: /tutorials/ + link: /tutorials - theme: alt text: Ecosystem link: /ecosystem @@ -28,7 +28,7 @@ features: - icon: 🚀 title: Fast & Extendible details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal Hardware. - link: /introduction/ + link: /introduction - icon: 🧑‍🔬 title: SciML ❤️ Lux diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 4bd07fe172..5ee530a230 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -89,7 +89,7 @@ @test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k) @jet layer(x, ps, st) __f = (x, ps) -> sum(first(layer(x, ps, st))) - @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu + @eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true end end end