-
-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve type stability of cached walks #82
Conversation
This adds some complexity to the code and some fragility as well, since it seems it could break with newer julia versions. |
Not a benchmark, but without this PR: julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
#self#::Core.Const(Flux.gpu)
x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Body::Chain{T} where T<:Tuple{Any, Any}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│ %2 = Flux.gpu(%1, x)::Chain{T} where T<:Tuple{Any, Any}
└── return %2 v.s. with: julia> @code_warntype gpu(Chain(Dense(3, 5), Dense(5, 2)))
MethodInstance for Flux.gpu(::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity
), Matrix{Float32}, Vector{Float32}}}})
from gpu(x) @ Flux ~/.julia/packages/Flux/Wz6D4/src/functor.jl:248
Arguments
#self#::Core.Const(Flux.gpu)
x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vecto
r{Float32}}}}
Body::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.D
eviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuff
er}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
1 ─ %1 = Flux.FluxCUDAAdaptor()::Core.Const(Flux.FluxCUDAAdaptor(nothing))
│ %2 = Flux.gpu(%1, x)::Union{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Fl
oat32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1,
CUDA.Mem.DeviceBuffer}}}}, Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}
└── return %2 |
@darsnack @ToucheSir what do you think? I'm unfamiliar with expression manipulations. |
I am also concerned about fragility. The implementation itself is sensible, but as written seems like it will need to get updated for internal changes often. The core idea is to use the return type of the walk to force the type when accessing the cache, right? That seems like a very straight-forward generated function to write with the call to Pulling back, is there a use-case where we lack a function barrier between the call to |
Yes, essentially the whole generated function is just to generate
if you need to handle data movement during the forward/backward pass. |
given the concerns expressed in LuxDL/Lux.jl#1017 I think we should do this. |
@CarloLucibello Since Julia v1.10 is the new LTS, do you think we could drop v1.6 support so that we can remove that |
yes, we should do that. |
This PR adds a special cache type that allows the compiler to use the signature of the un-cached
walk
to generate corresponding type assertion to the untyped cache (IdDict{Any, Any}
). This would improve the type stability offmap
and friends. It also looses the constraint of the cache type so functionality outsidefmap
remains the same.