diff --git a/Project.toml b/Project.toml index 9daf15b..c21d0d2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,17 +4,17 @@ authors = ["Mike J Innes "] version = "0.4.12" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] -Documenter = "1" +ConstructionBase = "1.4" julia = "1.6" [extras] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "StaticArrays", "Zygote"] +test = ["Test", "StaticArrays", "Zygote"] diff --git a/README.md b/README.md index f9aea80..bf9ddb0 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> model = Foo(1, [1, 2, 3]) Foo(1, [1, 2, 3]) @@ -41,8 +39,6 @@ julia> struct Bar x end -julia> @functor Bar - julia> model = Bar(Foo(1, [1, 2, 3])) Bar(Foo(1, [1, 2, 3])) diff --git a/docs/src/index.md b/docs/src/index.md index 276b3c7..8f0325f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,9 +8,9 @@ For large models it can be cumbersome or inefficient to work with parameters as ## Basic Usage and Implementation -When one marks a structure as [`@functor`](@ref) it means that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). +By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). -The workhorse of fmap is actually a lower level function, functor: +The workhorse of `fmap` is actually a lower level function, functor: ```julia-repl julia> using Functors @@ -20,8 +20,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> foo = Foo(1, [1, 2, 3]) # notice all the elements are integers julia> xs, re = Functors.functor(foo) @@ -50,12 +48,17 @@ julia> fmap(float, model) Baz(1.0, 2) ``` -Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default. +Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default. -## Appropriate Use +The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` +and also equivalent to avoiding `@functor` altogether. + +Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed. -!!! warning "Not everything should be a functor!" - Due to its generic nature it is very attractive to mark several structures as [`@functor`](@ref) when it may not be quite safe to do so. +!!! warning "Change to opt-out behaviour in v0.5" + Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were not functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. + +## Appropriate Use Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely. diff --git a/src/Functors.jl b/src/Functors.jl index ffcbf8f..599ee72 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -3,6 +3,7 @@ module Functors export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, fmap_with_path, fmapstructure_with_path, KeyPath, getkeypath, haskeypath, setkeypath! +using ConstructionBase: constructorof include("functor.jl") include("keypath.jl") @@ -42,8 +43,6 @@ this can be restricted be restructed by providing a tuple of field names. ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> Functors.children(Foo(1,2)) (x = 1, y = 2) @@ -52,6 +51,8 @@ julia> _, re = Functors.functor(Foo(1,2)); julia> re((10, 20)) Foo(10, 20) +julia> @functor Foo # same as before, nothing changes + julia> struct TwoThirds a; b; c; end julia> @functor TwoThirds (a, c) diff --git a/src/base.jl b/src/base.jl index 913aaa7..2691e73 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,14 +1,5 @@ -@functor Base.RefValue - -@functor Base.Pair - -@functor Base.Generator # aka Iterators.map - -@functor Base.ComposedFunction -@functor Base.Fix1 -@functor Base.Fix2 -@functor Base.Broadcast.BroadcastFunction +functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) @static if VERSION >= v"1.9" @functor Base.Splat @@ -51,26 +42,3 @@ end _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) _PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent _PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm) - -### -### Iterators -### - -@functor Iterators.Accumulate -# Count -@functor Iterators.Cycle -@functor Iterators.Drop -@functor Iterators.DropWhile -@functor Iterators.Enumerate -@functor Iterators.Filter -@functor Iterators.Flatten -# IterationCutShort -@functor Iterators.PartitionIterator -@functor Iterators.ProductIterator -@functor Iterators.Repeated -@functor Iterators.Rest -@functor Iterators.Reverse -# Stateful -@functor Iterators.Take -@functor Iterators.TakeWhile -@functor Iterators.Zip diff --git a/src/functor.jl b/src/functor.jl index 653a928..ca6b995 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -11,7 +11,19 @@ macro leaf(T) :($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x)) end -@leaf Any # every type is a leaf by default +# @leaf Any # every type is a leaf by default + +# Default functor +function functor(T, x) + names = fieldnames(T) + if isempty(names) + return NoChildren(), _ -> x + end + S = constructorof(T) # remove parameters from parametric types and support anonymous functions + vals = ntuple(i -> getfield(x, names[i]), length(names)) + return NamedTuple{names}(vals), y -> S(y...) +end + functor(x) = functor(typeof(x), x) functor(::Type{<:Tuple}, x) = x, identity @@ -30,7 +42,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f) end escfs = [:($f=x.$f) for f in fs] - + @eval m begin function $Functors.functor(::Type{<:$T}, x) reconstruct(y) = $T($(escargs...)) diff --git a/test/basics.jl b/test/basics.jl index 625e630..ced07a1 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -2,36 +2,38 @@ using Functors: functor, usecache struct Foo; x; y; end -@functor Foo Base.:(==)(x::Foo, y::Foo) = x.x == y.x && x.y == y.y struct Bar{T}; x::T; end -@functor Bar Base.:(==)(x::Bar, y::Bar) = x.x == y.x struct OneChild3; x; y; z; end @functor OneChild3 (y,) -struct NoChildren2; x; y; end +struct NoChild2; x; y; end +@functor NoChild2 () -struct NoChild{T}; x::T; end +struct NoChild1{T}; x::T; end +@functor NoChild1 () struct WrongOrder; x; y; z; end @functor WrongOrder (z, x) +struct LeafType{T}; x::T; end +@leaf LeafType ### ### Basic functionality ### -@testset "Children and Leaves" begin - no_children = NoChildren2(1, 2) +@testset "NoChild is not a leaf" begin + no_children = NoChild2(1, 2) has_children = Foo(1, 2) - @test Functors.isleaf(no_children) + @test !Functors.isleaf(no_children) @test !Functors.isleaf(has_children) - @test Functors.children(no_children) === Functors.NoChildren() + @test Functors.children(no_children) === (;) @test Functors.children(has_children) == (x=1, y=2) end @@ -108,8 +110,8 @@ end # Leaf types: @test usecache(d, [1,2]) @test !usecache(d, 4.0) - @test usecache(d, NoChild([1,2])) - @test !usecache(d, NoChild((3,4))) + @test usecache(d, LeafType([1,2])) + @test !usecache(d, LeafType((3,4))) # Not leaf: @test usecache(d, Ref(3)) # mutable container @@ -163,6 +165,17 @@ end @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) end +@testset "anonymous functions" begin + model = let W = rand(2,2), b = ones(2) + x -> tanh.(W*x .+ b) + end + newmodel = fmap(zero, model) + @test newmodel isa Function + @test newmodel([1,2]) == [0,0] + @test newmodel.W == [0 0; 0 0] + @test newmodel.b == [0, 0] +end + ### ### Extras ### @@ -185,7 +198,7 @@ end m1 = [1, 2, 3] m2 = Bar(m1) - m0 = NoChildren2(:a, :b) + m0 = NoChild2(:a, :b) m3 = Foo(m2, m0) m4 = Bar(m3) @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) @@ -299,74 +312,13 @@ end @test m̂.b ≈ fill(-0.2f0, size(m.b)) end -### -### FlexibleFunctors.jl -### - -struct FFoo - x - y - p -end -@flexiblefunctor FFoo p - -struct FBar - x - p -end -@flexiblefunctor FBar p - -struct FOneChild4 - x - y - z - p -end -@flexiblefunctor FOneChild4 p - -@testset "Flexible Nested" begin - model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) - - model′ = fmap(float, model) - - @test model.x.y == model′.x.y - @test model′.x.y isa Vector{Float64} -end - -@testset "Flexible Walk" begin - model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) - - model′ = fmapstructure(identity, model) - @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) - - model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) - - model2′ = fmapstructure(identity, model2) - @test model2′ == (; x=(0, (; x=[1, 2, 3]))) -end - -@testset "Flexible Property list" begin - model = FOneChild4(1, 2, 3, (:x, :z)) - model′ = fmap(x -> 2x, model) - - @test (model′.x, model′.y, model′.z) == (2, 2, 6) -end - -@testset "Flexible fcollect" begin - m1 = 1 - m2 = [1, 2, 3] - m3 = FFoo(m1, m2, (:y, )) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) - @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) +@testset "parametric types" begin + struct A{T} + x::T + end - m0 = NoChildren2(:a, :b) - m1 = [1, 2, 3] - m2 = FBar(m1, ()) - m3 = FFoo(m2, m0, (:x, :y,)) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + a = A(1) + @test fmap(x -> x/2, a) == A(0.5) end @testset "Dict" begin @@ -396,15 +348,13 @@ end end @testset "@leaf" begin - struct A; x; end - @functor A - a = A(1) - @test Functors.children(a) === (x = 1,) - struct B; x; end Functors.@leaf B b = B(1) children, re = Functors.functor(b) + + a = LeafType(1) + children, re = Functors.functor(a) @test children == Functors.NoChildren() @test re(children) === b end diff --git a/test/flexiblefunctors.jl b/test/flexiblefunctors.jl new file mode 100644 index 0000000..d34daf4 --- /dev/null +++ b/test/flexiblefunctors.jl @@ -0,0 +1,71 @@ + +### +### FlexibleFunctors.jl +### + +struct FFoo + x + y + p + end + @flexiblefunctor FFoo p + + struct FBar + x + p + end + @flexiblefunctor FBar p + + struct FOneChild4 + x + y + z + p + end + @flexiblefunctor FOneChild4 p + + @testset "Flexible Nested" begin + model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) + + model′ = fmap(float, model) + + @test model.x.y == model′.x.y + @test model′.x.y isa Vector{Float64} + end + + @testset "Flexible Walk" begin + model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) + + model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) + + model2′ = fmapstructure(identity, model2) + @test model2′ == (; x=(0, (; x=[1, 2, 3]))) + end + + @testset "Flexible Property list" begin + model = FOneChild4(1, 2, 3, (:x, :z)) + model′ = fmap(x -> 2x, model) + + @test (model′.x, model′.y, model′.z) == (2, 2, 6) + end + + @testset "Flexible fcollect" begin + m1 = 1 + m2 = [1, 2, 3] + m3 = FFoo(m1, m2, (:y, )) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2]) + @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) + @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) + + m0 = NoChild2(:a, :b) + m1 = [1, 2, 3] + m2 = FBar(m1, ()) + m3 = FFoo(m2, m0, (:x, :y,)) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + end + \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7db1e1b..18cc997 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,9 +4,8 @@ using LinearAlgebra using StaticArrays @testset "Functors.jl" begin - include("basics.jl") include("base.jl") include("keypath.jl") - + include("flexiblefunctors.jl") end