diff --git a/src/functor.jl b/src/functor.jl index 267ad42..2956524 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -26,7 +26,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) escargs = map(fieldnames(T)) do f f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f) end - escfs = [:($f=x.$f) for f in fs] + escfs = [:($f = x.$f) for f in fs] @eval m begin $Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...)) @@ -169,7 +169,54 @@ function _default_walk(f, x) func, re = functor(x) re(map(f, func)) end -_default_walk(f, ::Nothing, ::Nothing) = nothing +_default_walk(_, ::Nothing, ::Nothing) = nothing + +# Side effects only, saves a restructure +function _foreach_walk(f, x) + foreach(f, children(x)) + return x +end + +### WARNING: the following is unstable internal functionality. Use at your own risk! +# Wrapper over an IdDict which only saves values with a stable object identity +struct Cache{K,V} + inner::IdDict{K,V} +end +Cache() = Cache(IdDict()) + +usecache(x) = !isbits(x) && ismutable(x) +# Functionally immutable and observe value semantics, but still `ismutable` and not `isbits` +usecache(::Union{String,Symbol}) = false +# For varargs +usecache(xs::Tuple) = all(usecache, xs) +Base.get!(f, c::Cache, x) = usecache(x) ? get!(f, c.inner, x) : f() + +# Passthrough used to disable caching (e.g. when passing `cache=false`) +struct NoCache end +Base.get!(f, ::NoCache, _) = f() + +# Encapsulates the self-recursive part of a recursive tree reduction (fold). +# This allows calling functions to remove any self-calls or nested callback closures. +struct Fold{F,L,C,W} + fn::F + isleaf::L + cache::C + walk::W +end +(fld::Fold)(x) = get!(fld.cache, x) do + fld.fn(fld.isleaf(x) ? x : fld.walk(fld, x)) +end + +# Convenience function for working with `Fold` +function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk) + if cache === true + cache = Cache() + elseif cache === false + cache = NoCache() + end + return Fold(f, isleaf, cache, walk)(x) +end +### end of unstable internal functionality """ fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk) @@ -253,11 +300,9 @@ Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7)))) ``` """ function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict()) - haskey(cache, x) && return cache[x] - y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x) - cache[x] = y - - return y + return fold(x; cache, walk, isleaf = exclude) do node + exclude(node) ? f(node) : node + end end """ @@ -296,8 +341,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)) fcollect(x; exclude = v -> false) Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) -and collecting the results into a flat array, ordered by a breadth-first -traversal of `x`, respecting the iteration order of `children` calls. +and collecting the results into a flat array, ordered by a depth-first, +post-order traversal of `x` that respects the iteration order of `children` calls. Doesn't recurse inside branches rooted at nodes `v` for which `exclude(v) == true`. @@ -324,33 +369,27 @@ Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) julia> fcollect(m) 4-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) - Bar([1, 2, 3]) [1, 2, 3] + Bar([1, 2, 3]) NoChildren(:a, :b) + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) julia> fcollect(m, exclude = v -> v isa Bar) 2-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) NoChildren(:a, :b) + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) julia> fcollect(m, exclude = v -> Functors.isleaf(v)) 2-element Vector{Any}: - Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) Bar([1, 2, 3]) + Foo(Bar([1, 2, 3]), NoChildren(:a, :b)) ``` """ -function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false) - # note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache - # (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector` - # for the results, to preserve traversal order (important downstream!). - x in cache && return output - if !exclude(x) - push!(cache, x) - push!(output, x) - foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x)) - end - return output +function fcollect(x; output = [], cache = Base.IdDict(), exclude = v -> false) + fold(x; cache, isleaf = exclude, walk = _foreach_walk) do node + exclude(node) || push!(output, node); # always return nothing + end + return output end # Allow gradients and other constructs that match the structure of the functor diff --git a/test/basics.jl b/test/basics.jl index 21aa445..ad4e06a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -31,6 +31,26 @@ end end end +@testset "Folds" begin + arrays = ntuple(i -> [i], 3) + model = Foo( + Foo(arrays[1], arrays[2]), + Foo(arrays[3], arrays[1]) + ) + + total = Ref(0) + Functors.fmap(model, cache = true) do x + total[] += only(x) + end + @test total[] == 6 + + total = Ref(0) + Functors.fmap(model, cache = false) do x + total[] += only(x) + end + @test total[] == 7 +end + @testset "Nested" begin model = Bar(Foo(1, [1, 2, 3])) @@ -72,8 +92,8 @@ end m2 = 1 m3 = Foo(m1, m2) m4 = Bar(m3) - @test all(fcollect(m4) .=== [m4, m3, m1, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2]) + @test all(fcollect(m4) .=== [m1, m2, m3, m4]) + @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m2, m3, m4]) @test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4]) m1 = [1, 2, 3] @@ -81,12 +101,12 @@ end m0 = NoChildren(:a, :b) m3 = Foo(m2, m0) m4 = Bar(m3) - @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) + @test all(fcollect(m4) .=== [m1, m2, m0, m3, m4]) m1 = [1, 2, 3] m2 = [1, 2, 3] m3 = Foo(m1, m2) - @test all(fcollect(m3) .=== [m3, m1, m2]) + @test all(fcollect(m3) .=== [m1, m2, m3]) end struct FFoo @@ -143,8 +163,8 @@ end m2 = [1, 2, 3] m3 = FFoo(m1, m2, (:y, )) m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) + @test all(fcollect(m4) .=== [m2, m3, m4]) + @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m3, m4]) @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) m0 = NoChildren(:a, :b) @@ -152,5 +172,5 @@ end m2 = FBar(m1, ()) m3 = FFoo(m2, m0, (:x, :y,)) m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + @test all(fcollect(m4) .=== [m2, m0, m3, m4]) end