Skip to content

Commit

Permalink
Testing Improvements (#43)
Browse files Browse the repository at this point in the history
* Improve testing infrastructure

* Add test case for IdDict get rrule
  • Loading branch information
willtebbutt authored Nov 22, 2023
1 parent fc713d2 commit f0bfb0f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
62 changes: 47 additions & 15 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ function has_equal_data(x::T, y::T; equal_undefs=true) where {T<:Array}
end
return all(equality)
end
has_equal_data(x::T, y::T) where {T<:Union{Float16, Float32, Float64}} = isapprox(x, y)
has_equal_data(x::Float64, y::Float64; equal_undefs=true) = isapprox(x, y)
function has_equal_data(x::T, y::T; equal_undefs=true) where {T<:Core.SimpleVector}
return all(map((a, b) -> has_equal_data(a, b; equal_undefs), x, y))
end
function has_equal_data(x::T, y::T; equal_undefs=true) where {T}
isprimitivetype(T) && return isequal(x, y)
return all(map(
Expand Down Expand Up @@ -413,13 +416,26 @@ function test_rule_and_type_interactions(rng::AbstractRNG, x::P) where {P}
end
end

"""
test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, T}
Verify that primal `p` with tangents `z_target`, `x`, and `y`, satisfies the tangent
interface. If these tests pass, then it should be possible to write `rrule!!`s for primals
of type `P`, and to test them using `test_rrule!!`.
#
# Tests for tangents
#

As always, there are limits to the errors that these tests can identify -- they form
necessary but not sufficient conditions for the correctness of your code.
"""
function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P, T}
@nospecialize rng p z_target x y

# This basic functionality must run in order to be able to check everything else.
@test tangent_type(P) isa Type
@test tangent_type(P) == T
@test zero_tangent(p) isa T
@test randn_tangent(rng, p) isa T
test_equality_comparison(p)
test_equality_comparison(x)

# Verify that interface `tangent_type` runs.
Tt = tangent_type(P)
Expand Down Expand Up @@ -483,21 +499,37 @@ function test_tangent(rng::AbstractRNG, p::P, z_target::T, x::T, y::T) where {P,
@test set_to_zero!!(tc) === tc
end

# Check that we can get an address map.
populate_address_map(p, x)
end
z = zero_tangent(p)
r = randn_tangent(rng, p)

function test_numerical_testing_interface(p::P, t::T) where {P, T}
@assert tangent_type(P) == T
@test _scale(2.0, t) isa T
# Verify that operations required for finite difference testing to run, and produce the
# correct output type.
@test _add_to_primal(p, t) isa P
@test _diff(p, p) isa T
@test _dot(t, t) isa Float64
@test _scale(11.0, t) isa T
@test populate_address_map(p, t) isa AddressMap

# Run some basic numerical sanity checks on the output the functions required for finite
# difference testing. These are necessary but insufficient conditions.
@test has_equal_data(_add_to_primal(p, z), p)
if !has_equal_data(z, r)
@test !has_equal_data(_add_to_primal(p, r), p)
end
@test has_equal_data(_diff(p, p), zero_tangent(p))
@test _dot(t, t) >= 0.0
@test _dot(t, zero_tangent(p)) == 0.0
@test _dot(t, increment!!(deepcopy(t), t)) 2 * _dot(t, t)
@test _add_to_primal(p, t) isa P
@test has_equal_data(_add_to_primal(p, zero_tangent(p)), p)
@test _diff(p, p) isa T
@test has_equal_data(_diff(p, p), zero_tangent(p))
@test has_equal_data(_scale(1.0, t), t)
@test has_equal_data(_scale(2.0, t), increment!!(deepcopy(t), t))
end

function test_equality_comparison(x)
@nospecialize x
@test has_equal_data(x, x) isa Bool
@test has_equal_data_up_to_undefs(x, x) isa Bool
@test has_equal_data(x, x)
@test has_equal_data_up_to_undefs(x, x)
end

end
Expand Down
2 changes: 1 addition & 1 deletion test/rrules/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
y = IdDict(true => 2.0, false => 1.0)
rng = Xoshiro(123456)
test_tangent(rng, p, z, x, y)
test_numerical_testing_interface(p, x)
end

@testset "$f, $(typeof(x))" for (interface_only, perf_flag, f, x...) in [
(false, :stability, Base.rehash!, IdDict(true => 5.0, false => 4.0), 10),
(false, :none, setindex!, IdDict(true => 5.0, false => 4.0), 3.0, false),
(false, :none, setindex!, IdDict(true => 5.0), 3.0, false),
(false, :none, get, IdDict(true => 5.0, false => 4.0), false, 2.0),
(false, :none, get, IdDict(true => 5.0), false, 2.0),
(false, :none, getindex, IdDict(true => 5.0, false => 4.0), true),
]
test_rrule!!(Xoshiro(123456), f, x...; interface_only, perf_flag)
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ using .TestUtils:
AddressMap,
populate_address_map!,
populate_address_map,
test_tangent,
test_numerical_testing_interface
test_tangent

using .TestResources:
TypeStableMutableStruct,
Expand Down
1 change: 0 additions & 1 deletion test/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@
)
rng = Xoshiro(123456)
test_tangent(rng, p, z, x, y)
test_numerical_testing_interface(p, x)
end

tangent(nt::NamedTuple) = Tangent(map(PossiblyUninitTangent, nt))
Expand Down

0 comments on commit f0bfb0f

Please sign in to comment.