Skip to content

Commit

Permalink
Benchmarking derived rules (#48)
Browse files Browse the repository at this point in the history
* Move over testing infrastructure

* uncomment some tests

* Fix up tests

* Tweaks

* Fix boxing

* Fix tests

* Fix tests

* Run tests on unrolled functions

* Tweak

* Typo

* Tidy up CI naming

* Provide performance intervals

* Fix timings thresholds

* Tweak perf bounds

* Fix typo

* Add nightly cron job to test benchmarking robustness

* Increase number of benchmark cron jobs

* Uncommed some code

* Loosen bounds on perf
  • Loading branch information
willtebbutt authored Dec 4, 2023
1 parent be72030 commit 721a9a0
Show file tree
Hide file tree
Showing 26 changed files with 535 additions and 359 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ on:
- main
tags: ['*']
pull_request:
schedule:
- cron: '28 0,6,12,18 * * *'
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ github.event_name }}
name: ${{ matrix.test_group }}
runs-on: ubuntu-latest
if: github.event_name != 'schedule'
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -44,6 +47,7 @@ jobs:
matrix:
perf_group:
- 'hand_written'
- 'derived'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand All @@ -53,4 +57,6 @@ jobs:
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- run: julia --project=benchmarking --eval 'include("benchmarking/run_benchmarks.jl")'
env:
PERF_GROUP: ${{ matrix.perf_group }}
shell: bash
181 changes: 135 additions & 46 deletions benchmarking/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,89 +3,176 @@ Pkg.develop(path=joinpath(@__DIR__, ".."))

using BenchmarkTools, Random, Taped, Test

using Taped: CoDual, generate_hand_written_rrule!!_test_cases
using Taped:
CoDual,
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases

using Taped.TestUtils: _deepcopy

function generate_rrule!!_benchmarks(rng::AbstractRNG, args)

# Generate CoDuals etc.
primals = map(x -> x isa CoDual ? primal(x) : x, args)
dargs = map(x -> x isa CoDual ? tangent(x) : randn_tangent(rng, x), args)
codual_args = map(CoDual, primals, dargs)
cd_args = map(CoDual, primals, dargs)
suite = BenchmarkGroup()

# Benchmark primal.
suite["primal"] = @benchmarkable(
(a[1])((a[2:end])...);
setup=(a = deepcopy($primals)),
setup=(a = ($primals[1], _deepcopy($primals[2:end])...)),
evals=1,
)

# Benchmark forwards-pass.
suite["forwards"] = @benchmarkable(
Taped.rrule!!(ca...);
setup=(ca = deepcopy($(codual_args))),
setup=(ca = ($cd_args[1], _deepcopy($(cd_args)[2:end])...)),
evals=1,
)

# Benchmark pullback.
suite["pullback"] = @benchmarkable(
x[2]((tangent(x[1])), map(tangent, ca)...),
setup=(ca = deepcopy($codual_args); x = Taped.rrule!!(ca...)),
setup=(ca = ($cd_args[1], _deepcopy($cd_args[2:end])...); x = Taped.rrule!!(ca...)),
evals=1,
)

return suite
end

function benchmark_rrules!!(rng::AbstractRNG)
function generate_hand_written_cases(rng_ctor, v::Val)
test_cases, memory = generate_hand_written_rrule!!_test_cases(rng_ctor, v)
ranges = map(x -> x[3], test_cases)
return map(x -> x[4:end], test_cases), memory, ranges
end

default_hand_written_ratios() = (lb=1e-3, ub=10.0)

function benchmark_hand_written_rrules!!(rng_ctor)
rng = rng_ctor(123)

# Benchmark the performance of all benchmarks.
test_case_data = [
generate_hand_written_rrule!!_test_cases(Val(:avoiding_non_differentiable_code)),
generate_hand_written_rrule!!_test_cases(Val(:blas)),
generate_hand_written_rrule!!_test_cases(Val(:builtins)),
generate_hand_written_rrule!!_test_cases(Val(:foreigncall)),
generate_hand_written_rrule!!_test_cases(Val(:iddict)),
generate_hand_written_rrule!!_test_cases(Val(:lapack)),
generate_hand_written_rrule!!_test_cases(Val(:low_level_maths)),
generate_hand_written_rrule!!_test_cases(Val(:misc)),
generate_hand_written_rrule!!_test_cases(Val(:umlaut_internals_rules)),
generate_hand_written_rrule!!_test_cases(Val(:unrolled_function)),
]
test_case_data = map([
:avoiding_non_differentiable_code,
:blas,
:builtins,
:foreigncall,
:iddict,
:lapack,
:low_level_maths,
:misc,
:umlaut_internals_rules,
:unrolled_function
]) do s
generate_hand_written_cases(Xoshiro, Val(s))
end
test_cases = reduce(vcat, map(first, test_case_data))
memory = map(last, test_case_data)
memory = map(x -> x[2], test_case_data)
ranges = reduce(vcat, map(x -> x[3], test_case_data))

GC.@preserve memory begin
results = map(enumerate(test_cases)) do (n, x)
args = (x[4:end]..., )
@info "$n / $(length(test_cases))", args
suite = generate_rrule!!_benchmarks(rng, args)
@info "$n / $(length(test_cases))", x
suite = generate_rrule!!_benchmarks(rng, x)
return (x, BenchmarkTools.run(suite; verbose=true, seconds=3))
end
end

