Skip to content

Commit

Permalink
Reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 4, 2024
1 parent f7743b4 commit 88dd1d5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# TODO: Special Handling for GPU Arrays with @sync
function benchmark_forward_pass(tag::String, model, x, ps, st)
SUITE[tag]["cpu"]["forward"]["default"] = @benchmarkable Lux.apply($model, $x, $ps, $st)
function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps, st)
SUITE[tag]["cpu"]["forward"]["default"][end_tag] = @benchmarkable Lux.apply(
$model, $x, $ps, $st)

ps_ca = ComponentArray(ps)
SUITE[tag]["cpu"]["forward"]["ComponentArray"] = @benchmarkable Lux.apply(
SUITE[tag]["cpu"]["forward"]["ComponentArray"][end_tag] = @benchmarkable Lux.apply(
$model, $x, $ps_ca, $st)

return
Expand Down
4 changes: 2 additions & 2 deletions bench/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function add_dense_benchmarks!()
for n in (2, 20, 200, 2000)
layer = Dense(n => n)
x, ps, st = general_setup(layer, (n, 128))
benchmark_forward_pass("Dense($n => $n) -- ($n, 128)", layer, x, ps, st)
benchmark_forward_pass("Dense($n => $n)", "($n, 128)", layer, x, ps, st)
end

return
Expand All @@ -13,7 +13,7 @@ function add_conv_benchmarks!()
layer = Conv((3, 3), ch => ch)
x, ps, st = general_setup(layer, (64, 64, ch, 128))
benchmark_forward_pass(
"Conv((3, 3), $ch => $ch) -- (64, 64, $ch, 128)", layer, x, ps, st)
"Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)", layer, x, ps, st)
end
end

Expand Down
2 changes: 1 addition & 1 deletion bench/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function add_vgg_benchmarks!()

for bsize in (1, 16, 64)
x, ps, st = general_setup(vgg16, (32, 32, 3, bsize))
benchmark_forward_pass("vgg16 -- (32, 32, 3, $bsize)", vgg16, x, ps, st)
benchmark_forward_pass("vgg16", "(32, 32, 3, $bsize)", vgg16, x, ps, st)
end

return
Expand Down

0 comments on commit 88dd1d5

Please sign in to comment.