Skip to content

Commit

Permalink
improve setall() inference (#120)
Browse files Browse the repository at this point in the history
Simplifying code at the same time:
- no @eval @generated anymore
- less Val usage
  • Loading branch information
aplavin authored Jan 5, 2024
1 parent 7e200c9 commit 2b0c6c2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 38 deletions.
86 changes: 54 additions & 32 deletions src/getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ getall(obj, optic::ComposedFunction) = _getall(obj, decompose(optic))
function setall(obj, optic::ComposedFunction, vs)
optics = decompose(optic)
N = length(optics)
vss = to_nested_shape(vs, Val(getall_lengths(obj, optics)), Val(N))
lengths = getall_lengths(obj, optics)

total_length = _val(nestedsum(lengths))
length(vs) == total_length || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $total_length destinations"))

vss = to_nested_shape(vs, lengths, Val(N))
_setall(obj, optics, vss)
end

Expand Down Expand Up @@ -144,46 +149,63 @@ _staticlength(x::AbstractVector) = length(x)

getall_lengths(obj, optics::Tuple{Any}) = _staticlength(getall(obj, only(optics)))
for N in [2:10; :(<: Any)]
@eval function getall_lengths(obj, optics::NTuple{$N,Any})
# convert to Tuple: vectors cannot be put into Val
map(getall(obj, last(optics)) |> Tuple) do o
@eval getall_lengths(obj, optics::NTuple{$N,Any}) =
map(getall(obj, last(optics))) do o
getall_lengths(o, Base.front(optics))
end
end
end

_val(N::Int) = N
_val(::Val{N}) where {N} = N

nestedsum(ls::Union{Int,Val}) = _val(ls)
nestedsum(ls::Tuple) = sum(nestedsum, ls; init=0)

# to_nested_shape() definition uses both @eval and @generated
#
# @eval is needed because the code for different recursion depths should be different for inference,
# not the same method with different parameters.
#
# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible.
#
# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example.
# That's why it's safe to make it @generated.
to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs)
_valadd(::Val{N}, ::Val{M}) where {N,M} = Val(N+M)
_valadd(n, m) = _val(n) + _val(m)

# nestedsum(): compute the sum of all values in a nested tuple/vector of int/val(int)
nestedsum(ls::Union{Int,Val}) = ls
nestedsum(ls::Tuple) = _valadd(nestedsum(first(ls)), nestedsum(Base.tail(ls)))
nestedsum(ls::Tuple{}) = Val(0)
nestedsum(ls::Vector) = sum(_val nestedsum, ls)

# splitelems() - split values provided to setall() into two parts: the first N elements, and the rest
# should always be type-stable
# if more collections should be supported, maybe add a fallback method that materializes to vectors; but is it actually needed?
splitelems(vs::NTuple{M,Any}, ::Val{N}) where {N,M} =
ntuple(j -> vs[j], Val(N)), ntuple(j -> vs[N+j], Val(M-N))
splitelems(vs::Tuple, N) =
map(i -> vs[i], 1:N), map(i -> vs[i], N+1:length(vs))
# staticarrays can be sliced into compile-time length slices for further efficiency, but this is still type-stable
splitelems(vs::AbstractVector, N) =
(@view vs[1:_val(N)]), (@view vs[_val(N)+1:end])

_sliceview(v::AbstractVector, i::AbstractVector) = view(v, i)
_sliceview(v::Tuple, i::AbstractVector) = collect(Iterators.map(i -> v[i], i)) # should be regular map(), but it exceed the recursion depth heuristic

# to_nested_shape(): convert a flat tuple/vector of values (as provided to setall) into a nested structure of tuples/vectors following the shape (ls)
# shape is always a (nested) tuple or vector with int or val(int) values, it is generated by getall_lengths()
# values can be any collection passed to setall, here we support tuples and abstractvectors
to_nested_shape(vs, LS, ::Val{1}) = (@assert length(vs) == _val(LS); vs)

for i in 2:10
@eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS}
vi = 1
subs = map(LS) do lss
n = nestedsum(lss)
elems = map(vi:vi+n-1) do j
:( vs[$j] )
end
res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) )
vi += n
@eval to_nested_shape(vs, ls::Tuple{}, ::Val{$i}) = ()

@eval function to_nested_shape(vs, ls::Tuple, ::Val{$i})
lss = first(ls)
n = nestedsum(lss)
elems, elemstail = splitelems(vs, n)
reshead = to_nested_shape(elems, lss, $(Val(i - 1)))
restail = to_nested_shape(elemstail, Base.tail(ls), $(Val(i)))
return (reshead, restail...)
end

@eval function to_nested_shape(vs, ls::Vector, ::Val{$i})
vi = Ref(1)
map(ls) do lss
n = nestedsum(lss) |> _val
elems = _sliceview(vs, vi[]:vi[]+n-1)
res = to_nested_shape(elems, lss, $(Val(i - 1)))
vi[] += n
res
end
total_n = nestedsum(LS)
quote
length(vs) == $total_n || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $($total_n) destinations"))
($(subs...),)
end
end
end
24 changes: 18 additions & 6 deletions test/test_getsetall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,32 @@ end
@test setall(obj.c[1], Elements(), (5,)) === SVector(5)
@test setall(obj.c[1], Elements(), [5, 6]) === SVector(5, 6)
@test setall(obj.c[1], Elements(), [5]) === SVector(5)
@testset for o in (
@testset for (i,o) in (
(@optic _.c |> Elements() |> Elements()),
(@optic _.c |> Elements() |> Elements() |> _ + 1),
)
@test setall(obj, o, getall(obj, o)) === obj
) |> enumerate
@test (@inferred setall(obj, o, getall(obj, o))) === obj
@test setall(obj, o, collect(getall(obj, o))) === obj
if VERSION v"1.10" || i == 2
@test (@inferred setall(obj, o, Vector{Float64}(collect(getall(obj, o))))) == obj
@test (@inferred setall(obj, o, SVector(getall(obj, o)))) == obj
else
@test setall(obj, o, Vector{Float64}(collect(getall(obj, o)))) == obj
@test setall(obj, o, SVector(getall(obj, o))) == obj
end
end

obj = ([1, 2], 3:5, (6,))
@test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
# can this infer?..
@test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)

@test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6)
@test ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6)
@test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), ntuple(identity, 6))
@test obj == @inferred setall(obj, @optic(_ |> identity |> Elements() |> Elements()), ntuple(identity, 6))
@test obj[1] == @inferred setall(obj[1], @optic(_ |> Elements() |> _ + 1), (2, 3))
# impossible to infer:
@test_broken ([1, 2], [3.0, 4.0, 5.0], ("6",)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), (1, 2, 3., 4., 5., "6"))
end

@testset "getall/setall consistency" begin
Expand Down

0 comments on commit 2b0c6c2

Please sign in to comment.