Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Remove `Val` from `ntuple`s where constant propagation occurs

Co-authored-by: Brian Chen <[email protected]>
  • Loading branch information
lxvm and ToucheSir authored Jan 19, 2024
1 parent 734a6c4 commit 8c61928
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
(::StaticGetter{i})(::Nothing) where {i} = nothing
function _unzip(tuples, ::Val{N}) where {N}
getters = ntuple(n -> StaticGetter{n}(), Val(N))
getters = ntuple(n -> StaticGetter{n}(), N)
map(g -> map(g, tuples), getters)
end
function unzip(tuples)
Expand Down Expand Up @@ -276,7 +276,7 @@ function productfunc(xs, dy)
@assert length(first(dy)) == length(xs)
ndim = map(Zygote._ndims, xs)
cdim = cumsum((1, ndim[begin:end-1]...))
getters = ntuple(n -> StaticGetter{n}(), Val(length(xs)))
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(first(dy), xs, cdim, getters) do dyn, x, cd, getter
dyn === nothing && return nothing
nd = _ndims(x)
Expand All @@ -300,7 +300,7 @@ end
end

function zipfunc(xs, dy)
getters = ntuple(n -> StaticGetter{n}(), Val(length(xs)))
getters = ntuple(n -> StaticGetter{n}(), length(xs))
map(xs, getters) do x, getter
dx = map(getter, dy)
_project(x, _restore(dx, _tryaxes(x)))
Expand Down

0 comments on commit 8c61928

Please sign in to comment.