# Compute performance ratio for all cases.
ratios = map(results) do result
result_dict = result[2]
primal_time = time(minimum(result_dict["primal"]))
forwards_time = time(minimum(result_dict["forwards"]))
pullback_time = time(minimum(result_dict["pullback"]))
_range = result[1][3]
return (
tag=result[1],
forwards_range=_range === nothing ? default_hand_written_ratios() : _range,
pullback_range=_range === nothing ? default_hand_written_ratios() : _range,
primal_time=primal_time,
forwards_time=forwards_time,
pullback_time=pullback_time,
forwards_ratio=forwards_time / primal_time,
pullback_ratio=pullback_time / primal_time,
)
return combine_results.(results, ranges, Ref(default_hand_written_ratios()))
end

default_derived_ratios() = (lb=100, ub=100_000)

function benchmark_derived_rrules!!(rng_ctor)
rng = rng_ctor(123)

# We only run a subset of the test cases, because it will take far too long to run
# then all. Moreover, it seems unlikely that there is presently a need for all tests,
# and that the :unrolled_function tests will be sufficient.
test_case_data = map([
# :avoiding_non_differentiable_code,
# :blas,
# :builtins,
# :foreigncall,
# :iddict,
# :lapack,
# :low_level_maths,
# :misc,
# :umlaut_internals_rules,
:unrolled_function
]) do s
test_cases, memory = generate_derived_rrule!!_test_cases(rng_ctor, Val(s))
unrolled_test_cases = map(test_cases) do test_case
f, x... = test_case[3:end]
f_t = last(Taped.trace_recursive_tape!!(f, map(_deepcopy, x)...))
return Any[f_t, f, x...]
end
ranges = map(x -> x[2], test_cases)
return unrolled_test_cases, memory, ranges
end
test_cases = reduce(vcat, map(first, test_case_data))
memory = map(x -> x[2], test_case_data)
ranges = reduce(vcat, map(x -> x[3], test_case_data))

GC.@preserve memory begin
results = map(enumerate(test_cases)) do (n, args)
@info "$n / $(length(test_cases))", args

# Generate CoDuals etc.
primals = map(x -> x isa CoDual ? primal(x) : x, args[2:end])
unrolled_primals = map(x -> x isa CoDual ? primal(x) : x, args)
dargs = map(x -> x isa CoDual ? tangent(x) : randn_tangent(rng, x), args)
cd_args = map(CoDual, unrolled_primals, dargs)
suite = BenchmarkGroup()

# Benchmark primal.
suite["primal"] = @benchmarkable(
(a[1])((a[2:end])...);
setup=(a = ($primals[1], _deepcopy($primals[2:end])...)),
evals=1,
)

# Benchmark forwards-pass.
suite["forwards"] = @benchmarkable(
Taped.rrule!!(ca...);
setup=(ca = ($cd_args[1], _deepcopy($(cd_args)[2:end])...)),
evals=1,
)

# Benchmark pullback.
suite["pullback"] = @benchmarkable(
x[2]((tangent(x[1])), map(tangent, ca)...),
setup=(ca = ($cd_args[1], _deepcopy($cd_args[2:end])...); x = Taped.rrule!!(ca...)),
evals=1,
)

return (args, BenchmarkTools.run(suite; verbose=true, seconds=3))
end
end
return ratios

# Compute performance ratio for all cases.
return combine_results.(results, ranges, Ref(default_derived_ratios()))
end

default_hand_written_ratios() = (lb=1e-3, ub=10.0)
function combine_results(result, _range, default_range)
result_dict = result[2]
primal_time = time(minimum(result_dict["primal"]))
forwards_time = time(minimum(result_dict["forwards"]))
pullback_time = time(minimum(result_dict["pullback"]))
return (
tag=result[1],
forwards_range=_range === nothing ? default_range : _range,
pullback_range=_range === nothing ? default_range : _range,
primal_time=primal_time,
forwards_time=forwards_time,
pullback_time=pullback_time,
forwards_ratio=forwards_time / primal_time,
pullback_ratio=pullback_time / primal_time,
)
end

between(x, (lb, ub)) = lb < x && x < ub

Expand All @@ -98,10 +185,12 @@ function flag_concerning_performance(ratios)
end
end

function main()
ratios = benchmark_rrules!!(Xoshiro(123456))
flag_concerning_performance(ratios)
return nothing
end
const perf_group = get(ENV, "PERF_GROUP", "derived")

main()
if perf_group == "hand_written"
flag_concerning_performance(benchmark_hand_written_rrules!!(Xoshiro))
elseif perf_group == "derived"
flag_concerning_performance(benchmark_derived_rrules!!(Xoshiro))
else
throw(error("perf_group=$(perf_group) is not recognised"))
end
3 changes: 3 additions & 0 deletions src/Taped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import Umlaut: isprimitive, Frame, Tracer, __foreigncall__, __to_tuple__, __new_
using Base:
IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref,
arrayset
using Base.Iterators: product
using Core: Intrinsics, bitcast, SimpleVector, svec
using Core.Intrinsics: pointerref, pointerset
using FunctionWrappers: FunctionWrapper
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!
using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs!

include("tracing.jl")
include("acceleration.jl")
Expand Down
10 changes: 9 additions & 1 deletion src/rrules/avoiding_non_differentiable_code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ function rrule!!(::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Inte
return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback()
end

function generate_hand_written_rrule!!_test_cases(::Val{:avoiding_non_differentiable_code})
function generate_hand_written_rrule!!_test_cases(
rng_ctor, ::Val{:avoiding_non_differentiable_code}
)
_x = Ref(5.0)
_dx = Ref(4.0)
test_cases = Any[
Expand All @@ -26,3 +28,9 @@ function generate_hand_written_rrule!!_test_cases(::Val{:avoiding_non_differenti
memory = Any[_x, _dx]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(
rng_ctor, ::Val{:avoiding_non_differentiable_code},
)
return Any[], Any[]
end
Loading

0 comments on commit 721a9a0

Please sign in to comment.