Skip to content

Commit

Permalink
Move over testing infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 4, 2023
1 parent be72030 commit 4e3a2eb
Show file tree
Hide file tree
Showing 22 changed files with 335 additions and 280 deletions.
3 changes: 3 additions & 0 deletions src/Taped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import Umlaut: isprimitive, Frame, Tracer, __foreigncall__, __to_tuple__, __new_
using Base:
IEEEFloat, unsafe_convert, unsafe_pointer_to_objref, pointer_from_objref, arrayref,
arrayset
using Base.Iterators: product
using Core: Intrinsics, bitcast, SimpleVector, svec
using Core.Intrinsics: pointerref, pointerset
using FunctionWrappers: FunctionWrapper
using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm!
using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs!

include("tracing.jl")
include("acceleration.jl")
Expand Down
10 changes: 9 additions & 1 deletion src/rrules/avoiding_non_differentiable_code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ function rrule!!(::CoDual{typeof(Base.:(+))}, x::CoDual{<:Ptr}, y::CoDual{<:Inte
return CoDual(primal(x) + primal(y), tangent(x) + primal(y)), NoPullback()
end

function generate_hand_written_rrule!!_test_cases(::Val{:avoiding_non_differentiable_code})
function generate_hand_written_rrule!!_test_cases(
rng_ctor, ::Val{:avoiding_non_differentiable_code}
)
_x = Ref(5.0)
_dx = Ref(4.0)
test_cases = Any[
Expand All @@ -26,3 +28,9 @@ function generate_hand_written_rrule!!_test_cases(::Val{:avoiding_non_differenti
memory = Any[_x, _dx]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(
rng_ctor, ::Val{:avoiding_non_differentiable_code},
)
return Any[], Any[]
end
108 changes: 107 additions & 1 deletion src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,4 +442,110 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32))
end
end

generate_hand_written_rrule!!_test_cases(::Val{:blas}) = Any[], Any[]
generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) = Any[], Any[]

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
t_flags = ['N', 'T', 'C']
aliased_gemm! = (tA, tB, a, b, A, C) -> BLAS.gemm!(tA, tB, a, A, A, b, C)

test_cases = vcat(

#
# BLAS LEVEL 1
#

[
Any[false, nothing, BLAS.dot, 3, randn(5), 1, randn(4), 1],
Any[false, nothing, BLAS.dot, 3, randn(6), 2, randn(4), 1],
Any[false, nothing, BLAS.dot, 3, randn(6), 1, randn(9), 3],
Any[false, nothing, BLAS.dot, 3, randn(12), 3, randn(9), 2],
Any[false, nothing, BLAS.scal!, 10, 2.4, randn(30), 2],
],

#
# BLAS LEVEL 2
#

# gemv!
vec(reduce(
vcat,
map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N)
t = tA == 'N'
As = [
t ? randn(M, N) : randn(N, M),
view(randn(15, 15), t ? (3:M+2) : (2:N+1), t ? (2:N+1) : (3:M+2)),
]
xs = [randn(N), view(randn(15), 3:N+2), view(randn(30), 1:2:2N)]
ys = [randn(M), view(randn(15), 2:M+1), view(randn(30), 2:2:2M)]
return map(Iterators.product(As, xs, ys)) do (A, x, y)
Any[false, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y]
end
end,
)),

# trmv!
vec(reduce(
vcat,
map(product(['L', 'U'], t_flags, ['N', 'U'], [1, 3])) do (ul, tA, dA, N)
As = [randn(N, N), view(randn(15, 15), 3:N+2, 4:N+3)]
bs = [randn(N), view(randn(14), 4:N+3)]
return map(product(As, bs)) do (A, b)
Any[false, nothing, BLAS.trmv!, ul, tA, dA, A, b]
end
end,
)),

#
# BLAS LEVEL 3
#

# gemm!
vec(map(product(t_flags, t_flags)) do (tA, tB)
A = tA == 'N' ? randn(3, 4) : randn(4, 3)
B = tB == 'N' ? randn(4, 5) : randn(5, 4)
Any[false, nothing, BLAS.gemm!, tA, tB, randn(), A, B, randn(), randn(3, 5)]
end),

vec(map(product(t_flags, t_flags)) do (tA, tB)
A = randn(5, 5)
B = randn(5, 5)
Any[false, nothing, aliased_gemm!, tA, tB, randn(), randn(), A, B]
end),

# trmm!
vec(reduce(
vcat,
map(
product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]),
) do (side, ul, tA, dA, M, N)
t = tA == 'N'
R = side == 'L' ? M : N
As = [randn(R, R), view(randn(15, 15), 3:R+2, 4:R+3)]
Bs = [randn(M, N), view(randn(15, 15), 2:M+1, 5:N+4)]
return map(product(As, Bs)) do (A, B)
alpha = randn()
Any[false, nothing, BLAS.trmm!, side, ul, tA, dA, alpha, A, B]
end
end,
)),

