Skip to content

Commit

Permalink
Rebase with FluxML#39
Browse files Browse the repository at this point in the history
  • Loading branch information
darsnack committed Oct 31, 2022
1 parent cb9a0f2 commit f788256
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 78 deletions.
75 changes: 0 additions & 75 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,81 +34,6 @@ isleaf(x) = children(x) === ()

children(x) = functor(x)[1]

function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end

usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
usecache(::Nothing, x) = false

@generated function anymutable(x::T) where {T}
ismutabletype(T) && return true
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
return Expr(:(||), subs...)
end

struct NoKeyword end

function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
if usecache(cache, x) && haskey(cache, x)
return prune isa NoKeyword ? cache[x] : prune
end
ret = if exclude(x)
f(x)
else
walk(x -> fmap(f, x; exclude, walk, cache, prune), x)
end
if usecache(cache, x)
cache[x] = ret
end
ret
end

###
### Extras
###

fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)

function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
# for the results, to preserve traversal order (important downstream!).
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
end
return output
end

###
### Vararg forms
###

function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = anymutable(x) ? IdDict() : nothing, prune = NoKeyword())
if usecache(cache, x) && haskey(cache, x)
return prune isa NoKeyword ? cache[x] : prune
end
ret = if exclude(x)
f(x, ys...)
else
walk((xy...,) -> fmap(f, xy...; exclude, walk, cache, prune), x, ys...)
end
if usecache(cache, x)
cache[x] = ret
end
ret
end

function _default_walk(f, x, ys...)
func, re = functor(x)
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end

###
### FlexibleFunctors.jl
###
Expand Down
19 changes: 16 additions & 3 deletions src/walks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ end

struct NoKeyword end

usecache(::AbstractDict, x) = isleaf(x) ? anymutable(x) : ismutable(x)
usecache(::Nothing, x) = false

@generated function anymutable(x::T) where {T}
ismutabletype(T) && return true
subs = [:(anymutable(getfield(x, $f))) for f in QuoteNode.(fieldnames(T))]
return Expr(:(||), subs...)
end

"""
CachedWalk(walk[; prune])
Expand All @@ -109,11 +118,15 @@ CachedWalk(walk; prune = NoKeyword(), cache = IdDict()) =
CachedWalk(walk, prune, cache)

function (walk::CachedWalk)(recurse, x, ys...)
if haskey(walk.cache, x)
should_cache = usecache(walk.cache, x)
if should_cache && haskey(walk.cache, x)
return walk.prune isa NoKeyword ? walk.cache[x] : walk.prune
else
walk.cache[x] = walk.walk(recurse, x, ys...)
return walk.cache[x]
ret = walk.walk(recurse, x, ys...)
if should_cache
walk.cache[x] = ret
end
return ret
end
end

Expand Down

0 comments on commit f788256

Please sign in to comment.