Skip to content

Commit

Permalink
feat: support array variables in linear_expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 23, 2024
1 parent 10a61ee commit 2164cef
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ function _linear_expansion(t, x)
op, args = operation(t), arguments(t)
expansion_check(op)

if iscall(x) && operation(x) == getindex
arrx, idxsx... = arguments(x)
else
arrx = nothing
idxsx = nothing
end

if op === (+)
a₁ = b₁ = 0
islinear = true
Expand Down Expand Up @@ -318,8 +325,19 @@ function _linear_expansion(t, x)
a₁, b₁, islinear = linear_expansion(args[1], x)
# (a₁ x + b₁)/b₂
return islinear ? (a₁ / b₂, b₁ / b₂, islinear) : (0, 0, false)
elseif op === getindex
arrt, idxst... = arguments(t)
isequal(arrt, arrx) && return (0, t, true)

indexed_t = Symbolics.scalarize(arrt)[idxst...]
# when indexing a registered function/callable symbolic
# scalarizing and indexing leads to the same symbolic variable
# which causes a StackOverflowError without this
isequal(t, indexed_t) && return (0, t, true)
return linear_expansion(Symbolics.scalarize(arrt)[idxst...], x)
else
for (i, arg) in enumerate(args)
isequal(arg, arrx) && return (0, 0, false)
a, b, islinear = linear_expansion(arg, x)
(_iszero(a) && islinear) || return (0, 0, false)
end
Expand Down
12 changes: 12 additions & 0 deletions test/linear_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,15 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x)
@test islinear
@test isequal(a, -1)
@test isequal(b, D(x))

@testset "linear_expansion with array variables" begin
@variables x[1:2] y[1:2] z(..)
@test !linear_expansion(z(x) + x[1], x[1])[3]
@test !linear_expansion(z(x[1]) + x[1], x[1])[3]
a, b, islin = linear_expansion(z(x[2]) + x[1], x[1])
@test islin && isequal(a, 1) && isequal(b, z(x[2]))
a, b, islin = linear_expansion((x + x)[1], x[1])
@test islin && isequal(a, 2) && isequal(b, 0)
a, b, islin = linear_expansion(y[1], x[1])
@test islin && isequal(a, 0) && isequal(b, y[1])
end

0 comments on commit 2164cef

Please sign in to comment.