diff --git a/src/getsetall.jl b/src/getsetall.jl index ef7a565d..739a7f49 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -98,7 +98,12 @@ getall(obj, optic::ComposedFunction) = _getall(obj, decompose(optic)) function setall(obj, optic::ComposedFunction, vs) optics = decompose(optic) N = length(optics) - vss = to_nested_shape(vs, Val(getall_lengths(obj, optics)), Val(N)) + lengths = getall_lengths(obj, optics) + + total_length = _val(nestedsum(lengths)) + length(vs) == total_length || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $total_length destinations")) + + vss = to_nested_shape(vs, lengths, Val(N)) _setall(obj, optics, vss) end @@ -144,46 +149,63 @@ _staticlength(x::AbstractVector) = length(x) getall_lengths(obj, optics::Tuple{Any}) = _staticlength(getall(obj, only(optics))) for N in [2:10; :(<: Any)] - @eval function getall_lengths(obj, optics::NTuple{$N,Any}) - # convert to Tuple: vectors cannot be put into Val - map(getall(obj, last(optics)) |> Tuple) do o + @eval getall_lengths(obj, optics::NTuple{$N,Any}) = + map(getall(obj, last(optics))) do o getall_lengths(o, Base.front(optics)) end - end end _val(N::Int) = N _val(::Val{N}) where {N} = N -nestedsum(ls::Union{Int,Val}) = _val(ls) -nestedsum(ls::Tuple) = sum(nestedsum, ls; init=0) - -# to_nested_shape() definition uses both @eval and @generated -# -# @eval is needed because the code for different recursion depths should be different for inference, -# not the same method with different parameters. -# -# @generated is used to unpack target lengths from the second argument at compile time to make to_nested_shape() as cheap as possible. -# -# Note: to_nested_shape() only operates on plain Julia types and won't be affected by user lens definition, unlike setall for example. -# That's why it's safe to make it @generated. -to_nested_shape(vs, ::Val{LS}, ::Val{1}) where {LS} = (@assert length(vs) == _val(LS); vs) +_valadd(::Val{N}, ::Val{M}) where {N,M} = Val(N+M) +_valadd(n, m) = _val(n) + _val(m) + +# nestedsum(): compute the sum of all values in a nested tuple/vector of int/val(int) +nestedsum(ls::Union{Int,Val}) = ls +nestedsum(ls::Tuple) = _valadd(nestedsum(first(ls)), nestedsum(Base.tail(ls))) +nestedsum(ls::Tuple{}) = Val(0) +nestedsum(ls::Vector) = sum(_val ∘ nestedsum, ls) + +# splitelems() - split values provided to setall() into two parts: the first N elements, and the rest +# should always be type-stable +# if more collections should be supported, maybe add a fallback method that materializes to vectors; but is it actually needed? +splitelems(vs::NTuple{M,Any}, ::Val{N}) where {N,M} = + ntuple(j -> vs[j], Val(N)), ntuple(j -> vs[N+j], Val(M-N)) +splitelems(vs::Tuple, N) = + map(i -> vs[i], 1:N), map(i -> vs[i], N+1:length(vs)) +# staticarrays can be sliced into compile-time length slices for further efficiency, but this is still type-stable +splitelems(vs::AbstractVector, N) = + (@view vs[1:_val(N)]), (@view vs[_val(N)+1:end]) + +_sliceview(v::AbstractVector, i::AbstractVector) = view(v, i) +_sliceview(v::Tuple, i::AbstractVector) = collect(Iterators.map(i -> v[i], i)) # should be regular map(), but it exceed the recursion depth heuristic + +# to_nested_shape(): convert a flat tuple/vector of values (as provided to setall) into a nested structure of tuples/vectors following the shape (ls) +# shape is always a (nested) tuple or vector with int or val(int) values, it is generated by getall_lengths() +# values can be any collection passed to setall, here we support tuples and abstractvectors +to_nested_shape(vs, LS, ::Val{1}) = (@assert length(vs) == _val(LS); vs) + for i in 2:10 - @eval @generated function to_nested_shape(vs, ls::Val{LS}, ::Val{$i}) where {LS} - vi = 1 - subs = map(LS) do lss - n = nestedsum(lss) - elems = map(vi:vi+n-1) do j - :( vs[$j] ) - end - res = :( to_nested_shape(($(elems...),), $(Val(lss)), $(Val($(i - 1)))) ) - vi += n + @eval to_nested_shape(vs, ls::Tuple{}, ::Val{$i}) = () + + @eval function to_nested_shape(vs, ls::Tuple, ::Val{$i}) + lss = first(ls) + n = nestedsum(lss) + elems, elemstail = splitelems(vs, n) + reshead = to_nested_shape(elems, lss, $(Val(i - 1))) + restail = to_nested_shape(elemstail, Base.tail(ls), $(Val(i))) + return (reshead, restail...) + end + + @eval function to_nested_shape(vs, ls::Vector, ::Val{$i}) + vi = Ref(1) + map(ls) do lss + n = nestedsum(lss) |> _val + elems = _sliceview(vs, vi[]:vi[]+n-1) + res = to_nested_shape(elems, lss, $(Val(i - 1))) + vi[] += n res end - total_n = nestedsum(LS) - quote - length(vs) == $total_n || throw(DimensionMismatch("tried to assign $(length(vs)) elements to $($total_n) destinations")) - ($(subs...),) - end end end diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index ba99fad8..5424e668 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -127,20 +127,32 @@ end @test setall(obj.c[1], Elements(), (5,)) === SVector(5) @test setall(obj.c[1], Elements(), [5, 6]) === SVector(5, 6) @test setall(obj.c[1], Elements(), [5]) === SVector(5) - @testset for o in ( + @testset for (i,o) in ( (@optic _.c |> Elements() |> Elements()), (@optic _.c |> Elements() |> Elements() |> _ + 1), - ) - @test setall(obj, o, getall(obj, o)) === obj + ) |> enumerate + @test (@inferred setall(obj, o, getall(obj, o))) === obj @test setall(obj, o, collect(getall(obj, o))) === obj + if VERSION ≥ v"1.10" || i == 2 + @test (@inferred setall(obj, o, Vector{Float64}(collect(getall(obj, o))))) == obj + @test (@inferred setall(obj, o, SVector(getall(obj, o)))) == obj + else + @test setall(obj, o, Vector{Float64}(collect(getall(obj, o)))) == obj + @test setall(obj, o, SVector(getall(obj, o))) == obj + end end obj = ([1, 2], 3:5, (6,)) @test obj == setall(obj, @optic(_ |> Elements() |> Elements()), 1:6) @test ([2, 3], 4:6, (7,)) == setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6) - # can this infer?.. - @test_broken obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6) - @test_broken ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6) + + @test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), 1:6) + @test ([2, 3], 4:6, (7,)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements() |> _ - 1), 1:6) + @test obj == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), ntuple(identity, 6)) + @test obj == @inferred setall(obj, @optic(_ |> identity |> Elements() |> Elements()), ntuple(identity, 6)) + @test obj[1] == @inferred setall(obj[1], @optic(_ |> Elements() |> _ + 1), (2, 3)) + # impossible to infer: + @test_broken ([1, 2], [3.0, 4.0, 5.0], ("6",)) == @inferred setall(obj, @optic(_ |> Elements() |> Elements()), (1, 2, 3., 4., 5., "6")) end @testset "getall/setall consistency" begin