Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functor by default #51

Merged
merged 24 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
Manifest.toml
build
.vscode
benchmarks*.json
results*.json
*.tmp

12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@ authors = ["Mike J Innes <[email protected]>"]
version = "0.4.12"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
Documenter = "1"
Compat = "4.16"
ConstructionBase = "1.4"
Measurements = "2"
OrderedCollections = "1.6"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Documenter", "StaticArrays", "Zygote"]
test = ["Test", "OrderedCollections", "StaticArrays", "Zygote", "Measurements"]
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
[action-img]: https://github.com/FluxML/Functors.jl/workflows/CI/badge.svg
[action-url]: https://github.com/FluxML/Functors.jl/actions

Functors.jl provides tools to express a powerful design pattern for dealing with large/ nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step.
Functors.jl provides tools to express a powerful design pattern for dealing with large / nested structures, as in machine learning and optimisation. For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it is also desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step.

## Basic Usage

Functors.jl provides `fmap` to make those things easy, acting as a 'map over parameters':

Expand All @@ -25,8 +27,6 @@ julia> struct Foo
y
end

julia> @functor Foo

julia> model = Foo(1, [1, 2, 3])
Foo(1, [1, 2, 3])

Expand All @@ -41,26 +41,32 @@ julia> struct Bar
x
end

julia> @functor Bar

julia> model = Bar(Foo(1, [1, 2, 3]))
Bar(Foo(1, [1, 2, 3]))

julia> fmap(float, model)
Bar(Foo(1.0, [1.0, 2.0, 3.0]))
```

> [!NOTE]
> Up to to v0.4, Functors.jl's functionality had to be opted in on custom types via the `@functor Foo` macro call.
> With v0.5 instead, this is no longer necessary: by default any type is recursively traversed up to the leaves
> and `ConstructionBase.constructorof` is used to reconstruct it.
> In order to opt-out of this behaviour and make a type non traversable you can use `@leaf Foo`.

## Further Details

The workhorse of `fmap` is actually a lower level function, `functor`:

```julia
julia> xs, re = functor(Foo(1, [1, 2, 3]))
((x = 1, y = [1, 2, 3]), var"#21#22"())
julia> children, reconstruct = Functors.functor(Foo(1, [1, 2, 3]))
((x = 1, y = [1, 2, 3]), Functors.var"#3#6"{DataType}(Foo))

julia> re(map(float, xs))
julia> reconstruct(map(float, children))
Foo(1.0, [1.0, 2.0, 3.0])
```

`functor` returns the parts of the object that can be inspected, as well as a `re` function that takes those values and restructures them back into an object of the original type.
`functor` returns the parts of the object that can be inspected, as well as a `reconstruct` function that takes those values and restructures them back into an object of the original type.

To include only certain fields, pass a tuple of field names to `@functor`:

Expand Down
9 changes: 9 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
AirspeedVelocity = "1c8270ee-6884-45cc-9545-60fa71ec23e4"
BenchmarkPlots = "ab8c0f59-4072-4e0d-8f91-a91e1495eb26"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
57 changes: 57 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# We run the benchmarks using AirspeedVelocity.jl

# To run benchmarks locally, first install AirspeedVelocity.jl:
# julia> using Pkg; Pkg.add("AirspeedVelocity"); Pkg.build("AirspeedVelocity")
# and make sure .julia/bin is in your PATH.

# Then commit the changes and run:
# $ benchpkg Functors --rev=mybranch,master --bench-on=mybranch


using BenchmarkTools: BenchmarkTools, BenchmarkGroup, @benchmarkable, @btime, @benchmark, judge
using ConcreteStructs: @concrete
using Flux: Dense, Chain
using LinearAlgebra: BLAS
using Functors
using Statistics: median

const SUITE = BenchmarkGroup()
const BENCHMARK_CPU_THREADS = Threads.nthreads()
BLAS.set_num_threads(BENCHMARK_CPU_THREADS)


@concrete struct A
w
b
σ
end

struct B
w
b
σ
end

function setup_fmap_bench!(suite)
a = A(rand(5,5), rand(5), tanh)
suite["fmap"]["concrete struct"] = @benchmarkable fmap(identity, $a)

a = B(rand(5,5), rand(5), tanh)
suite["fmap"]["non-concrete struct"] = @benchmarkable fmap(identity, $a)

a = Dense(5, 5, tanh)
suite["fmap"]["flux dense"] = @benchmarkable fmap(identity, $a)

a = Chain(Dense(5, 5, tanh), Dense(5, 5, tanh))
suite["fmap"]["flux dense chain"] = @benchmarkable fmap(identity, $a)

nt = (layers=(w= rand(5,5), b=rand(5), σ=tanh),)
suite["fmap"]["named tuple"] = @benchmarkable fmap(identity, $nt)

return suite
end

setup_fmap_bench!(SUITE)

## AirspeedVelocity.jl will automatically run the benchmarks and save the results
# results = BenchmarkTools.run(SUITE; verbose=true)
40 changes: 29 additions & 11 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ Functors.jl provides a set of tools to represent [functors](https://en.wikipedia

The most straightforward use is to traverse a complicated nested structure as a tree, and apply a function `f` to every field it encounters along the way.

For large models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step.
For large machine learning models it can be cumbersome or inefficient to work with parameters as one big, flat vector, and structs help manage complexity; but it may be desirable to easily operate over all parameters at once, e.g. for changing precision or applying an optimiser update step.

## Basic Usage and Implementation

When one marks a structure as [`@functor`](@ref) it means that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`Functors.fmap`](@ref).
By default, julia types are marked as [`@functor`](@ref Functors.functor)s, meaning that Functors.jl is allowed to look into the fields of the instances of the struct and modify them. This is achieved through [`fmap`](@ref). To opt-out of this behaviour, use [`@leaf`](@ref) on your custom type.

The workhorse of fmap is actually a lower level function, functor:
```julia-repl

