From e2b85a46772e3f7d3469c88dc8bf65b8dabb15f0 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 25 Apr 2024 19:19:19 -0400 Subject: [PATCH] Add inner_unwrap for array ops --- src/array-lib.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/array-lib.jl b/src/array-lib.jl index cc5923032..c16d3984c 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -282,7 +282,10 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) wrap(unwrap(A)[j, i]) end +inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x function _matmul(A, B) + A = inner_unwrap(A) + B = inner_unwrap(B) @syms i::Int j::Int k::Int if isadjointvec(A) op = operation(A.term) @@ -295,6 +298,8 @@ end @wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) function _matvec(A, b) + A = inner_unwrap(A) + b = inner_unwrap(b) @syms i::Int k::Int sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b) if isdot(A, b)