Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functor by default #51

Merged
merged 24 commits into from
Nov 1, 2024
Merged

functor by default #51

merged 24 commits into from
Nov 1, 2024

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Nov 25, 2022

Makes everything a functor by default, avoiding the need to sprinkle

@functor T

everywhere in Flux's layers and similar use cases.
The types already decorated with @functor T or @functor T (a, b) won't be affected by the change.

The amount of breakage and unintended consequence this PR could produce is something I cannot estimate at the moment.

Fix #49

TODO:

  • docs
  • run Optimisers.jl tests on this PR
  • run Flux.jl tests on this PR

@CarloLucibello CarloLucibello added this to the v0.5 milestone Nov 25, 2022
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that there's still a method for AbstractArray{<:Number} means that it doesn't recurse into the reshape here, which is good I think:

julia> pr(x) = (@show typeof(x); x);

julia> fmap(pr, rand(3)');
typeof(x) = Vector{Float64}

julia> fmap(pr, reshape(rand(Int8, 4)',2,2))
typeof(x) = Base.ReshapedArray{Int8, 2, Adjoint{Int8, Vector{Int8}}, Tuple{}}
2×2 reshape(adjoint(::Vector{Int8}), 2, 2) with eltype Int8:
   53  -63
 -125  -58

The default functor doesn't seem able to reconstruct closures like:

julia> D = let W = rand(2,2), b = zeros(2)
             x -> tanh.(W*x .+ b)
           end
#11 (generic function with 1 method)

julia> fmap(pr, D)
typeof(x) = Matrix{Float64}
typeof(x) = Vector{Float64}
ERROR: MethodError: no method matching var"#11#12"(::Matrix{Float64}, ::Vector{Float64})
Stacktrace:
 [1] (::Functors.var"#3#6"{UnionAll})(y::NamedTuple{(:W, :b), Tuple{Matrix{Float64}, Vector{Float64}}})
   @ Functors ~/.julia/packages/Functors/1AaAn/src/functor.jl:8
 [2] (::Functors.DefaultWalk)(::Function, ::Function)
   @ Functors ~/.julia/packages/Functors/1AaAn/src/walks.jl:56
...

julia> fieldnames(var"#11#12")
(:W, :b)

julia> methods(var"#11#12")
# 0 methods for type constructor

In global scope, example from https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Simple-Models just does nothing instead, unsurprisingly:

julia> W = rand(2, 5);
julia> b = rand(2);
julia> predict(x) = W*x .+ b;

julia> fmap(pr, predict)
typeof(x) = typeof(predict)
predict (generic function with 1 method)

julia> fieldnames(typeof(predict))
()

src/functor.jl Outdated Show resolved Hide resolved
S = T.name.wrapper # remove parameters from parametric types
vals = ntuple(i -> getfield(x, names[i]), length(names))
return NamedTuple{names}(vals), y -> S(y...)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this in mind for a future optimization-oriented PR

Copy link
Member Author

@CarloLucibello CarloLucibello Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott I'm not familiar with generated functions. Can you show me the snippet we should use here? Otherwise it will be done in a future PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My link above is a few lines off. I don't remember all the details, but I think my comments are saying that it works around Base.typename(T).wrapper taking 2μs every time. However, this PR doesn't do that, it calls S = constructorof(T). Maybe that removes the need for @generated. Might be worth timing these two?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to time this today...

julia> struct Mine1{T,S}
       x::T
       y::S
       end

julia> x = Mine1([1,2,3], [4,5,6]')
Mine1{Vector{Int64}, Adjoint{Int64, Vector{Int64}}}([1, 2, 3], adjoint([4, 5, 6]))

julia> @btime Base.typename(typeof($x)).wrapper
  1.458 ns (0 allocations: 0 bytes)
Mine1

julia> @btime ConstructionBase.constructorof(typeof($x))
  1.458 ns (0 allocations: 0 bytes)
Mine1

julia> VERSION
v"1.12.0-DEV.1375"

and on Julia 1.6 -- slower but far from 2μs, even if I have a faster computer than I did then...

julia> @btime Base.typename(typeof($x)).wrapper
  14.194 ns (0 allocations: 0 bytes)
Mine1

julia> @btime ConstructionBase.constructorof(typeof($x))
  0.001 ns (0 allocations: 0 bytes)
Mine1

julia> VERSION
v"1.6.0"

@mcabbott
Copy link
Member

I think Kyle said he had a branch doing something like this, using ConstructionBase. That appears to be able to reconstruct closures which is neat:

julia> adder = let y = ones(1)
                 x -> x .+ y
               end
#38 (generic function with 1 method)

julia> adder.y
1-element Vector{Float64}:
 1.0

julia> adder(2)
1-element Vector{Float64}:
 3.0

julia> using ConstructionBase

julia> newadder = constructorof(typeof(adder))([4 5 6])
#38 (generic function with 1 method)

julia> newadder(2)
1×3 Matrix{Int64}:
 6  7  8

It is happy to re-build things like reshape(rand(Int8, 4)',2,2), not sure how that will interact with fmap(f, x, dx) for gradients etc.

@ToucheSir
Copy link
Member

I learned recently that it's possible to de- and reconstruct closure types. See JuliaGPU/Adapt.jl#58 for an implementation of this. Even if we can't functor all user-defined types by default, maybe this is something we could in a backwards-compatible way? I could see it as a pilot project of sorts too.

strip type's parameters

factorize flexiblefunctors tests

ops

support closures

cleanup

rebase

use nochildren

update readme

docs
@CarloLucibello
Copy link
Member Author

With Flux v0.15 in preparation I think this is the right time to merge this.

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Oct 21, 2024

I marked as leaves

  • Number
  • AbstractArray{<:Number}
  • AbstractString

Other suggestions?

@mcabbott
Copy link
Member

Would be good for docs to note that not every AbstractArray of Numbers is a leaf, e.g. Transpose

@CarloLucibello
Copy link
Member Author

If I can get an approval I would go on and merge this

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Oct 28, 2024

I added some benchmarks. It seems that we have some regressions with respect to master, although not on flux types (I don't know why). I'm not sure how much we should care about this performance regression since the absolute values seem reasonable. If we do care we can try another implementation with @generated, cf #51 (comment), but not in this PR.

cl/fun master cl/fun/master
fmap/concrete struct 0.458 ± 0.001 μs 0.125 ± 0.041 μs 3.66
fmap/flux dense 0.5 ± 0 μs 0.5 ± 0.042 μs 1
fmap/flux dense chain 1.46 ± 0.042 μs 1.42 ± 0.042 μs 1.03
fmap/named tuple 0.292 ± 0.042 μs 0.334 ± 0.042 μs 0.874
fmap/non-concrete struct 0.542 ± 0.041 μs 0.125 ± 0 μs 4.34
time_to_load 0.059 ± 0.0021 s 0.0518 ± 0.0054 s 1.14

@mcabbott
Copy link
Member

Can this update the readme too? I think that (besides deleting @functor Foo) it wants a big notice that [email protected] is sort-of opt-out now, defaults to calling ConstructionBase.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do it.

@CarloLucibello
Copy link
Member Author

let's unleash havoc on the world!
ahah, hopefully not

@CarloLucibello CarloLucibello merged commit 100291a into master Nov 1, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Does functor have the right semantics for Flux?
3 participants