Skip to content

Commit

Permalink
Benchmark AD on some problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 5, 2024
1 parent 5fdf7b6 commit c464d18
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 12 deletions.
46 changes: 36 additions & 10 deletions bench/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: Special Handling for GPU Arrays with @sync
function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::NamedTuple,
st; simple_chains = nothing)
st; simple_chains=nothing)
SUITE[tag]["cpu"]["forward"]["NamedTuple"][end_tag] = @benchmarkable Lux.apply(
$model, $x, $ps_nt, $st)

Expand All @@ -19,10 +19,15 @@ function benchmark_forward_pass(tag::String, end_tag::String, model, x, ps_nt::N
end

function benchmark_reverse_pass(
tag::String, end_tag::String, backends::NTuple, model, x, ps_nt::NamedTuple, st)
tag::String, end_tag::String, backends, model, x, ps_nt::NamedTuple, st)
# Not everyone can handle NamedTuples so convert to ComponentArray
__f = @closure ps -> sum(abs2, first(Lux.apply(model, x, ps, st)))
ps_ca = ComponentArray(ps_nt)

for backend in backends
__benchmark_reverse_pass(tag, end_tag, backend, __f, ps_ca)
end

return
end

Expand All @@ -34,17 +39,38 @@ function general_setup(model, x_dims)
return x, ps, st
end

@inline __typein(::Type{T}, x) where {T} = any(Base.Fix2(isa, T), x)

# TODO: Remove these once DifferentiationInterface has been released
function __benchmark_tapir_reverse_pass(tag::String, end_tag::String, f::F, x) where {F}
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoEnzyme, f::F, x) where {F}
# TODO: Enable this. But enzyme doesn't handle closures well it seems...
# SUITE[tag]["cpu"]["reverse"]["Enzyme"][end_tag] = @benchmarkable Enzyme.gradient(
# $Enzyme.Reverse, $f, $x)
return
end
function __benchmark_tracker_reverse_pass(tag::String, end_tag::String, f::F, x) where {F}
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTapir, f::F, x) where {F}
end
function __benchmark_enzyme_reverse_pass(tag::String, end_tag::String, f::F, x) where {F}
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoTracker, f::F, x) where {F}
SUITE[tag]["cpu"]["reverse"]["Tracker"][end_tag] = @benchmarkable Tracker.gradient(
$f, $x)
return
end
function __benchmark_reversediff_reverse_pass(
tag::String, end_tag::String, f::F, x) where {F}
function __benchmark_reverse_pass(
tag::String, end_tag::String, ad::AutoReverseDiff, f::F, x) where {F}
if ad.compile
tape = ReverseDiff.compile(ReverseDiff.GradientTape(f, x))
∂x = similar(x)
SUITE[tag]["cpu"]["reverse"]["ReverseDiff (compiled)"][end_tag] = @benchmarkable ReverseDiff.gradient!(
$∂x, $tape, $x)
else
SUITE[tag]["cpu"]["reverse"]["ReverseDiff"][end_tag] = @benchmarkable ReverseDiff.gradient(
$f, $x)
end
end
function __benchmark_zygote_reverse_pass(tag::String, end_tag::String, f::F, x) where {F}
function __benchmark_reverse_pass(
tag::String, end_tag::String, ::AutoZygote, f::F, x) where {F}
SUITE[tag]["cpu"]["reverse"]["Zygote"][end_tag] = @benchmarkable Zygote.gradient(
$f, $x)
return
end
10 changes: 10 additions & 0 deletions bench/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ function add_dense_benchmarks!()
simple_chains = Lux.ToSimpleChainsAdaptor((static(n),))
benchmark_forward_pass(
"Dense($n => $n)", "($n, 128)", layer, x, ps, st; simple_chains)
benchmark_reverse_pass(
"Dense($n => $n)", "($n, 128)",
(AutoEnzyme(), AutoTapir(), AutoTracker(),
AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st)
end

return
Expand All @@ -17,6 +22,11 @@ function add_conv_benchmarks!()
simple_chains = Lux.ToSimpleChainsAdaptor((static(64), static(64), static(ch)))
benchmark_forward_pass("Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
layer, x, ps, st; simple_chains)
benchmark_reverse_pass(
"Conv((3, 3), $ch => $ch)", "(64, 64, $ch, 128)",
(AutoEnzyme(), AutoTapir(), AutoTracker(),
AutoReverseDiff(), AutoReverseDiff(true), AutoZygote()),
layer, x, ps, st)
end
end

Expand Down
10 changes: 8 additions & 2 deletions bench/runbenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ADTypes: ADTypes
using ADTypes: ADTypes, AutoEnzyme, AutoTracker, AutoReverseDiff, AutoZygote
using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @btime, @benchmarkable
using ComponentArrays: ComponentArray
using InteractiveUtils: versioninfo
Expand All @@ -16,6 +16,12 @@ using Tapir: Tapir
using Tracker: Tracker
using Zygote: Zygote

# BenchmarkTools Parameters
BenchmarkTools.DEFAULT_PARAMETERS.samples = 100
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.25

struct AutoTapir <: ADTypes.AbstractReverseMode end

@info sprint(versioninfo)

const SUITE = BenchmarkGroup()
Expand All @@ -24,7 +30,7 @@ include("helpers.jl")
include("vgg.jl")
include("layers.jl")

BenchmarkTools.tune!(SUITE)
BenchmarkTools.tune!(SUITE; verbose=true)
results = BenchmarkTools.run(SUITE; verbose=true)

BenchmarkTools.save(joinpath(@__DIR__, "benchmark_results.json"), median(results))

0 comments on commit c464d18

Please sign in to comment.