From 100291ae60177b29e23949363d78e852e5a2d1f5 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 1 Nov 2024 09:49:00 +0100 Subject: [PATCH] functor by default (#51) --- .gitignore | 4 ++ Project.toml | 12 +++- README.md | 24 +++++--- benchmark/Project.toml | 9 +++ benchmark/benchmarks.jl | 57 ++++++++++++++++++ docs/src/index.md | 40 +++++++++---- src/Functors.jl | 59 ++++++++++--------- src/base.jl | 62 ++++++++------------ src/functor.jl | 23 ++++---- src/keypath.jl | 1 - src/walks.jl | 6 +- test/base.jl | 24 ++++++++ test/basics.jl | 122 ++++++++++++--------------------------- test/flexiblefunctors.jl | 71 +++++++++++++++++++++++ test/runtests.jl | 5 +- 15 files changed, 330 insertions(+), 189 deletions(-) create mode 100644 benchmark/Project.toml create mode 100644 benchmark/benchmarks.jl create mode 100644 test/flexiblefunctors.jl diff --git a/.gitignore b/.gitignore index 2cbd804..ed0d812 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ Manifest.toml build .vscode +benchmarks*.json +results*.json +*.tmp + diff --git a/Project.toml b/Project.toml index 9daf15b..3d47588 100644 --- a/Project.toml +++ b/Project.toml @@ -4,17 +4,23 @@ authors = ["Mike J Innes "] version = "0.4.12" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [compat] -Documenter = "1" +Compat = "4.16" +ConstructionBase = "1.4" +Measurements = "2" +OrderedCollections = "1.6" julia = "1.6" [extras] -Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "Documenter", "StaticArrays", "Zygote"] +test = ["Test", "OrderedCollections", "StaticArrays", "Zygote", "Measurements"] diff --git a/README.md b/README.md index f9aea80..ba7a687 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ [action-img]: https://github.com/FluxML/Functors.jl/workflows/CI/badge.svg [action-url]: https://github.com/FluxML/Functors.jl/actions -Functors.jl provides tools to express a powerful design pattern for dealing with large/ nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. +Functors.jl provides tools to express a powerful design pattern for dealing with large / nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. + +## Basic Usage Functors.jl provides `fmap` to make those things easy, acting as a 'map over parameters': @@ -25,8 +27,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> model = Foo(1, [1, 2, 3]) Foo(1, [1, 2, 3]) @@ -41,8 +41,6 @@ julia> struct Bar x end -julia> @functor Bar - julia> model = Bar(Foo(1, [1, 2, 3])) Bar(Foo(1, [1, 2, 3])) @@ -50,17 +48,25 @@ julia> fmap(float, model) Bar(Foo(1.0, [1.0, 2.0, 3.0])) ``` +> [!NOTE] +> Up to to v0.4, Functors.jl's functionality had to be opted in on custom types via the `@functor Foo` macro call. +> With v0.5 instead, this is no longer necessary: by default any type is recursively traversed up to the leaves +> and `ConstructionBase.constructorof` is used to reconstruct it. +> In order to opt-out of this behaviour and make a type non traversable you can use `@leaf Foo`. + +## Further Details + The workhorse of `fmap` is actually a lower level function, `functor`: ```julia -julia> xs, re = functor(Foo(1, [1, 2, 3])) -((x = 1, y = [1, 2, 3]), var"#21#22"()) +julia> children, reconstruct = Functors.functor(Foo(1, [1, 2, 3])) +((x = 1, y = [1, 2, 3]), Functors.var"#3#6"{DataType}(Foo)) -julia> re(map(float, xs)) +julia> reconstruct(map(float, children)) Foo(1.0, [1.0, 2.0, 3.0]) ``` -`functor` returns the parts of the object that can be inspected, as well as a `re` function that takes those values and restructures them back into an object of the original type. +`functor` returns the parts of the object that can be inspected, as well as a `reconstruct` function that takes those values and restructures them back into an object of the original type. To include only certain fields, pass a tuple of field names to `@functor`: diff --git a/benchmark/Project.toml b/benchmark/Project.toml new file mode 100644 index 0000000..cd75249 --- /dev/null +++ b/benchmark/Project.toml @@ -0,0 +1,9 @@ +[deps] +AirspeedVelocity = "1c8270ee-6884-45cc-9545-60fa71ec23e4" +BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl new file mode 100644 index 0000000..1b57ce7 --- /dev/null +++ b/benchmark/benchmarks.jl @@ -0,0 +1,57 @@ +# We run the benchmarks using AirspeedVelocity.jl + +# To run benchmarks locally, first install AirspeedVelocity.jl: +# julia> using Pkg; Pkg.add("AirspeedVelocity"); Pkg.build("AirspeedVelocity") +# and make sure .julia/bin is in your PATH. + +# Then commit the changes and run: +# $ benchpkg Functors --rev=mybranch,master --bench-on=mybranch + + +using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge +using ConcreteStructs: @concrete +using Flux: Dense, Chain +using LinearAlgebra: BLAS +using Functors +using Statistics: median + +const SUITE = BenchmarkGroup() +const BENCHMARK_CPU_THREADS = Threads.nthreads() +BLAS.set_num_threads(BENCHMARK_CPU_THREADS) + + +@concrete struct A + w + b + σ +end + +struct B + w + b + σ +end + +function setup_fmap_bench!(suite) + a = A(rand(5,5), rand(5), tanh) + suite["fmap"]["concrete struct"] = @benchmarkable fmap(identity, $a) + + a = B(rand(5,5), rand(5), tanh) + suite["fmap"]["non-concrete struct"] = @benchmarkable fmap(identity, $a) + + a = Dense(5, 5, tanh) + suite["fmap"]["flux dense"] = @benchmarkable fmap(identity, $a) + + a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh)) + suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a) + + nt = (layers=(w= rand(5,5), b=rand(5), σ=tanh),) + suite["fmap"]["named tuple"] = @benchmarkable fmap(identity, $nt) + + return suite +end + +setup_fmap_bench!(SUITE) + +## AirspeedVelocity.jl will automatically run the benchmarks and save the results +# results = BenchmarkTools.run(SUITE; verbose=true) diff --git a/docs/src/index.md b/docs/src/index.md index 276b3c7..805e9bf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,13 +4,15 @@ Functors.jl provides a set of tools to represent [functors](https://en.wikipedia The most straightforward use is to traverse a complicated nested structure as a tree, and apply a function `f` to every field it encounters along the way. -For large models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. +For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step. ## Basic Usage and Implementation -When one marks a structure as [`@functor`](@ref) it means that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref). +By default, julia types are marked as [`@functor`](@ref Functors.functor)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`fmap`](@ref). To opt-out of this behaviour, use [`@leaf`](@ref) on your custom type. -The workhorse of fmap is actually a lower level function, functor: +```julia-repl + +The workhorse of `fmap` is actually a lower level function, [`functor`](@ref Functors.functor): ```julia-repl julia> using Functors @@ -20,8 +22,6 @@ julia> struct Foo y end -julia> @functor Foo - julia> foo = Foo(1, [1, 2, 3]) # notice all the elements are integers julia> xs, re = Functors.functor(foo) @@ -50,13 +50,31 @@ julia> fmap(float, model) Baz(1.0, 2) ``` -Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default. +Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default. + +The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` and also equivalent to avoiding `@functor` altogether. + +Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed. + +!!! warning "Change to opt-out behaviour in v0.5" + Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`. + +## Which types are leaves? + +By default all composite types in are functors and can be traversed, unless marked with [`@leaf`](@ref). + +The following types instead are explicitly marked as leaves in Functors.jl: +- `Number`. +- `AbstractArray{<:Number}`, except for the wrappers `Transpose`, `Adjoint`, and `PermutedDimsArray`. +- `AbstractString`. -## Appropriate Use +This is because in typical application the internals of these are abstracted away and it is not desirable to traverse them. -!!! warning "Not everything should be a functor!" - Due to its generic nature it is very attractive to mark several structures as [`@functor`](@ref) when it may not be quite safe to do so. +## What if I get an error? -Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely. +Since by default Functors.jl tries to traverse most types e.g. when using [`fmap`](@ref), it is possible it fails in case the type has not an appropriate constructor. If use experience this issue, you have a few alternatives: +- Mark the type as a leaf using [`@leaf`](@ref) +- Use the `@functor` macro to specify which fields to traverse. +- Define an appropriate constructor for the type. -Examples of this include element types of arrays which typically have their own mathematical operations defined. Adding a [`@functor`](@ref) to such a type would end up missing methods such as `+(::MyElementType, ::MyElementType)`. Think `RGB` from Colors.jl. +If you are not able to traverse types in julia Base, please open an issue. diff --git a/src/Functors.jl b/src/Functors.jl index ffcbf8f..f99c665 100644 --- a/src/Functors.jl +++ b/src/Functors.jl @@ -1,9 +1,15 @@ module Functors +using Compat: @compat +using ConstructionBase: constructorof +using LinearAlgebra -export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves, +export @leaf, @functor, @flexiblefunctor, + fmap, fmapstructure, fcollect, execute, fleaves, fmap_with_path, fmapstructure_with_path, KeyPath, getkeypath, haskeypath, setkeypath! +@compat(public, (isleaf, children, functor)) + include("functor.jl") include("keypath.jl") include("walks.jl") @@ -16,10 +22,11 @@ include("base.jl") """ - Functors.functor(x) = functor(typeof(x), x) + functor(x) + functor(typeof(x), x) Returns a tuple containing, first, a `NamedTuple` of the children of `x` -(typically its fields), and second, a reconstruction funciton. +(typically its fields), and second, a reconstruction function. This controls the behaviour of [`fmap`](@ref). Methods should be added to `functor(::Type{T}, x)` for custom types, @@ -31,7 +38,7 @@ functor @functor T @functor T (x,) -Adds methods to [`functor`](@ref) allowing recursion into objects of type `T`, +Adds methods to [`functor`](@ref Functors.functor) allowing recursion into objects of type `T`, and reconstruction. Assumes that `T` has a constructor accepting all of its fields, which is true unless you have provided an inner constructor which does not. @@ -42,8 +49,6 @@ this can be restricted be restructed by providing a tuple of field names. ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> Functors.children(Foo(1,2)) (x = 1, y = 2) @@ -52,6 +57,8 @@ julia> _, re = Functors.functor(Foo(1,2)); julia> re((10, 20)) Foo(10, 20) +julia> @functor Foo # same as before, nothing changes + julia> struct TwoThirds a; b; c; end julia> @functor TwoThirds (a, c) @@ -71,9 +78,9 @@ TwoThirds(Foo(10, 20), Foo(3, 4), 560) var"@functor" """ - Functors.isleaf(x) + isleaf(x) -Return true if `x` has no [`children`](@ref) according to [`functor`](@ref). +Return true if `x` has no [`children`](@ref Functors.children) according to [`functor`](@ref Functors.functor). # Examples ```jldoctest @@ -99,9 +106,9 @@ true isleaf """ - Functors.children(x) + children(x) -Return the children of `x` as defined by [`functor`](@ref). +Return the children of `x` as defined by [`functor`](@ref Functors.functor). Equivalent to `functor(x)[1]`. """ children @@ -111,8 +118,8 @@ children A structure and type preserving `map`. -By default it transforms every leaf node (identified by `exclude`, default [`isleaf`](@ref)) -by applying `f`, and otherwise traverses `x` recursively using [`functor`](@ref). +By default it transforms every leaf node (identified by `exclude`, default [`isleaf`](@ref Functors.isleaf)) +by applying `f`, and otherwise traverses `x` recursively using [`functor`](@ref Functors.functor). Optionally, it may also be associated with objects `ys` with the same tree structure. In that case, `f` is applied to the corresponding leaf nodes in `x` and `ys`. @@ -163,20 +170,16 @@ Thus the relationship `x.i === x.iv[1]` will be preserved. An immutable object which appears twice is not stored in the cache, thus `f(34)` will be called twice, and the results will agree only if `f` is pure. -By default, `Tuple`s, `NamedTuple`s, and some other container-like types in Base have -children to recurse into. Arrays of numbers do not. -To enable recursion into new types, you must provide a method of [`functor`](@ref), -which can be done using the macro [`@functor`](@ref): +By default, almost all container-like types have children to recurse into. +Arrays of numbers do not. + +To opt out of recursion for custom types use [`@leaf`](@ref) or pass a custom `exclude` function. ```jldoctest withfoo julia> struct Foo; x; y; end -julia> @functor Foo - julia> struct Bar; x; end -julia> @functor Bar - julia> m = Foo(Bar([1,2,3]), (4, 5, Bar(Foo(6, 7)))); julia> fmap(x -> 10x, m) @@ -241,8 +244,6 @@ See also [`fmap`](@ref) and [`fmapstructure_with_path`](@ref). ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> m = Foo([1,2,3], [4, (5, 6), Foo(7, 8)]); julia> fmapstructure(x -> 2x, m) @@ -263,7 +264,7 @@ fmapstructure """ fcollect(x; exclude = v -> false) -Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) +Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref Functors.functor) and collecting the results into a flat array, ordered by a breadth-first traversal of `x`, respecting the iteration order of `children` calls. @@ -279,14 +280,12 @@ See also [`children`](@ref). ```jldoctest julia> struct Foo; x; y; end -julia> @functor Foo - julia> struct Bar; x; end -julia> @functor Bar - julia> struct TypeWithNoChildren; x; y; end +julia> @leaf TypeWithNoChildren + julia> m = Foo(Bar([1,2,3]), TypeWithNoChildren(:a, :b)) Foo(Bar([1, 2, 3]), TypeWithNoChildren(:a, :b)) @@ -356,7 +355,7 @@ fmapstructure_with_path """ fleaves(x; exclude = isleaf) -Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref) +Traverse `x` by recursing each child of `x` as defined by [`functor`](@ref Functors.functor) and collecting the leaves into a flat array, ordered by a breadth-first traversal of `x`, respecting the iteration order of `children` calls. @@ -370,10 +369,10 @@ See also [`fcollect`](@ref) for a similar function that collects all nodes inste ```jldoctest julia> struct Bar; x; end -julia> @functor Bar - julia> struct TypeWithNoChildren; x; y; end +julia> @leaf TypeWithNoChildren + julia> m = (a = Bar([1,2,3]), b = TypeWithNoChildren(4, 5)); julia> fleaves(m) diff --git a/src/base.jl b/src/base.jl index 913aaa7..63d8ee5 100644 --- a/src/base.jl +++ b/src/base.jl @@ -1,28 +1,39 @@ +### +### Opt-Out +### -@functor Base.RefValue - -@functor Base.Pair +@leaf Number +@leaf AbstractArray{<:Number} +@leaf AbstractString -@functor Base.Generator # aka Iterators.map +### +### Fast Paths for common types +### -@functor Base.ComposedFunction -@functor Base.Fix1 -@functor Base.Fix2 -@functor Base.Broadcast.BroadcastFunction +functor(::Type{<:Tuple}, x) = x, identity +functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity +functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity +functor(::Type{<:AbstractArray}, x) = x, identity -@static if VERSION >= v"1.9" - @functor Base.Splat +# This may be a reasonable default for AbstractDict +# but is not guaranteed to be correct for all dict subtypes +function functor(::Type{D}, x) where {D<:AbstractDict} + return constructorof(D)([k => x[k] for k in keys(x)]...), identity end -@static if VERSION >= v"1.7" - @functor Base.Returns +### +### Base Types requiring special handling +### + +@static if VERSION >= v"1.12-DEV" + functor(::Type{<:Base.Fix{N}}, x) where N = (; x.f, x.x), y -> Base.Fix{N}(y.f, y.x) end + ### ### Array wrappers ### -using LinearAlgebra # The reason for these is to let W and W' be seen as tied weights in Flux models. # Can't treat ReshapedArray very well, as its type doesn't include enough details for reconstruction. @@ -45,32 +56,9 @@ function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x) where {T,N,perm (parent = _PermutedDimsArray(x, iperm),), y -> PermutedDimsArray(only(y), perm) end function functor(::Type{<:PermutedDimsArray{T,N,perm,iperm}}, x::PermutedDimsArray{Tx,N,perm,iperm}) where {T,Tx,N,perm,iperm} - (parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping wrice. + (parent = parent(x),), y -> PermutedDimsArray(only(y), perm) # most common case, avoid wrapping twice. end _PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm) _PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent _PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm) - -### -### Iterators -### - -@functor Iterators.Accumulate -# Count -@functor Iterators.Cycle -@functor Iterators.Drop -@functor Iterators.DropWhile -@functor Iterators.Enumerate -@functor Iterators.Filter -@functor Iterators.Flatten -# IterationCutShort -@functor Iterators.PartitionIterator -@functor Iterators.ProductIterator -@functor Iterators.Repeated -@functor Iterators.Rest -@functor Iterators.Reverse -# Stateful -@functor Iterators.Take -@functor Iterators.TakeWhile -@functor Iterators.Zip diff --git a/src/functor.jl b/src/functor.jl index 653a928..a770652 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -5,21 +5,24 @@ const NoChildren = Tuple{} """ @leaf T -Define [`functor`](@ref) for the type `T` so that `isleaf(x::T) == true`. +Define [`functor`](@ref Functors.functor) for the type `T` so that `isleaf(x::T) == true`. """ macro leaf(T) :($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x)) end -@leaf Any # every type is a leaf by default -functor(x) = functor(typeof(x), x) - -functor(::Type{<:Tuple}, x) = x, identity -functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity -functor(::Type{<:Dict}, x) = Dict(k => x[k] for k in keys(x)), identity +# Default functor +function functor(T, x) + names = fieldnames(T) + if isempty(names) + return NoChildren(), _ -> x + end + S = constructorof(T) # remove parameters from parametric types and support anonymous functions + vals = ntuple(i -> getfield(x, names[i]), length(names)) + return NamedTuple{names}(vals), y -> S(y...) +end -functor(::Type{<:AbstractArray}, x) = x, identity -@leaf AbstractArray{<:Number} +functor(x) = functor(typeof(x), x) function makefunctor(m::Module, T, fs = fieldnames(T)) fidx = Ref(0) @@ -30,7 +33,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T)) f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f) end escfs = [:($f=x.$f) for f in fs] - + @eval m begin function $Functors.functor(::Type{<:$T}, x) reconstruct(y) = $T($(escargs...)) diff --git a/src/keypath.jl b/src/keypath.jl index 280ef94..d6f0b17 100644 --- a/src/keypath.jl +++ b/src/keypath.jl @@ -55,7 +55,6 @@ struct KeyPath{T<:Tuple} keys::T end -@functor KeyPath isleaf(::KeyPath, @nospecialize(x)) = isleaf(x) function KeyPath(keys::Union{KeyT, KeyPath}...) diff --git a/src/walks.jl b/src/walks.jl index 334cce8..2426a88 100644 --- a/src/walks.jl +++ b/src/walks.jl @@ -9,11 +9,13 @@ function check_lenghts(x, ys...) end _map(f, x::Dict, ys...) = Dict(k => f(v, (y[k] for y in ys)...) for (k, v) in x) +_map(f, x::D, ys...) where {D<:AbstractDict} = + constructorof(D)([k => f(v, (y[k] for y in ys)...) for (k, v) in x]...) _values(x) = x -_values(x::Dict) = values(x) +_values(x::AbstractDict) = values(x) -_keys(x::Dict) = Dict(k => k for k in keys(x)) +_keys(x::D) where {D <: AbstractDict} = constructorof(D)(k => k for k in keys(x)) _keys(x::Tuple) = (keys(x)...,) _keys(x::AbstractArray) = collect(keys(x)) _keys(x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(Ks) diff --git a/test/base.jl b/test/base.jl index e06d415..78078e3 100644 --- a/test/base.jl +++ b/test/base.jl @@ -1,3 +1,10 @@ +@testset "Numbers are leaves" begin + @test Functors.isleaf(1) + @test Functors.isleaf(1.0) + @test Functors.isleaf(1im) + @test Functors.isleaf(1//2) + @test Functors.isleaf(1.0 + 2.0im) +end @testset "RefValue" begin @test fmap(sqrt, Ref(16))[] == 4.0 @@ -172,3 +179,20 @@ end @test x.is[1] isa Vector{<:Complex} @test collect(x) isa Vector{<:Tuple{Complex, Complex}} end + +@testset "AbstractString is leaf" begin + struct DummyString <: AbstractString + str::String + end + s = DummyString("hello") + @test Functors.isleaf(s) +end + +@testset "AbstractDict is functor" begin + od = OrderedDict(1 => 1, 2 => 2) + @test !Functors.isleaf(od) + od2 = fmap(x -> 2x, od) + @test od2 isa OrderedDict + @test od2[1] == 2 + @test od2[2] == 4 +end \ No newline at end of file diff --git a/test/basics.jl b/test/basics.jl index 625e630..36b0a80 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -2,36 +2,38 @@ using Functors: functor, usecache struct Foo; x; y; end -@functor Foo Base.:(==)(x::Foo, y::Foo) = x.x == y.x && x.y == y.y struct Bar{T}; x::T; end -@functor Bar Base.:(==)(x::Bar, y::Bar) = x.x == y.x struct OneChild3; x; y; z; end @functor OneChild3 (y,) -struct NoChildren2; x; y; end +struct NoChild2; x; y; end +@functor NoChild2 () -struct NoChild{T}; x::T; end +struct NoChild1{T}; x::T; end +@functor NoChild1 () struct WrongOrder; x; y; z; end @functor WrongOrder (z, x) +struct LeafType{T}; x::T; end +@leaf LeafType ### ### Basic functionality ### -@testset "Children and Leaves" begin - no_children = NoChildren2(1, 2) +@testset "NoChild is not a leaf" begin + no_children = NoChild2(1, 2) has_children = Foo(1, 2) - @test Functors.isleaf(no_children) + @test !Functors.isleaf(no_children) @test !Functors.isleaf(has_children) - @test Functors.children(no_children) === Functors.NoChildren() + @test Functors.children(no_children) === (;) @test Functors.children(has_children) == (x=1, y=2) end @@ -108,8 +110,8 @@ end # Leaf types: @test usecache(d, [1,2]) @test !usecache(d, 4.0) - @test usecache(d, NoChild([1,2])) - @test !usecache(d, NoChild((3,4))) + @test usecache(d, LeafType([1,2])) + @test !usecache(d, LeafType((3,4))) # Not leaf: @test usecache(d, Ref(3)) # mutable container @@ -128,7 +130,10 @@ end end @testset "Self-referencing types" begin + # https://github.com/FluxML/Functors.jl/pull/72/ @test fmap(identity, Base.ImmutableDict(:a => 42)) == Base.ImmutableDict(:a => 42) + nt = fmap(x -> 2x, (; a = 1 ± 0.1, b = 2 ± 0.2)) + @test nt == (; a = 2 ± 0.2, b = 4 ± 0.4) end @testset "functor(typeof(x), y) from @functor" begin @@ -163,6 +168,17 @@ end @test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1)) end +@testset "anonymous functions" begin + model = let W = rand(2,2), b = ones(2) + x -> tanh.(W*x .+ b) + end + newmodel = fmap(zero, model) + @test newmodel isa Function + @test newmodel([1,2]) == [0,0] + @test newmodel.W == [0 0; 0 0] + @test newmodel.b == [0, 0] +end + ### ### Extras ### @@ -185,7 +201,7 @@ end m1 = [1, 2, 3] m2 = Bar(m1) - m0 = NoChildren2(:a, :b) + m0 = NoChild2(:a, :b) m3 = Foo(m2, m0) m4 = Bar(m3) @test all(fcollect(m4) .=== [m4, m3, m2, m1, m0]) @@ -299,74 +315,13 @@ end @test m̂.b ≈ fill(-0.2f0, size(m.b)) end -### -### FlexibleFunctors.jl -### - -struct FFoo - x - y - p -end -@flexiblefunctor FFoo p - -struct FBar - x - p -end -@flexiblefunctor FBar p - -struct FOneChild4 - x - y - z - p -end -@flexiblefunctor FOneChild4 p - -@testset "Flexible Nested" begin - model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) - - model′ = fmap(float, model) - - @test model.x.y == model′.x.y - @test model′.x.y isa Vector{Float64} -end - -@testset "Flexible Walk" begin - model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) - - model′ = fmapstructure(identity, model) - @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) - - model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) - - model2′ = fmapstructure(identity, model2) - @test model2′ == (; x=(0, (; x=[1, 2, 3]))) -end - -@testset "Flexible Property list" begin - model = FOneChild4(1, 2, 3, (:x, :z)) - model′ = fmap(x -> 2x, model) - - @test (model′.x, model′.y, model′.z) == (2, 2, 6) -end - -@testset "Flexible fcollect" begin - m1 = 1 - m2 = [1, 2, 3] - m3 = FFoo(m1, m2, (:y, )) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2]) - @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) - @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) +@testset "parametric types" begin + struct A{T} + x::T + end - m0 = NoChildren2(:a, :b) - m1 = [1, 2, 3] - m2 = FBar(m1, ()) - m3 = FFoo(m2, m0, (:x, :y,)) - m4 = FBar(m3, (:x,)) - @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + a = A(1) + @test fmap(x -> x/2, a) == A(0.5) end @testset "Dict" begin @@ -396,17 +351,16 @@ end end @testset "@leaf" begin - struct A; x; end - @functor A - a = A(1) - @test Functors.children(a) === (x = 1,) - struct B; x; end Functors.@leaf B b = B(1) children, re = Functors.functor(b) - @test children == Functors.NoChildren() @test re(children) === b + + a = LeafType(1) + children, re = Functors.functor(a) + @test children == Functors.NoChildren() + @test re(children) === a end @testset "IterateWalk" begin diff --git a/test/flexiblefunctors.jl b/test/flexiblefunctors.jl new file mode 100644 index 0000000..d34daf4 --- /dev/null +++ b/test/flexiblefunctors.jl @@ -0,0 +1,71 @@ + +### +### FlexibleFunctors.jl +### + +struct FFoo + x + y + p + end + @flexiblefunctor FFoo p + + struct FBar + x + p + end + @flexiblefunctor FBar p + + struct FOneChild4 + x + y + z + p + end + @flexiblefunctor FOneChild4 p + + @testset "Flexible Nested" begin + model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,)) + + model′ = fmap(float, model) + + @test model.x.y == model′.x.y + @test model′.x.y isa Vector{Float64} + end + + @testset "Flexible Walk" begin + model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y)) + + model′ = fmapstructure(identity, model) + @test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5]) + + model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,)) + + model2′ = fmapstructure(identity, model2) + @test model2′ == (; x=(0, (; x=[1, 2, 3]))) + end + + @testset "Flexible Property list" begin + model = FOneChild4(1, 2, 3, (:x, :z)) + model′ = fmap(x -> 2x, model) + + @test (model′.x, model′.y, model′.z) == (2, 2, 6) + end + + @testset "Flexible fcollect" begin + m1 = 1 + m2 = [1, 2, 3] + m3 = FFoo(m1, m2, (:y, )) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2]) + @test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3]) + @test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4]) + + m0 = NoChild2(:a, :b) + m1 = [1, 2, 3] + m2 = FBar(m1, ()) + m3 = FFoo(m2, m0, (:x, :y,)) + m4 = FBar(m3, (:x,)) + @test all(fcollect(m4) .=== [m4, m3, m2, m0]) + end + \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7db1e1b..e0c5970 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,12 @@ using Functors, Test using Zygote using LinearAlgebra using StaticArrays +using OrderedCollections: OrderedDict +using Measurements: ± @testset "Functors.jl" begin - include("basics.jl") include("base.jl") include("keypath.jl") - + include("flexiblefunctors.jl") end