# trmm!
vec(reduce(
vcat,
map(
product(['L', 'R'], ['U', 'L'], t_flags, ['N', 'U'], [1, 3], [1, 2]),
) do (side, ul, tA, dA, M, N)
t = tA == 'N'
R = side == 'L' ? M : N
As = [randn(R, R) + 5I, view(randn(15, 15), 3:R+2, 4:R+3) + 5I]
Bs = [randn(M, N), view(randn(15, 15), 2:M+1, 5:N+4)]
return map(product(As, Bs)) do (A, B)
alpha = randn()
Any[false, nothing, BLAS.trsm!, side, ul, tA, dA, alpha, A, B]
end
end,
)),
)
memory = Any[]
return test_cases, memory
end
26 changes: 25 additions & 1 deletion src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ end

rrule!!(::CoDual{typeof(typeof)}, x) = CoDual(typeof(primal(x)), NoTangent()), NoPullback()

function generate_hand_written_rrule!!_test_cases(::Val{:builtins})
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins})

_x = Ref(5.0) # data used in tests which aren't protected by GC.
_dx = Ref(4.0)
Expand Down Expand Up @@ -821,3 +821,27 @@ function generate_hand_written_rrule!!_test_cases(::Val{:builtins})
memory = Any[_x, _dx, _a]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:builtins})
test_cases = Any[
[
false,
nothing,
(
function (x)
rx = Ref(x)
pointerref(bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1)
end
),
5.0,
],
[false, nothing, (v, x) -> (pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5)],
[false, nothing, x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5)],
[false, nothing, getindex, randn(5), [1, 1]],
[false, nothing, getindex, randn(5), [1, 2, 2]],
[false, nothing, setindex!, randn(5), [4.0, 5.0], [1, 1]],
[false, nothing, setindex!, randn(5), [4.0, 5.0, 6.0], [1, 2, 2]],
]
memory = Any[]
return test_cases, memory
end
35 changes: 34 additions & 1 deletion src/rrules/foreigncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ for name in [
end


function generate_hand_written_rrule!!_test_cases(::Val{:foreigncall})
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall})
_x = Ref(5.0)
_dx = randn_tangent(Xoshiro(123456), _x)

Expand Down Expand Up @@ -494,3 +494,36 @@ function generate_hand_written_rrule!!_test_cases(::Val{:foreigncall})
memory = Any[_x, _dx, _a, _da, _b, _db]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall})

_x = Ref(5.0)

function unsafe_copyto_tester(x::Vector{T}, y::Vector{T}, n::Int) where {T}
GC.@preserve x y unsafe_copyto!(pointer(x), pointer(y), n)
return x
end

test_cases = [
Any[false, nothing, reshape, randn(5, 4), (4, 5)],
Any[false, nothing, reshape, randn(5, 4), (2, 10)],
Any[false, nothing, reshape, randn(5, 4), (10, 2)],
Any[false, nothing, reshape, randn(5, 4), (5, 4, 1)],
Any[false, nothing, reshape, randn(5, 4), (2, 10, 1)],
Any[false, nothing, unsafe_copyto_tester, randn(5), randn(3), 2],
Any[false, nothing, unsafe_copyto_tester, randn(5), randn(6), 4],
[
false,
nothing,
unsafe_copyto_tester,
[randn(3) for _ in 1:5],
[randn(4) for _ in 1:6],
4,
],
Any[false, nothing, x -> unsafe_pointer_to_objref(pointer_from_objref(x)), _x],
Any[false, nothing, isassigned, randn(5), 4],
Any[false, nothing, x -> (Base._growbeg!(x, 2); x[1:2] .= 2.0), randn(5)],
]
memory = Any[_x]
return test_cases, memory
end
4 changes: 3 additions & 1 deletion src/rrules/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ for name in [
end
end

function generate_hand_written_rrule!!_test_cases(::Val{:iddict})
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:iddict})
test_cases = Any[
(false, :stability, nothing, Base.rehash!, IdDict(true => 5.0, false => 4.0), 10),
(false, :none, nothing, setindex!, IdDict(true => 5.0, false => 4.0), 3.0, false),
Expand All @@ -140,3 +140,5 @@ function generate_hand_written_rrule!!_test_cases(::Val{:iddict})
memory = Any[]
return test_cases, memory
end

generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:iddict}) = Any[], Any[]
88 changes: 87 additions & 1 deletion src/rrules/lapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,4 +420,90 @@ for (fname, elty) in ((:dpotrs_, :Float64), (:spotrs_, :Float32))
end
end

