diff --git a/Project.toml b/Project.toml index 653c550..1dba91c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,8 +12,9 @@ julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "Zygote"] +test = ["Test", "Documenter", "StaticArrays", "Zygote"] diff --git a/docs/src/api.md b/docs/src/api.md index 59a639e..5d1c200 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,6 +9,16 @@ Functors.children Functors.isleaf ``` +```@docs +Functors.AbstractWalk +Functors.DefaultWalk +Functors.StructuralWalk +Functors.ExcludeWalk +Functors.CachedWalk +Functors.CollectWalk +Functors.AnonymousWalk +``` + ```@docs Functors.fmapstructure Functors.fcollect diff --git a/src/Functors.jl b/src/Functors.jl index fe9b399..5d357ca 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -3,6 +3,8 @@ module Functors export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect include("functor.jl") +include("walks.jl") +include("maps.jl") include("base.jl") ### @@ -102,7 +104,8 @@ Equivalent to `functor(x)[1]`. children """ - fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk) + fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk()[, prune]) + fmap(walk, f, x, ys...) A structure and type preserving `map`. @@ -176,12 +179,23 @@ Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))")) To recurse into custom types without reconstructing them afterwards, use [`fmapstructure`](@ref). -For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. -This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. +For advanced customization of the traversal behaviour, +pass a custom `walk` function that subtypes [`Functors.AbstractWalk`](ref). +The form `fmap(walk, f, x, ys...)` can be called for custom walks. +The simpler form `fmap(f, x, ys...; walk = mywalk)` will wrap `mywalk` in +[`ExcludeWalk`](@ref) then [`CachedWalk`](@ref). ```jldoctest withfoo -julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) -Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7)))) +julia> struct MyWalk <: Functors.AbstractWalk end + +julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" : + Functors.DefaultWalk()(recurse, x) + +julia> fmap(x -> 10x, m; walk = MyWalk()) +Foo("hello", (40, 50, "hello")) + +julia> fmap(MyWalk(), x -> 10x, m) +Foo("hello", (4, 5, "hello")) ``` The behaviour when the same node appears twice can be altered by giving a value diff --git a/src/functor.jl b/src/functor.jl index c89fd03..9c8cf78 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -34,81 +34,6 @@ isleaf(x) = children(x) === () children(x) = functor(x)[1] -function _default_walk(f, x) - func, re = functor(x) - re(map(f, func)) -end - -usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) -usecache(::Nothing, x) = false - -@generated function anymutable(x::T) where {T} - ismutabletype(T) && return true - subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] - return Expr(:(||), subs...) -end - -struct NoKeyword end - -function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x) - else - walk(x -> fmap(f, x; exclude, walk, cache, prune), x) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -### -### Extras -### - -fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) - -function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) - # note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache - # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` - # for the results, to preserve traversal order (important downstream!). - x in cache && return output - if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x)) - end - return output -end - -### -### Vararg forms -### - -function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x, ys...) - else - walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -function _default_walk(f, x, ys...) - func, re = functor(x) - yfuncs = map(y -> functor(typeof(x), y)[1], ys) - re(map(f, func, yfuncs...)) -end - ### ### FlexibleFunctors.jl ### diff --git a/src/maps.jl b/src/maps.jl new file mode 100644 index 0000000..e5e89e4 --- /dev/null +++ b/src/maps.jl @@ -0,0 +1,17 @@ +fmap(walk::AbstractWalk, f, x, ys...) = walk((xs...) -> fmap(walk, f, xs...), x, ys...) + +function fmap(f, x, ys...; exclude = isleaf, + walk = DefaultWalk(), + cache = IdDict(), + prune = NoKeyword()) + _walk = ExcludeWalk(AnonymousWalk(walk), f, exclude) + if !isnothing(cache) + _walk = CachedWalk(_walk, prune, cache) + end + fmap(_walk, f, x, ys...) +end + +fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...) + +fcollect(x; exclude = v -> false) = + fmap(ExcludeWalk(CollectWalk(), _ -> nothing, exclude), _ -> nothing, x) diff --git a/src/walks.jl b/src/walks.jl new file mode 100644 index 0000000..79429f9 --- /dev/null +++ b/src/walks.jl @@ -0,0 +1,162 @@ +""" + AbstractWalk + +Any walk for use with [`fmap`](@ref) should inherit from this type. +A walk subtyping `AbstractWalk` must satisfy the walk function interface: +```julia +struct MyWalk <: AbstractWalk end + +function (::MyWalk)(recurse, x, ys...) + # implement this +end +``` +The walk function is called on a node `x` in a Functors tree. +It may also be passed associated nodes `ys...` in other Functors trees. +The walk function recurses further into `(x, ys...)` by calling +`recurse` on the child nodes. +The choice of which nodes to recurse and in what order is custom to the walk. +""" +abstract type AbstractWalk end + +""" + AnonymousWalk(walk_fn) + +Wrap a `walk_fn` so that `AnonymousWalk(walk_fn) isa AbstractWalk`. +This type only exists for backwards compatability and should be directly used. +Attempting to wrap an existing `AbstractWalk` is a no-op (i.e. it is not wrapped). +""" +struct AnonymousWalk{F} <: AbstractWalk + walk::F + + function AnonymousWalk(walk::F) where F + Base.depwarn("Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.", :AnonymousWalk) + return new{F}(walk) + end +end +# do not wrap an AbstractWalk +AnonymousWalk(walk::AbstractWalk) = walk + +(walk::AnonymousWalk)(recurse, x, ys...) = walk.walk(recurse, x, ys...) + +""" + DefaultWalk() + +The default walk behavior for Functors.jl. +Walks all the [`Functors.children`](@ref) of trees `(x, ys...)` based on +the structure of `x`. +The resulting mapped child nodes are restructured into the type of `x`. + +See [`fmap`](@ref) for more information. +""" +struct DefaultWalk <: AbstractWalk end + +function (::DefaultWalk)(recurse, x, ys...) + func, re = functor(x) + yfuncs = map(y -> functor(typeof(x), y)[1], ys) + re(map(recurse, func, yfuncs...)) +end + +""" + StructuralWalk() + +A structural variant of [`Functors.DefaultWalk`](@ref). +The recursion behavior is identical, but the mapped children are not restructured. + +See [`fmapstructure`](@ref) for more information. +""" +struct StructuralWalk <: AbstractWalk end + +(::StructuralWalk)(recurse, x) = map(recurse, children(x)) + +""" + ExcludeWalk(walk, fn, exclude) + +A walk that recurses nodes `(x, ys...)` according to `walk`, +except when `exclude(x)` is true. +Then, `fn(x, ys...)` is applied instead of recursing further. + +Typically wraps an existing `walk` for use with [`fmap`](@ref). +""" +struct ExcludeWalk{T, F, G} <: AbstractWalk + walk::T + fn::F + exclude::G +end + +(walk::ExcludeWalk)(recurse, x, ys...) = + walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...) + +struct NoKeyword end + +usecache(::Union{AbstractDict, AbstractSet}, x) = + isleaf(x) ? anymutable(x) : ismutable(x) +usecache(::Nothing, x) = false + +@generated function anymutable(x::T) where {T} + ismutabletype(T) && return true + subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] + return Expr(:(||), subs...) +end + +""" + CachedWalk(walk[; prune]) + +A walk that recurses nodes `(x, ys...)` according to `walk` and storing the +output of the recursion in a cache indexed by `x` (based on object ID). +Whenever the cache already contains `x`, either: +- `prune` is specified, then it is returned, or +- `prune` is unspecified, and the previously cached recursion of `(x, ys...)` + returned. + +Typically wraps an existing `walk` for use with [`fmap`](@ref). +""" +struct CachedWalk{T, S} <: AbstractWalk + walk::T + prune::S + cache::IdDict{Any, Any} +end +CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = + CachedWalk(walk, prune, cache) + +function (walk::CachedWalk)(recurse, x, ys...) + should_cache = usecache(walk.cache, x) + if should_cache && haskey(walk.cache, x) + return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune + else + ret = walk.walk(recurse, x, ys...) + if should_cache + walk.cache[x] = ret + end + return ret + end +end + +""" + CollectWalk() + +A walk that recurses into a node `x` via [`Functors.children`](@ref), +storing the recursion history in a cache. +The resulting ordered recursion history is returned. + +See [`fcollect`](@ref) for more information. +""" +struct CollectWalk <: AbstractWalk + cache::Base.IdSet{Any} + output::Vector{Any} +end +CollectWalk() = CollectWalk(Base.IdSet(), Any[]) + +# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache +# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` +# for the results, to preserve traversal order (important downstream!). +function (walk::CollectWalk)(recurse, x) + if usecache(walk.cache, x) && (x in walk.cache) + return walk.output + end + # to exclude, we wrap this walk in ExcludeWalk + usecache(walk.cache, x) && push!(walk.cache, x) + push!(walk.output, x) + map(recurse, children(x)) + + return walk.output +end diff --git a/test/basics.jl b/test/basics.jl index 625e673..5aeba77 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -85,9 +85,7 @@ end @test m5f.x === m5f.y @test m5f.x !== m5f.z - @testset "usecache" begin - d = IdDict() - + @testset "usecache ($d)" for d in [IdDict(), Base.IdSet()] # Leaf types: @test usecache(d, [1,2]) @test !usecache(d, 4.0) @@ -101,9 +99,9 @@ end @test !usecache(d, (5, [6.0]')) # contains mutable @test !usecache(d, (x = [1,2,3], y = 4)) - + usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care? - + # No dictionary: @test !usecache(nothing, [1,2]) @test !usecache(nothing, 3) @@ -173,6 +171,14 @@ end m2 = [1, 2, 3] m3 = Foo(m1, m2) @test all(fcollect(m3) .=== [m3, m1, m2]) + + m1 = [1, 2, 3] + m2 = SVector{length(m1)}(m1) + m2′ = SVector{length(m1)}(m1) + m3 = Foo(m1, m1) + m4 = Foo(m2, m2′) + @test all(fcollect(m3) .=== [m3, m1]) + @test all(fcollect(m4) .=== [m4, m2, m2′]) end ### diff --git a/test/runtests.jl b/test/runtests.jl index 7bc7c88..86f1e6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Functors, Test using Zygote using LinearAlgebra +using StaticArrays @testset "Functors.jl" begin