From e051e217c2a578c30951a251b36cac5df2e69f33 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 31 Oct 2022 11:19:57 -0500 Subject: [PATCH 1/7] Separate walks out from fmap --- src/Functors.jl | 2 ++ src/maps.jl | 14 ++++++++++++ src/walks.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 src/maps.jl create mode 100644 src/walks.jl diff --git a/src/Functors.jl b/src/Functors.jl index fe9b399..c76e2ea 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -3,6 +3,8 @@ module Functors export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect include("functor.jl") +include("walks.jl") +include("maps.jl") include("base.jl") ### diff --git a/src/maps.jl b/src/maps.jl new file mode 100644 index 0000000..63bf1e0 --- /dev/null +++ b/src/maps.jl @@ -0,0 +1,14 @@ +fmap(walk::AbstractWalk, f, x, ys...) = walk((xs...) -> fmap(walk, f, xs...), x, ys...) + +function fmap(f, x, ys...; exclude = isleaf, + walk = DefaultWalk(), + cache = IdDict(), + prune = NoKeyword()) + _walk = CachedWalk(ExcludeWalk(walk, f, exclude), prune, cache) + fmap(_walk, f, x, ys...) +end + +fmapstructure(f, x; kwargs...) = fmap(f, x; walk = StructuralWalk(), kwargs...) + +fcollect(x; exclude = v -> false) = + fmap(ExcludeWalk(CollectWalk(), _ -> nothing, exclude), _ -> nothing, x) diff --git a/src/walks.jl b/src/walks.jl new file mode 100644 index 0000000..575f6d8 --- /dev/null +++ b/src/walks.jl @@ -0,0 +1,60 @@ +abstract type AbstractWalk end + +struct DefaultWalk <: AbstractWalk end + +function (::DefaultWalk)(recurse, x, ys...) + func, re = functor(x) + yfuncs = map(y -> functor(typeof(x), y)[1], ys) + re(map(recurse, func, yfuncs...)) +end + +struct StructuralWalk <: AbstractWalk end + +(::StructuralWalk)(recurse, x) = map(recurse, children(x)) + +struct ExcludeWalk{T, F, G} <: AbstractWalk + walk::T + fn::F + exclude::G +end + +(walk::ExcludeWalk)(recurse, x, ys...) = + walk.exclude(x) ? walk.fn(x, ys...) : walk.walk(recurse, x, ys...) + +struct NoKeyword end + +struct CachedWalk{T, S} <: AbstractWalk + walk::T + prune::S + cache::IdDict{Any, Any} +end +CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = + CachedWalk(walk, prune, cache) + +function (walk::CachedWalk)(recurse, x, ys...) + if haskey(walk.cache, x) + return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune + else + walk.cache[x] = walk.walk(recurse, x, ys...) + return walk.cache[x] + end +end + +struct CollectWalk <: AbstractWalk + cache::Base.IdSet{Any} + output::Vector{Any} +end +CollectWalk() = CollectWalk(Base.IdSet(), Any[]) + +# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache +# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` +# for the results, to preserve traversal order (important downstream!). +function (walk::CollectWalk)(recurse, x) + x in walk.cache && return walk.output + # to exclude, we wrap this walk in ExcludeWalk + push!(walk.cache, x) + push!(walk.output, x) + map(recurse, children(x)) + + return walk.output +end From 3f11d5a2c257b4a1133b0b2b9d37b41c87c30f94 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 11 Oct 2022 13:45:41 -0400 Subject: [PATCH 2/7] Add AnonymousWalk and update docstrings --- docs/src/api.md | 9 ++++++ src/Functors.jl | 22 +++++++++++--- src/maps.jl | 2 +- src/walks.jl | 81 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 6 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 59a639e..7b6f11a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -9,6 +9,15 @@ Functors.children Functors.isleaf ``` +```@docs +Functors.AbstractWalk +Functors.DefaultWalk +Functors.StructuralWalk +Functors.ExcludeWalk +Functors.CachedWalk +Functors.CollectWalk +``` + ```@docs Functors.fmapstructure Functors.fcollect diff --git a/src/Functors.jl b/src/Functors.jl index c76e2ea..5d357ca 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -104,7 +104,8 @@ Equivalent to `functor(x)[1]`. children """ - fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk) + fmap(f, x, ys...; exclude = Functors.isleaf, walk = Functors.DefaultWalk()[, prune]) + fmap(walk, f, x, ys...) A structure and type preserving `map`. @@ -178,12 +179,23 @@ Foo("Bar([1, 2, 3])", (4, 5, "Bar(Foo(6, 7))")) To recurse into custom types without reconstructing them afterwards, use [`fmapstructure`](@ref). -For advanced customization of the traversal behaviour, pass a custom `walk` function of the form `(f', xs) -> ...`. -This function walks (maps) over `xs` calling the continuation `f'` to continue traversal. +For advanced customization of the traversal behaviour, +pass a custom `walk` function that subtypes [`Functors.AbstractWalk`](ref). +The form `fmap(walk, f, x, ys...)` can be called for custom walks. +The simpler form `fmap(f, x, ys...; walk = mywalk)` will wrap `mywalk` in +[`ExcludeWalk`](@ref) then [`CachedWalk`](@ref). ```jldoctest withfoo -julia> fmap(x -> 10x, m, walk=(f, x) -> x isa Bar ? x : Functors._default_walk(f, x)) -Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7)))) +julia> struct MyWalk <: Functors.AbstractWalk end + +julia> (::MyWalk)(recurse, x) = x isa Bar ? "hello" : + Functors.DefaultWalk()(recurse, x) + +julia> fmap(x -> 10x, m; walk = MyWalk()) +Foo("hello", (40, 50, "hello")) + +julia> fmap(MyWalk(), x -> 10x, m) +Foo("hello", (4, 5, "hello")) ``` The behaviour when the same node appears twice can be altered by giving a value diff --git a/src/maps.jl b/src/maps.jl index 63bf1e0..49d04a6 100644 --- a/src/maps.jl +++ b/src/maps.jl @@ -4,7 +4,7 @@ function fmap(f, x, ys...; exclude = isleaf, walk = DefaultWalk(), cache = IdDict(), prune = NoKeyword()) - _walk = CachedWalk(ExcludeWalk(walk, f, exclude), prune, cache) + _walk = CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) fmap(_walk, f, x, ys...) end diff --git a/src/walks.jl b/src/walks.jl index 575f6d8..99f093f 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -1,5 +1,48 @@ +""" + AbstractWalk + +Any walk for use with [`fmap`](@ref) should inherit from this type. +A walk subtyping `AbstractWalk` must satisfy the walk function interface: +```julia +struct MyWalk <: AbstractWalk end + +function (::MyWalk)(recurse, x, ys...) + # implement this +end +``` +The walk function is called on a node `x` in a Functors tree. +It may also be passed associated nodes `ys...` in other Functors trees. +The walk function recurses further into `(x, ys...)` by calling +`recurse` on the child nodes. +The choice of which nodes to recurse and in what order is custom to the walk. +""" abstract type AbstractWalk end +""" + AnonymousWalk(walk_fn) + +Wrap a `walk_fn` so that `AnonymousWalk(walk_fn) isa AbstractWalk`. +This type only exists for backwards compatability and should be directly used. +Attempting to wrap an existing `AbstractWalk` is a no-op (i.e. it is not wrapped). +""" +struct AnonymousWalk{F} <: AbstractWalk + walk::F +end +# do not wrap an AbstractWalk +AnonymousWalk(walk::AbstractWalk) = walk + +(walk::AnonymousWalk)(recurse, x, ys...) = walk.walk(recurse, x, ys...) + +""" + DefaultWalk() + +The default walk behavior for Functors.jl. +Walks all the [`Functors.children`](@ref) of trees `(x, ys...)` based on +the structure of `x`. +The resulting mapped child nodes are restructured into the type of `x`. + +See [`fmap`](@ref) for more information. +""" struct DefaultWalk <: AbstractWalk end function (::DefaultWalk)(recurse, x, ys...) @@ -8,10 +51,27 @@ function (::DefaultWalk)(recurse, x, ys...) re(map(recurse, func, yfuncs...)) end +""" + StructuralWalk() + +A structural variant of [`Functors.DefaultWalk`](@ref). +The recursion behavior is identical, but the mapped children are not restructured. + +See [`fmapstructure`](@ref) for more information. +""" struct StructuralWalk <: AbstractWalk end (::StructuralWalk)(recurse, x) = map(recurse, children(x)) +""" + ExcludeWalk(walk, fn, exclude) + +A walk that recurses nodes `(x, ys...)` according to `walk`, +except when `exclude(x)` is true. +Then, `fn(x, ys...)` is applied instead of recursing further. + +Typically wraps an existing `walk` for use with [`fmap`](@ref). +""" struct ExcludeWalk{T, F, G} <: AbstractWalk walk::T fn::F @@ -23,6 +83,18 @@ end struct NoKeyword end +""" + CachedWalk(walk[; prune]) + +A walk that recurses nodes `(x, ys...)` according to `walk` and storing the +output of the recursion in a cache indexed by `x` (based on object ID). +Whenever the cache already contains `x`, either: +- `prune` is specified, then it is returned, or +- `prune` is unspecified, and the previously cached recursion of `(x, ys...)` + returned. + +Typically wraps an existing `walk` for use with [`fmap`](@ref). +""" struct CachedWalk{T, S} <: AbstractWalk walk::T prune::S @@ -40,6 +112,15 @@ function (walk::CachedWalk)(recurse, x, ys...) end end +""" + CollectWalk() + +A walk that recurses into a node `x` via [`Functors.children`](@ref), +storing the recursion history in a cache. +The resulting ordered recursion history is returned. + +See [`fcollect`](@ref) for more information. +""" struct CollectWalk <: AbstractWalk cache::Base.IdSet{Any} output::Vector{Any} From 27a8f8c5527b91f98c63b72aece67c5f8f111150 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 11 Oct 2022 13:53:45 -0400 Subject: [PATCH 3/7] Add depwarn for AnonymousWalk --- src/walks.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/walks.jl b/src/walks.jl index 99f093f..d329e4b 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -27,6 +27,11 @@ Attempting to wrap an existing `AbstractWalk` is a no-op (i.e. it is not wrapped """ struct AnonymousWalk{F} <: AbstractWalk walk::F + + function AnonymousWalk(walk::F) where F + Base.depwarn("Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtyle `AbstractWalk`.", :AnonymousWalk) + return new{F}(walk) + end end # do not wrap an AbstractWalk AnonymousWalk(walk::AbstractWalk) = walk From cb9a0f223c7e764774df6dbb925861c63c5d4fb7 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 11 Oct 2022 13:58:26 -0400 Subject: [PATCH 4/7] Add AnonymousWalk docstring to docs --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 7b6f11a..5d1c200 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -16,6 +16,7 @@ Functors.StructuralWalk Functors.ExcludeWalk Functors.CachedWalk Functors.CollectWalk +Functors.AnonymousWalk ``` ```@docs From 31cb97a8c18508cd3259d6b2870ba2f8d7ea57df Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 31 Oct 2022 11:25:09 -0500 Subject: [PATCH 5/7] Rebase with #39 --- src/functor.jl | 75 -------------------------------------------------- src/maps.jl | 6 +++- src/walks.jl | 19 +++++++++++-- 3 files changed, 21 insertions(+), 79 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index c89fd03..9c8cf78 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -34,81 +34,6 @@ isleaf(x) = children(x) === () children(x) = functor(x)[1] -function _default_walk(f, x) - func, re = functor(x) - re(map(f, func)) -end - -usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) -usecache(::Nothing, x) = false - -@generated function anymutable(x::T) where {T} - ismutabletype(T) && return true - subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] - return Expr(:(||), subs...) -end - -struct NoKeyword end - -function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x) - else - walk(x -> fmap(f, x; exclude, walk, cache, prune), x) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -### -### Extras -### - -fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...) - -function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) - # note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache - # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` - # for the results, to preserve traversal order (important downstream!). - x in cache && return output - if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x)) - end - return output -end - -### -### Vararg forms -### - -function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword()) - if usecache(cache, x) && haskey(cache, x) - return prune isa NoKeyword ? cache[x] : prune - end - ret = if exclude(x) - f(x, ys...) - else - walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...) - end - if usecache(cache, x) - cache[x] = ret - end - ret -end - -function _default_walk(f, x, ys...) - func, re = functor(x) - yfuncs = map(y -> functor(typeof(x), y)[1], ys) - re(map(f, func, yfuncs...)) -end - ### ### FlexibleFunctors.jl ### diff --git a/src/maps.jl b/src/maps.jl index 49d04a6..6536928 100644 --- a/src/maps.jl +++ b/src/maps.jl @@ -4,7 +4,11 @@ function fmap(f, x, ys...; exclude = isleaf, walk = DefaultWalk(), cache = IdDict(), prune = NoKeyword()) - _walk = CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) + _walk = if isnothing(cache) + ExcludeWalk(AnonymousWalk(walk), f, exclude) + else + CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) + end fmap(_walk, f, x, ys...) end diff --git a/src/walks.jl b/src/walks.jl index d329e4b..63fbcbb 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -88,6 +88,15 @@ end struct NoKeyword end +usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) +usecache(::Nothing, x) = false + +@generated function anymutable(x::T) where {T} + ismutabletype(T) && return true + subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))] + return Expr(:(||), subs...) +end + """ CachedWalk(walk[; prune]) @@ -109,11 +118,15 @@ CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) = CachedWalk(walk, prune, cache) function (walk::CachedWalk)(recurse, x, ys...) - if haskey(walk.cache, x) + should_cache = usecache(walk.cache, x) + if should_cache && haskey(walk.cache, x) return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune else - walk.cache[x] = walk.walk(recurse, x, ys...) - return walk.cache[x] + ret = walk.walk(recurse, x, ys...) + if should_cache + walk.cache[x] = ret + end + return ret end end From e402dad01c8242f63a24d834b7d7c0de96d5baee Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 3 Nov 2022 09:56:24 -0500 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Brian Chen --- src/maps.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/maps.jl b/src/maps.jl index 6536928..e5e89e4 100644 --- a/src/maps.jl +++ b/src/maps.jl @@ -4,10 +4,9 @@ function fmap(f, x, ys...; exclude = isleaf, walk = DefaultWalk(), cache = IdDict(), prune = NoKeyword()) - _walk = if isnothing(cache) - ExcludeWalk(AnonymousWalk(walk), f, exclude) - else - CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache) + _walk = ExcludeWalk(AnonymousWalk(walk), f, exclude) + if !isnothing(cache) + _walk = CachedWalk(_walk, prune, cache) end fmap(_walk, f, x, ys...) end From ffd6ae980c23b58b25586ed15d8f546011571c04 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 3 Nov 2022 10:26:23 -0500 Subject: [PATCH 7/7] Add new caching behavior to fcollect --- Project.toml | 3 ++- src/walks.jl | 9 ++++++--- test/basics.jl | 16 +++++++++++----- test/runtests.jl | 1 + 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 653c550..1dba91c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,8 +12,9 @@ julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "Zygote"] +test = ["Test", "Documenter", "StaticArrays", "Zygote"] diff --git a/src/walks.jl b/src/walks.jl index 63fbcbb..79429f9 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -88,7 +88,8 @@ end struct NoKeyword end -usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x) +usecache(::Union{AbstractDict, AbstractSet}, x) = + isleaf(x) ? anymutable(x) : ismutable(x) usecache(::Nothing, x) = false @generated function anymutable(x::T) where {T} @@ -149,9 +150,11 @@ CollectWalk() = CollectWalk(Base.IdSet(), Any[]) # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` # for the results, to preserve traversal order (important downstream!). function (walk::CollectWalk)(recurse, x) - x in walk.cache && return walk.output + if usecache(walk.cache, x) && (x in walk.cache) + return walk.output + end # to exclude, we wrap this walk in ExcludeWalk - push!(walk.cache, x) + usecache(walk.cache, x) && push!(walk.cache, x) push!(walk.output, x) map(recurse, children(x)) diff --git a/test/basics.jl b/test/basics.jl index 625e673..5aeba77 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -85,9 +85,7 @@ end @test m5f.x === m5f.y @test m5f.x !== m5f.z - @testset "usecache" begin - d = IdDict() - + @testset "usecache ($d)" for d in [IdDict(), Base.IdSet()] # Leaf types: @test usecache(d, [1,2]) @test !usecache(d, 4.0) @@ -101,9 +99,9 @@ end @test !usecache(d, (5, [6.0]')) # contains mutable @test !usecache(d, (x = [1,2,3], y = 4)) - + usecache(d, OneChild3([1,2], 3, nothing)) # mutable isn't a child, do we care? - + # No dictionary: @test !usecache(nothing, [1,2]) @test !usecache(nothing, 3) @@ -173,6 +171,14 @@ end m2 = [1, 2, 3] m3 = Foo(m1, m2) @test all(fcollect(m3) .=== [m3, m1, m2]) + + m1 = [1, 2, 3] + m2 = SVector{length(m1)}(m1) + m2′ = SVector{length(m1)}(m1) + m3 = Foo(m1, m1) + m4 = Foo(m2, m2′) + @test all(fcollect(m3) .=== [m3, m1]) + @test all(fcollect(m4) .=== [m4, m2, m2′]) end ### diff --git a/test/runtests.jl b/test/runtests.jl index 7bc7c88..86f1e6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Functors, Test using Zygote using LinearAlgebra +using StaticArrays @testset "Functors.jl" begin