diff --git a/src/arrays.jl b/src/arrays.jl index 6df9fbffb..ac65fb616 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -63,6 +63,14 @@ end ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m) + args = map(args) do arg + if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic() + return Ref(only(arguments(arg))) + else + return arg + end + end + t = f(args...) t isa Symbolic && !isnothing(m) ? metadata(t, m) : t @@ -968,7 +976,7 @@ end ### Codegen function SymbolicUtils.Code.toexpr(x::ArrayOp, st) - haskey(st.symbolify, x) && return st.symbolify[x] + haskey(st.rewrites, x) && return st.rewrites[x] if iscall(x.term) toexpr(x.term, st) diff --git a/test/arrays.jl b/test/arrays.jl index 13fa73379..27041d846 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -82,6 +82,8 @@ end @test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing)) T2 = unwrap(3B) @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing)) + T3 = unwrap(A .^ 2) + @test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing)) end getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) diff --git a/test/build_function.jl b/test/build_function.jl index 18e0a574b..5e7266402 100644 --- a/test/build_function.jl +++ b/test/build_function.jl @@ -1,7 +1,7 @@ using Symbolics, SparseArrays, LinearAlgebra, Test using ReferenceTests using Symbolics: value -using SymbolicUtils.Code: DestructuredArgs, Func +using SymbolicUtils.Code: DestructuredArgs, Func, NameState @variables a b c1 c2 c3 d e g oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true) @test all(isnan, eval(oop)([-1, Inf])) @@ -275,3 +275,9 @@ let #658 k = eval(build_function(a * X1 + X2, X1, X2, a)[2]) @test k(ones(3), ones(3), 1.5) == [2.5, 2.5, 2.5] end + +@testset "ArrayOp codegen" begin + @variables x[1:2] + T = value(x .^ 2) + @test_nowarn toexpr(T, NameState()) +end