From 2164cef6c64625fdb3c75cd1914ce3fd7a3e9a49 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 23 Oct 2024 12:53:23 +0530 Subject: [PATCH] feat: support array variables in `linear_expansion` --- src/linear_algebra.jl | 18 ++++++++++++++++++ test/linear_solver.jl | 12 ++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 2b37a9b9a..61fc8639a 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -276,6 +276,13 @@ function _linear_expansion(t, x) op, args = operation(t), arguments(t) expansion_check(op) + if iscall(x) && operation(x) == getindex + arrx, idxsx... = arguments(x) + else + arrx = nothing + idxsx = nothing + end + if op === (+) a₁ = b₁ = 0 islinear = true @@ -318,8 +325,19 @@ function _linear_expansion(t, x) a₁, b₁, islinear = linear_expansion(args[1], x) # (a₁ x + b₁)/b₂ return islinear ? (a₁ / b₂, b₁ / b₂, islinear) : (0, 0, false) + elseif op === getindex + arrt, idxst... = arguments(t) + isequal(arrt, arrx) && return (0, t, true) + + indexed_t = Symbolics.scalarize(arrt)[idxst...] + # when indexing a registered function/callable symbolic + # scalarizing and indexing leads to the same symbolic variable + # which causes a StackOverflowError without this + isequal(t, indexed_t) && return (0, t, true) + return linear_expansion(Symbolics.scalarize(arrt)[idxst...], x) else for (i, arg) in enumerate(args) + isequal(arg, arrx) && return (0, 0, false) a, b, islinear = linear_expansion(arg, x) (_iszero(a) && islinear) || return (0, 0, false) end diff --git a/test/linear_solver.jl b/test/linear_solver.jl index 64be2db7b..89b2a3371 100644 --- a/test/linear_solver.jl +++ b/test/linear_solver.jl @@ -59,3 +59,15 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x) @test islinear @test isequal(a, -1) @test isequal(b, D(x)) + +@testset "linear_expansion with array variables" begin + @variables x[1:2] y[1:2] z(..) + @test !linear_expansion(z(x) + x[1], x[1])[3] + @test !linear_expansion(z(x[1]) + x[1], x[1])[3] + a, b, islin = linear_expansion(z(x[2]) + x[1], x[1]) + @test islin && isequal(a, 1) && isequal(b, z(x[2])) + a, b, islin = linear_expansion((x + x)[1], x[1]) + @test islin && isequal(a, 2) && isequal(b, 0) + a, b, islin = linear_expansion(y[1], x[1]) + @test islin && isequal(a, 0) && isequal(b, y[1]) +end