diff --git a/src/cache.jl b/src/cache.jl index 1519a07..a4af472 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -11,23 +11,48 @@ Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...) Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key) Base.getindex(cache::WalkCache, x) = cache.cache[x] -function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#) - # :(return cache.cache[x]::(return_type(cache.walk, typeof(args)))) - walk = cache.parameters[3] - RT = Core.Compiler.return_type(Tuple{walk, args...}, world) - body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) - if RT != Any - body = Expr(:(::), body, RT) - end - expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], - Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) - ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) - ci.inlineable = true - return ci -end -@eval function cacheget(cache::WalkCache, x, args...) - $(Expr(:meta, :generated, __cacheget_generator__)) - $(Expr(:meta, :generated_only)) +@static if VERSION >= v"1.10.0-DEV.609" + function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#) + # :(return cache.cache[x]::(return_type(cache.walk, typeof(args)))) + walk = cache.parameters[3] + RT = Core.Compiler.return_type(Tuple{walk, args...}, world) + body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) + if RT != Any + body = Expr(:(::), body, RT) + end + expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], + Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) + ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) + ci.inlineable = true + return ci + end + @eval function cacheget(cache::WalkCache, x, args...) + $(Expr(:meta, :generated, __cacheget_generator__)) + $(Expr(:meta, :generated_only)) + end +else + @generated function cacheget(cache::WalkCache, x, args...) + walk = cache.parameters[3] + world = typemax(UInt) + @static if VERSION >= v"1.8" + RT = Core.Compiler.return_type(Tuple{walk, args...}, world) + else + if isdefined(walk, :instance) + RT = Core.Compiler.return_type(walk.instance, Tuple{args...}, world) + else + RT = Any + end + end + body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) + if RT != Any + body = Expr(:(::), body, RT) + end + expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], + Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) + ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) + ci.inlineable = true + return ci + end end # fallback behavior that only lookup for `x` -cacheget(cache::AbstractDict, x, args...) = cache[x] +@inline cacheget(cache::AbstractDict, x, args...) = cache[x]