From 58b90667b673d0d6ee96fc9a0fb4c0135a9f6db3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Mar 2024 19:39:27 +0530 Subject: [PATCH] feat: support similarterm for arrayop --- src/arrays.jl | 28 ++++++++++++++++++++++++++++ test/arrays.jl | 9 +++++++++ 2 files changed, 37 insertions(+) diff --git a/src/arrays.jl b/src/arrays.jl index bd4a4699c..3b2f069f2 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -62,6 +62,34 @@ end ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} +function SymbolicUtils.similarterm(t::ArrayOp, f, args, _symtype = nothing; metadata = nothing) + oldargs = arguments(t) + if _symtype === nothing + _symtype = symtype(t) + end + + if !all(isequal.(args, oldargs)) || !isequal(f, operation(t)) + term = similarterm(t.term, f, args) + subs = Dict() + for (orig, new) in zip(oldargs, args) + isequal(orig, new) && continue + subs[orig] = new + end + if !isequal(f, operation(t)) + subs[operation(t)] = f + end + expr = substitute(t.expr, subs) + expr = SymbolicUtils.term(operation(expr), arguments(expr)...) + else + term = t.term + expr = t.expr + end + if _symtype === nothing + _symtype = symtype(t) + end + return ArrayOp{_symtype}(t.output_idx, expr, t.reduce, term, t.shape, t.ranges, metadata) +end + shape(aop::ArrayOp) = aop.shape const show_arrayop = Ref{Bool}(false) diff --git a/test/arrays.jl b/test/arrays.jl index 42a383cc5..50c71ff5e 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -73,6 +73,15 @@ end @test getmetadata(unwrap(v[1]), TestMetaT) == 4 end +@testset "similarterm" begin + @variables A[1:5, 1:5] B[1:5, 1:5] + + T = unwrap(3A) + @test isequal(T, similarterm(T, operation(T), arguments(T))) + T2 = unwrap(3B) + @test isequal(T2, similarterm(T, operation(T), [*, 3, unwrap(B)])) +end + getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) @testset "broadcast & scalarize" begin @variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x(t)[1:4] u[1:1]