Skip to content

Commit

Permalink
use 5-arg copyto
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Apr 11, 2024
1 parent e38170d commit 058a25b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,18 @@ end
function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
copyto!(y, _getat(y, o, flat, view))
# copyto!(y, _getat_view(y, o, flat))
copyto!(y, 1, flat, o+1, length(y))
end
x
end

_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1])
_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes
_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes

# _getat_view(y::AbstractArray, o::Int, flat::AbstractVector) =
# view(flat, o .+ (1:length(y)))

struct _Trainable_biwalk <: AbstractWalk end

Expand Down

0 comments on commit 058a25b

Please sign in to comment.