Skip to content

Commit

Permalink
Adapt Constructionbase and add AlignedStyle for zipped tree traversal (
Browse files Browse the repository at this point in the history
…#4)

* adapt constructionbase; add alignedstyle

* default aligned style

* more test

* refine scan

* more test

* update readme
  • Loading branch information
chengchingwen authored Jul 20, 2022
1 parent cf5bf72 commit 764bd69
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 63 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
name = "StructWalk"
uuid = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
authors = ["chengchingwen <[email protected]> and contributors"]
version = "0.1.0"
version = "0.2.0"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

[compat]
ConstructionBase = "1"
julia = "1.6"

[extras]
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,22 @@ import StructWalk: WalkStyle, walkstyle

struct FunctorStyle <: WalkStyle end

walkstyle(::FunctorStyle, x::AbstractArray) = identity, ()
StructWalk.children(::FunctorStyle, x::AbstractArray) = ()

struct Foo{X, Y}
x::X
y::Y
x::X
y::Y
end

struct Baz
x
y
x
y
end

walkstyle(::FunctorStyle, b::Baz) = x->Baz(x, b.y), (b.x,)

myfmap(f, x) = StructWalk.walk(f, identity, FunctorStyle(), x, x -> myfmap(f, x))
StructWalk.constructor(::FunctorStyle, b::Baz) = Base.Fix2(Baz, b.y)
StructWalk.children(::FunctorStyle, b::Baz) = (b.x,)

myfmap(f, x) = mapleaves(f, FunctorStyle(), x)

julia> foo = Foo(1, [1, 2, 3])
Foo{Int64, Vector{Int64}}(1, [1, 2, 3])
Expand Down
138 changes: 107 additions & 31 deletions src/StructWalk.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
module StructWalk

export prewalk, postwalk
import ConstructionBase
using ConstructionBase: constructorof

export prewalk, postwalk, mapleaves

"""
Abstract type `WalkStyle`
Expand All @@ -12,11 +15,11 @@ abstract type WalkStyle end
"""
walkstyle(::CustomWalkStyle, x::T) where {CumstomWalkStyle <: WalkStyle}
Should return a tuple of length 2-3 with:
Should return a tuple of length 3 with:
1. A proper constuctor for `T`, can be `identity` if `x` isa leaf.
2. Children of `x` in a tuple, or empty tuple `()` if `x` is a leaf.
3. [optional] a bool indicate whether element of 2. is the actual list of children. default to `false`.
1. [constructor](@ref): A proper constuctor for `T`, can be `identity` if `x` isa leaf.
2. [children](@ref): Children of `x` in a tuple, or empty tuple `()` if `x` is a leaf.
3. [iscontainer](@ref): A bool indicate whether element of 2. is the actual list of children. default to `false`.
For example, since `Array` has 0 `fieldcount`, we doesn't split the value into a tuple as children.
Instead, we return `(x,)` as children and the extra boolean `true`, so it will `walk`/`map` through `x`
Expand All @@ -28,20 +31,54 @@ function walkstyle end
walkstyle(x)
walkstyle(::Type{WalkStyle}, x::T) where T
return `T` and a tuple all field values of `x`.
Return `T` and a tuple all field values of `x`. The default behavior use
`ConstructionBase.constructorof` for the constructor and
`ConstructionBase.getfields` for the children.
"""
walkstyle(x) = walkstyle(WalkStyle, x)
walkstyle(s::WalkStyle, x) = walkstyle(WalkStyle, x)
function walkstyle(::Type{WalkStyle}, x::T) where T
n = fieldcount(T)
isleaf = iszero(n)
return T.name.wrapper, isleaf ? () : ntuple(i->getfield(x, i), n)
end
walkstyle(s::WalkStyle, x) = _walkstyle(s, x)
walkstyle(::Type{WalkStyle}, x) = _walkstyle(WalkStyle, x)
@inline _walkstyle(s, x) = constructor(s, x), children(s, x), iscontainer(s, x)

"""
constructor(s::WalkStyle, x)
Return the constructor for `x`, which would be applied to `children(s, x)`.
See also: [children](@ref), [iscontainer](@ref)
"""
constructor(x) = constructor(WalkStyle, x)
constructor(s::WalkStyle, x) = constructor(WalkStyle, x)
constructor(::Type{WalkStyle}, x) = iszero(fieldcount(typeof(x))) ? identity : ConstructionBase.constructorof(typeof(x))

"""
children(s::WalkStyle, x)
Return the children of `x`, which would be feeded to `constructor(s, x)`. If `x` is an container type like `Array`,
it can return a tuple of itself and set `iscontainer(s, x)` to `true`.
See also: [constructor](@ref), [iscontainer](@ref)
"""
children(x) = children(WalkStyle, x)
children(s::WalkStyle, x) = children(WalkStyle, x)
children(::Type{WalkStyle}, x) = iszero(fieldcount(typeof(x))) ? () : ConstructionBase.getfields(x)

"""
iscontainer(s::WalkStyle, x)
Return a `Bool` indicating whether `children(x)` return a tuple of itself or not.
See also: [constructor](@ref), [children](@ref)
"""
iscontainer(x) = iscontainer(WalkStyle, x)
iscontainer(s::WalkStyle, x) = iscontainer(WalkStyle, x)
iscontainer(::Type{WalkStyle}, x) = false

const WALKSTYLE = Union{WalkStyle, Type{WalkStyle}}

# default walkstyle for some types
include("./walkstyle.jl")


"""
LeafNode(x)
Expand All @@ -55,31 +92,31 @@ end

@nospecialize

walk(_, _, _, x::LeafNode, _) = x.x
walk(_, _, ::WALKSTYLE, _, x::LeafNode) = x.x

walk(f, style, x, inner_walk) = walk(f, f, style, x, inner_walk)
function walk(f, g, style, x, inner_walk)
S = walkstyle(style, x)
T, fields = S
walk(f, style::WALKSTYLE, inner_walk, x) = walk(f, f, style, inner_walk, x)
function walk(f, g, style::WALKSTYLE, inner_walk, x)
T, fields, iscontainer = walkstyle(style, x)
isleaf = isempty(fields)
isnontuple = length(S) <= 2 ? false : S[3]
if isleaf
return f(x)
else
h = isnontuple ? v->map(inner_walk, v) : inner_walk
return g(T(map(h, fields)...))
h = iscontainer ? Base.Fix1(map, inner_walk) : inner_walk
v = map(h, fields)
return g(T(v...))
end
end


"""
postwalk(f, [style = WalkStyle], x)
Applies `f` to each node in `x` and return the result.
`f` sees the leaves first and then the transformed node.
Apply `f` to each node in `x` and return the result.
`f` sees the leaves first and then the transformed node.
# Example
```julia
```julia-repl
julia> postwalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = 2
x = 4
Expand All @@ -105,21 +142,20 @@ x = (3//2, 5//2)
See also: [`prewalk`](@ref)
"""
postwalk(f, x) = postwalk(f, WalkStyle, x)
postwalk(f, style, x) = walk(f, style, x, x -> postwalk(f, style, x))

postwalk(f, style::WALKSTYLE, x) = walk(f, style, x -> postwalk(f, style, x), x)

"""
prewalk(f, [style = WalkStyle], x)
Applies `f` to each node in `x` and return the result.
`f` sees the node first and then the transformed leaves.
Apply `f` to each node in `x` and return the result.
`f` sees the node first and then the transformed leaves.
*Notice* that it is possible it walk infinitely if you transform a node into non-leaf value.
Wrapping the non-leaf value with `LeafNode(y)` in `f` to prevent infinite walk.
# Example
```julia
```julia-repl
julia> prewalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = (a = 2, b = (c = 4, d = 0))
x = 2
Expand All @@ -145,10 +181,50 @@ x = 6
See also: [`postwalk`](@ref), [`LeafNode`](@ref)
"""
prewalk(f, x) = prewalk(f, WalkStyle, x)
prewalk(f, style, x) = walk(identity, style, f(x), x -> prewalk(f, style, x))
prewalk(f, style::WALKSTYLE, x) = walk(identity, style, x -> prewalk(f, style, x), f(x))

"""
mapleaves(f, [style = WalkStyle], x)
@specialize
Apply `f` to each leaf nodes in `x` and return the result.
`f` only see leaf nodes.
# Example
```julia-repl
julia> mapleaves(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = 2
x = 4
x = 0
(a = 3, b = (c = 5, d = 1))
```
"""
mapleaves(f, x) = mapleaves(f, WalkStyle, x)
mapleaves(f, style::WALKSTYLE, x) = walk(f, identity, style, x -> mapleaves(f, style, x), x)

"""
mapnonleaves(f, [style = WalkStyle], x)
Apply `f` to each non-leaf in `x` and return the result.
`f` only see non-leaf nodes.
# Example
```julia-repl
julia> StructWalk.mapnonleaves(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = (c = 4, d = 0)
x = (a = 2, b = (c = 4, d = 0))
(a = 2, b = (c = 4, d = 0))
```
"""
mapnonleaves(f, x) = mapnonleaves(f, WalkStyle, x)
mapnonleaves(f, style::WALKSTYLE, x) = walk(identity, f, style, x -> mapnonleaves(f, style, x), x)

include("./aligned.jl")
include("./scan.jl")

@specialize

end
60 changes: 60 additions & 0 deletions src/aligned.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
abstract type AlignedStyle{W<:WalkStyle} end

struct DefaultAlignedStyle{W} <: AlignedStyle{W}
walkstyle::W
end

const ALIGNED = Union{AlignedStyle, Type{<:AlignedStyle}, DefaultAlignedStyle}

WalkStyle(style::DefaultAlignedStyle) = style.walkstyle
WalkStyle(::AlignedStyle{W}) where W = W()
WalkStyle(::Type{AlignedStyle}) = WalkStyle
WalkStyle(::Type{<:AlignedStyle{W}}) where W = W()

constructor(s::ALIGNED, x::T, y::T, z::T...) where T = T
constructor(s::ALIGNED, x::NamedTuple{name}, y::NamedTuple{name}, z::NamedTuple{name}...) where name = NamedTuple{name}
constructor(s::ALIGNED, x::Union{NamedTuple, Tuple}, y::Union{NamedTuple, Tuple}, z::Union{NamedTuple, Tuple}...) = Tuple
constructor(s::ALIGNED, x, y, z...) = Vector

function children(style::ALIGNED, x)
wstyle = WalkStyle(style)
xc = children(wstyle, x)
x_is_c = iscontainer(wstyle, x)
return x_is_c ? length(xc) == 1 ? xc[1] : Iterators.flatten(xc) : xc
end
children(style::ALIGNED, x, y) = (children(style, x), children(style, y))
children(style::ALIGNED, x, y, z, w...) = (children(style, x), children(style, y, z, w...)...)

alignedstyle(x, y, z...) = alignedstyle(AlignedStyle, x, y, z...)
function alignedstyle(style::ALIGNED, x, y, z...)
T = constructor(style, x, y, z...)
C = children(style, x, y, z...)
return T, zip(C...)
end

walk(f, style::ALIGNED, inner_walk, x, y, z...) = walk(f, f, style, inner_walk, x, y, z...)
function walk(f, g, style::ALIGNED, inner_walk, x, y, z...)
T, C = alignedstyle(style, x, y, z...)
isleaf = isempty(C)
if isleaf
return f((x, y, z...))
else
return g(T(map(inner_walk, C)))
end
end

postwalk(f, x, y, z...) = postwalk(f, AlignedStyle, x, y, z...)
postwalk(f, style::WalkStyle, x, y, z...) = postwalk(f, DefaultAlignedStyle(style), x, y, z...)
postwalk(f, style::ALIGNED, x, y, z...) = walk(f, style, x -> postwalk(f, style, x...), x, y, z...)

prewalk(f, x, y, z...) = prewalk(f, AlignedStyle, x, y, z...)
prewalk(f, style::WalkStyle, x, y, z...) = prewalk(f, DefaultAlignedStyle(style), x, y, z...)
prewalk(f, style::ALIGNED, x) = walk(identity, style, x -> prewalk(f, style, x...), f(x), f(y), map(f, z)...)

mapleaves(f, x, y, z...) = mapleaves(f, AlignedStyle, x, y, z...)
mapleaves(f, style::WalkStyle, x, y, z...) = mapleaves(f, DefaultAlignedStyle(style), x, y, z...)
mapleaves(f, style::ALIGNED, x, y, z...) = walk(f, identity, style, x -> mapleaves(f, style, x...), x, y, z...)

mapnonleaves(f, x, y, z...) = mapnonleaves(f, AlignedStyle, x, y, z...)
mapnonleaves(f, style::WalkStyle, x, y, z...) = mapnonleaves(f, DefaultAlignedStyle(style), x, y, z...)
mapnonleaves(f, style::ALIGNED, x, y, z...) = walk(identity, f, style, x -> mapnonleaves(f, style, x...), x, y, z...)
46 changes: 46 additions & 0 deletions src/scan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
function walkby(f, g, h, style::WALKSTYLE, inner_scan, x)
_, fields, iscontainer = walkstyle(style, x)
isleaf = isempty(fields)
if isleaf
f(x)
else
g(x)
_h = iscontainer ? Base.Fix1(map, inner_scan) : inner_scan
foreach(_h, fields)
h(x)
end
return nothing
end

function walkby(f, g, h, style::ALIGNED, inner_scan, x, y, z...)
_, C = alignedstyle(style, x, y, z...)
isleaf = isempty(C)
X = (x, y, z...)
if isleaf
f(X)
else
g(X)
foreach(inner_scan, C)
h(X)
end
return nothing
end


"""
scan(f, [style = WalkStyle], x)
Walk through `x` without constructing anything.
"""
scan(f, x) = scan(f, WalkStyle, x)
scan(f, style::WALKSTYLE, x) = scan(f, f, style, x)
scan(f, g, style::WALKSTYLE, x) = scan(f, g, identity, style, x)
scan(f, g, h, style::WALKSTYLE, x) = walkby(f, g, h, style, x -> scan(f, g, h, style, x), x)

scan(f, x, y, z...) = scan(f, AlignedStyle, x, y, z...)
scan(f, style::WalkStyle, x, y, z...) = scan(f, DefaultAlignedStyle(style), x, y, z...)
scan(f, g, style::WalkStyle, x, y, z...) = scan(f, g, DefaultAlignedStyle(style), x, y, z...)
scan(f, g, h, style::WalkStyle, x, y, z...) = scan(f, g, h, DefaultAlignedStyle(style), x, y, z...)
scan(f, style::ALIGNED, x, y, z...) = scan(f, f, style, x, y, z...)
scan(f, g, style::ALIGNED, x, y, z...) = scan(f, g, identity, style, x, y, z...)
scan(f, g, h, style::ALIGNED, x, y, z...) = walkby(f, g, h, style, x -> scan(f, g, h, style, x...), x, y, z...)
24 changes: 19 additions & 5 deletions src/walkstyle.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
walkstyle(::Type{WalkStyle}, x::T) where {T <: AbstractArray} = t->convert(AbstractArray, t), (x,), true
walkstyle(::Type{WalkStyle}, x::T) where {T <: Tuple} = Tuple, (x,), true
walkstyle(::Type{WalkStyle}, x::T) where {T <: NamedTuple} = let name=keys(x); x->NamedTuple{name}(x); end, (x,), true
walkstyle(::Type{WalkStyle}, x::Expr) = (head, args)->Expr(head, args...), (x.head, x.args)
walkstyle(::Type{WalkStyle}, x::T) where {T <: AbstractDict} = Dict, ((p for p in x),), true
expr_constructor(head, args) = Expr(head, args...)

constructor(::Type{WalkStyle}, x::Tuple) = Tuple
constructor(::Type{WalkStyle}, x::NamedTuple) = let name = keys(x); x->NamedTuple{name}(x); end
constructor(::Type{WalkStyle}, x::Expr) = expr_constructor

children(::Type{WalkStyle}, x::AbstractArray) = (x,)
children(::Type{WalkStyle}, x::Tuple) = (x,)
children(::Type{WalkStyle}, x::NamedTuple) = (x,)
children(::Type{WalkStyle}, x::AbstractDict) = ((p for p in x),)

for type in :(
AbstractArray,
AbstractDict,
Tuple,
NamedTuple,
).args
@eval iscontainer(::Type{WalkStyle}, x::$type) = true
end
Loading

2 comments on commit 764bd69

@chengchingwen
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/64585

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 764bd69ec550a1451a41f17bca36235b979ce4d3
git push origin v0.2.0

Please sign in to comment.