The workhorse of `fmap` is actually a lower level function, [`functor`](@ref Functors.functor):

```julia-repl
julia> using Functors
Expand All @@ -20,8 +22,6 @@ julia> struct Foo
y
end

julia> @functor Foo

julia> foo = Foo(1, [1, 2, 3]) # notice all the elements are integers

julia> xs, re = Functors.functor(foo)
Expand Down Expand Up @@ -50,13 +50,31 @@ julia> fmap(float, model)
Baz(1.0, 2)
```

Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor, so structs that define custom inner constructors are expected to provide one that acts like the default.
Any field not in the list will be passed through as-is during reconstruction. This is done by invoking the default constructor accepting all fields as arguments, so structs that define custom inner constructors are expected to provide one that acts like the default.

The use of `@functor` with no fields argument as in `@functor Baz` is equivalent to `@functor Baz fieldnames(Baz)` and also equivalent to avoiding `@functor` altogether.

Using [`@leaf`](@ref) instead of [`@functor`](@ref) will prevent the fields of a struct from being traversed.

!!! warning "Change to opt-out behaviour in v0.5"
Previous releases of functors, up to v0.4, used an opt-in behaviour where structs were leaves functors unless marked with `@functor`. This was changed in v0.5 to an opt-out behaviour where structs are functors unless marked with `@leaf`.

## Which types are leaves?

By default all composite types in are functors and can be traversed, unless marked with [`@leaf`](@ref).

The following types instead are explicitly marked as leaves in Functors.jl:
- `Number`.
- `AbstractArray{<:Number}`, except for the wrappers `Transpose`, `Adjoint`, and `PermutedDimsArray`.
- `AbstractString`.

## Appropriate Use
This is because in typical application the internals of these are abstracted away and it is not desirable to traverse them.

!!! warning "Not everything should be a functor!"
Due to its generic nature it is very attractive to mark several structures as [`@functor`](@ref) when it may not be quite safe to do so.
## What if I get an error?

Typically, since any function `f` is applied to the leaves of the tree, but it is possible for some functions to require dispatching on the specific type of the fields causing some methods to be missed entirely.
Since by default Functors.jl tries to traverse most types e.g. when using [`fmap`](@ref), it is possible it fails in case the type has not an appropriate constructor. If use experience this issue, you have a few alternatives:
- Mark the type as a leaf using [`@leaf`](@ref)
- Use the `@functor` macro to specify which fields to traverse.
- Define an appropriate constructor for the type.

Examples of this include element types of arrays which typically have their own mathematical operations defined. Adding a [`@functor`](@ref) to such a type would end up missing methods such as `+(::MyElementType, ::MyElementType)`. Think `RGB` from Colors.jl.
If you are not able to traverse types in julia Base, please open an issue.
Loading
Loading