Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract common functionality into fold #32

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
escfs = [:($f = x.$f) for f in fs]

@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
Expand Down Expand Up @@ -169,7 +169,54 @@ function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end
_default_walk(f, ::Nothing, ::Nothing) = nothing
_default_walk(_, ::Nothing, ::Nothing) = nothing

# Side effects only, saves a restructure
function _foreach_walk(f, x)
foreach(f, children(x))
return x
end
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

### WARNING: the following is unstable internal functionality. Use at your own risk!
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered _-prefixing everything but didn't for now because these will probably become part of the public API at some point. Happy to prefix them if that's desired though.

# Wrapper over an IdDict which only saves values with a stable object identity
struct Cache{K,V}
inner::IdDict{K,V}
end
Cache() = Cache(IdDict())

usecache(x) = !isbits(x) && ismutable(x)
# Functionally immutable and observe value semantics, but still `ismutable` and not `isbits`
usecache(::Union{String,Symbol}) = false
# For varargs
usecache(xs::Tuple) = all(usecache, xs)
Base.get!(f, c::Cache, x) = usecache(x) ? get!(f, c.inner, x) : f()

# Passthrough used to disable caching (e.g. when passing `cache=false`)
struct NoCache end
Base.get!(f, ::NoCache, _) = f()

# Encapsulates the self-recursive part of a recursive tree reduction (fold).
# This allows calling functions to remove any self-calls or nested callback closures.
struct Fold{F,L,C,W}
fn::F
isleaf::L
cache::C
walk::W
end
(fld::Fold)(x) = get!(fld.cache, x) do
fld.fn(fld.isleaf(x) ? x : fld.walk(fld, x))
end

# Convenience function for working with `Fold`
function fold(f, x; isleaf = isleaf, cache = false, walk = _default_walk)
if cache === true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that fmap and fcollect are still passing through IdDict() by default. The hope is that new functions either don't allow customization of the cache or use a "safer" default.

cache = Cache()
elseif cache === false
cache = NoCache()
end
return Fold(f, isleaf, cache, walk)(x)
end
### end of unstable internal functionality

"""
fmap(f, x; exclude = Functors.isleaf, walk = Functors._default_walk)
Expand Down Expand Up @@ -253,11 +300,9 @@ Foo(Bar([1, 2, 3]), (40, 50, Bar(Foo(6, 7))))
```
"""
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict())
haskey(cache, x) && return cache[x]
y = exclude(x) ? f(x) : walk(x -> fmap(f, x, exclude = exclude, walk = walk, cache = cache), x)
cache[x] = y

return y
return fold(x; cache, walk, isleaf = exclude) do node
exclude(node) ? f(node) : node
end
end

"""
Expand Down Expand Up @@ -296,8 +341,8 @@ fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x))
fcollect(x; exclude = v -> false)

Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref)
and collecting the results into a flat array, ordered by a breadth-first
traversal of `x`, respecting the iteration order of `children` calls.
and collecting the results into a flat array, ordered by a depth-first,
post-order traversal of `x` that respects the iteration order of `children` calls.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically breaking (as can be seen from the test changes), but it's unclear to me how many users were relying on the breadth-first iteration order. The one example I found on JuliaHub, AlphaZero, was not. It's also odd that the original worked breadth-first when pretty much all other traversals based on functors are depth-first. So RFC, and happy to back this out if we consider it too breaking.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to have this breaking change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that nothing now in Optimisers.jl now cares about the order taken by fmap.

Of course the two halves of destructure care a lot and must match. I haven't completely understood how #31 would have to change; it currently uses fmap for one half and something hand-written for the other.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order in fmap has always been stable, things brings fcollect in line as it was the odd one out.


Doesn't recurse inside branches rooted at nodes `v`
for which `exclude(v) == true`.
Expand All @@ -324,33 +369,27 @@ Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m)
4-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
[1, 2, 3]
Bar([1, 2, 3])
NoChildren(:a, :b)
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m, exclude = v -> v isa Bar)
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
NoChildren(:a, :b)
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))

julia> fcollect(m, exclude = v -> Functors.isleaf(v))
2-element Vector{Any}:
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
Bar([1, 2, 3])
Foo(Bar([1, 2, 3]), NoChildren(:a, :b))
```
"""
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
# for the results, to preserve traversal order (important downstream!).
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
end
return output
function fcollect(x; output = [], cache = Base.IdDict(), exclude = v -> false)
fold(x; cache, isleaf = exclude, walk = _foreach_walk) do node
exclude(node) || push!(output, node); # always return nothing
end
return output
end

# Allow gradients and other constructs that match the structure of the functor
Expand Down
34 changes: 27 additions & 7 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,26 @@ end
end
end

@testset "Folds" begin
arrays = ntuple(i -> [i], 3)
model = Foo(
Foo(arrays[1], arrays[2]),
Foo(arrays[3], arrays[1])
)

total = Ref(0)
Functors.fmap(model, cache = true) do x
total[] += only(x)
end
@test total[] == 6

total = Ref(0)
Functors.fmap(model, cache = false) do x
total[] += only(x)
end
@test total[] == 7
end

@testset "Nested" begin
model = Bar(Foo(1, [1, 2, 3]))

Expand Down Expand Up @@ -72,21 +92,21 @@ end
m2 = 1
m3 = Foo(m1, m2)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m1, m2])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3, m2])
@test all(fcollect(m4) .=== [m1, m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Foo) .=== [m4])

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
@test all(fcollect(m4) .=== [m1, m2, m0, m3, m4])

m1 = [1, 2, 3]
m2 = [1, 2, 3]
m3 = Foo(m1, m2)
@test all(fcollect(m3) .=== [m3, m1, m2])
@test all(fcollect(m3) .=== [m1, m2, m3])
end

struct FFoo
Expand Down Expand Up @@ -143,14 +163,14 @@ end
m2 = [1, 2, 3]
m3 = FFoo(m1, m2, (:y, ))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4) .=== [m2, m3, m4])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m3, m4])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])

m0 = NoChildren(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2, m0])
@test all(fcollect(m4) .=== [m2, m0, m3, m4])
end