From e607ea99e855902ee6a2fa3789a25d381345683c Mon Sep 17 00:00:00 2001 From: serenity4 Date: Thu, 2 Nov 2023 15:35:38 +0100 Subject: [PATCH] Only reconstruct with non-tuple T at the end for flattened output --- src/macro.jl | 8 +++++++- test/macro.jl | 8 ++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/macro.jl b/src/macro.jl index b126f70..98dae73 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -554,7 +554,13 @@ function to_expr(cache, ex, flatten::Bool, T, variables, stop_early = false) end end isexpr(ex, FACTOR) && return to_final_expr(cache, ex[1], flatten, T, variables) - isexpr(ex, KVECTOR) && return :($construct($(reconstructed_type(T, ex.cache.sig, ex)), $(Expr(:tuple, to_final_expr.(cache, ex, flatten, Ref(T), Ref(variables))...)))) + if isexpr(ex, KVECTOR) + Tintermediate = flatten ? :Tuple : T + RT = reconstructed_type(Tintermediate, ex.cache.sig, ex) + components = Expr(:tuple) + append!(components.args, [to_final_expr(cache, x, flatten, Tintermediate, variables) for x in ex]) + return :($construct($RT, $components)) + end isexpr(ex, BLADE) && return 1 if isexpr(ex, GEOMETRIC_PRODUCT) @assert isweightedblade(ex) diff --git a/test/macro.jl b/test/macro.jl index 831856e..715b158 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -212,4 +212,12 @@ using SymbolicGA: extract_weights, input_expression, extract_expression, restruc @test (@ga 2 a::1 ⊢ b::1 ⊢ c::1 ⊢ d::1) isa KVector @test (@ga 2 a::1 ⊣ (a::1 ∧ b::1) ⊢ c::1 ⊣ d::1) isa KVector end + + @testset "Flattening" begin + a, b, c, d = rand(3), rand(3), rand(3), rand(3) + ex = @ga 3 (a::1 ⟑ b::1)::(0 + 2) + @test isa(ex, NTuple{4,Float64}) + ex = @ga 3 NTuple{4,Float32} (a::1 ⟑ b::1)::(0 + 2) + @test isa(ex, NTuple{4,Float32}) + end end;