Skip to content

Commit

Permalink
Merge pull request #1289 from AayushSabharwal/as/arrayop-fixes
Browse files Browse the repository at this point in the history
fix: fix `maketerm` and `toexpr` for `ArrayOp`
  • Loading branch information
ChrisRackauckas authored Sep 27, 2024
2 parents 113a5a5 + 2fca7f4 commit adedbf9
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
10 changes: 9 additions & 1 deletion src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ end
ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T}

function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m)
args = map(args) do arg
if iscall(arg) && operation(arg) == Ref && symbolic_type(only(arguments(arg))) == NotSymbolic()
return Ref(only(arguments(arg)))
else
return arg
end
end

t = f(args...)
t isa Symbolic && !isnothing(m) ?
metadata(t, m) : t
Expand Down Expand Up @@ -968,7 +976,7 @@ end
### Codegen

function SymbolicUtils.Code.toexpr(x::ArrayOp, st)
haskey(st.symbolify, x) && return st.symbolify[x]
haskey(st.rewrites, x) && return st.rewrites[x]

if iscall(x.term)
toexpr(x.term, st)
Expand Down
2 changes: 2 additions & 0 deletions test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ end
@test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing))
T2 = unwrap(3B)
@test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing))
T3 = unwrap(A .^ 2)
@test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing))
end

getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
Expand Down
8 changes: 7 additions & 1 deletion test/build_function.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Symbolics, SparseArrays, LinearAlgebra, Test
using ReferenceTests
using Symbolics: value
using SymbolicUtils.Code: DestructuredArgs, Func
using SymbolicUtils.Code: DestructuredArgs, Func, NameState
@variables a b c1 c2 c3 d e g
oop, iip = Symbolics.build_function([sqrt(a), sin(b)], [a, b], nanmath = true)
@test all(isnan, eval(oop)([-1, Inf]))
Expand Down Expand Up @@ -275,3 +275,9 @@ let #658
k = eval(build_function(a * X1 + X2, X1, X2, a)[2])
@test k(ones(3), ones(3), 1.5) == [2.5, 2.5, 2.5]
end

@testset "ArrayOp codegen" begin
@variables x[1:2]
T = value(x .^ 2)
@test_nowarn toexpr(T, NameState())
end

0 comments on commit adedbf9

Please sign in to comment.