Skip to content

Commit

Permalink
functor by default
Browse files Browse the repository at this point in the history
strip type's parameters

factorize flexiblefunctors tests

ops

support closures

cleanup

rebase

use nochildren

update readme

docs
  • Loading branch information
CarloLucibello committed Oct 21, 2024
1 parent b597d47 commit 1b4ac0f
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 137 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ authors = ["Mike J Innes <[email protected]>"]
version = "0.4.12"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Documenter = "1"
ConstructionBase = "1.4"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Documenter", "StaticArrays", "Zygote"]
test = ["Test", "StaticArrays", "Zygote"]
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ julia> struct Foo
y
end

julia> @functor Foo

julia> model = Foo(1, [1, 2, 3])
Foo(1, [1, 2, 3])

Expand All @@ -41,8 +39,6 @@ julia> struct Bar
x
end

julia> @functor Bar

julia> model = Bar(Foo(1, [1, 2, 3]))
Bar(Foo(1, [1, 2, 3]))

Expand Down
19 changes: 11 additions & 8 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ For large models it can be cumbersome or inefficient to work with parameters as

## Basic Usage and Implementation

When one marks a structure as [`@functor`](@ref) it means that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref).
By default, julia types are marked as [`@functor`](@ref)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref).

The workhorse of fmap is actually a lower level function, functor:
The workhorse of `fmap` is actually a lower level function, functor:

```julia-repl
julia> using Functors
Expand All @@ -20,8 +20,6 @@ julia> struct Foo
y
end
julia> @functor Foo
julia> foo = Foo(1, [1, 2, 3]) # notice all the elements are integers
julia> xs, re = Functors.functor(foo)
Expand Down Expand Up @@ -50,12 +48,17 @@ julia> fmap(float, model)
Baz(1.0, 2)
```

Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default.
Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default.

## Appropriate Use
The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)`
and also equivalent to avoiding `@functor` altogether.

Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed.

!!! warning "Not everything should be a functor!"
Due to its generic nature it is very attractive to mark several structures as [`@functor`](@ref) when it may not be quite safe to do so.
!!! warning "Change to opt-out behaviour in v0.5"
Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were not functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`.

## Appropriate Use

Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely.

Expand Down
5 changes: 3 additions & 2 deletions src/Functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Functors
export @functor, @flexiblefunctor, fmap, fmapstructure, fcollect, execute, fleaves,
fmap_with_path, fmapstructure_with_path,
KeyPath, getkeypath, haskeypath, setkeypath!
using ConstructionBase: constructorof

