-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #575 from LuxDL/ap/ad_benchmarks
More Continuous Benchmarks
- Loading branch information
Showing
10 changed files
with
269 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Lux" | ||
uuid = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.5.32" | ||
version = "0.5.33" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
|
@@ -70,7 +70,7 @@ LuxAMDGPU = "0.2.2" | |
LuxCUDA = "0.3.2" | ||
LuxCore = "0.1.12" | ||
LuxDeviceUtils = "0.1.16" | ||
LuxLib = "0.3.10" | ||
LuxLib = "0.3.11" | ||
LuxTestUtils = "0.1.15" | ||
MacroTools = "0.5.13" | ||
Markdown = "1.10" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,18 @@ | ||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" | ||
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" | ||
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,131 @@ | ||
# TODO: Special Handling for GPU Arrays with @sync | ||
function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::NamedTuple, | ||
st) | ||
function benchmark_forward_pass( | ||
tag::String, end_tag::String, model, x_dims; simple_chains=nothing, | ||
flux_model=nothing) | ||
SUITE[tag]["cpu"]["forward"]["NamedTuple"][end_tag] = @benchmarkable Lux.apply( | ||
$model, $x, $ps_nt, $st) | ||
$model, x, ps_nt, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims)) | ||
|
||
ps_ca = ComponentArray(ps_nt) | ||
SUITE[tag]["cpu"]["forward"]["ComponentArray"][end_tag] = @benchmarkable Lux.apply( | ||
$model, $x, $ps_ca, $st) | ||
$model, x, ps_ca, st) setup=((x, ps_nt, st) = general_setup($model, $x_dims); ps_ca = ComponentArray(ps_nt)) | ||
|
||
if simple_chains !== nothing | ||
simple_chains_model = simple_chains(model) | ||
SUITE[tag]["cpu"]["forward"]["SimpleChains"][end_tag] = @benchmarkable Lux.apply( | ||
$simple_chains_model, x, ps_simple_chains, st_simple_chains) setup=((x, ps_simple_chains, st_simple_chains) = general_setup( | ||
$simple_chains_model, $x_dims)) | ||
end | ||
|
||
if flux_model !== nothing | ||
SUITE[tag]["cpu"]["forward"]["Flux"][end_tag] = @benchmarkable fmodel(x) setup=(x = randn( | ||
StableRNG(0), Float32, $x_dims); | ||
fmodel = $(flux_model())) | ||
end | ||
|
||
return | ||
end | ||
|
||
function benchmark_reverse_pass( | ||
tag::String, end_tag::String, backends, model, x_dims; | ||
simple_chains=nothing, flux_model=nothing) | ||
for backend in backends | ||
__benchmark_reverse_pass(tag, end_tag, backend, model, x_dims) | ||
end | ||
|
||
if simple_chains !== nothing | ||
simple_chains_model = simple_chains(model) | ||
__benchmark_reverse_pass_simple_chains( | ||
tag, end_tag, AutoZygote(), simple_chains_model, x_dims) | ||
end | ||
|
||
if flux_model !== nothing | ||
__benchmark_reverse_pass_flux(tag, end_tag, AutoZygote(), flux_model, x_dims) | ||
end | ||
|
||
return | ||
end | ||
|
||
function general_setup(model, x_dims) | ||
rng = StableRNG(0) | ||
ps, st = Lux.setup(rng, model) | ||
x_dims === nothing && return ps, st | ||
x = randn(rng, Float32, x_dims) | ||
return x, ps, st | ||
end | ||
|
||
# TODO: Remove these once DifferentiationInterface has been released | ||
function __benchmark_reverse_pass( | ||
tag::String, end_tag::String, ::AutoEnzyme, model, x_dims) | ||
# TODO: Enable this. But enzyme doesn't handle closures well it seems... | ||
# SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient( | ||
# $Enzyme.Reverse, $f, $x) | ||
return error("Enzyme backend hasn't been implemented yet.") | ||
end | ||
function __benchmark_reverse_pass( | ||
tag::String, end_tag::String, ::AutoTapir, model, x_dims) | ||
SUITE[tag]["cpu"]["reverse"]["Tapir"][end_tag] = @benchmarkable Tapir.value_and_pullback!!( | ||
trrule, 1.0f0, f, ps_ca) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
ps_ca = ComponentArray(ps) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
trrule = Tapir.build_rrule(f, ps_ca) | ||
end | ||
return | ||
end | ||
function __benchmark_reverse_pass( | ||
tag::String, end_tag::String, ::AutoTracker, model, x_dims) | ||
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient( | ||
f, ps_ca) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
ps_ca = ComponentArray(ps) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
end | ||
return | ||
end | ||
function __benchmark_reverse_pass( | ||
tag::String, end_tag::String, ad::AutoReverseDiff, model, x_dims) | ||
if ad.compile | ||
SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!( | ||
∂ps, tape, ps_ca) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
ps_ca = ComponentArray(ps) | ||
∂ps = similar(ps_ca) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, ps_ca)) | ||
end | ||
else | ||
SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable ReverseDiff.gradient( | ||
f, ps_ca) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
ps_ca = ComponentArray(ps) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
end | ||
end | ||
end | ||
function __benchmark_reverse_pass(tag::String, end_tag::String, ::AutoZygote, model, x_dims) | ||
SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient( | ||
f, ps_ca) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
ps_ca = ComponentArray(ps) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
end | ||
return | ||
end | ||
function __benchmark_reverse_pass_simple_chains( | ||
tag::String, end_tag::String, ::AutoZygote, model, x_dims) | ||
SUITE[tag]["cpu"]["reverse"]["SimpleChains"][end_tag] = @benchmarkable Zygote.gradient( | ||
f, ps) setup=begin | ||
(x, ps, st) = general_setup($model, $x_dims) | ||
f = @closure(p->sum(abs2, first(Lux.apply($model, x, p, st)))) | ||
end | ||
return | ||
end | ||
function __benchmark_reverse_pass_flux( | ||
tag::String, end_tag::String, ::AutoZygote, model, x_dims) | ||
SUITE[tag]["cpu"]["reverse"]["Flux"][end_tag] = @benchmarkable Zygote.gradient( | ||
f, m) setup=begin | ||
x = randn(StableRNG(0), Float32, $x_dims) | ||
m = $(model)() | ||
f = @closure(m->sum(abs2, m(x))) | ||
end | ||
return | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
80e3475
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128)
3267.375
nsDense(2 => 2)/cpu/reverse/Zygote/(2, 128)
9518
nsDense(2 => 2)/cpu/reverse/Tracker/(2, 128)
16922
nsDense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128)
4762.5
nsDense(2 => 2)/cpu/reverse/Flux/(2, 128)
7057.2
nsDense(2 => 2)/cpu/reverse/SimpleChains/(2, 128)
1712.7
nsDense(2 => 2)/cpu/forward/NamedTuple/(2, 128)
1091
ns2001.7
ns0.55
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128)
701.922077922078
ns1590.4709677419355
ns0.44
Dense(2 => 2)/cpu/forward/Flux/(2, 128)
1308.4
nsDense(2 => 2)/cpu/forward/SimpleChains/(2, 128)
179.44851904090268
nsDense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128)
17643
nsDense(20 => 20)/cpu/reverse/Zygote/(20, 128)
25061.5
nsDense(20 => 20)/cpu/reverse/Tracker/(20, 128)
37245
nsDense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128)
23073
nsDense(20 => 20)/cpu/reverse/Flux/(20, 128)
21485
nsDense(20 => 20)/cpu/reverse/SimpleChains/(20, 128)
13360
nsDense(20 => 20)/cpu/forward/NamedTuple/(20, 128)
5101.714285714286
ns4869.142857142857
ns1.05
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128)
5107.428571428572
ns4718.857142857143
ns1.08
Dense(20 => 20)/cpu/forward/Flux/(20, 128)
5200.428571428571
nsDense(20 => 20)/cpu/forward/SimpleChains/(20, 128)
1682.65
nsConv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128)
39284500.5
nsConv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128)
91429963.5
nsConv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128)
92389439.5
nsConv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128)
39691452
nsConv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128)
89786916.5
nsConv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128)
12259593
nsConv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128)
10225995
ns10357033
ns0.99
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128)
10142638
ns10383988.5
ns0.98
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128)
10137133.5
nsConv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128)
6423977
nsvgg16/cpu/reverse/Zygote/(32, 32, 3, 1)
4365894701
nsvgg16/cpu/reverse/Zygote/(32, 32, 3, 16)
5118048782
nsvgg16/cpu/reverse/Zygote/(32, 32, 3, 64)
15719820981
nsvgg16/cpu/reverse/Tracker/(32, 32, 3, 1)
1394532184
nsvgg16/cpu/reverse/Tracker/(32, 32, 3, 16)
2737271241
nsvgg16/cpu/reverse/Tracker/(32, 32, 3, 64)
5392134495
nsvgg16/cpu/reverse/Flux/(32, 32, 3, 1)
86035309
nsvgg16/cpu/reverse/Flux/(32, 32, 3, 16)
750690104
nsvgg16/cpu/reverse/Flux/(32, 32, 3, 64)
3106551216
nsvgg16/cpu/forward/NamedTuple/(32, 32, 3, 1)
23365059.5
ns22706894
ns1.03
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16)
248196871
ns251415835
ns0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64)
1003838362
ns990621987
ns1.01
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 1)
23649034
ns22766979.5
ns1.04
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16)
248954800
ns249637390
ns1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64)
1074000658
ns991675056.5
ns1.08
vgg16/cpu/forward/Flux/(32, 32, 3, 1)
22160375
nsvgg16/cpu/forward/Flux/(32, 32, 3, 16)
247762136
nsvgg16/cpu/forward/Flux/(32, 32, 3, 64)
895437036
nsConv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128)
1052351659
nsConv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128)
2300391794
nsConv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128)
2730677981
nsConv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128)
1401783862
nsConv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128)
2328019747
nsConv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128)
375448343
ns371776615
ns1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128)
374614390
ns372349398
ns1.01
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128)
375734589
nsConv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128)
13597474
nsConv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128)
30506392
nsConv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128)
31258806
nsConv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128)
13877995
nsConv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128)
30495395
nsConv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128)
1149994
nsConv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128)
3880661
ns3900400
ns0.99
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128)
3880529
ns3899295
ns1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128)
3876882
nsConv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128)
195006
nsDense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128)
297678
nsDense(200 => 200)/cpu/reverse/Zygote/(200, 128)
315185.5
nsDense(200 => 200)/cpu/reverse/Tracker/(200, 128)
397685
nsDense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128)
344595.5
nsDense(200 => 200)/cpu/reverse/Flux/(200, 128)
471604
nsDense(200 => 200)/cpu/reverse/SimpleChains/(200, 128)
395982
nsDense(200 => 200)/cpu/forward/NamedTuple/(200, 128)
91626.5
ns87314
ns1.05
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128)
95404
ns87424
ns1.09
Dense(200 => 200)/cpu/forward/Flux/(200, 128)
87344
nsDense(200 => 200)/cpu/forward/SimpleChains/(200, 128)
104776.5
nsConv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128)
191759064
nsConv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128)
511408215
nsConv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128)
523965339
nsConv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128)
212766749
nsConv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128)
486845887
nsConv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128)
320740323
nsConv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128)
64118671
ns56912462
ns1.13
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128)
64797793
ns56953662
ns1.14
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128)
64228406
nsConv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128)
29516170.5
nsDense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128)
19828112
nsDense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128)
25262569
nsDense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128)
34145184
nsDense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128)
22848976.5
nsDense(2000 => 2000)/cpu/reverse/Flux/(2000, 128)
19974956
nsDense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128)
6640955
ns6515785
ns1.02
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128)
6637819
ns6560088
ns1.01
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128)
6568209
nsThis comment was automatically generated by workflow using github-action-benchmark.