diff --git a/src/functor.jl b/src/functor.jl index c89fd03..9c8cf78 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -34,81 +34,6 @@ isleaf(x) = children(x) === () children(x) = functor(x)[1] -function _default_walk(f, x) - func, re = functor(x) - re(map(f, func)) -end - -usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) -usecache(::Nothing, x) = false - -@generated function anymutable(x::T) where {T} - ismutabletype(T) && return true - subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] - return Expr(:(||), subs...) -end - -struct NoKeyword end - -function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x) - else - walk(x -> fmap(f, x; exclude, walk, cache, prune), x) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -### -### Extras -### - -fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) - -function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) - # note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache - # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` - # for the results, to preserve traversal order (important downstream!). - x in cache && return output - if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x)) - end - return output -end - -### -### Vararg forms -### - -function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x, ys...) - else - walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -function _default_walk(f, x, ys...) - func, re = functor(x) - yfuncs = map(y -> functor(typeof(x), y)[1], ys) - re(map(f, func, yfuncs...)) -end - ### ### FlexibleFunctors.jl ### diff --git a/src/maps.jl b/src/maps.jl index 49d04a6..6536928 100644 --- a/src/maps.jl +++ b/src/maps.jl @@ -4,7 +4,11 @@ function fmap(f, x, ys...; exclude = isleaf, walk = DefaultWalk(), cache = IdDict(), prune = NoKeyword()) - _walk = CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) + _walk = if isnothing(cache) + ExcludeWalk(AnonymousWalk(walk), f, exclude) + else + CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) + end fmap(_walk, f, x, ys...) end diff --git a/src/walks.jl b/src/walks.jl index d329e4b..63fbcbb 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -88,6 +88,15 @@ end struct NoKeyword end +usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) +usecache(::Nothing, x) = false + +@generated function anymutable(x::T) where {T} + ismutabletype(T) && return true + subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] + return Expr(:(||), subs...) +end + """ CachedWalk(walk[; prune]) @@ -109,11 +118,15 @@ CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = CachedWalk(walk, prune, cache) function (walk::CachedWalk)(recurse, x, ys...) - if haskey(walk.cache, x) + should_cache = usecache(walk.cache, x) + if should_cache && haskey(walk.cache, x) return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune else - walk.cache[x] = walk.walk(recurse, x, ys...) - return walk.cache[x] + ret = walk.walk(recurse, x, ys...) + if should_cache + walk.cache[x] = ret + end + return ret end end