diff --git a/src/functor.jl b/src/functor.jl index 67f56ec..df9e31b 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -40,14 +40,24 @@ function _default_walk(f, x) end usecache(x) = !isbits(x) +usecache(x::Union{String, Symbol}) = false struct NoKeyword end function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = usecache(x) ? IdDict() : nothing, prune = NoKeyword()) - usecache(x) && haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune - xnew = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x) - usecache(x) && setindex!(cache, xnew, x) - return xnew + if exclude(x) + if usecache(x) + if haskey(cache, x) + prune isa NoKeyword ? cache[x] : prune + else + cache[x] = f(x) + end + else + f(x) + end + else + walk(x -> fmap(f, x; exclude = exclude, walk = walk, cache = cache, prune = prune), x) + end end ### @@ -73,11 +83,20 @@ end ### Vararg forms ### -function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword()) - usecache(x) && haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune - xnew = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...) - usecache(x) && setindex!(cache, xnew, x) - return xnew +function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = usecache(x) ? IdDict() : nothing, prune = NoKeyword()) + if exclude(x) + if usecache(x) + if haskey(cache, x) + prune isa NoKeyword ? cache[x] : prune + else + cache[x] = f(x, ys...) + end + else + f(x, ys...) + end + else + walk((xy...,) -> fmap(f, xy...; exclude = exclude, walk = walk, cache = cache, prune = prune), x, ys...) + end end function _default_walk(f, x, ys...) diff --git a/test/basics.jl b/test/basics.jl index 6ce8f94..c771098 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -14,6 +14,7 @@ struct NoChildren2; x; y; end struct NoChild{T}; x::T; end + ### ### Basic functionality ### @@ -56,13 +57,13 @@ end m1p = fmapstructure(identity, m1; prune = nothing) @test m1p == (x = [1, 2, 3], y = (x = [1, 2, 3], y = (x = nothing, y = [1, 2, 3]))) - # A non-leaf node can also be repeated: + # The cache applies only to leaf nodes, so that "4" is not shared: m2 = Foo(Foo(shared, 4), Foo(shared, 4)) @test m2.x === m2.y m2f = fmap(float, m2) @test m2f.x.x === m2f.y.x m2p = fmapstructure(identity, m2; prune = Bar(0)) - @test m2p == (x = (x = [1, 2, 3], y = 4), y = Bar(0)) + @test m2p == (x = (x = [1, 2, 3], y = 4), y = (x = Bar{Int64}(0), y = 4)) # Repeated isbits types should not automatically be regarded as shared: m3 = Foo(Foo(shared, 1:3), Foo(1:3, shared)) @@ -75,15 +76,18 @@ end @test_skip 0 == @allocated fmap(float, (x=1, y=(2, 3), z=4:5)) @testset "usecache" begin + # Leaf types: @test usecache([1,2]) - @test usecache(Ref(3)) - @test !usecache(4.0) + @test usecache(NoChild([1,2])) + @test !usecache(NoChild((3,4))) + + # Not leaf by default, but `exclude` can change that: + @test usecache(Ref(3)) @test !usecache((5, 6.0)) @test !usecache((a = 2pi, b = missing)) - @test usecache(Bar([1,2])) - @test !usecache(Bar((3,4))) + @test usecache((x = [1,2,3], y = 4)) end end