From fb11a9c30178723830186f0b4be90d8ba5418e89 Mon Sep 17 00:00:00 2001 From: Frames White <me@oxinabox.net> Date: Mon, 16 Oct 2023 14:38:51 +0800 Subject: [PATCH 1/2] =?UTF-8?q?type=20of=20type=20erasure=20in=20=E2=88=82?= =?UTF-8?q?=E2=98=86new?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/stage1/recurse_fwd.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 2c561e73..9cb763f5 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -24,7 +24,8 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) tangent_nt = NamedTuple{names}(tangent_tup) Tangent{B, typeof(tangent_nt)}(tangent_nt) end - return TaylorBundle{1, B}(the_primal, (the_partial,)) + B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...) + return TaylorBundle{1, B2}(the_primal, (the_partial,)) end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} From 2a9a3f79967146336672470abeabfd12689efdd5 Mon Sep 17 00:00:00 2001 From: Frames White <me@oxinabox.net> Date: Tue, 17 Oct 2023 16:11:00 +0800 Subject: [PATCH 2/2] tests for type of type --- test/forward.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/forward.jl b/test/forward.jl index 1e4b7142..4f2c6ae6 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -148,6 +148,18 @@ end end +@testset "types in tuples" begin + function foo(a) + tup = (a, 2a, Int) + return tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + end +end + + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible