From 1b4ac0f1afc84286b8f976065a8f71fc10e9a198 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 25 Nov 2022 08:01:22 +0100 Subject: [PATCH 01/24] functor by default strip type's parameters factorize flexiblefunctors tests ops support closures cleanup rebase use nochildren update readme docs --- Project.toml | 6 +- README.md | 4 -- docs/src/index.md | 19 ++++--- src/Functors.jl | 5 +- src/base.jl | 34 +----------- src/functor.jl | 16 +++++- test/basics.jl | 116 +++++++++++---------------------------- test/flexiblefunctors.jl | 71 ++++++++++++++++++++++++ test/runtests.jl | 3 +- 9 files changed, 137 insertions(+), 137 deletions(-) create mode 100644 test/flexiblefunctors.jl diff --git a/Project.toml b/Project.toml index 9daf15b..c21d0d2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,17 +4,17 @@ authors = ["Mike J Innes "] version = "0.4.12" [deps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] -Documenter = "1" +ConstructionBase = "1.4" julia = "1.6" [extras] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "StaticArrays", "Zygote"] +test = ["Test", "StaticArrays", "Zygote"] diff --git a/README.md b/README.md index f9aea80..bf9ddb0 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> model = Foo(1, [1, 2, 3]) Foo(1, [1, 2, 3]) @@ -41,8 +39,6 @@ julia> struct Bar x end -julia> @functor Bar - julia> model = Bar(Foo(1, [1, 2, 3])) Bar(Foo(1, [1, 2, 3])) diff --git a/docs/src/index.md b/docs/src/index.md index 276b3c7..8f0325f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,9 +8,9 @@ For large models it can be cumbersome or inefficient to work with parameters as ## Basic Usage and Implementation -When one marks a structure as [`@functor`](@ref) it means that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). +By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). -The workhorse of fmap is actually a lower level function, functor: +The workhorse of `fmap` is actually a lower level function, functor: ```julia-repl julia> using Functors @@ -20,8 +20,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> foo = Foo(1, [1, 2, 3]) # notice all the elements are integers julia> xs, re = Functors.functor(foo) @@ -50,12 +48,17 @@ julia> fmap(float, model) Baz(1.0, 2) ``` -Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default. +Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default. -## Appropriate Use +The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` +and also equivalent to avoiding `@functor` altogether. + +Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed. -!!! warning "Not everything should be a functor!" - Due to its generic nature it is very attractive to mark several structures as [`@functor`](@ref) when it may not be quite safe to do so. +!!! warning "Change to opt-out behaviour in v0.5" + Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were not functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. + +## Appropriate Use Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely. diff --git a/src/Functors.jl b/src/Functors.jl index ffcbf8f..599ee72 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -3,6 +3,7 @@ module Functors export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, fmap_with_path, fmapstructure_with_path, KeyPath, getkeypath, haskeypath, setkeypath! +using ConstructionBase: constructorof include("functor.jl") include("keypath.jl") @@ -42,8 +43,6 @@ this can be restricted be restructed by providing a tuple of field names. ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> Functors.children(Foo(1,2)) (x = 1, y = 2) @@ -52,6 +51,8 @@ julia> _, re = Functors.functor(Foo(1,2)); julia> re((10, 20)) Foo(10, 20) +julia> @functor Foo # same as before, nothing changes + julia> struct TwoThirds a; b; c; end julia> @functor TwoThirds (a, c) diff --git a/src/base.jl b/src/base.jl index 913aaa7..2691e73 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,14 +1,5 @@ -@functor Base.RefValue - -@functor Base.Pair - -@functor Base.Generator # aka Iterators.map - -@functor Base.ComposedFunction -@functor Base.Fix1 -@functor Base.Fix2 -@functor Base.Broadcast.BroadcastFunction +functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) @static if VERSION >= v"1.9" @functor Base.Splat @@ -51,26 +42,3 @@ end _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) _PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent _PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm) - -### -### Iterators -### - -@functor Iterators.Accumulate -# Count -@functor Iterators.Cycle -@functor Iterators.Drop -@functor Iterators.DropWhile -@functor Iterators.Enumerate -@functor Iterators.Filter -@functor Iterators.Flatten -# IterationCutShort -@functor Iterators.PartitionIterator -@functor Iterators.ProductIterator -@functor Iterators.Repeated -@functor Iterators.Rest -@functor Iterators.Reverse -# Stateful -@functor Iterators.Take -@functor Iterators.TakeWhile -@functor Iterators.Zip diff --git a/src/functor.jl b/src/functor.jl index 653a928..ca6b995 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -11,7 +11,19 @@ macro leaf(T) :($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x)) end -@leaf Any # every type is a leaf by default +# @leaf Any # every type is a leaf by default + +# Default functor +function functor(T, x) + names = fieldnames(T) + if isempty(names) + return NoChildren(), _ -> x + end + S = constructorof(T) # remove parameters from parametric types and support anonymous functions + vals = ntuple(i -> getfield(x, names[i]), length(names)) + return NamedTuple{names}(vals), y -> S(y...) +end + functor(x) = functor(typeof(x), x) functor(::Type{<:Tuple}, x) = x, identity @@ -30,7 +42,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f) end escfs = [:($f=x.$f) for f in fs] - + @eval m begin function $Functors.functor(::Type{<:$T}, x) reconstruct(y) = $T($(escargs...)) diff --git a/test/basics.jl b/test/basics.jl index 625e630..ced07a1 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -2,36 +2,38 @@ using Functors: functor, usecache struct Foo; x; y; end -@functor Foo Base.:(==)(x::Foo, y::Foo) = x.x == y.x && x.y == y.y struct Bar{T}; x::T; end -@functor Bar Base.:(==)(x::Bar, y::Bar) = x.x == y.x struct OneChild3; x; y; z; end @functor OneChild3 (y,) -struct NoChildren2; x; y; end +struct NoChild2; x; y; end +@functor NoChild2 () -struct NoChild{T}; x::T; end +struct NoChild1{T}; x::T; end +@functor NoChild1 () struct WrongOrder; x; y; z; end @functor WrongOrder (z, x) +struct LeafType{T}; x::T; end +@leaf LeafType ### ### Basic functionality ### -@testset "Children and Leaves" begin - no_children = NoChildren2(1, 2) +@testset "NoChild is not a leaf" begin + no_children = NoChild2(1, 2) has_children = Foo(1, 2) - @test Functors.isleaf(no_children) + @test !Functors.isleaf(no_children) @test !Functors.isleaf(has_children) - @test Functors.children(no_children) === Functors.NoChildren() + @test Functors.children(no_children) === (;) @test Functors.children(has_children) == (x=1, y=2) end @@ -108,8 +110,8 @@ end # Leaf types: @test usecache(d, [1,2]) @test !usecache(d, 4.0) - @test usecache(d, NoChild([1,2])) - @test !usecache(d, NoChild((3,4))) + @test usecache(d, LeafType([1,2])) + @test !usecache(d, LeafType((3,4))) # Not leaf: @test usecache(d, Ref(3)) # mutable container @@ -163,6 +165,17 @@ end @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) end +@testset "anonymous functions" begin + model = let W = rand(2,2), b = ones(2) + x -> tanh.(W*x .+ b) + end + newmodel = fmap(zero, model) + @test newmodel isa Function + @test newmodel([1,2]) == [0,0] + @test newmodel.W == [0 0; 0 0] + @test newmodel.b == [0, 0] +end + ### ### Extras ### @@ -185,7 +198,7 @@ end m1 = [1, 2, 3] m2 = Bar(m1) - m0 = NoChildren2(:a, :b) + m0 = NoChild2(:a, :b) m3 = Foo(m2, m0) m4 = Bar(m3) @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) @@ -299,74 +312,13 @@ end @test m̂.b ≈ fill(-0.2f0, size(m.b)) end -### -### FlexibleFunctors.jl -### - -struct FFoo - x - y - p -end -@flexiblefunctor FFoo p - -struct FBar - x - p -end -@flexiblefunctor FBar p - -struct FOneChild4 - x - y - z - p -end -@flexiblefunctor FOneChild4 p - -@testset "Flexible Nested" begin - model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) - - model′ = fmap(float, model) - - @test model.x.y == model′.x.y - @test model′.x.y isa Vector{Float64} -end - -@testset "Flexible Walk" begin - model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) - - model′ = fmapstructure(identity, model) - @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) - - model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) - - model2′ = fmapstructure(identity, model2) - @test model2′ == (; x=(0, (; x=[1, 2, 3]))) -end - -@testset "Flexible Property list" begin - model = FOneChild4(1, 2, 3, (:x, :z)) - model′ = fmap(x -> 2x, model) - - @test (model′.x, model′.y, model′.z) == (2, 2, 6) -end - -@testset "Flexible fcollect" begin - m1 = 1 - m2 = [1, 2, 3] - m3 = FFoo(m1, m2, (:y, )) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) - @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) +@testset "parametric types" begin + struct A{T} + x::T + end - m0 = NoChildren2(:a, :b) - m1 = [1, 2, 3] - m2 = FBar(m1, ()) - m3 = FFoo(m2, m0, (:x, :y,)) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + a = A(1) + @test fmap(x -> x/2, a) == A(0.5) end @testset "Dict" begin @@ -396,15 +348,13 @@ end end @testset "@leaf" begin - struct A; x; end - @functor A - a = A(1) - @test Functors.children(a) === (x = 1,) - struct B; x; end Functors.@leaf B b = B(1) children, re = Functors.functor(b) + + a = LeafType(1) + children, re = Functors.functor(a) @test children == Functors.NoChildren() @test re(children) === b end diff --git a/test/flexiblefunctors.jl b/test/flexiblefunctors.jl new file mode 100644 index 0000000..d34daf4 --- /dev/null +++ b/test/flexiblefunctors.jl @@ -0,0 +1,71 @@ + +### +### FlexibleFunctors.jl +### + +struct FFoo + x + y + p + end + @flexiblefunctor FFoo p + + struct FBar + x + p + end + @flexiblefunctor FBar p + + struct FOneChild4 + x + y + z + p + end + @flexiblefunctor FOneChild4 p + + @testset "Flexible Nested" begin + model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) + + model′ = fmap(float, model) + + @test model.x.y == model′.x.y + @test model′.x.y isa Vector{Float64} + end + + @testset "Flexible Walk" begin + model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) + + model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) + + model2′ = fmapstructure(identity, model2) + @test model2′ == (; x=(0, (; x=[1, 2, 3]))) + end + + @testset "Flexible Property list" begin + model = FOneChild4(1, 2, 3, (:x, :z)) + model′ = fmap(x -> 2x, model) + + @test (model′.x, model′.y, model′.z) == (2, 2, 6) + end + + @testset "Flexible fcollect" begin + m1 = 1 + m2 = [1, 2, 3] + m3 = FFoo(m1, m2, (:y, )) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2]) + @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) + @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) + + m0 = NoChild2(:a, :b) + m1 = [1, 2, 3] + m2 = FBar(m1, ()) + m3 = FFoo(m2, m0, (:x, :y,)) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + end + \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7db1e1b..18cc997 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,9 +4,8 @@ using LinearAlgebra using StaticArrays @testset "Functors.jl" begin - include("basics.jl") include("base.jl") include("keypath.jl") - + include("flexiblefunctors.jl") end From 66c3893bcabe9d5d2f8877bacdd346c20a3f99c2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 09:22:53 +0200 Subject: [PATCH 02/24] rebase --- src/Functors.jl | 7 +++++-- src/functor.jl | 2 -- test/basics.jl | 8 +++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index 599ee72..e332d73 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,9 +1,12 @@ module Functors -export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, +using ConstructionBase: constructorof + +export @leaf, @functor, @flexiblefunctor, + fmap, fmapstructure, fcollect, execute, fleaves, fmap_with_path, fmapstructure_with_path, KeyPath, getkeypath, haskeypath, setkeypath! -using ConstructionBase: constructorof + include("functor.jl") include("keypath.jl") diff --git a/src/functor.jl b/src/functor.jl index ca6b995..ca426d7 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -11,8 +11,6 @@ macro leaf(T) :($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x)) end -# @leaf Any # every type is a leaf by default - # Default functor function functor(T, x) names = fieldnames(T) diff --git a/test/basics.jl b/test/basics.jl index ced07a1..e8d326f 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -130,7 +130,8 @@ end end @testset "Self-referencing types" begin - @test fmap(identity, Base.ImmutableDict(:a => 42)) == Base.ImmutableDict(:a => 42) + # https://github.com/FluxML/Functors.jl/pull/72/ + @test_broken fmap(identity, Base.ImmutableDict(:a => 42)) == Base.ImmutableDict(:a => 42) end @testset "functor(typeof(x), y) from @functor" begin @@ -352,11 +353,12 @@ end Functors.@leaf B b = B(1) children, re = Functors.functor(b) + @test re(children) === b a = LeafType(1) children, re = Functors.functor(a) - @test children == Functors.NoChildren() - @test re(children) === b + @test children == Functors.NoChildren() + @test re(children) === a end @testset "IterateWalk" begin From 28854e53d7c1df34683c267fbc97af1f509c17d4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 09:39:46 +0200 Subject: [PATCH 03/24] docs --- Project.toml | 4 +++- docs/src/index.md | 11 +++++------ src/Functors.jl | 12 +++++++----- src/base.jl | 8 -------- src/keypath.jl | 3 --- 5 files changed, 15 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index c21d0d2..d5ede63 100644 --- a/Project.toml +++ b/Project.toml @@ -1,13 +1,15 @@ name = "Functors" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" authors = ["Mike J Innes "] -version = "0.4.12" +version = "0.5.0" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] +Compat = "4.16" ConstructionBase = "1.4" julia = "1.6" diff --git a/docs/src/index.md b/docs/src/index.md index 8f0325f..f0da7d9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,13 +4,13 @@ Functors.jl provides a set of tools to represent [functors](https://en.wikipedia The most straightforward use is to traverse a complicated nested structure as a tree, and apply a function `f` to every field it encounters along the way. -For large models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. +For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. ## Basic Usage and Implementation -By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). +By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`fmap`](@ref). -The workhorse of `fmap` is actually a lower level function, functor: +The workhorse of `fmap` is actually a lower level function, [`functor`](@ref): ```julia-repl julia> using Functors @@ -50,13 +50,12 @@ Baz(1.0, 2) Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default. -The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` -and also equivalent to avoiding `@functor` altogether. +The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` and also equivalent to avoiding `@functor` altogether. Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed. !!! warning "Change to opt-out behaviour in v0.5" - Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were not functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. + Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. ## Appropriate Use diff --git a/src/Functors.jl b/src/Functors.jl index e332d73..b4b5833 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,12 +1,13 @@ module Functors - +using Compat: @compat using ConstructionBase: constructorof -export @leaf, @functor, @flexiblefunctor, +export @leaf, @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, fmap_with_path, fmapstructure_with_path, KeyPath, getkeypath, haskeypath, setkeypath! +@compat(public, (isleaf, children, functor)) include("functor.jl") include("keypath.jl") @@ -20,7 +21,8 @@ include("base.jl") """ - Functors.functor(x) = functor(typeof(x), x) + functor(x) + functor(typeof(x), x) Returns a tuple containing, first, a `NamedTuple` of the children of `x` (typically its fields), and second, a reconstruction funciton. @@ -75,7 +77,7 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560) var"@functor" """ - Functors.isleaf(x) + isleaf(x) Return true if `x` has no [`children`](@ref) according to [`functor`](@ref). @@ -103,7 +105,7 @@ true isleaf """ - Functors.children(x) + children(x) Return the children of `x` as defined by [`functor`](@ref). Equivalent to `functor(x)[1]`. diff --git a/src/base.jl b/src/base.jl index 2691e73..cc877a4 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,14 +1,6 @@ functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) -@static if VERSION >= v"1.9" - @functor Base.Splat -end - -@static if VERSION >= v"1.7" - @functor Base.Returns -end - ### ### Array wrappers ### diff --git a/src/keypath.jl b/src/keypath.jl index 280ef94..f953ebe 100644 --- a/src/keypath.jl +++ b/src/keypath.jl @@ -55,9 +55,6 @@ struct KeyPath{T<:Tuple} keys::T end -@functor KeyPath -isleaf(::KeyPath, @nospecialize(x)) = isleaf(x) - function KeyPath(keys::Union{KeyT, KeyPath}...) ks = (k isa KeyPath ? (k.keys...,) : (k,) for k in keys) return KeyPath(((ks...)...,)) From 430844d88d6e0be223193f8ad7137d3f2d74bb38 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 09:44:40 +0200 Subject: [PATCH 04/24] fix --- src/keypath.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/keypath.jl b/src/keypath.jl index f953ebe..d6f0b17 100644 --- a/src/keypath.jl +++ b/src/keypath.jl @@ -55,6 +55,8 @@ struct KeyPath{T<:Tuple} keys::T end +isleaf(::KeyPath, @nospecialize(x)) = isleaf(x) + function KeyPath(keys::Union{KeyT, KeyPath}...) ks = (k isa KeyPath ? (k.keys...,) : (k,) for k in keys) return KeyPath(((ks...)...,)) From 3c76a03ff38f81d4871b8fca7d6d12c7641c6b2c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 14:36:41 +0200 Subject: [PATCH 05/24] numbers are leaves --- src/base.jl | 10 +++++++++- src/functor.jl | 7 ------- test/base.jl | 7 +++++++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/base.jl b/src/base.jl index cc877a4..33a02d8 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,5 +1,13 @@ -functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner) +@leaf Number + +functor(::Type{<:Tuple}, x) = x, identity +functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity +functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity + +functor(::Type{<:AbstractArray}, x) = x, identity +@leaf AbstractArray{<:Number} + ### ### Array wrappers diff --git a/src/functor.jl b/src/functor.jl index ca426d7..de9318e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -24,13 +24,6 @@ end functor(x) = functor(typeof(x), x) -functor(::Type{<:Tuple}, x) = x, identity -functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity -functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity - -functor(::Type{<:AbstractArray}, x) = x, identity -@leaf AbstractArray{<:Number} - function makefunctor(m::Module, T, fs = fieldnames(T)) fidx = Ref(0) escargs = map(fieldnames(T)) do f diff --git a/test/base.jl b/test/base.jl index e06d415..d41e710 100644 --- a/test/base.jl +++ b/test/base.jl @@ -1,3 +1,10 @@ +@testset "Numbers are leaves" begin + @test Functors.isleaf(1) + @test Functors.isleaf(1.0) + @test Functors.isleaf(1im) + @test Functors.isleaf(1//2) + @test Functors.isleaf(1.0 + 2.0im) +end @testset "RefValue" begin @test fmap(sqrt, Ref(16))[] == 4.0 From 0ca8b1cb8c1ab96a3e0fbaff5286cc6d8912bdad Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 14:47:14 +0200 Subject: [PATCH 06/24] docs --- src/Functors.jl | 23 +++++++---------------- src/base.jl | 1 - 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/Functors.jl b/src/Functors.jl index b4b5833..6b13373 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,6 +1,7 @@ module Functors using Compat: @compat using ConstructionBase: constructorof +using LinearAlgebra export @leaf, @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, @@ -169,20 +170,16 @@ Thus the relationship `x.i === x.iv[1]` will be preserved. An immutable object which appears twice is not stored in the cache, thus `f(34)` will be called twice, and the results will agree only if `f` is pure. -By default, `Tuple`s, `NamedTuple`s, and some other container-like types in Base have -children to recurse into. Arrays of numbers do not. -To enable recursion into new types, you must provide a method of [`functor`](@ref), -which can be done using the macro [`@functor`](@ref): +By default, almost all container-like types have children to recurse into. +Arrays of numbers do not. + +To opt out of recursion for custom types use [`@leaf`](@ref) or pass a custom `exclude` function. ```jldoctest withfoo julia> struct Foo; x; y; end -julia> @functor Foo - julia> struct Bar; x; end -julia> @functor Bar - julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7)))); julia> fmap(x -> 10x, m) @@ -247,8 +244,6 @@ See also [`fmap`](@ref) and [`fmapstructure_with_path`](@ref). ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]); julia> fmapstructure(x -> 2x, m) @@ -285,14 +280,12 @@ See also [`children`](@ref). ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> struct Bar; x; end -julia> @functor Bar - julia> struct TypeWithNoChildren; x; y; end +julia> @leaf TypeWithNoChildren + julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b)) Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b)) @@ -376,8 +369,6 @@ See also [`fcollect`](@ref) for a similar function that collects all nodes inste ```jldoctest julia> struct Bar; x; end -julia> @functor Bar - julia> struct TypeWithNoChildren; x; y; end julia> m = (a = Bar([1,2,3]), b = TypeWithNoChildren(4, 5)); diff --git a/src/base.jl b/src/base.jl index 33a02d8..247288e 100644 --- a/src/base.jl +++ b/src/base.jl @@ -13,7 +13,6 @@ functor(::Type{<:AbstractArray}, x) = x, identity ### Array wrappers ### -using LinearAlgebra # The reason for these is to let W and W' be seen as tied weights in Flux models. # Can't treat ReshapedArray very well, as its type doesn't include enough details for reconstruction. From db961706e7f5484ab6394172bea44f2690e6ae22 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 14:53:17 +0200 Subject: [PATCH 07/24] docs --- docs/src/index.md | 2 +- src/Functors.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index f0da7d9..ac941c5 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -55,7 +55,7 @@ The use of `@functor` with no fields argument as in `@functor Baz` is equivalent Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed. !!! warning "Change to opt-out behaviour in v0.5" - Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. + Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. ## Appropriate Use diff --git a/src/Functors.jl b/src/Functors.jl index 6b13373..db3c731 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -371,6 +371,8 @@ julia> struct Bar; x; end julia> struct TypeWithNoChildren; x; y; end +julia> @leaf TypeWithNoChildren + julia> m = (a = Bar([1,2,3]), b = TypeWithNoChildren(4, 5)); julia> fleaves(m) From 88ffc38380a11fe0e709a43631092dfb99a83bc3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 15:00:31 +0200 Subject: [PATCH 08/24] docs --- docs/src/index.md | 6 ++++-- src/Functors.jl | 14 +++++++------- src/functor.jl | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index ac941c5..d93066f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -8,9 +8,11 @@ For large machine learning models it can be cumbersome or inefficient to work wi ## Basic Usage and Implementation -By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`fmap`](@ref). +By default, julia types are marked as [`@functor`](@ref Functors.functor)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`fmap`](@ref). To opt-out of this behaviour, use [`@leaf`](@ref) on your custom type. -The workhorse of `fmap` is actually a lower level function, [`functor`](@ref): +```julia-repl + +The workhorse of `fmap` is actually a lower level function, [`functor`](@ref Functors.functor): ```julia-repl julia> using Functors diff --git a/src/Functors.jl b/src/Functors.jl index db3c731..be39349 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -38,7 +38,7 @@ functor @functor T @functor T (x,) -Adds methods to [`functor`](@ref) allowing recursion into objects of type `T`, +Adds methods to [`functor`](@ref Functors.functor) allowing recursion into objects of type `T`, and reconstruction. Assumes that `T` has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not. @@ -80,7 +80,7 @@ var"@functor" """ isleaf(x) -Return true if `x` has no [`children`](@ref) according to [`functor`](@ref). +Return true if `x` has no [`children`](@ref Functors.children) according to [`functor`](@ref Functors.functor). # Examples ```jldoctest @@ -108,7 +108,7 @@ isleaf """ children(x) -Return the children of `x` as defined by [`functor`](@ref). +Return the children of `x` as defined by [`functor`](@ref Functors.functor). Equivalent to `functor(x)[1]`. """ children @@ -118,8 +118,8 @@ children A structure and type preserving `map`. -By default it transforms every leaf node (identified by `exclude`, default [`isleaf`](@ref)) -by applying `f`, and otherwise traverses `x` recursively using [`functor`](@ref). +By default it transforms every leaf node (identified by `exclude`, default [`isleaf`](@ref Functors.isleaf)) +by applying `f`, and otherwise traverses `x` recursively using [`functor`](@ref Functors.functor). Optionally, it may also be associated with objects `ys` with the same tree structure. In that case, `f` is applied to the corresponding leaf nodes in `x` and `ys`. @@ -264,7 +264,7 @@ fmapstructure """ fcollect(x; exclude = v -> false) -Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) +Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref Functors.functor) and collecting the results into a flat array, ordered by a breadth-first traversal of `x`, respecting the iteration order of `children` calls. @@ -355,7 +355,7 @@ fmapstructure_with_path """ fleaves(x; exclude = isleaf) -Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) +Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref Functors.functor) and collecting the leaves into a flat array, ordered by a breadth-first traversal of `x`, respecting the iteration order of `children` calls. diff --git a/src/functor.jl b/src/functor.jl index de9318e..a770652 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -5,7 +5,7 @@ const NoChildren = Tuple{} """ @leaf T -Define [`functor`](@ref) for the type `T` so that `isleaf(x::T) == true`. +Define [`functor`](@ref Functors.functor) for the type `T` so that `isleaf(x::T) == true`. """ macro leaf(T) :($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x)) From a98f9c6d9da4d2cd5538a7f3e2a3acb2b2317f32 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 15:40:45 +0200 Subject: [PATCH 09/24] opt out of abstractdicts --- src/base.jl | 12 ++++++++++-- test/base.jl | 5 +++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/base.jl b/src/base.jl index 247288e..0a8d094 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,12 +1,20 @@ +## Opt-Out @leaf Number +@leaf AbstractArray{<:Number} +@leaf AbstractDict # We are conservative here + # since most probably default functor does the wrong thing +## Fast Paths for common types functor(::Type{<:Tuple}, x) = x, identity functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity - functor(::Type{<:AbstractArray}, x) = x, identity -@leaf AbstractArray{<:Number} + +# TODO: evaluate if this is a good idea +# function functor(::Type{<:D}, x) where {D<:AbstractDict} +# return constructorof(D)(k => x[k] for k in keys(x)), identity +# end ### diff --git a/test/base.jl b/test/base.jl index d41e710..be90f3d 100644 --- a/test/base.jl +++ b/test/base.jl @@ -179,3 +179,8 @@ end @test x.is[1] isa Vector{<:Complex} @test collect(x) isa Vector{<:Tuple{Complex, Complex}} end + +@testset "AbstractDict is leaf" begin + struct DummyDict <: AbstractDict end + @test Functors.isleaf(DummyDict()) +end \ No newline at end of file From 11c1ae151e71b63eb0c2de6bfd4528eda189672e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 15:46:01 +0200 Subject: [PATCH 10/24] fix --- test/base.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/base.jl b/test/base.jl index be90f3d..1191a90 100644 --- a/test/base.jl +++ b/test/base.jl @@ -181,6 +181,6 @@ end end @testset "AbstractDict is leaf" begin - struct DummyDict <: AbstractDict end - @test Functors.isleaf(DummyDict()) + struct DummyDict{K,V} <: AbstractDict{K,V} end + @test Functors.isleaf(DummyDict{Int,Int}()) end \ No newline at end of file From 7d60887b02469d20ab98043922e07c0dbc8af56c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 16:41:11 +0200 Subject: [PATCH 11/24] abstract dictor --- docs/src/index.md | 9 +++++++++ src/base.jl | 12 ++++++------ test/base.jl | 8 ++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index d93066f..318cd11 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -59,6 +59,15 @@ Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a !!! warning "Change to opt-out behaviour in v0.5" Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. +## Which types are leaves? + +By default all composite types in are functors and can be traversed, unless marked with [`@leaf`](@ref). + +The following types instead are explicitly marked as leaves in Functors.jl: +- `Number` +- `AbstractArray{<:Number}` +- `AbstractString` + ## Appropriate Use Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely. diff --git a/src/base.jl b/src/base.jl index 0a8d094..bc78f8e 100644 --- a/src/base.jl +++ b/src/base.jl @@ -2,8 +2,7 @@ @leaf Number @leaf AbstractArray{<:Number} -@leaf AbstractDict # We are conservative here - # since most probably default functor does the wrong thing +@leaf AbstractString ## Fast Paths for common types functor(::Type{<:Tuple}, x) = x, identity @@ -11,10 +10,11 @@ functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity functor(::Type{<:AbstractArray}, x) = x, identity -# TODO: evaluate if this is a good idea -# function functor(::Type{<:D}, x) where {D<:AbstractDict} -# return constructorof(D)(k => x[k] for k in keys(x)), identity -# end +## This may be a reasonable default for AbstractDict +## but is not guaranteed to be correct for all dict subtypes +function functor(::Type{D}, x) where {D<:AbstractDict} + return constructorof(D)(k => x[k] for k in keys(x)), identity +end ### diff --git a/test/base.jl b/test/base.jl index 1191a90..698b825 100644 --- a/test/base.jl +++ b/test/base.jl @@ -180,7 +180,7 @@ end @test collect(x) isa Vector{<:Tuple{Complex, Complex}} end -@testset "AbstractDict is leaf" begin - struct DummyDict{K,V} <: AbstractDict{K,V} end - @test Functors.isleaf(DummyDict{Int,Int}()) -end \ No newline at end of file +# @testset "AbstractDict is leaf" begin +# struct DummyDict{K,V} <: AbstractDict{K,V} end +# @test Functors.isleaf(DummyDict{Int,Int}()) +# end \ No newline at end of file From 628550707d27596d99980e25ba6e06d70eca095f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 21 Oct 2024 17:42:46 +0200 Subject: [PATCH 12/24] docs and tests --- Project.toml | 4 +++- docs/src/index.md | 11 ++++++++--- src/walks.jl | 5 +++-- test/base.jl | 20 ++++++++++++++++---- test/runtests.jl | 1 + 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index d5ede63..ff52cb7 100644 --- a/Project.toml +++ b/Project.toml @@ -11,12 +11,14 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] Compat = "4.16" ConstructionBase = "1.4" +OrderedCollections = "1.6" julia = "1.6" [extras] +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "OrderedCollections", "StaticArrays", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 318cd11..4ed2a40 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -68,8 +68,13 @@ The following types instead are explicitly marked as leaves in Functors.jl: - `AbstractArray{<:Number}` - `AbstractString` -## Appropriate Use +This is because in typical application the internals of these are abstracted away and it is not desirable to traverse them. -Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely. +## What if I get an error? -Examples of this include element types of arrays which typically have their own mathematical operations defined. Adding a [`@functor`](@ref) to such a type would end up missing methods such as `+(::MyElementType, ::MyElementType)`. Think `RGB` from Colors.jl. +Since by default Funcotrs.jl tries to traverse most types e.g. when using [`fmap`](@ref), it is possible it fails in case the type has not an appropriate constructor. If use experience this issue, you have a few alternatives: +- Mark the type as a leaf using [`@leaf`](@ref) +- Use the `@functor` macro to specify which fields to traverse. +- Define an appropriate constructor for the type. + +If you are not able to traverse types in julia Base, please open an issue. diff --git a/src/walks.jl b/src/walks.jl index 334cce8..928cc3f 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -9,11 +9,12 @@ function check_lenghts(x, ys...) end _map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x) +_map(f, x::D, ys...) where {D<:AbstractDict} = constructorof(D)(k => f(v, (y[k] for y in ys)...) for (k, v) in x) _values(x) = x -_values(x::Dict) = values(x) +_values(x::AbstractDict) = values(x) -_keys(x::Dict) = Dict(k => k for k in keys(x)) +_keys(x::D) where {D <: AbstractDict} = constructorof(D)(k => k for k in keys(x)) _keys(x::Tuple) = (keys(x)...,) _keys(x::AbstractArray) = collect(keys(x)) _keys(x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(Ks) diff --git a/test/base.jl b/test/base.jl index 698b825..78078e3 100644 --- a/test/base.jl +++ b/test/base.jl @@ -180,7 +180,19 @@ end @test collect(x) isa Vector{<:Tuple{Complex, Complex}} end -# @testset "AbstractDict is leaf" begin -# struct DummyDict{K,V} <: AbstractDict{K,V} end -# @test Functors.isleaf(DummyDict{Int,Int}()) -# end \ No newline at end of file +@testset "AbstractString is leaf" begin + struct DummyString <: AbstractString + str::String + end + s = DummyString("hello") + @test Functors.isleaf(s) +end + +@testset "AbstractDict is functor" begin + od = OrderedDict(1 => 1, 2 => 2) + @test !Functors.isleaf(od) + od2 = fmap(x -> 2x, od) + @test od2 isa OrderedDict + @test od2[1] == 2 + @test od2[2] == 4 +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 18cc997..65658f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Functors, Test using Zygote using LinearAlgebra using StaticArrays +using OrderedCollections: OrderedDict @testset "Functors.jl" begin include("basics.jl") From c43f2ea8c24c8be4913721aac30b7fe1502c1126 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 22 Oct 2024 00:09:02 +0200 Subject: [PATCH 13/24] docs --- docs/src/index.md | 8 ++++---- src/Functors.jl | 2 +- src/base.jl | 2 +- src/walks.jl | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4ed2a40..805e9bf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -64,15 +64,15 @@ Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a By default all composite types in are functors and can be traversed, unless marked with [`@leaf`](@ref). The following types instead are explicitly marked as leaves in Functors.jl: -- `Number` -- `AbstractArray{<:Number}` -- `AbstractString` +- `Number`. +- `AbstractArray{<:Number}`, except for the wrappers `Transpose`, `Adjoint`, and `PermutedDimsArray`. +- `AbstractString`. This is because in typical application the internals of these are abstracted away and it is not desirable to traverse them. ## What if I get an error? -Since by default Funcotrs.jl tries to traverse most types e.g. when using [`fmap`](@ref), it is possible it fails in case the type has not an appropriate constructor. If use experience this issue, you have a few alternatives: +Since by default Functors.jl tries to traverse most types e.g. when using [`fmap`](@ref), it is possible it fails in case the type has not an appropriate constructor. If use experience this issue, you have a few alternatives: - Mark the type as a leaf using [`@leaf`](@ref) - Use the `@functor` macro to specify which fields to traverse. - Define an appropriate constructor for the type. diff --git a/src/Functors.jl b/src/Functors.jl index be39349..f99c665 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -26,7 +26,7 @@ include("base.jl") functor(typeof(x), x) Returns a tuple containing, first, a `NamedTuple` of the children of `x` -(typically its fields), and second, a reconstruction funciton. +(typically its fields), and second, a reconstruction function. This controls the behaviour of [`fmap`](@ref). Methods should be added to `functor(::Type{T}, x)` for custom types, diff --git a/src/base.jl b/src/base.jl index bc78f8e..c94916d 100644 --- a/src/base.jl +++ b/src/base.jl @@ -43,7 +43,7 @@ function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x) where {T,N,perm (parent = _PermutedDimsArray(x, iperm),), y -> PermutedDimsArray(only(y), perm) end function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x::PermutedDimsArray{Tx,N,perm,iperm}) where {T,Tx,N,perm,iperm} - (parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping wrice. + (parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping twice. end _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) diff --git a/src/walks.jl b/src/walks.jl index 928cc3f..856b622 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -9,7 +9,8 @@ function check_lenghts(x, ys...) end _map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x) -_map(f, x::D, ys...) where {D<:AbstractDict} = constructorof(D)(k => f(v, (y[k] for y in ys)...) for (k, v) in x) +_map(f, x::D, ys...) where {D<:AbstractDict} = + constructorof(D)(k => f(v, (y[k] for y in ys)...) for (k, v) in x) _values(x) = x _values(x::AbstractDict) = values(x) From be05d81f7ee094003c47aaea10956f41b08b96e2 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 26 Oct 2024 11:50:38 +0200 Subject: [PATCH 14/24] don't bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ff52cb7..21fa53d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Functors" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" authors = ["Mike J Innes "] -version = "0.5.0" +version = "0.4.12" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 5cd237d6c959ebb28f41e816c1239ae5503c27dd Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 26 Oct 2024 12:49:02 +0200 Subject: [PATCH 15/24] selfref --- Project.toml | 4 +++- src/base.jl | 3 +-- src/walks.jl | 2 +- test/basics.jl | 4 +++- test/runtests.jl | 1 + 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 21fa53d..3d47588 100644 --- a/Project.toml +++ b/Project.toml @@ -11,14 +11,16 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] Compat = "4.16" ConstructionBase = "1.4" +Measurements = "2" OrderedCollections = "1.6" julia = "1.6" [extras] +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "OrderedCollections", "StaticArrays", "Zygote"] +test = ["Test", "OrderedCollections", "StaticArrays", "Zygote", "Measurements"] diff --git a/src/base.jl b/src/base.jl index c94916d..07d87fb 100644 --- a/src/base.jl +++ b/src/base.jl @@ -13,10 +13,9 @@ functor(::Type{<:AbstractArray}, x) = x, identity ## This may be a reasonable default for AbstractDict ## but is not guaranteed to be correct for all dict subtypes function functor(::Type{D}, x) where {D<:AbstractDict} - return constructorof(D)(k => x[k] for k in keys(x)), identity + return constructorof(D)([k => x[k] for k in keys(x)]...), identity end - ### ### Array wrappers ### diff --git a/src/walks.jl b/src/walks.jl index 856b622..2426a88 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -10,7 +10,7 @@ end _map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x) _map(f, x::D, ys...) where {D<:AbstractDict} = - constructorof(D)(k => f(v, (y[k] for y in ys)...) for (k, v) in x) + constructorof(D)([k => f(v, (y[k] for y in ys)...) for (k, v) in x]...) _values(x) = x _values(x::AbstractDict) = values(x) diff --git a/test/basics.jl b/test/basics.jl index e8d326f..36b0a80 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -131,7 +131,9 @@ end @testset "Self-referencing types" begin # https://github.com/FluxML/Functors.jl/pull/72/ - @test_broken fmap(identity, Base.ImmutableDict(:a => 42)) == Base.ImmutableDict(:a => 42) + @test fmap(identity, Base.ImmutableDict(:a => 42)) == Base.ImmutableDict(:a => 42) + nt = fmap(x -> 2x, (; a = 1 ± 0.1, b = 2 ± 0.2)) + @test nt == (; a = 2 ± 0.2, b = 4 ± 0.4) end @testset "functor(typeof(x), y) from @functor" begin diff --git a/test/runtests.jl b/test/runtests.jl index 65658f1..151ebc4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Zygote using LinearAlgebra using StaticArrays using OrderedCollections: OrderedDict +import Measurements # for ± @testset "Functors.jl" begin include("basics.jl") From 6e9e93d4179c46bd822461a2ffab41d296bad52d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 07:55:10 +0100 Subject: [PATCH 16/24] add benchmarks --- benchmarks/Project.toml | 9 +++++++ benchmarks/benchmarks.jl | 55 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 benchmarks/Project.toml create mode 100644 benchmarks/benchmarks.jl diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml new file mode 100644 index 0000000..cd75249 --- /dev/null +++ b/benchmarks/Project.toml @@ -0,0 +1,9 @@ +[deps] +AirspeedVelocity = "1c8270ee-6884-45cc-9545-60fa71ec23e4" +BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl new file mode 100644 index 0000000..9ec3d97 --- /dev/null +++ b/benchmarks/benchmarks.jl @@ -0,0 +1,55 @@ +using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge +using ConcreteStructs: @concrete +using Flux: Dense, Chain +using LinearAlgebra: BLAS +using Functors +using Statistics: median + +const SUITE = BenchmarkGroup() +const BENCHMARK_CPU_THREADS = Threads.nthreads() +BLAS.set_num_threads(BENCHMARK_CPU_THREADS) + + +@concrete struct A + w + b + σ +end + +struct B + w + b + σ +end + +function setup_fmap_bench!(suite) + a = A(rand(5,5), rand(5), tanh) + suite["fmap"]["concrete struct"] = @benchmarkable fmap(identity, $a) + + a = B(rand(5,5), rand(5), tanh) + suite["fmap"]["non-concrete struct"] = @benchmarkable fmap(identity, $a) + + a = Dense(5, 5, tanh) + suite["fmap"]["flux dense"] = @benchmarkable fmap(identity, $a) + + a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh)) + suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a) + + return suite +end + +setup_fmap_bench!(SUITE) + +# results = BenchmarkTools.run(SUITE; verbose=true) + +# filename = joinpath(@__DIR__, "benchmarks_old.json") +# BenchmarkTools.save(filename, median(results)) + + +# # Plot +# using StatsPlots, BenchmarkTools +# plot(results["fmap"], yaxis=:log10, st=:violin) + +# # Compare +# old_results = BenchmarkTools.load("benchmarks.json")[1] +# judge(median(results), old_results) From 5ed2f00afa2430e0458f9229a841035599dfb524 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 07:57:49 +0100 Subject: [PATCH 17/24] benchamark --- benchmark/Project.toml | 9 +++++++ benchmark/benchmarks.jl | 55 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 benchmark/Project.toml create mode 100644 benchmark/benchmarks.jl diff --git a/benchmark/Project.toml b/benchmark/Project.toml new file mode 100644 index 0000000..cd75249 --- /dev/null +++ b/benchmark/Project.toml @@ -0,0 +1,9 @@ +[deps] +AirspeedVelocity = "1c8270ee-6884-45cc-9545-60fa71ec23e4" +BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl new file mode 100644 index 0000000..9ec3d97 --- /dev/null +++ b/benchmark/benchmarks.jl @@ -0,0 +1,55 @@ +using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge +using ConcreteStructs: @concrete +using Flux: Dense, Chain +using LinearAlgebra: BLAS +using Functors +using Statistics: median + +const SUITE = BenchmarkGroup() +const BENCHMARK_CPU_THREADS = Threads.nthreads() +BLAS.set_num_threads(BENCHMARK_CPU_THREADS) + + +@concrete struct A + w + b + σ +end + +struct B + w + b + σ +end + +function setup_fmap_bench!(suite) + a = A(rand(5,5), rand(5), tanh) + suite["fmap"]["concrete struct"] = @benchmarkable fmap(identity, $a) + + a = B(rand(5,5), rand(5), tanh) + suite["fmap"]["non-concrete struct"] = @benchmarkable fmap(identity, $a) + + a = Dense(5, 5, tanh) + suite["fmap"]["flux dense"] = @benchmarkable fmap(identity, $a) + + a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh)) + suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a) + + return suite +end + +setup_fmap_bench!(SUITE) + +# results = BenchmarkTools.run(SUITE; verbose=true) + +# filename = joinpath(@__DIR__, "benchmarks_old.json") +# BenchmarkTools.save(filename, median(results)) + + +# # Plot +# using StatsPlots, BenchmarkTools +# plot(results["fmap"], yaxis=:log10, st=:violin) + +# # Compare +# old_results = BenchmarkTools.load("benchmarks.json")[1] +# judge(median(results), old_results) From c1faa24e082249263abf1eade855897ca144780d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 08:08:24 +0100 Subject: [PATCH 18/24] cleanup benchmark --- benchmark/benchmarks.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 9ec3d97..b6c59f6 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,3 +1,13 @@ +# We run the benchmarks using AirspeedVelocity.jl + +# To run benchmarks locally, first install AirspeedVelocity.jl: +# julia> using Pkg; Pkg.add("AirspeedVelocity"); Pkg.build("AirspeedVelocity") +# and make sure .julia/bin is in your PATH. + +# Then commit the changes and run: +# $ benchpkg Functors --rev=master,mybranch --bench-on=mybranch + + using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge using ConcreteStructs: @concrete using Flux: Dense, Chain @@ -40,16 +50,5 @@ end setup_fmap_bench!(SUITE) +## AirspeedVelocity.jl will automatically run the benchmarks and save the results # results = BenchmarkTools.run(SUITE; verbose=true) - -# filename = joinpath(@__DIR__, "benchmarks_old.json") -# BenchmarkTools.save(filename, median(results)) - - -# # Plot -# using StatsPlots, BenchmarkTools -# plot(results["fmap"], yaxis=:log10, st=:violin) - -# # Compare -# old_results = BenchmarkTools.load("benchmarks.json")[1] -# judge(median(results), old_results) From 2c9ee309e98c21c774ef96465269fc7099c6c433 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 08:08:52 +0100 Subject: [PATCH 19/24] gitignore --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 2cbd804..ed0d812 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ Manifest.toml build .vscode +benchmarks*.json +results*.json +*.tmp + From 8296b4de2fe558c21bee5e4d7119fa4cdf02240a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 08:18:24 +0100 Subject: [PATCH 20/24] fix tests --- benchmarks/Project.toml | 9 ------- benchmarks/benchmarks.jl | 55 ---------------------------------------- test/runtests.jl | 2 +- 3 files changed, 1 insertion(+), 65 deletions(-) delete mode 100644 benchmarks/Project.toml delete mode 100644 benchmarks/benchmarks.jl diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml deleted file mode 100644 index cd75249..0000000 --- a/benchmarks/Project.toml +++ /dev/null @@ -1,9 +0,0 @@ -[deps] -AirspeedVelocity = "1c8270ee-6884-45cc-9545-60fa71ec23e4" -BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl deleted file mode 100644 index 9ec3d97..0000000 --- a/benchmarks/benchmarks.jl +++ /dev/null @@ -1,55 +0,0 @@ -using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge -using ConcreteStructs: @concrete -using Flux: Dense, Chain -using LinearAlgebra: BLAS -using Functors -using Statistics: median - -const SUITE = BenchmarkGroup() -const BENCHMARK_CPU_THREADS = Threads.nthreads() -BLAS.set_num_threads(BENCHMARK_CPU_THREADS) - - -@concrete struct A - w - b - σ -end - -struct B - w - b - σ -end - -function setup_fmap_bench!(suite) - a = A(rand(5,5), rand(5), tanh) - suite["fmap"]["concrete struct"] = @benchmarkable fmap(identity, $a) - - a = B(rand(5,5), rand(5), tanh) - suite["fmap"]["non-concrete struct"] = @benchmarkable fmap(identity, $a) - - a = Dense(5, 5, tanh) - suite["fmap"]["flux dense"] = @benchmarkable fmap(identity, $a) - - a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh)) - suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a) - - return suite -end - -setup_fmap_bench!(SUITE) - -# results = BenchmarkTools.run(SUITE; verbose=true) - -# filename = joinpath(@__DIR__, "benchmarks_old.json") -# BenchmarkTools.save(filename, median(results)) - - -# # Plot -# using StatsPlots, BenchmarkTools -# plot(results["fmap"], yaxis=:log10, st=:violin) - -# # Compare -# old_results = BenchmarkTools.load("benchmarks.json")[1] -# judge(median(results), old_results) diff --git a/test/runtests.jl b/test/runtests.jl index 151ebc4..e0c5970 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using Zygote using LinearAlgebra using StaticArrays using OrderedCollections: OrderedDict -import Measurements # for ± +using Measurements: ± @testset "Functors.jl" begin include("basics.jl") From 57aee9420723b9a63424a3b11f9d54ebdf62f782 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 10:06:34 +0100 Subject: [PATCH 21/24] add nt bench --- benchmark/benchmarks.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index b6c59f6..1b57ce7 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -5,7 +5,7 @@ # and make sure .julia/bin is in your PATH. # Then commit the changes and run: -# $ benchpkg Functors --rev=master,mybranch --bench-on=mybranch +# $ benchpkg Functors --rev=mybranch,master --bench-on=mybranch using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge @@ -45,6 +45,9 @@ function setup_fmap_bench!(suite) a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh)) suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a) + nt = (layers=(w= rand(5,5), b=rand(5), σ=tanh),) + suite["fmap"]["named tuple"] = @benchmarkable fmap(identity, $nt) + return suite end From 5381ed286bb85760bf7a848621e47e0ee9d9609a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 19:37:10 +0100 Subject: [PATCH 22/24] update readme --- README.md | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index bf9ddb0..bae2e86 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ [action-img]: https://github.com/FluxML/Functors.jl/workflows/CI/badge.svg [action-url]: https://github.com/FluxML/Functors.jl/actions -Functors.jl provides tools to express a powerful design pattern for dealing with large/ nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. +Functors.jl provides tools to express a powerful design pattern for dealing with large / nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. + +## Basic Usage Functors.jl provides `fmap` to make those things easy, acting as a 'map over parameters': @@ -46,17 +48,25 @@ julia> fmap(float, model) Bar(Foo(1.0, [1.0, 2.0, 3.0])) ``` +> [!NOTE] +> Up to to v0.4 Functors.jl's functionality on custom type had to be opted in via the `@functor Foo` macro call. +> With v0.5, this no longer necessary: by default any type is recursively traversed up to the leaves +> and `ConstructionBase.constructorof` is used to reconstruct the type. +> In order to opt-out of this behaviour and make a type non traversable you can use `@leaf Foo`. + +## Further Details + The workhorse of `fmap` is actually a lower level function, `functor`: ```julia -julia> xs, re = functor(Foo(1, [1, 2, 3])) -((x = 1, y = [1, 2, 3]), var"#21#22"()) +julia> children, reconstruct = Functors.functor(Foo(1, [1, 2, 3])) +((x = 1, y = [1, 2, 3]), Functors.var"#3#6"{DataType}(Foo)) -julia> re(map(float, xs)) +julia> reconstruct(map(float, children)) Foo(1.0, [1.0, 2.0, 3.0]) ``` -`functor` returns the parts of the object that can be inspected, as well as a `re` function that takes those values and restructures them back into an object of the original type. +`functor` returns the parts of the object that can be inspected, as well as a `reconstruct` function that takes those values and restructures them back into an object of the original type. To include only certain fields, pass a tuple of field names to `@functor`: From 40810437b33f3086ff9019b9be47c1ef388c28ff Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 28 Oct 2024 19:53:30 +0100 Subject: [PATCH 23/24] cleanup readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index bae2e86..ba7a687 100644 --- a/README.md +++ b/README.md @@ -49,9 +49,9 @@ Bar(Foo(1.0, [1.0, 2.0, 3.0])) ``` > [!NOTE] -> Up to to v0.4 Functors.jl's functionality on custom type had to be opted in via the `@functor Foo` macro call. -> With v0.5, this no longer necessary: by default any type is recursively traversed up to the leaves -> and `ConstructionBase.constructorof` is used to reconstruct the type. +> Up to to v0.4, Functors.jl's functionality had to be opted in on custom types via the `@functor Foo` macro call. +> With v0.5 instead, this is no longer necessary: by default any type is recursively traversed up to the leaves +> and `ConstructionBase.constructorof` is used to reconstruct it. > In order to opt-out of this behaviour and make a type non traversable you can use `@leaf Foo`. ## Further Details From 229691ad724e00cfcf2f819ea9c847c7818fabbf Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 30 Oct 2024 09:54:55 +0100 Subject: [PATCH 24/24] fix Base.Fix on nightly --- src/base.jl | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/base.jl b/src/base.jl index 07d87fb..63d8ee5 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,21 +1,35 @@ -## Opt-Out +### +### Opt-Out +### @leaf Number @leaf AbstractArray{<:Number} @leaf AbstractString -## Fast Paths for common types +### +### Fast Paths for common types +### + functor(::Type{<:Tuple}, x) = x, identity functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity functor(::Type{<:AbstractArray}, x) = x, identity -## This may be a reasonable default for AbstractDict -## but is not guaranteed to be correct for all dict subtypes +# This may be a reasonable default for AbstractDict +# but is not guaranteed to be correct for all dict subtypes function functor(::Type{D}, x) where {D<:AbstractDict} return constructorof(D)([k => x[k] for k in keys(x)]...), identity end +### +### Base Types requiring special handling +### + +@static if VERSION >= v"1.12-DEV" + functor(::Type{<:Base.Fix{N}}, x) where N = (; x.f, x.x), y -> Base.Fix{N}(y.f, y.x) +end + + ### ### Array wrappers ###