From 88dd1d5f3f2a3d2df4a085d96adb2ad5ad73e72e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 4 Apr 2024 12:28:41 -0400 Subject: [PATCH] Reorganize --- bench/helpers.jl | 7 ++++--- bench/layers.jl | 4 ++-- bench/vgg.jl | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/bench/helpers.jl b/bench/helpers.jl index 728340700d..787a9a71d4 100644 --- a/bench/helpers.jl +++ b/bench/helpers.jl @@ -1,9 +1,10 @@ # TODO: Special Handling for GPU Arrays with @sync -function benchmark_forward_pass(tag::String, model, x, ps, st) - SUITE[tag]["cpu"]["forward"]["default"] = @benchmarkable Lux.apply($model, $x, $ps, $st) +function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps, st) + SUITE[tag]["cpu"]["forward"]["default"][end_tag] = @benchmarkable Lux.apply( + $model, $x, $ps, $st) ps_ca = ComponentArray(ps) - SUITE[tag]["cpu"]["forward"]["ComponentArray"] = @benchmarkable Lux.apply( + SUITE[tag]["cpu"]["forward"]["ComponentArray"][end_tag] = @benchmarkable Lux.apply( $model, $x, $ps_ca, $st) return diff --git a/bench/layers.jl b/bench/layers.jl index a964faaec2..f9c51b1e3a 100644 --- a/bench/layers.jl +++ b/bench/layers.jl @@ -2,7 +2,7 @@ function add_dense_benchmarks!() for n in (2, 20, 200, 2000) layer = Dense(n => n) x, ps, st = general_setup(layer, (n, 128)) - benchmark_forward_pass("Dense($n => $n) -- ($n, 128)", layer, x, ps, st) + benchmark_forward_pass("Dense($n => $n)", "($n, 128)", layer, x, ps, st) end return @@ -13,7 +13,7 @@ function add_conv_benchmarks!() layer = Conv((3, 3), ch => ch) x, ps, st = general_setup(layer, (64, 64, ch, 128)) benchmark_forward_pass( - "Conv((3, 3), $ch => $ch) -- (64, 64, $ch, 128)", layer, x, ps, st) + "Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", layer, x, ps, st) end end diff --git a/bench/vgg.jl b/bench/vgg.jl index b4b19f146d..1464ffd8bf 100644 --- a/bench/vgg.jl +++ b/bench/vgg.jl @@ -19,7 +19,7 @@ function add_vgg_benchmarks!() for bsize in (1, 16, 64) x, ps, st = general_setup(vgg16, (32, 32, 3, bsize)) - benchmark_forward_pass("vgg16 -- (32, 32, 3, $bsize)", vgg16, x, ps, st) + benchmark_forward_pass("vgg16", "(32, 32, 3, $bsize)", vgg16, x, ps, st) end return