diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index f13bc8854..43906f34c 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -98,7 +98,22 @@ function solve_for(eq, var; simplify=false, check=true) # scalar case # the cases. a, b, islinear = linear_expansion(eq, var) check && @assert islinear - islinear || return nothing + + if eq isa AbstractArray + for eqᵢ in eq + try + islinear &= Symbolics.isaffine(eqᵢ.lhs-eqᵢ.rhs, var) + catch e + end + end + else + try + islinear &= Symbolics.isaffine(eq.lhs-eq.rhs, [var]) + catch e + end + end + + if !islinear return nothing end # a * x + b = 0 if eq isa AbstractArray && var isa AbstractArray x = _solve(a, -b, simplify)