-
-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
functor by default #51
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fact that there's still a method for AbstractArray{<:Number} means that it doesn't recurse into the reshape here, which is good I think:
julia> pr(x) = (@show typeof(x); x);
julia> fmap(pr, rand(3)');
typeof(x) = Vector{Float64}
julia> fmap(pr, reshape(rand(Int8, 4)',2,2))
typeof(x) = Base.ReshapedArray{Int8, 2, Adjoint{Int8, Vector{Int8}}, Tuple{}}
2×2 reshape(adjoint(::Vector{Int8}), 2, 2) with eltype Int8:
53 -63
-125 -58
The default functor doesn't seem able to reconstruct closures like:
julia> D = let W = rand(2,2), b = zeros(2)
x -> tanh.(W*x .+ b)
end
#11 (generic function with 1 method)
julia> fmap(pr, D)
typeof(x) = Matrix{Float64}
typeof(x) = Vector{Float64}
ERROR: MethodError: no method matching var"#11#12"(::Matrix{Float64}, ::Vector{Float64})
Stacktrace:
[1] (::Functors.var"#3#6"{UnionAll})(y::NamedTuple{(:W, :b), Tuple{Matrix{Float64}, Vector{Float64}}})
@ Functors ~/.julia/packages/Functors/1AaAn/src/functor.jl:8
[2] (::Functors.DefaultWalk)(::Function, ::Function)
@ Functors ~/.julia/packages/Functors/1AaAn/src/walks.jl:56
...
julia> fieldnames(var"#11#12")
(:W, :b)
julia> methods(var"#11#12")
# 0 methods for type constructor
In global scope, example from https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Simple-Models just does nothing instead, unsurprisingly:
julia> W = rand(2, 5);
julia> b = rand(2);
julia> predict(x) = W*x .+ b;
julia> fmap(pr, predict)
typeof(x) = typeof(predict)
predict (generic function with 1 method)
julia> fieldnames(typeof(predict))
()
S = T.name.wrapper # remove parameters from parametric types | ||
vals = ntuple(i -> getfield(x, names[i]), length(names)) | ||
return NamedTuple{names}(vals), y -> S(y...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will be slow. In FluxML/Flux.jl#1932 it needed a generated function to be as quick as before:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep this in mind for a future optimization-oriented PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mcabbott I'm not familiar with generated functions. Can you show me the snippet we should use here? Otherwise it will be done in a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My link above is a few lines off. I don't remember all the details, but I think my comments are saying that it works around Base.typename(T).wrapper
taking 2μs every time. However, this PR doesn't do that, it calls S = constructorof(T)
. Maybe that removes the need for @generated
. Might be worth timing these two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to time this today...
julia> struct Mine1{T,S}
x::T
y::S
end
julia> x = Mine1([1,2,3], [4,5,6]')
Mine1{Vector{Int64}, Adjoint{Int64, Vector{Int64}}}([1, 2, 3], adjoint([4, 5, 6]))
julia> @btime Base.typename(typeof($x)).wrapper
1.458 ns (0 allocations: 0 bytes)
Mine1
julia> @btime ConstructionBase.constructorof(typeof($x))
1.458 ns (0 allocations: 0 bytes)
Mine1
julia> VERSION
v"1.12.0-DEV.1375"
and on Julia 1.6 -- slower but far from 2μs, even if I have a faster computer than I did then...
julia> @btime Base.typename(typeof($x)).wrapper
14.194 ns (0 allocations: 0 bytes)
Mine1
julia> @btime ConstructionBase.constructorof(typeof($x))
0.001 ns (0 allocations: 0 bytes)
Mine1
julia> VERSION
v"1.6.0"
I think Kyle said he had a branch doing something like this, using ConstructionBase. That appears to be able to reconstruct closures which is neat: julia> adder = let y = ones(1)
x -> x .+ y
end
#38 (generic function with 1 method)
julia> adder.y
1-element Vector{Float64}:
1.0
julia> adder(2)
1-element Vector{Float64}:
3.0
julia> using ConstructionBase
julia> newadder = constructorof(typeof(adder))([4 5 6])
#38 (generic function with 1 method)
julia> newadder(2)
1×3 Matrix{Int64}:
6 7 8 It is happy to re-build things like |
81017f2
to
6a2f79c
Compare
I learned recently that it's possible to de- and reconstruct closure types. See JuliaGPU/Adapt.jl#58 for an implementation of this. Even if we can't functor all user-defined types by default, maybe this is something we could in a backwards-compatible way? I could see it as a pilot project of sorts too. |
strip type's parameters factorize flexiblefunctors tests ops support closures cleanup rebase use nochildren update readme docs
eafa78f
to
66c3893
Compare
With Flux v0.15 in preparation I think this is the right time to merge this. |
I marked as leaves
Other suggestions? |
Would be good for docs to note that not every AbstractArray of Numbers is a leaf, e.g. Transpose |
If I can get an approval I would go on and merge this |
I added some benchmarks. It seems that we have some regressions with respect to master, although not on flux types (I don't know why). I'm not sure how much we should care about this performance regression since the absolute values seem reasonable. If we do care we can try another implementation with
|
Can this update the readme too? I think that (besides deleting |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it.
let's unleash havoc on the world! |
Makes everything a functor by default, avoiding the need to sprinkle
@functor T
everywhere in Flux's layers and similar use cases.
The types already decorated with
@functor T
or@functor T (a, b)
won't be affected by the change.The amount of breakage and unintended consequence this PR could produce is something I cannot estimate at the moment.
Fix #49
TODO: