diff --git a/src/array-lib.jl b/src/array-lib.jl index cc5923032..c16d3984c 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -282,7 +282,10 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) wrap(unwrap(A)[j, i]) end +inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x function _matmul(A, B) + A = inner_unwrap(A) + B = inner_unwrap(B) @syms i::Int j::Int k::Int if isadjointvec(A) op = operation(A.term) @@ -295,6 +298,8 @@ end @wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) function _matvec(A, b) + A = inner_unwrap(A) + b = inner_unwrap(b) @syms i::Int k::Int sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b) if isdot(A, b)