Skip to content

Commit

Permalink
Use walk(x, ys...) for recursion and remove requirement for recursive
Browse files Browse the repository at this point in the history
closure
  • Loading branch information
gaurav-arya committed Jan 19, 2023
1 parent 0830691 commit 2d18f1b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 29 deletions.
4 changes: 0 additions & 4 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ children

"""
fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk()[, prune])
fmap(walk, f, x, ys...)
A structure and type preserving `map`.
Expand Down Expand Up @@ -202,9 +201,6 @@ julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" :
julia> fmap(x -> 10x, m; walk = MyWalk())
Foo("hello", (40, 50, "hello"))
julia> fmap(MyWalk(), x -> 10x, m)
Foo("hello", (4, 5, "hello"))
```
The behaviour when the same node appears twice can be altered by giving a value
Expand Down
17 changes: 6 additions & 11 deletions src/maps.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# Note that the argument f is not actually used in this method.
# See issue #62 for a discussion on how best to remove it.
function fmap(walk::AbstractWalk, f, x, ys...)
# This avoids a performance penalty for recursive constructs in an anonymous function.
# See Julia issue #47760 and Functors.jl issue #59.
recurse(xs...) = walk(var"#self#", xs...)
walk(recurse, x, ys...)
end
Base.@deprecate fmap(walk::AbstractWalk, f, x, ys...) walk(x, ys...)

function fmap(f, x, ys...; exclude = isleaf,
walk = DefaultWalk(),
Expand All @@ -15,10 +8,12 @@ function fmap(f, x, ys...; exclude = isleaf,
if !isnothing(cache)
_walk = CachedWalk(_walk, prune, cache)
end
fmap(_walk, f, x, ys...)
_walk(x, ys...)
end

fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...)

fcollect(x; exclude = v -> false) =
fmap(ExcludeWalk(CollectWalk(), _ -> nothing, exclude), _ -> nothing, x)
function fcollect(x; exclude = v -> false)
walk = ExcludeWalk(CollectWalk(), _ -> nothing, exclude)
walk(x)
end
33 changes: 19 additions & 14 deletions src/walks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,23 @@ A walk subtyping `AbstractWalk` must satisfy the walk function interface:
```julia
struct MyWalk <: AbstractWalk end
function (::MyWalk)(recurse, x, ys...)
function (::MyWalk)(outer_walk::AbstractWalk, x, ys...)
# implement this
end
```
The walk function is called on a node `x` in a Functors tree.
It may also be passed associated nodes `ys...` in other Functors trees.
The walk function recurses further into `(x, ys...)` by calling
`recurse` on the child nodes.
`outer_walk` on the child nodes.
The choice of which nodes to recurse and in what order is custom to the walk.
By default, `outer_walk` it set to the walk being called,
i.e. `(walk::AbstractWalk)(x, ys...) = walk(walk, x, ys...)`,
but in general it allows for greater flexibility (e.g. nesting walks in one another).
"""
abstract type AbstractWalk end

(walk::AbstractWalk)(x, ys...) = walk(walk, x, ys...)

"""
AnonymousWalk(walk_fn)
Expand All @@ -42,7 +47,7 @@ end
# do not wrap an AbstractWalk
AnonymousWalk(walk::AbstractWalk) = walk

(walk::AnonymousWalk)(recurse, x, ys...) = walk.walk(recurse, x, ys...)
(walk::AnonymousWalk)(outer_walk::AbstractWalk, x, ys...) = walk.walk(outer_walk, x, ys...)

"""
DefaultWalk()
Expand All @@ -56,10 +61,10 @@ See [`fmap`](@ref) for more information.
"""
struct DefaultWalk <: AbstractWalk end

function (::DefaultWalk)(recurse, x, ys...)
function (::DefaultWalk)(outer_walk::AbstractWalk, x, ys...)
func, re = functor(x)
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(_map(recurse, func, yfuncs...))
re(_map(outer_walk, func, yfuncs...))
end

"""
Expand All @@ -72,7 +77,7 @@ See [`fmapstructure`](@ref) for more information.
"""
struct StructuralWalk <: AbstractWalk end

(::StructuralWalk)(recurse, x) = _map(recurse, children(x))
(::StructuralWalk)(outer_walk::AbstractWalk, x) = _map(outer_walk, children(x))

"""
ExcludeWalk(walk, fn, exclude)
Expand All @@ -89,8 +94,8 @@ struct ExcludeWalk{T, F, G} <: AbstractWalk
exclude::G
end

(walk::ExcludeWalk)(recurse, x, ys...) =
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...)
(walk::ExcludeWalk)(outer_walk::AbstractWalk, x, ys...) =
walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(outer_walk, x, ys...)

struct NoKeyword end

Expand Down Expand Up @@ -124,12 +129,12 @@ end
CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
CachedWalk(walk, prune, cache)

function (walk::CachedWalk)(recurse, x, ys...)
function (walk::CachedWalk)(outer_walk::AbstractWalk, x, ys...)
should_cache = usecache(walk.cache, x)
if should_cache && haskey(walk.cache, x)
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
else
ret = walk.walk(recurse, x, ys...)
ret = walk.walk(outer_walk, x, ys...)
if should_cache
walk.cache[x] = ret
end
Expand All @@ -155,14 +160,14 @@ CollectWalk() = CollectWalk(Base.IdSet(), Any[])
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
# for the results, to preserve traversal order (important downstream!).
function (walk::CollectWalk)(recurse, x)
function (walk::CollectWalk)(outer_walk::AbstractWalk, x)
if usecache(walk.cache, x) && (x in walk.cache)
return walk.output
end
# to exclude, we wrap this walk in ExcludeWalk
usecache(walk.cache, x) && push!(walk.cache, x)
push!(walk.output, x)
_map(recurse, children(x))
_map(outer_walk, children(x))

return walk.output
end
Expand Down Expand Up @@ -218,8 +223,8 @@ julia> collect(zipped_iter)
"""
struct IterateWalk <: AbstractWalk end

function (walk::IterateWalk)(recurse, x, ys...)
function (walk::IterateWalk)(outer_walk::AbstractWalk, x, ys...)
func, _ = functor(x)
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
return Iterators.flatten(_map(recurse, func, yfuncs...))
return Iterators.flatten(_map(outer_walk, func, yfuncs...))
end
4 changes: 4 additions & 0 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,7 @@ end
@test zipped_iter isa Iterators.Flatten
@test collect(zipped_iter) == collect(Iterators.zip([1, 2, 3, 4, 5, 6, 7, 8].^2, [8, 7, 6, 5, 4, 3, 2, 1].^2))
end

@testset "Deprecated first-arg walk API to fmap" begin
@test fmap(Functors.DefaultWalk(), nothing, (1, 2, 3)) == (1, 2, 3)
end

0 comments on commit 2d18f1b

Please sign in to comment.