From 7efc6d7140632c86c10fa3584a288ba7caa434ff Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 16 Aug 2024 16:14:54 +0100 Subject: [PATCH] Various fixes and extensions (#229) * Add rule for jl_type_union ccall * Tidy up Project toml * Bump project * Test Categorical distribution * Generalise lazy_zero_rdata to handle undefined type parameter * Support args in pi nodes * Zero-like rdata for TypeVars * Make _id_count threadsafe * Fix ID when not seeded --- Project.toml | 8 ++++---- src/fwds_rvs_data.jl | 5 ++++- src/interpreter/bbcode.jl | 10 ++++++---- src/interpreter/s2s_reverse_mode_ad.jl | 9 +++++---- src/interpreter/zero_like_rdata.jl | 4 ++++ src/rrules/foreigncall.jl | 19 +++++++++++++++++++ src/test_utils.jl | 3 +++ test/fwds_rvs_data.jl | 2 ++ test/integration_testing/distributions.jl | 1 + test/interpreter/zero_like_rdata.jl | 1 + 10 files changed, 49 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 362506999..76c431d92 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.36" +version = "0.2.37" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -19,15 +19,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" [extensions] TapirCUDAExt = "CUDA" +TapirJETExt = "JET" TapirLogDensityProblemsADExt = "LogDensityProblemsAD" TapirSpecialFunctionsExt = "SpecialFunctions" -TapirJETExt = "JET" [compat] ADTypes = "1.2" @@ -60,8 +60,8 @@ DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 328bc648f..a0cbc62ec 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -745,7 +745,10 @@ end @inline lazy_zero_rdata(p::P) where {P} = lazy_zero_rdata(lazy_zero_rdata_type(P), p) # Ensure proper specialisation on types. -@inline lazy_zero_rdata(::Type{P}) where {P} = LazyZeroRData{Type{P}, Nothing}(nothing) +@inline function lazy_zero_rdata(p::Type{P}) where {P} + Rtype = @isdefined(P) ? Type{P} : _typeof(p) + return LazyZeroRData{Rtype, Nothing}(nothing) +end @inline instantiate(::LazyZeroRData{P, Nothing}) where {P} = zero_rdata_from_type(P) @inline instantiate(r::LazyZeroRData) = r.data diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 383d79906..e8aef9cdb 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -1,6 +1,6 @@ # See the docstring for `BBCode` for some context on this file. -_id_count::Int32 = 0 +const _id_count::Dict{Int, Int32} = Dict{Int, Int32}() """ ID() @@ -14,8 +14,10 @@ produced, in the same way that seed for random number generators can be set. struct ID id::Int32 function ID() - global _id_count += 1 - return new(_id_count) + current_thread_id = Threads.threadid() + id_count = get(_id_count, current_thread_id, Int32(0)) + _id_count[current_thread_id] = id_count + Int32(1) + return new(id_count) end end @@ -30,7 +32,7 @@ ensure determinism between two runs of the same function which makes use of `ID` This is akin to setting the random seed associated to a random number generator globally. """ function seed_id!() - global _id_count = 0 + global _id_count[Threads.threadid()] = 0 end """ diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 88c7569d1..c4f41f91d 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -213,9 +213,10 @@ end __ref(P) = new_inst(Expr(:call, __make_ref, P)) # Helper for reverse_data_ref_stmts. -@inline @generated function __make_ref(::Type{P}) where {P} - R = zero_like_rdata_type(P) - return :(Ref{$R}(Tapir.zero_like_rdata_from_type(P))) +@inline @generated function __make_ref(p::Type{P}) where {P} + _P = @isdefined(P) ? P : _typeof(p) + R = zero_like_rdata_type(_P) + return :(Ref{$R}(Tapir.zero_like_rdata_from_type($_P))) end @inline __make_ref(::Type{Union{}}) = nothing @@ -368,7 +369,7 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) # Assemble the above lines and construct reverse-pass. return ad_stmt_info( line, - PiNode(stmt.val, fcodual_type(_type(stmt.typ))), + PiNode(__inc(stmt.val), fcodual_type(_type(stmt.typ))), Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id), ) end diff --git a/src/interpreter/zero_like_rdata.jl b/src/interpreter/zero_like_rdata.jl index 38a990cf2..425a71ba1 100644 --- a/src/interpreter/zero_like_rdata.jl +++ b/src/interpreter/zero_like_rdata.jl @@ -24,6 +24,8 @@ function zero_like_rdata_type(::Type{P}) where {P} return can_produce_zero_rdata_from_type(P) ? R : Union{R, ZeroRData} end +zero_like_rdata_type(::TypeVar) = NoRData + """ zero_like_rdata_from_type(::Type{P}) where {P} @@ -35,3 +37,5 @@ It is always valid to return a `ZeroRData`, function zero_like_rdata_from_type(::Type{P}) where {P} return can_produce_zero_rdata_from_type(P) ? zero_rdata_from_type(P) : ZeroRData() end + +zero_like_rdata_from_type(::TypeVar) = NoRData() diff --git a/src/rrules/foreigncall.jl b/src/rrules/foreigncall.jl index 6cf20828e..58ec8c9f1 100644 --- a/src/rrules/foreigncall.jl +++ b/src/rrules/foreigncall.jl @@ -454,6 +454,21 @@ function rrule!!( return zero_fcodual(y), NoPullback(ntuple(_ -> NoRData(), length(args) + 8)) end +function rrule!!( + ::CoDual{typeof(_foreigncall_)}, + ::CoDual{Val{:jl_type_unionall}}, + ::CoDual{Val{Any}}, # return type + ::CoDual{Tuple{Val{Any}, Val{Any}}}, # arg types + ::CoDual{Val{0}}, # number of required args + ::CoDual{Val{:ccall}}, + a::CoDual, + b::CoDual, + args... +) + y = ccall(:jl_type_unionall, Any, (Any, Any), primal(a), primal(b)) + return zero_fcodual(y), NoPullback(ntuple(_ -> NoRData(), length(args) + 8)) +end + @is_primitive MinimalCtx Tuple{typeof(deepcopy), Any} function rrule!!(::CoDual{typeof(deepcopy)}, x::CoDual) fdx = tangent(x) @@ -639,6 +654,10 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall}) ), (false, :none, nothing, isassigned, randn(5), 4), (false, :none, nothing, x -> (Base._growbeg!(x, 2); x[1:2] .= 2.0), randn(5)), + ( + false, :none, nothing, + (t, v) -> ccall(:jl_type_unionall, Any, (Any, Any), t, v), TypeVar(:a), Real, + ), ] memory = Any[_x] return test_cases, memory diff --git a/src/test_utils.jl b/src/test_utils.jl index 188395b50..74d9b82bd 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1197,6 +1197,8 @@ function pi_node_tester(y::Ref{Any}) return isa(x, Int) ? sin(x) : x end +Base.@nospecializeinfer arg_in_pi_node(@nospecialize(x)) = x isa Bool ? x : false + function avoid_throwing_path_tester(x) if x < 0 Base.throw_boundserror(1:5, 6) @@ -1496,6 +1498,7 @@ function generate_test_functions() ), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5.0)), (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5)), + (false, :none, nothing, arg_in_pi_node, false), (false, :allocs, nothing, intrinsic_tester, 5.0), (false, :allocs, nothing, goto_tester, 5.0), (false, :allocs, nothing, new_tester, 5.0, :hello), diff --git a/test/fwds_rvs_data.jl b/test/fwds_rvs_data.jl index c25735ce1..2434b8e06 100644 --- a/test/fwds_rvs_data.jl +++ b/test/fwds_rvs_data.jl @@ -32,6 +32,7 @@ end (Type{Tapir.TestResources.StableFoo}, Tapir.TestResources.StableFoo, true), (Tuple{Float64, Float64}, (5.0, 4.0), true), (Tuple{Float64, Vararg{Float64}}, (5.0, 4.0, 3.0), false), + (Type{Type{Tuple{T}} where {T}}, Type{Tuple{T}} where {T}, true), ] L = Tapir.lazy_zero_rdata_type(P) @test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(L, p))) @@ -39,6 +40,7 @@ end @inferred Tapir.instantiate(lazy_zero_rdata(L, p)) end @test typeof(lazy_zero_rdata(L, p)) == Tapir.lazy_zero_rdata_type(P) + @test lazy_zero_rdata(p) isa LazyZeroRData{_typeof(p)} end @test isa( lazy_zero_rdata(Tapir.TestResources.StableFoo), diff --git a/test/integration_testing/distributions.jl b/test/integration_testing/distributions.jl index 1e3a8452b..57f975689 100644 --- a/test/integration_testing/distributions.jl +++ b/test/integration_testing/distributions.jl @@ -224,6 +224,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I) @testset "$name" for (interface_only, name, f, x) in Any[ (false, "InverseGamma", (a, b, x) -> logpdf(InverseGamma(a, b), x), (1.5, 1.4, 0.4)), (false, "NormalCanon", (m, s, x) -> logpdf(NormalCanon(m, s), x), (0.1, 1.0, -0.5)), + (false, "Categorical", x -> logpdf(Categorical(x, 1 - x), 1), 0.3), ( false, "MvLogitNormal", diff --git a/test/interpreter/zero_like_rdata.jl b/test/interpreter/zero_like_rdata.jl index 98d5f4c80..f841d568c 100644 --- a/test/interpreter/zero_like_rdata.jl +++ b/test/interpreter/zero_like_rdata.jl @@ -6,6 +6,7 @@ Float64, Int, Vector{Float64}, + TypeVar(:a), ] @test Tapir.zero_like_rdata_from_type(P) isa Tapir.zero_like_rdata_type(P) end