Skip to content

Commit

Permalink
Refactor to allow non-square matrix variables + test
Browse files Browse the repository at this point in the history
  • Loading branch information
thatgardnerone committed Aug 21, 2024
1 parent affe33f commit d6321a6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/SumOfSquares/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import math
from collections import defaultdict
from typing import Iterable, Tuple, List
from typing import Iterable, Tuple, List, Union

from .util import *

Expand Down Expand Up @@ -127,12 +127,24 @@ def poly_variable(name: str, variables: List[sp.Symbol], deg: int,
return sum(coeff * prod(var**power for var, power in zip(variables, monom))
for monom, coeff in zip(basis, coeffs))

def matrix_variable(name: str, variables: List[sp.Symbol], deg: int, dim: int,

def matrix_variable(name: str, variables: List[sp.Symbol], deg: int, dim: Union[int, tuple],
hom: bool=False, sym: bool=True) -> sp.Matrix:
'''Returns a (symmetric) matrix variable of size dim x dim'''
arr = [[None] * dim for _ in range(dim)]
for i in range(dim):
for j in range(dim):
'''
Returns a matrix variable.
If the specified dimension is a tuple, then the matrix is of size dim_1 x dim_2,
otherwise it is of size dim x dim.
Note: If the matrix is not square, it cannot be symmetric.
'''
dim_1, dim_2 = (dim, dim) if isinstance(dim, int) else dim

sym = sym and dim_1 == dim_2

arr = np.empty((dim_1, dim_2), dtype=object)

for i in range(dim_1):
for j in range(dim_2):
if j < i and sym:
arr[i][j] = arr[j][i]
else:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,12 @@ def test_sym_matrix_variable(self):
for i in range(n):
for j in range(n):
self.assertEqual(M[i,j], M[j,i])

def test_non_square_matrix_var(self):
"""A matrix variable may not be square. If so, it will not be symmetric."""
x, y = sp.symbols('x y')
n, m = 3, 4
deg = 2
M = matrix_variable('M', [x, y], deg, (n, m), hom=False)

self.assertEqual(M.shape, (n, m))

0 comments on commit d6321a6

Please sign in to comment.