From 408f5fe7f0a28aca81bffe1489ffd8f9f79bbe49 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 23 Oct 2024 12:53:23 +0530 Subject: [PATCH] feat: support array variables in `linear_expansion` --- src/linear_algebra.jl | 39 +++++++++++++++++++++++++++++++++++++++ test/linear_solver.jl | 15 +++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 2b37a9b9a..177a3571a 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -276,6 +276,13 @@ function _linear_expansion(t, x) op, args = operation(t), arguments(t) expansion_check(op) + if iscall(x) && operation(x) == getindex + arrx, idxsx... = arguments(x) + else + arrx = nothing + idxsx = nothing + end + if op === (+) a₁ = b₁ = 0 islinear = true @@ -318,8 +325,24 @@ function _linear_expansion(t, x) a₁, b₁, islinear = linear_expansion(args[1], x) # (a₁ x + b₁)/b₂ return islinear ? (a₁ / b₂, b₁ / b₂, islinear) : (0, 0, false) + elseif op === getindex + arrt, idxst... = arguments(t) + isequal(arrt, arrx) && return (0, t, true) + + indexed_t = Symbolics.scalarize(arrt)[idxst...] + # when indexing a registered function/callable symbolic + # scalarizing and indexing leads to the same symbolic variable + # which causes a StackOverflowError without this + isequal(t, indexed_t) && return (0, t, true) + return linear_expansion(Symbolics.scalarize(arrt)[idxst...], x) else for (i, arg) in enumerate(args) + isequal(arg, arrx) && return (0, 0, false) + if symbolic_type(arg) == NotSymbolic() + arg isa AbstractArray || continue + _occursin_array(x, arrx, arg) && return (0, 0, false) + continue + end a, b, islinear = linear_expansion(arg, x) (_iszero(a) && islinear) || return (0, 0, false) end @@ -327,6 +350,22 @@ function _linear_expansion(t, x) end end +""" + _occursin_array(sym, arrsym, arr) + +Check if `sym` (or, if `sym` is an element of an array symbolic, the array symbolic +`arrsym`) occursin in the non-symbolic array `arr`. +""" +function _occursin_array(sym, arrsym, arr) + for el in arr + if symbolic_type(el) == NotSymbolic() + return el isa AbstractArray && _occursin_array(sym, arrsym, el) + else + return occursin(sym, el) || occursin(arrsym, el) + end + end +end + ### ### Utilities ### diff --git a/test/linear_solver.jl b/test/linear_solver.jl index 64be2db7b..f2df87e7f 100644 --- a/test/linear_solver.jl +++ b/test/linear_solver.jl @@ -59,3 +59,18 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x) @test islinear @test isequal(a, -1) @test isequal(b, D(x)) + +@testset "linear_expansion with array variables" begin + @variables x[1:2] y[1:2] z(..) + @test !Symbolics.linear_expansion(z(x) + x[1], x[1])[3] + @test !Symbolics.linear_expansion(z(x[1]) + x[1], x[1])[3] + a, b, islin = Symbolics.linear_expansion(z(x[2]) + x[1], x[1]) + @test islin && isequal(a, 1) && isequal(b, z(x[2])) + a, b, islin = Symbolics.linear_expansion((x + x)[1], x[1]) + @test islin && isequal(a, 2) && isequal(b, 0) + a, b, islin = Symbolics.linear_expansion(y[1], x[1]) + @test islin && isequal(a, 0) && isequal(b, y[1]) + @test !Symbolics.linear_expansion(z([x...]), x[1])[3] + @test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3] + @test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3] +end