From d6321a6b5216bfaf1b2c5281d2bed935b47038a1 Mon Sep 17 00:00:00 2001 From: Jamie Gardner Date: Wed, 21 Aug 2024 14:54:27 +0100 Subject: [PATCH] Refactor to allow non-square matrix variables + test --- src/SumOfSquares/basis.py | 24 ++++++++++++++++++------ tests/test_basis.py | 9 +++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/SumOfSquares/basis.py b/src/SumOfSquares/basis.py index 23e0881..9e2072b 100644 --- a/src/SumOfSquares/basis.py +++ b/src/SumOfSquares/basis.py @@ -4,7 +4,7 @@ import numpy as np import math from collections import defaultdict -from typing import Iterable, Tuple, List +from typing import Iterable, Tuple, List, Union from .util import * @@ -127,12 +127,24 @@ def poly_variable(name: str, variables: List[sp.Symbol], deg: int, return sum(coeff * prod(var**power for var, power in zip(variables, monom)) for monom, coeff in zip(basis, coeffs)) -def matrix_variable(name: str, variables: List[sp.Symbol], deg: int, dim: int, + +def matrix_variable(name: str, variables: List[sp.Symbol], deg: int, dim: Union[int, tuple], hom: bool=False, sym: bool=True) -> sp.Matrix: - '''Returns a (symmetric) matrix variable of size dim x dim''' - arr = [[None] * dim for _ in range(dim)] - for i in range(dim): - for j in range(dim): + ''' + Returns a matrix variable. + If the specified dimension is a tuple, then the matrix is of size dim_1 x dim_2, + otherwise it is of size dim x dim. + + Note: If the matrix is not square, it cannot be symmetric. + ''' + dim_1, dim_2 = (dim, dim) if isinstance(dim, int) else dim + + sym = sym and dim_1 == dim_2 + + arr = np.empty((dim_1, dim_2), dtype=object) + + for i in range(dim_1): + for j in range(dim_2): if j < i and sym: arr[i][j] = arr[j][i] else: diff --git a/tests/test_basis.py b/tests/test_basis.py index 5d90847..15b3b95 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -40,3 +40,12 @@ def test_sym_matrix_variable(self): for i in range(n): for j in range(n): self.assertEqual(M[i,j], M[j,i]) + + def test_non_square_matrix_var(self): + """A matrix variable may not be square. If so, it will not be symmetric.""" + x, y = sp.symbols('x y') + n, m = 3, 4 + deg = 2 + M = matrix_variable('M', [x, y], deg, (n, m), hom=False) + + self.assertEqual(M.shape, (n, m))