Skip to content

Commit

Permalink
prepare for v0.5 release (FluxML#91)
Browse files Browse the repository at this point in the history
* remove AnonymousWalk

* add inferred test

* revisit leaves
  • Loading branch information
CarloLucibello authored Nov 4, 2024
1 parent 2945731 commit def8bf3
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6' # Replace this with the minimum Julia version that your package supports.
- '1.10' # Replace this with the minimum Julia version that your package supports.
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
name = "Functors"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
authors = ["Mike J Innes <[email protected]>"]
version = "0.4.12"
version = "0.5.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Compat = "4.16"
ConstructionBase = "1.4"
Measurements = "2"
OrderedCollections = "1.6"
julia = "1.6"
Random = "1"
julia = "1.10"

[extras]
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ Bar(Foo(1.0, [1.0, 2.0, 3.0]))
> With v0.5 instead, this is no longer necessary: by default any type is recursively traversed up to the leaves
> and `ConstructionBase.constructorof` is used to reconstruct it.
> In order to opt-out of this behaviour and make a type non traversable you can use `@leaf Foo`.
>
> Most users should be unaffected by the change and could remove `@functor` from their custom types.
## Further Details

Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Functors.StructuralWalk
Functors.ExcludeWalk
Functors.CachedWalk
Functors.CollectWalk
Functors.AnonymousWalk
Functors.IterateWalk
```

Expand Down
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ By default all composite types in are functors and can be traversed, unless mark
The following types instead are explicitly marked as leaves in Functors.jl:
- `Number`.
- `AbstractArray{<:Number}`, except for the wrappers `Transpose`, `Adjoint`, and `PermutedDimsArray`.
- `AbstractString`.
- `AbstractRNG`.
- `AbstractString`, `AbstractChar`, `AbstractPattern`, `AbstractMatch`.

This is because in typical application the internals of these are abstracted away and it is not desirable to traverse them.

Expand Down
1 change: 1 addition & 0 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Functors
using Compat: @compat
using ConstructionBase: constructorof
using LinearAlgebra
using Random: AbstractRNG

export @leaf, @functor, @flexiblefunctor,
fmap, fmapstructure, fcollect, execute, fleaves,
Expand Down
6 changes: 5 additions & 1 deletion src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

@leaf Number
@leaf AbstractArray{<:Number}
@leaf AbstractString
@leaf AbstractString
@leaf AbstractChar
@leaf AbstractMatch
@leaf AbstractPattern
@leaf AbstractRNG

###
### Fast Paths for common types
Expand Down
59 changes: 18 additions & 41 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,48 +11,25 @@ Base.iterate(cache::WalkCache, state...) = iterate(cache.cache, state...)
Base.setindex!(cache::WalkCache, value, key) = setindex!(cache.cache, value, key)
Base.getindex(cache::WalkCache, x) = cache.cache[x]

@static if VERSION >= v"1.10.0-DEV.609"
function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#)
# :(return cache.cache[x]::(return_type(cache.walk, typeof(args))))
walk = cache.parameters[3]
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
if RT != Any
body = Expr(:(::), body, RT)
end
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
ci.inlineable = true
return ci
end
@eval function cacheget(cache::WalkCache, x, args...)
$(Expr(:meta, :generated, __cacheget_generator__))
$(Expr(:meta, :generated_only))
end
else
@generated function cacheget(cache::WalkCache, x, args...)
walk = cache.parameters[3]
world = typemax(UInt)
@static if VERSION >= v"1.8"
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
else
if isdefined(walk, :instance)
RT = Core.Compiler.return_type(walk.instance, Tuple{args...}, world)
else
RT = Any
end
end
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
if RT != Any
body = Expr(:(::), body, RT)
end
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
ci.inlineable = true
return ci
function __cacheget_generator__(world, source, self, cache, x, args #= for `return_type` only =#)
# :(return cache.cache[x]::(return_type(cache.walk, typeof(args))))
walk = cache.parameters[3]
RT = Core.Compiler.return_type(Tuple{walk, args...}, world)
body = Expr(:call, GlobalRef(Base, :getindex), Expr(:., :cache, QuoteNode(:cache)), :x)
if RT != Any
body = Expr(:(::), body, RT)
end
expr = Expr(:lambda, [Symbol("#self#"), :cache, :x, :args],
Expr(Symbol("scope-block"), Expr(:block, Expr(:meta, :inline), Expr(:return, body))))
ci = ccall(:jl_expand, Any, (Any, Any), expr, @__MODULE__)
ci.inlineable = true
return ci
end

@eval function cacheget(cache::WalkCache, x, args...)
$(Expr(:meta, :generated, __cacheget_generator__))
$(Expr(:meta, :generated_only))
end

# fallback behavior that only lookup for `x`
@inline cacheget(cache::AbstractDict, x, args...) = cache[x]
10 changes: 0 additions & 10 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,3 @@ end
macro flexiblefunctor(args...)
flexiblefunctorm(args...)
end

###
### Compat
###

if VERSION < v"1.7"
# Function in 1.7 checks t.name.flags & 0x2 == 0x2,
# but for 1.6 this seems to work instead:
ismutabletype(@nospecialize t) = t.mutable
end
2 changes: 1 addition & 1 deletion src/maps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function fmap(f, x, ys...; exclude = isleaf,
walk = DefaultWalk(),
cache = IdDict(),
prune = NoKeyword())
_walk = ExcludeWalk(AnonymousWalk(walk), f, exclude)
_walk = ExcludeWalk(walk, f, exclude)
if !isnothing(cache)
_walk = CachedWalk(_walk, prune, WalkCache(_walk, cache))
end
Expand Down
20 changes: 0 additions & 20 deletions src/walks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,6 @@ function execute(walk::AbstractWalk, x, ys...)
walk(recurse, x, ys...)
end

"""
AnonymousWalk(walk_fn)
Wrap a `walk_fn` so that `AnonymousWalk(walk_fn) isa AbstractWalk`.
This type only exists for backwards compatability and should not be directly used.
Attempting to wrap an existing `AbstractWalk` is a no-op (i.e. it is not wrapped).
"""
struct AnonymousWalk{F} <: AbstractWalk
walk::F

function AnonymousWalk(walk::F) where F
Base.depwarn("Wrapping a custom walk function as an `AnonymousWalk`. Future versions will only support custom walks that explicitly subtype `AbstractWalk`.", :AnonymousWalk)
return new{F}(walk)
end
end
# do not wrap an AbstractWalk
AnonymousWalk(walk::AbstractWalk) = walk

(walk::AnonymousWalk)(recurse, x, ys...) = walk.walk(recurse, x, ys...)

"""
DefaultWalk()
Expand Down
20 changes: 18 additions & 2 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ end
@test fmap(x -> x + 10, bf) == Base.Broadcast.BroadcastFunction(Bar(13.3))
end

VERSION >= v"1.7" && @testset "Returns" begin
@testset "Returns" begin
ret = Returns([0, pi, 2pi])
@test Functors.functor(ret)[1] == (value = [0, pi, 2pi],)
@test Functors.functor(ret)[2]((value = 1:3,)) === Returns(1:3)
end

VERSION >= v"1.9" && @testset "Splat" begin
@testset "Splat" begin
ret = Base.splat(Returns([0, pi, 2pi]))
@test Functors.functor(ret)[1].f.value == [0, pi, 2pi]
@test Functors.functor(ret)[2]((f = sin,)) === Base.splat(sin)
Expand Down Expand Up @@ -187,6 +187,22 @@ end
s = DummyString("hello")
@test Functors.isleaf(s)
end
@testset "AbstractPattern is leaf" begin
struct DummyPattern <: AbstractPattern
pat::Regex
end
p = DummyPattern(r"\d+")
@test Functors.isleaf(p)
@test Functors.isleaf(r"\d+")
end
@testset "AbstractChar is leaf" begin
struct DummyChar <: AbstractChar
ch::Char
end
c = DummyChar('a')
@test Functors.isleaf(c)
@test Functors.isleaf('a')
end

@testset "AbstractDict is functor" begin
od = OrderedDict(1 => 1, 2 => 2)
Expand Down
6 changes: 6 additions & 0 deletions test/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@testset "inferred" begin
r = [1,2]
x = (a = r, b = 3, c =(4, (d=5, e=r)))
y = @inferred(fmap(float, x))
@test y.a === y.c[2].e
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ using Measurements: ±
include("base.jl")
include("keypath.jl")
include("flexiblefunctors.jl")
include("cache.jl")
end

0 comments on commit def8bf3

Please sign in to comment.