-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test all kinds of tangent types (Thunk, ZeroTangent, Tangent{T} etc) #159
Comments
Just to confirm: for One other point is that requiring new Do we want to do it in one go instead? I guess that means we solve arrays first? |
Also, the current implementation I have in mind is something like function test_rrule(f, args; output_tangent=Auto(), test_other_tangents=true; other_kwargs...)
y = func(xs...; fkwargs...)
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
if test_other_tangents
for ybar in _all_possible_tangents(ȳ)
test_rrule(f, args; output_tangent=ybar, test_other_tangents=false, other_kwargs...)
end
else
do_the_current_thing()
end
end
It's kind of ugly, do we have other ideas? |
This is correct
Two parts sounds better. Potentially we can do something like what we have like we do for inference and have it's default controlled by a global? Possibly rather than a bool we want a typed object, for extensibility. |
Thanks for reviewing
Could you elaborate please? |
For example in ChainRulesTestUtils we have:
then In ChainRules.jl's tests we write: ChainRulesTestUtils.DEFAULT_ALT_TANGENTS_TO_TEST[] = AbstractZeroTangent then we make thjose tests pass and we add thunks ChainRulesTestUtils.DEFAULT_ALT_TANGENTS_TO_TEST[] = Union{AbstractZeroTangent, AbstractThunk} But the whole time, packages that haven't opted into the new default behavour by changing this global are unchanged. This is normally a pretty evil antipattern of globals, since it doesn't compose at all, you can't have your dependencies setting it to different values. |
This is somewhat more involved than I hoped. The easiest thing to do would be to: const DEFAULT_ALT_TANGENTS = Ref{Any}([])
function test_rrule(f, args; output_tangent=Auto(), other_tangents=DEFAULT_ALT_TANGENTS; other_kwargs...)
for ybar in other_tangents
test_rrule(f, args; output_tangent=ybar, test_other_tangents=[], other_kwargs...)
end
do_the_current_thing()
end But that would prevent us from doing To do that, we need something like passing functions: const DEFAULT_ALT_TANGENTS = Ref{Function}([x -> canonicalize(x), x -> @thunk(x), x -> ZeroTangent()])
function test_rrule(f, args; output_tangent=Auto(), new_tangents=DEFAULT_ALT_TANGENTS; other_kwargs...)
y = func(xs...; fkwargs...)
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
for new_tangent in new_tangents
test_rrule(f, args; output_tangent=new_tangent(ȳ), new_tangents=[], other_kwargs...)
end
do_the_current_thing()
end Also it looks like we will need |
I am not sure how much we need to actually finite difference test all these. |
Yeah, that sounds good. I've added a PR that does FD because I already had it. But I worry it will take 4 times as long to test, which would be 2hrs for ChainRules We still need to pass the functions though, because of |
We are now happy enough with this, just testing Thunk. For trying structural vs natural differentials, we now have a fairly strong push towards never do structural if there is a good natural. We might want more for forwards mode later but that will be nonbreaking, since we will just be asserting things that had to be true in well-behaved code. So I am going to call this done. |
This is a generalization of #98 and is a key part of JuliaDiff/ChainRules.jl#408
Lets say we have a tanget
dx
we should also test that in it's place we can put
@thunk(x̄)
, as well asZeroTangent
.Further, following JuliaDiff/ChainRulesCore.jl#286
if
dx <: AbstractArray
(other thanArray
) we should testcanonicalize(Composite{typeof(x)}, dx)
and conversely
if
dx <: Composite{P} where P<:AbstractArray
we should testcanonicalize(P, dx)
The text was updated successfully, but these errors were encountered: