Skip to content

Commit 31cb97a

Browse files
committed
Rebase with FluxML#39
1 parent cb9a0f2 commit 31cb97a

File tree

3 files changed

+21
-79
lines changed

3 files changed

+21
-79
lines changed

src/functor.jl

-75
Original file line numberDiff line numberDiff line change
@@ -34,81 +34,6 @@ isleaf(x) = children(x) === ()
3434

3535
children(x) = functor(x)[1]
3636

37-
function _default_walk(f, x)
38-
func, re = functor(x)
39-
re(map(f, func))
40-
end
41-
42-
usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
43-
usecache(::Nothing, x) = false
44-
45-
@generated function anymutable(x::T) where {T}
46-
ismutabletype(T) && return true
47-
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
48-
return Expr(:(||), subs...)
49-
end
50-
51-
struct NoKeyword end
52-
53-
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
54-
if usecache(cache, x) && haskey(cache, x)
55-
return prune isa NoKeyword ? cache[x] : prune
56-
end
57-
ret = if exclude(x)
58-
f(x)
59-
else
60-
walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
61-
end
62-
if usecache(cache, x)
63-
cache[x] = ret
64-
end
65-
ret
66-
end
67-
68-
###
69-
### Extras
70-
###
71-
72-
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)
73-
74-
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
75-
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
76-
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
77-
# for the results, to preserve traversal order (important downstream!).
78-
x in cache && return output
79-
if !exclude(x)
80-
push!(cache, x)
81-
push!(output, x)
82-
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
83-
end
84-
return output
85-
end
86-
87-
###
88-
### Vararg forms
89-
###
90-
91-
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
92-
if usecache(cache, x) && haskey(cache, x)
93-
return prune isa NoKeyword ? cache[x] : prune
94-
end
95-
ret = if exclude(x)
96-
f(x, ys...)
97-
else
98-
walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
99-
end
100-
if usecache(cache, x)
101-
cache[x] = ret
102-
end
103-
ret
104-
end
105-
106-
function _default_walk(f, x, ys...)
107-
func, re = functor(x)
108-
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
109-
re(map(f, func, yfuncs...))
110-
end
111-
11237
###
11338
### FlexibleFunctors.jl
11439
###

src/maps.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ function fmap(f, x, ys...; exclude = isleaf,
44
walk = DefaultWalk(),
55
cache = IdDict(),
66
prune = NoKeyword())
7-
_walk = CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache)
7+
_walk = if isnothing(cache)
8+
ExcludeWalk(AnonymousWalk(walk), f, exclude)
9+
else
10+
CachedWalk(ExcludeWalk(AnonymousWalk(walk), f, exclude), prune, cache)
11+
end
812
fmap(_walk, f, x, ys...)
913
end
1014

src/walks.jl

+16-3
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ end
8888

8989
struct NoKeyword end
9090

91+
usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
92+
usecache(::Nothing, x) = false
93+
94+
@generated function anymutable(x::T) where {T}
95+
ismutabletype(T) && return true
96+
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
97+
return Expr(:(||), subs...)
98+
end
99+
91100
"""
92101
CachedWalk(walk[; prune])
93102
@@ -109,11 +118,15 @@ CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
109118
CachedWalk(walk, prune, cache)
110119

111120
function (walk::CachedWalk)(recurse, x, ys...)
112-
if haskey(walk.cache, x)
121+
should_cache = usecache(walk.cache, x)
122+
if should_cache && haskey(walk.cache, x)
113123
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
114124
else
115-
walk.cache[x] = walk.walk(recurse, x, ys...)
116-
return walk.cache[x]
125+
ret = walk.walk(recurse, x, ys...)
126+
if should_cache
127+
walk.cache[x] = ret
128+
end
129+
return ret
117130
end
118131
end
119132

0 commit comments

Comments
 (0)