Skip to content

Commit

Permalink
Improve type stability of cached walks (FluxML#82)
Browse files Browse the repository at this point in the history
* improve type stability of cached walks

* fix doctest format

* handle old julia version
  • Loading branch information
chengchingwen authored Nov 4, 2024
1 parent c0936a5 commit 2945731
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export @leaf, @functor, @flexiblefunctor,
include("functor.jl")
include("keypath.jl")
include("walks.jl")
include("cache.jl")
include("maps.jl")
include("base.jl")

Expand Down
58 changes: 58 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
struct WalkCache{K, V, W <: AbstractWalk, C <: AbstractDict{K, V}} <: AbstractDict{K, V}
walk::W
cache::C
WalkCache(walk, cache::AbstractDict{K, V} = IdDict()) where {K, V} = new{K, V, typeof(walk), typeof(cache)}(walk, cache)
end
Base.length(cache::WalkCache) = length(cache.cache)
Base.empty!(cache::WalkCache) = empty!(cache.cache)
Base.haskey(cache::WalkCache, x) = haskey(cache.cache, x)
Base.get(cache::WalkCache, x, default) = haskey(cache.cache, x) ? cache[x] : default
Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...)
Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key)
Base.getindex(cache::WalkCache, x) = cache.cache[x]

@static if VERSION >= v"1.10.0-DEV.609"
function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#)
# :(return cache.cache[x]::(return_type(cache.walk, typeof(args))))
walk = cache.parameters[3]
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
if RT != Any
body = Expr(:(::), body, RT)
end
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
ci.inlineable = true
return ci
end
@eval function cacheget(cache::WalkCache, x, args...)
$(Expr(:meta, :generated, __cacheget_generator__))
$(Expr(:meta, :generated_only))
end
else
@generated function cacheget(cache::WalkCache, x, args...)
walk = cache.parameters[3]
world = typemax(UInt)
@static if VERSION >= v"1.8"
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
else
if isdefined(walk, :instance)
RT = Core.Compiler.return_type(walk.instance, Tuple{args...}, world)
else
RT = Any
end
end
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
if RT != Any
body = Expr(:(::), body, RT)
end
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
ci.inlineable = true
return ci
end
end
# fallback behavior that only lookup for `x`
@inline cacheget(cache::AbstractDict, x, args...) = cache[x]
4 changes: 2 additions & 2 deletions src/maps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function fmap(f, x, ys...; exclude = isleaf,
prune = NoKeyword())
_walk = ExcludeWalk(AnonymousWalk(walk), f, exclude)
if !isnothing(cache)
_walk = CachedWalk(_walk, prune, cache)
_walk = CachedWalk(_walk, prune, WalkCache(_walk, cache))
end
execute(_walk, x, ys...)
end
Expand All @@ -18,7 +18,7 @@ function fmap_with_path(f, x, ys...; exclude = isleaf,

_walk = ExcludeWalkWithKeyPath(walk, f, exclude)
if !isnothing(cache)
_walk = CachedWalkWithPath(_walk, prune, cache)
_walk = CachedWalkWithPath(_walk, prune, WalkCache(_walk, cache))
end
return execute(_walk, KeyPath(), x, ys...)
end
Expand Down
12 changes: 6 additions & 6 deletions src/walks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,18 @@ Whenever the cache already contains `x`, either:
Typically wraps an existing `walk` for use with [`fmap`](@ref).
"""
struct CachedWalk{T, S} <: AbstractWalk
struct CachedWalk{T, S, C <: AbstractDict} <: AbstractWalk
walk::T
prune::S
cache::IdDict{Any, Any}
cache::C
end
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
CachedWalk(walk, prune, cache)

function (walk::CachedWalk)(recurse, x, ys...)
should_cache = usecache(walk.cache, x)
if should_cache && haskey(walk.cache, x)
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, x, ys...) : walk.prune
else
ret = walk.walk(recurse, x, ys...)
if should_cache
Expand All @@ -202,10 +202,10 @@ function (walk::CachedWalk)(recurse, x, ys...)
end
end

struct CachedWalkWithPath{T, S} <: AbstractWalk
struct CachedWalkWithPath{T, S, C <: AbstractDict} <: AbstractWalk
walk::T
prune::S
cache::IdDict{Any, Any}
cache::C
end

CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) =
Expand All @@ -214,7 +214,7 @@ CachedWalkWithPath(walk; prune = NoKeyword(), cache = IdDict()) =
function (walk::CachedWalkWithPath)(recurse, kp::KeyPath, x, ys...)
should_cache = usecache(walk.cache, x)
if should_cache && haskey(walk.cache, x)
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
return walk.prune isa NoKeyword ? cacheget(walk.cache, x, recurse, kp, x, ys...) : walk.prune
else
ret = walk.walk(recurse, kp, x, ys...)
if should_cache
Expand Down

0 comments on commit 2945731

Please sign in to comment.