From e672c6e67ab4e8dcca6ba08af501a79d59198b83 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Thu, 9 May 2024 09:18:55 +0800 Subject: [PATCH 1/3] improve type stability of cached walks --- src/Functors.jl | 1 + src/cache.jl | 33 +++++++++++++++++++++++++++++++++ src/maps.jl | 4 ++-- src/walks.jl | 12 ++++++------ 4 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 src/cache.jl diff --git a/src/Functors.jl b/src/Functors.jl index 71e0c8b..89bf7ef 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -7,6 +7,7 @@ export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleav include("functor.jl") include("keypath.jl") include("walks.jl") +include("cache.jl") include("maps.jl") include("base.jl") diff --git a/src/cache.jl b/src/cache.jl new file mode 100644 index 0000000..1519a07 --- /dev/null +++ b/src/cache.jl @@ -0,0 +1,33 @@ +struct WalkCache{K, V, W <: AbstractWalk, C <: AbstractDict{K, V}} <: AbstractDict{K, V} + walk::W + cache::C + WalkCache(walk, cache::AbstractDict{K, V} = IdDict()) where {K, V} = new{K, V, typeof(walk), typeof(cache)}(walk, cache) +end +Base.length(cache::WalkCache) = length(cache.cache) +Base.empty!(cache::WalkCache) = empty!(cache.cache) +Base.haskey(cache::WalkCache, x) = haskey(cache.cache, x) +Base.get(cache::WalkCache, x, default) = haskey(cache.cache, x) ? cache[x] : default +Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...) +Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key) +Base.getindex(cache::WalkCache, x) = cache.cache[x] + +function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#) + # :(return cache.cache[x]::(return_type(cache.walk, typeof(args)))) + walk = cache.parameters[3] + RT = Core.Compiler.return_type(Tuple{walk, args...}, world) + body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) + if RT != Any + body = Expr(:(::), body, RT) + end + expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], + Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) + ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) + ci.inlineable = true + return ci +end +@eval function cacheget(cache::WalkCache, x, args...) + $(Expr(:meta, :generated, __cacheget_generator__)) + $(Expr(:meta, :generated_only)) +end +# fallback behavior that only lookup for `x` +cacheget(cache::AbstractDict, x, args...) = cache[x] diff --git a/src/maps.jl b/src/maps.jl index 50986db..5cc145c 100644 --- a/src/maps.jl +++ b/src/maps.jl @@ -6,7 +6,7 @@ function fmap(f, x, ys...; exclude = isleaf, prune = NoKeyword()) _walk = ExcludeWalk(AnonymousWalk(walk), f, exclude) if !isnothing(cache) - _walk = CachedWalk(_walk, prune, cache) + _walk = CachedWalk(_walk, prune, WalkCache(_walk, cache)) end execute(_walk, x, ys...) end @@ -18,7 +18,7 @@ function fmap_with_path(f, x, ys...; exclude = isleaf, _walk = ExcludeWalkWithKeyPath(walk, f, exclude) if !isnothing(cache) - _walk = CachedWalkWithPath(_walk, prune, cache) + _walk = CachedWalkWithPath(_walk, prune, WalkCache(_walk, cache)) end return execute(_walk, KeyPath(), x, ys...) end diff --git a/src/walks.jl b/src/walks.jl index e5db2a3..77008af 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -179,10 +179,10 @@ Whenever the cache already contains `x`, either: Typically wraps an existing `walk` for use with [`fmap`](@ref). """ -struct CachedWalk{T, S} <: AbstractWalk +struct CachedWalk{T, S, C <: AbstractDict} <: AbstractWalk walk::T prune::S - cache::IdDict{Any, Any} + cache::C end CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = CachedWalk(walk, prune, cache) @@ -190,7 +190,7 @@ CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = function (walk::CachedWalk)(recurse, x, ys...) should_cache = usecache(walk.cache, x) if should_cache && haskey(walk.cache, x) - return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune + return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, x, ys...) : walk.prune else ret = walk.walk(recurse, x, ys...) if should_cache @@ -200,10 +200,10 @@ function (walk::CachedWalk)(recurse, x, ys...) end end -struct CachedWalkWithPath{T, S} <: AbstractWalk +struct CachedWalkWithPath{T, S, C <: AbstractDict} <: AbstractWalk walk::T prune::S - cache::IdDict{Any, Any} + cache::C end CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) = @@ -212,7 +212,7 @@ CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) = function (walk::CachedWalkWithPath)(recurse, kp::KeyPath, x, ys...) should_cache = usecache(walk.cache, x) if should_cache && haskey(walk.cache, x) - return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune + return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, kp, x, ys...) : walk.prune else ret = walk.walk(recurse, kp, x, ys...) if should_cache From 8b91724982fff72750b3feb2a85d3e3dffd5332f Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Thu, 9 May 2024 09:19:07 +0800 Subject: [PATCH 2/3] fix doctest format --- src/keypath.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/keypath.jl b/src/keypath.jl index f64be1d..acca6f5 100644 --- a/src/keypath.jl +++ b/src/keypath.jl @@ -131,9 +131,9 @@ See also [`KeyPath`](@ref) and [`getkeypath`](@ref). # Examples ```jldoctest julia> x = Dict(:a => 3, :b => Dict(:c => 4, "d" => [5, 6, 7])) -Dict{Any,Any} with 2 entries: +Dict{Symbol, Any} with 2 entries: :a => 3 - :b => Dict{Any,Any}(:c=>4,"d"=>[5, 6, 7]) + :b => Dict{Any, Any}(:c=>4, "d"=>[5, 6, 7]) julia> haskeypath(x, KeyPath(:a)) true From 992b228af1a539ee1813e278bc5f2793e4251007 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Thu, 9 May 2024 10:32:07 +0800 Subject: [PATCH 3/3] handle old julia version --- src/cache.jl | 61 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/src/cache.jl b/src/cache.jl index 1519a07..a4af472 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -11,23 +11,48 @@ Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...) Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key) Base.getindex(cache::WalkCache, x) = cache.cache[x] -function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#) - # :(return cache.cache[x]::(return_type(cache.walk, typeof(args)))) - walk = cache.parameters[3] - RT = Core.Compiler.return_type(Tuple{walk, args...}, world) - body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) - if RT != Any - body = Expr(:(::), body, RT) - end - expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], - Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) - ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) - ci.inlineable = true - return ci -end -@eval function cacheget(cache::WalkCache, x, args...) - $(Expr(:meta, :generated, __cacheget_generator__)) - $(Expr(:meta, :generated_only)) +@static if VERSION >= v"1.10.0-DEV.609" + function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#) + # :(return cache.cache[x]::(return_type(cache.walk, typeof(args)))) + walk = cache.parameters[3] + RT = Core.Compiler.return_type(Tuple{walk, args...}, world) + body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) + if RT != Any + body = Expr(:(::), body, RT) + end + expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], + Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) + ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) + ci.inlineable = true + return ci + end + @eval function cacheget(cache::WalkCache, x, args...) + $(Expr(:meta, :generated, __cacheget_generator__)) + $(Expr(:meta, :generated_only)) + end +else + @generated function cacheget(cache::WalkCache, x, args...) + walk = cache.parameters[3] + world = typemax(UInt) + @static if VERSION >= v"1.8" + RT = Core.Compiler.return_type(Tuple{walk, args...}, world) + else + if isdefined(walk, :instance) + RT = Core.Compiler.return_type(walk.instance, Tuple{args...}, world) + else + RT = Any + end + end + body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x) + if RT != Any + body = Expr(:(::), body, RT) + end + expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args], + Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body)))) + ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__) + ci.inlineable = true + return ci + end end # fallback behavior that only lookup for `x` -cacheget(cache::AbstractDict, x, args...) = cache[x] +@inline cacheget(cache::AbstractDict, x, args...) = cache[x]