include("functor.jl")
include("keypath.jl")
Expand Down Expand Up @@ -42,8 +43,6 @@ this can be restricted be restructed by providing a tuple of field names.
```jldoctest
julia> struct Foo; x; y; end
julia> @functor Foo
julia> Functors.children(Foo(1,2))
(x = 1, y = 2)
Expand All @@ -52,6 +51,8 @@ julia> _, re = Functors.functor(Foo(1,2));
julia> re((10, 20))
Foo(10, 20)
julia> @functor Foo # same as before, nothing changes
julia> struct TwoThirds a; b; c; end
julia> @functor TwoThirds (a, c)
Expand Down
34 changes: 1 addition & 33 deletions src/base.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@

@functor Base.RefValue

@functor Base.Pair

@functor Base.Generator # aka Iterators.map

@functor Base.ComposedFunction
@functor Base.Fix1
@functor Base.Fix2
@functor Base.Broadcast.BroadcastFunction
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)

@static if VERSION >= v"1.9"
@functor Base.Splat
Expand Down Expand Up @@ -51,26 +42,3 @@ end
_PermutedDimsArray(x, iperm) = PermutedDimsArray(x, iperm)
_PermutedDimsArray(x::NamedTuple{(:parent,)}, iperm) = x.parent
_PermutedDimsArray(bc::Broadcast.Broadcasted, iperm) = _PermutedDimsArray(Broadcast.materialize(bc), iperm)

###
### Iterators
###

@functor Iterators.Accumulate
# Count
@functor Iterators.Cycle
@functor Iterators.Drop
@functor Iterators.DropWhile
@functor Iterators.Enumerate
@functor Iterators.Filter
@functor Iterators.Flatten
# IterationCutShort
@functor Iterators.PartitionIterator
@functor Iterators.ProductIterator
@functor Iterators.Repeated
@functor Iterators.Rest
@functor Iterators.Reverse
# Stateful
@functor Iterators.Take
@functor Iterators.TakeWhile
@functor Iterators.Zip
16 changes: 14 additions & 2 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@ macro leaf(T)
:($Functors.functor(::Type{<:$(esc(T))}, x) = ($Functors.NoChildren(), _ -> x))
end

@leaf Any # every type is a leaf by default
# @leaf Any # every type is a leaf by default

# Default functor
function functor(T, x)
names = fieldnames(T)
if isempty(names)
return NoChildren(), _ -> x
end
S = constructorof(T) # remove parameters from parametric types and support anonymous functions
vals = ntuple(i -> getfield(x, names[i]), length(names))
return NamedTuple{names}(vals), y -> S(y...)
end

functor(x) = functor(typeof(x), x)

functor(::Type{<:Tuple}, x) = x, identity
Expand All @@ -30,7 +42,7 @@ function makefunctor(m::Module, T, fs = fieldnames(T))
f in fs ? :(y[$(Meta.quot(f))]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]

@eval m begin
function $Functors.functor(::Type{<:$T}, x)
reconstruct(y) = $T($(escargs...))
Expand Down
116 changes: 33 additions & 83 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,38 @@
using Functors: functor, usecache

struct Foo; x; y; end
@functor Foo

Base.:(==)(x::Foo, y::Foo) = x.x == y.x && x.y == y.y

struct Bar{T}; x::T; end
@functor Bar

Base.:(==)(x::Bar, y::Bar) = x.x == y.x

struct OneChild3; x; y; z; end
@functor OneChild3 (y,)

struct NoChildren2; x; y; end
struct NoChild2; x; y; end
@functor NoChild2 ()

struct NoChild{T}; x::T; end
struct NoChild1{T}; x::T; end
@functor NoChild1 ()

struct WrongOrder; x; y; z; end
@functor WrongOrder (z, x)

struct LeafType{T}; x::T; end
@leaf LeafType

###
### Basic functionality
###

@testset "Children and Leaves" begin
no_children = NoChildren2(1, 2)
@testset "NoChild is not a leaf" begin
no_children = NoChild2(1, 2)
has_children = Foo(1, 2)
@test Functors.isleaf(no_children)
@test !Functors.isleaf(no_children)
@test !Functors.isleaf(has_children)
@test Functors.children(no_children) === Functors.NoChildren()
@test Functors.children(no_children) === (;)
@test Functors.children(has_children) == (x=1, y=2)
end

Expand Down Expand Up @@ -108,8 +110,8 @@ end
# Leaf types:
@test usecache(d, [1,2])
@test !usecache(d, 4.0)
@test usecache(d, NoChild([1,2]))
@test !usecache(d, NoChild((3,4)))
@test usecache(d, LeafType([1,2]))
@test !usecache(d, LeafType((3,4)))

# Not leaf:
@test usecache(d, Ref(3)) # mutable container
Expand Down Expand Up @@ -163,6 +165,17 @@ end
@test_throws Exception functor(NamedTuple{(:x, :y)}, (z=33, x=1))
end

@testset "anonymous functions" begin
model = let W = rand(2,2), b = ones(2)
x -> tanh.(W*x .+ b)
end
newmodel = fmap(zero, model)
@test newmodel isa Function
@test newmodel([1,2]) == [0,0]
@test newmodel.W == [0 0; 0 0]
@test newmodel.b == [0, 0]
end

###
### Extras
###
Expand All @@ -185,7 +198,7 @@ end

m1 = [1, 2, 3]
m2 = Bar(m1)
m0 = NoChildren2(:a, :b)
m0 = NoChild2(:a, :b)
m3 = Foo(m2, m0)
m4 = Bar(m3)
@test all(fcollect(m4) .=== [m4, m3, m2, m1, m0])
Expand Down Expand Up @@ -299,74 +312,13 @@ end
@test.b fill(-0.2f0, size(m.b))
end

###
### FlexibleFunctors.jl
###

struct FFoo
x
y
p
end
@flexiblefunctor FFoo p

struct FBar
x
p
end
@flexiblefunctor FBar p

struct FOneChild4
x
y
z
p
end
@flexiblefunctor FOneChild4 p

@testset "Flexible Nested" begin
model = FBar(FFoo(1, [1, 2, 3], (:y, )), (:x,))

model′ = fmap(float, model)

@test model.x.y == model′.x.y
@test model′.x.y isa Vector{Float64}
end

@testset "Flexible Walk" begin
model = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x, :y))

model′ = fmapstructure(identity, model)
@test model′ == (; x=(0, (; x=[1, 2, 3])), y=[4, 5])

model2 = FFoo((0, FBar([1, 2, 3], (:x,))), [4, 5], (:x,))

model2′ = fmapstructure(identity, model2)
@test model2′ == (; x=(0, (; x=[1, 2, 3])))
end

@testset "Flexible Property list" begin
model = FOneChild4(1, 2, 3, (:x, :z))
model′ = fmap(x -> 2x, model)

@test (model′.x, model′.y, model′.z) == (2, 2, 6)
end

@testset "Flexible fcollect" begin
m1 = 1
m2 = [1, 2, 3]
m3 = FFoo(m1, m2, (:y, ))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2])
@test all(fcollect(m4, exclude = x -> x isa Array) .=== [m4, m3])
@test all(fcollect(m4, exclude = x -> x isa FFoo) .=== [m4])
@testset "parametric types" begin
struct A{T}
x::T
end

m0 = NoChildren2(:a, :b)
m1 = [1, 2, 3]
m2 = FBar(m1, ())
m3 = FFoo(m2, m0, (:x, :y,))
m4 = FBar(m3, (:x,))
@test all(fcollect(m4) .=== [m4, m3, m2, m0])
a = A(1)
@test fmap(x -> x/2, a) == A(0.5)
end

@testset "Dict" begin
Expand Down Expand Up @@ -396,15 +348,13 @@ end
end

@testset "@leaf" begin
struct A; x; end
@functor A
a = A(1)
@test Functors.children(a) === (x = 1,)

struct B; x; end
Functors.@leaf B
b = B(1)
children, re = Functors.functor(b)

a = LeafType(1)
children, re = Functors.functor(a)
@test children == Functors.NoChildren()
@test re(children) === b
end
Expand Down
Loading

0 comments on commit 1b4ac0f

Please sign in to comment.