generate_hand_written_rrule!!_test_cases(::Val{:lapack}) = Any[], Any[]
generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack}) = Any[], Any[]

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
getrf_wrapper!(x, check) = getrf!(x; check)
test_cases = vcat(

# getrf!
[
Any[false, nothing, getrf_wrapper!, randn(5, 5), false],
Any[false, nothing, getrf_wrapper!, randn(5, 5), true],
Any[false, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), false],
Any[false, nothing, getrf_wrapper!, view(randn(10, 10), 1:5, 1:5), true],
Any[false, nothing, getrf_wrapper!, view(randn(10, 10), 2:7, 3:8), false],
Any[false, nothing, getrf_wrapper!, view(randn(10, 10), 3:8, 2:7), true],
],

# trtrs
vec(reduce(
vcat,
map(product(
['U', 'L'], ['N', 'T', 'C'], ['N', 'U'], [1, 3], [1, 2])
) do (ul, tA, diag, N, Nrhs)
As = [randn(N, N) + 10I, view(randn(15, 15) + 10I, 2:N+1, 2:N+1)]
Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:N+2)]
return map(product(As, Bs)) do (A, B)
Any[false, nothing, trtrs!, ul, tA, diag, A, B]
end
end,
)),

# getrs
vec(reduce(
vcat,
map(product(['N', 'T'], [1, 9], [1, 2])) do (trans, N, Nrhs)
As = getrf!.([
randn(N, N) + 5I,
view(randn(15, 15) + 5I, 2:N+1, 2:N+1),
])
Bs = [randn(N, Nrhs), view(randn(15, 15), 4:N+3, 3:Nrhs+2)]
return map(product(As, Bs)) do ((A, ipiv), B)
Any[false, nothing, getrs!, trans, A, ipiv, B]
end
end,
)),

# getri
vec(reduce(
vcat,
map([1, 9]) do N
As = getrf!.([randn(N, N) + 5I, view(randn(15, 15) + I, 2:N+1, 2:N+1)])
As = getrf!.([randn(N, N) + 5I])
return map(As) do (A, ipiv)
Any[false, nothing, getri!, A, ipiv]
end
end,
)),

# potrf
vec(reduce(
vcat,
map([1, 3, 9]) do N
X = randn(N, N)
A = X * X' + I
return [
Any[false, nothing, potrf!, 'L', A],
Any[false, nothing, potrf!, 'U', A],
]
end,
)),

# potrs
vec(reduce(
vcat,
map(product([1, 3, 9], [1, 2])) do (N, Nrhs)
X = randn(N, N)
A = X * X' + I
B = randn(N, Nrhs)
return [
Any[false, nothing, potrs!, 'L', potrf!('L', copy(A))[1], copy(B)],
Any[false, nothing, potrs!, 'U', potrf!('U', copy(A))[1], copy(B)],
]
end,
)),
)
memory = Any[]
return test_cases, memory
end
4 changes: 3 additions & 1 deletion src/rrules/low_level_maths.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ rand_inputs(rng, ::typeof(acosd), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, ::typeof(acos), _) = (2 * 0.9 * rand(rng) - 0.9, )
rand_inputs(rng, ::typeof(sqrt), _) = (rand(rng) + 1e-3, )

function generate_hand_written_rrule!!_test_cases(::Val{:low_level_maths})
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_maths})
rng = Xoshiro(123)
test_cases = Any[]
foreach(DiffRules.diffrules(; filter_modules=nothing)) do (M, f, arity)
Expand All @@ -64,3 +64,5 @@ function generate_hand_written_rrule!!_test_cases(::Val{:low_level_maths})
memory = Any[]
return test_cases, memory
end

generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:low_level_maths}) = Any[], Any[]
4 changes: 3 additions & 1 deletion src/rrules/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function rrule!!(x::CoDual{<:Type}, y::CoDual{<:TypeVar}, z::CoDual{<:Type})
return CoDual(primal(x)(primal(y), primal(z)), NoTangent()), NoPullback()
end

function generate_hand_written_rrule!!_test_cases(::Val{:misc})
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc})

# Data which needs to not be GC'd.
_x = Ref(5.0)
Expand Down Expand Up @@ -109,3 +109,5 @@ function generate_hand_written_rrule!!_test_cases(::Val{:misc})
]
return test_cases, memory
end

generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:misc}) = Any[], Any[]
Loading

0 comments on commit 4e3a2eb

Please sign in to comment.