Skip to content

Commit

Permalink
Benchmark Tapir
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 6, 2024
1 parent 6c0efeb commit 948fbc2
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 18 deletions.
3 changes: 3 additions & 0 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ function __benchmark_reverse_pass(
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTapir, f::F, x; kwargs...) where {F}
SUITE[tag]["cpu"]["reverse"]["Tapir"][end_tag] = @benchmarkable Tapir.value_and_pullback!!(
trrule, 1.0f0, $f, $x) setup=(trrule = Tapir.build_rrule($f, $x))
return
end
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTracker, f::F, x; kwargs...) where {F}
Expand Down
13 changes: 7 additions & 6 deletions bench/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ function add_dense_benchmarks!()
for n in (2, 20, 200, 2000)
layer = Dense(n => n)
x, ps, st = general_setup(layer, (n, 128))
simple_chains = Lux.ToSimpleChainsAdaptor((static(n),))
simple_chains = n 200 ? Lux.ToSimpleChainsAdaptor((static(n),)) : nothing
benchmark_forward_pass(
"Dense($n => $n)", "($n, 128)", layer, x, ps, st; simple_chains)
benchmark_reverse_pass(
"Dense($n => $n)", "($n, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()),
(AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st; simple_chains)
end

Expand All @@ -19,12 +18,14 @@ function add_conv_benchmarks!()
for ch in (1, 3, 16, 64)
layer = Conv((3, 3), ch => ch)
x, ps, st = general_setup(layer, (64, 64, ch, 128))
simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch)))
simple_chains = ch 16 ?
Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch))) :
nothing
benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
layer, x, ps, st; simple_chains)
benchmark_reverse_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
(AutoTapir(), AutoTracker(), AutoReverseDiff(),
AutoReverseDiff(true), AutoZygote()), layer, x, ps, st; simple_chains)
(AutoTracker(), AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st; simple_chains)
end
end

Expand Down
2 changes: 1 addition & 1 deletion bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct AutoTapir <: ADTypes.AbstractReverseMode end
const SUITE = BenchmarkGroup()

include("helpers.jl")
# include("vgg.jl")
include("vgg.jl")
include("layers.jl")

BenchmarkTools.tune!(SUITE; verbose=true)
Expand Down
3 changes: 1 addition & 2 deletions bench/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ function add_vgg_benchmarks!()
x, ps, st = general_setup(vgg16, (32, 32, 3, bsize))
benchmark_forward_pass("vgg16", "(32, 32, 3, $bsize)", vgg16, x, ps, st)
benchmark_reverse_pass(
"vgg16", "(32, 32, 3, $bsize)",
(AutoTapir(), AutoTracker(), AutoZygote()), vgg16, x, ps, st)
"vgg16", "(32, 32, 3, $bsize)", (AutoTracker(), AutoZygote()), vgg16, x, ps, st)
end

return
Expand Down
12 changes: 6 additions & 6 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ export default defineConfig({
},
nav: [
{ text: 'Home', link: '/' },
{ text: 'Getting Started', link: '/introduction/index' },
{ text: 'Getting Started', link: '/introduction' },
{ text: 'Benchmarks', link: 'https://lux.csail.mit.edu/benchmarks/' },
{ text: 'Tutorials', link: '/tutorials/index' },
{ text: 'Tutorials', link: '/tutorials' },
{ text: 'Manual', link: '/manual/interface' },
{
text: 'API', items: [
Expand Down Expand Up @@ -104,22 +104,22 @@ export default defineConfig({
},
{
text: 'Versions', items: [
{ text: 'Stable', link: 'https://lux.csail.mit.edu/stable/' },
{ text: 'Dev', link: 'https://lux.csail.mit.edu/dev/' }
{ text: 'Stable', link: 'https://lux.csail.mit.edu/stable' },
{ text: 'Dev', link: 'https://lux.csail.mit.edu/dev' }
]
}
],
sidebar: {
"/introduction/": {
text: 'Getting Started', collapsed: false, items: [
{ text: 'Introduction', link: '/introduction/index' },
{ text: 'Introduction', link: '/introduction' },
{ text: 'Overview', link: '/introduction/overview' },
{ text: 'Resources', link: '/introduction/resources' },
{ text: 'Citation', link: '/introduction/citation' }]
},
"/tutorials/": {
text: 'Tutorials', collapsed: false, items: [
{ text: 'Overview', link: '/tutorials/index' },
{ text: 'Overview', link: '/tutorials' },
{
text: 'Beginner', collapsed: false, items: [
{ text: 'Julia & Lux for the Uninitiated', link: '/tutorials/beginner/1_Basics' },
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ hero:
actions:
- theme: brand
text: Tutorials
link: /tutorials/
link: /tutorials
- theme: alt
text: Ecosystem
link: /ecosystem
Expand All @@ -28,7 +28,7 @@ features:
- icon: 🚀
title: Fast & Extendible
details: Lux.jl is written in Julia itself, making it extremely extendible. CUDA and AMDGPU are supported first-class, with experimental support for Metal Hardware.
link: /introduction/
link: /introduction
- icon: 🧑‍🔬
title: SciML ❤️ Lux
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
@test size(layer(x, ps, st)[1])[1:(end - 2)] == cld.(size(x)[1:(end - 2)], k)
@jet layer(x, ps, st)
__f = (x, ps) -> sum(first(layer(x, ps, st)))
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu
@eval @test_gradients $__f $x $ps atol=1.0f-3 rtol=1.0f-3 gpu_testing=$ongpu skip_finite_differences=true
end
end
end
Expand Down

0 comments on commit 948fbc2

Please sign in to comment.