Skip to content

Commit

Permalink
init impl
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Dec 22, 2021
1 parent 564702a commit 7031449
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 5 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ uuid = "31cdf514-beb7-4750-89db-dda9d2eb8d3d"
authors = ["chengchingwen <[email protected]> and contributors"]
version = "0.1.0"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
114 changes: 113 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,118 @@
# StructWalk
# StructWalk.jl

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://chengchingwen.github.io/StructWalk.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://chengchingwen.github.io/StructWalk.jl/dev)
[![Build Status](https://github.com/chengchingwen/StructWalk.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/chengchingwen/StructWalk.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/chengchingwen/StructWalk.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/chengchingwen/StructWalk.jl)

Transform functions for Julia struct. Can be viewed as a general version of `MacroTools`'s `prewalk`/`postwalk` or `Functors`'s `@functor`/`fmap`.

# Examples

## Basic usage

```julia
julia> postwalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = 2
x = 4
x = 0
x = (c = 5, d = 1)
x = (a = 3, b = (c = 5, d = 1))
(a = 3, b = (c = 5, d = 1))

julia> postwalk(x -> @show(x) isa Integer ? x + 1 : x .+ 1, (3, 5))
x = 3
x = 5
x = (4, 6)
(5, 7)

julia> postwalk(x -> @show(x) isa Integer ? x // 2 : x isa Tuple ? =>(x .+ 1...) : x, (3, 5))
x = 3
x = 5
x = (3//2, 5//2)
5//2 => 7//2

julia> prewalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = (a = 2, b = (c = 4, d = 0))
x = 2
x = (c = 4, d = 0)
x = 4
x = 0
(a = 3, b = (c = 5, d = 1))

julia> prewalk(x -> @show(x) isa Integer ? x + 1 : x .+ 1, (3, 5))
x = (3, 5)
x = 4
x = 6
(5, 7)

julia> prewalk(x -> @show(x) isa Integer ? StructWalk.LeafNode(x // 2) : x isa Tuple ? =>(x .+ 1...) : x, (3, 5))
x = (3, 5)
x = 4
x = 6
2 => 3

```


## Structural replace

```julia
julia> x = (a=3, b=(w=3, b=0))
(a = 3, b = (w = 3, b = 0))

julia> postwalk(x) do x
if x isa NamedTuple{(:w, :b)}
return x[1]=>x[2]
end
return x
end
(a = 3, b = 3 => 0)

```


## More example

```julia
using StructWalk
import StructWalk: WalkStyle, walkstyle

struct FunctorStyle <: WalkStyle end

walkstyle(::FunctorStyle, x::AbstractArray) = identity, ()

struct Foo{X, Y}
x::X
y::Y
end

struct Baz
x
y
end

walkstyle(::FunctorStyle, b::Baz) = x->Baz(x, b.y), (b.x,)

myfmap(f, x) = StructWalk.walk(f, identity, FunctorStyle(), x, x -> myfmap(f, x))


julia> foo = Foo(1, [1, 2, 3])
Foo{Int64, Vector{Int64}}(1, [1, 2, 3])

julia> postwalk(x-> x isa Integer ? float(x) : x, FunctorStyle(), foo)
Foo{Float64, Vector{Int64}}(1.0, [1, 2, 3])

julia> myfmap(float, foo)
Foo{Float64, Vector{Float64}}(1.0, [1.0, 2.0, 3.0])

julia> baz = Baz(1, 2)
Baz(1, 2)

julia> myfmap(float, baz)
Baz(1.0, 2)

julia> using CUDA; myfmap(CUDA.cu, foo)
Foo{Int64, CuArray{Int64, 1, CUDA.Mem.DeviceBuffer}}(1, [1, 2, 3])

```
155 changes: 154 additions & 1 deletion src/StructWalk.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,158 @@
module StructWalk

# Write your package code here.
export prewalk, postwalk

"""
Abstract type `WalkStyle`
Subtype `WalkStyle` and overload [`walkstyle`](@ref) to define custom walking behaviors (constructor / children /...).
"""
abstract type WalkStyle end

"""
walkstyle(::CustomWalkStyle, x::T) where {CumstomWalkStyle <: WalkStyle}
Should return a tuple of length 2-3 with:
1. A proper constuctor for `T`, can be `identity` if `x` isa leaf.
2. Children of `x` in a tuple, or empty tuple `()` if `x` is a leaf.
3. [optional] a bool indicate whether element of 2. is the actual list of children. default to `false`.
For example, since `Array` has 0 `fieldcount`, we doesn't split the value into a tuple as children.
Instead, we return `(x,)` as children and the extra boolean `true`, so it will `walk`/`map` through `x`
accordingly.
"""
function walkstyle end

"""
walkstyle(x)
walkstyle(::Type{WalkStyle}, x::T) where T
return `T` and a tuple all field values of `x`.
"""
walkstyle(x) = walkstyle(WalkStyle, x)
walkstyle(s::WalkStyle, x) = walkstyle(WalkStyle, x)
function walkstyle(::Type{WalkStyle}, x::T) where T
n = fieldcount(T)
isleaf = iszero(n)
return T.name.wrapper, isleaf ? () : ntuple(i->getfield(x, i), n)
end
walkstyle(::Type{WalkStyle}, x::T) where {T <: Array} = t->convert(AbstractArray, t), (x,), true
walkstyle(::Type{WalkStyle}, x::T) where {T <: Tuple} = Tuple, (x,), true
walkstyle(::Type{WalkStyle}, x::T) where {T <: NamedTuple} = let name=keys(x); x->NamedTuple{name}(x); end, (x,), true
walkstyle(::Type{WalkStyle}, x::Expr) = (head, args)->Expr(head, args...), (x.head, x.args)

"""
LeafNode(x)
special type for marking non-leaf value as leaf. Use with `prewalk`.
See also: [`prewalk`](@ref)
"""
struct LeafNode{T}
x::T
end

@nospecialize

walk(_, _, _, x::LeafNode, _) = x.x

walk(f, style, x, inner_walk) = walk(f, f, style, x, inner_walk)
function walk(f, g, style, x, inner_walk)
S = walkstyle(style, x)
T, fields = S
isleaf = isempty(fields)
isnontuple = length(S) <= 2 ? false : S[3]
if isleaf
return f(x)
else
h = isnontuple ? v->map(inner_walk, v) : inner_walk
return g(T(map(h, fields)...))
end
end

"""
postwalk(f, [style = WalkStyle], x)
Applies `f` to each node in `x` and return the result.
`f` sees the leaves first and then the transformed node.
# Example
```julia
julia> postwalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = 2
x = 4
x = 0
x = (c = 5, d = 1)
x = (a = 3, b = (c = 5, d = 1))
(a = 3, b = (c = 5, d = 1))
julia> postwalk(x -> @show(x) isa Integer ? x + 1 : x .+ 1, (3, 5))
x = 3
x = 5
x = (4, 6)
(5, 7)
julia> postwalk(x -> @show(x) isa Integer ? x // 2 : x isa Tuple ? =>(x .+ 1...) : x, (3, 5))
x = 3
x = 5
x = (3//2, 5//2)
5//2 => 7//2
```
See also: [`prewalk`](@ref)
"""
postwalk(f, x) = postwalk(f, WalkStyle, x)
postwalk(f, style, x) = walk(f, style, x, x -> postwalk(f, style, x))


"""
prewalk(f, [style = WalkStyle], x)
Applies `f` to each node in `x` and return the result.
`f` sees the node first and then the transformed leaves.
*Notice* that it is possible it walk infinitely if you transform a node into non-leaf value.
Wrapping the non-leaf value with `LeafNode(y)` in `f` to prevent infinite walk.
# Example
```julia
julia> prewalk(x -> @show(x) isa Integer ? x + 1 : x, (a=2, b=(c=4, d=0)))
x = (a = 2, b = (c = 4, d = 0))
x = 2
x = (c = 4, d = 0)
x = 4
x = 0
(a = 3, b = (c = 5, d = 1))
julia> prewalk(x -> @show(x) isa Integer ? x + 1 : x .+ 1, (3, 5))
x = (3, 5)
x = 4
x = 6
(5, 7)
julia> prewalk(x -> @show(x) isa Integer ? StructWalk.LeafNode(x // 2) : x isa Tuple ? =>(x .+ 1...) : x, (3, 5))
x = (3, 5)
x = 4
x = 6
2 => 3
```
See also: [`postwalk`](@ref), [`LeafNode`](@ref)
"""
prewalk(f, x) = prewalk(f, WalkStyle, x)
function prewalk(f, style, x)
y = f(x)
y == x && return x
return walk(identity, style, y, x -> prewalk(f, style, x))
end


@specialize


end

2 comments on commit 7031449

@chengchingwen
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/51047

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 703144955d9f32b4f9a494eacb6a0083621939d2
git push origin v0.1.0

Please sign in to comment.