diff --git a/Project.toml b/Project.toml index 28e472172..333477155 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.74" +version = "0.4.75" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/contexts.jl b/src/interpreter/contexts.jl index 5b82e47c7..c150cc897 100644 --- a/src/interpreter/contexts.jl +++ b/src/interpreter/contexts.jl @@ -51,5 +51,5 @@ is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true You should implemented more complicated method of `is_primitive` in the usual way. """ macro is_primitive(Tctx, sig) - return esc(:(Mooncake.is_primitive(::Type{$Tctx}, ::Type{<:$sig}) = true)) + return :(Mooncake.is_primitive(::Type{$(esc(Tctx))}, ::Type{<:$(esc(sig))}) = true) end diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index 8561735ee..d52a34ef7 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -6,12 +6,12 @@ function parse_signature_expr(sig::Expr) # Different parsing is required for `Tuple{...}` vs `Tuple{...} where ...`. if sig.head == :curly @assert sig.args[1] == :Tuple - arg_type_symbols = sig.args[2:end] + arg_type_symbols = map(esc, sig.args[2:end]) where_params = nothing elseif sig.head == :where @assert sig.args[1].args[1] == :Tuple - arg_type_symbols = sig.args[1].args[2:end] - where_params = sig.args[2:end] + arg_type_symbols = map(esc, sig.args[1].args[2:end]) + where_params = map(esc, sig.args[2:end]) else throw(ArgumentError("Expected either a `Tuple{...}` or `Tuple{...} where {...}")) end @@ -96,8 +96,12 @@ julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) """ macro mooncake_overlay(method_expr) def = splitdef(method_expr) - def[:name] = Expr(:overlay, :(Mooncake.mooncake_method_table), def[:name]) - return esc(combinedef(def)) + __mooncake_method_table = gensym("mooncake_method_table") + def[:name] = Expr(:overlay, __mooncake_method_table, def[:name]) + return quote + $(esc(__mooncake_method_table)) = Mooncake.mooncake_method_table + $(esc(combinedef(def))) + end end # @@ -200,7 +204,7 @@ macro zero_adjoint(ctx, sig) # then the last argument requires special treatment. arg_type_symbols, where_params = parse_signature_expr(sig) arg_names = map(n -> Symbol("x_$n"), eachindex(arg_type_symbols)) - is_vararg = arg_type_symbols[end] === :Vararg + is_vararg = arg_type_symbols[end] == Expr(:escape, :Vararg) if is_vararg arg_types = vcat( map(t -> :(Mooncake.CoDual{<:$t}), arg_type_symbols[1:(end - 1)]), @@ -215,10 +219,10 @@ macro zero_adjoint(ctx, sig) # Return code to create a method of is_primitive and a rule. ex = quote - Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + Mooncake.is_primitive(::Type{$(esc(ctx))}, ::Type{<:$(esc(sig))}) = true $(construct_def(arg_names, arg_types, where_params, body)) end - return esc(ex) + return ex end # @@ -469,10 +473,10 @@ macro from_rrule(ctx, sig::Expr, has_kwargs::Bool=false) end ex = quote - Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true + Mooncake.is_primitive(::Type{$(esc(ctx))}, ::Type{<:($(esc(sig)))}) = true $rule_expr $kw_is_primitive $kwargs_rule_expr end - return esc(ex) + return ex end diff --git a/test/interpreter/contexts.jl b/test/interpreter/contexts.jl index 43fab68e9..01f6b69c4 100644 --- a/test/interpreter/contexts.jl +++ b/test/interpreter/contexts.jl @@ -1 +1,14 @@ -@testset "contexts" begin end +module ContextsTestModule + +using Mooncake: @is_primitive, DefaultCtx + +foo(x) = x + +@is_primitive DefaultCtx Tuple{typeof(foo),Float64} + +end + +@testset "contexts" begin + @test Mooncake.is_primitive(DefaultCtx, Tuple{typeof(ContextsTestModule.foo),Float64}) + @test !Mooncake.is_primitive(DefaultCtx, Tuple{typeof(ContextsTestModule.foo),Real}) +end diff --git a/test/tools_for_rules.jl b/test/tools_for_rules.jl index 2b7a2af9b..9cd066efe 100644 --- a/test/tools_for_rules.jl +++ b/test/tools_for_rules.jl @@ -1,11 +1,14 @@ module ToolsForRulesResources -using ChainRulesCore, LinearAlgebra, Mooncake +# Note: do not `using Mooncake` in this module to ensure that all of the macros work +# correctly if `Mooncake` is not in scope. +using ChainRulesCore, LinearAlgebra using Base: IEEEFloat using Mooncake: @mooncake_overlay, @zero_adjoint, @from_rrule, MinimalCtx, DefaultCtx +local_function(x) = 3x overlay_tester(x) = 2x -@mooncake_overlay overlay_tester(x) = 3x +@mooncake_overlay overlay_tester(x) = local_function(x) zero_tester(x) = 0 @zero_adjoint MinimalCtx Tuple{typeof(zero_tester),Float64}