From d272593a3d1208a3f9be81e2ea0aac0b3effa22a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 21 Oct 2024 15:02:05 +0530 Subject: [PATCH] fix: fix maketerm for ArrayOp involving broadcasted symbolics --- src/arrays.jl | 9 +++++++-- test/arrays.jl | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index ac65fb616..0518bb48e 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -64,8 +64,13 @@ ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m) args = map(args) do arg - if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic() - return Ref(only(arguments(arg))) + if iscall(arg) && operation(arg) == Ref + inner = only(arguments(arg)) + if symbolic_type(inner) == NotSymbolic() + return Ref(inner) + else + return inner + end else return arg end diff --git a/test/arrays.jl b/test/arrays.jl index 27041d846..11ce88441 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -76,7 +76,7 @@ end end @testset "maketerm" begin - @variables A[1:5, 1:5] B[1:5, 1:5] + @variables A[1:5, 1:5] B[1:5, 1:5] C T = unwrap(3A) @test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing)) @@ -84,6 +84,8 @@ end @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing)) T3 = unwrap(A .^ 2) @test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing)) + T4 = unwrap(A .* C) + @test isequal(T4, Symbolics.maketerm(typeof(T4), operation(T4), arguments(T4), nothing)) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)