Skip to content

Commit

Permalink
Adapt to Solve changes in Scipy 1.15
Browse files Browse the repository at this point in the history
1. Use actual Solve Op to infer output dtype as CholSolve outputs a different dtype than basic Solve in Scipy==1.15

2. Tweaked test related to #1152

3. Tweak tolerage
  • Loading branch information
ricardoV94 committed Jan 13, 2025
1 parent cff058c commit 581f65a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
7 changes: 4 additions & 3 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,10 @@ def make_node(self, A, b):
raise ValueError(f"`b` must have {self.b_ndim} dims; got {b.type} instead.")

# Infer dtype by solving the most simple case with 1x1 matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)
).dtype
inp_arr = [np.eye(1).astype(A.dtype), np.eye(1).astype(b.dtype)]
out_arr = [[None]]
self.perform(None, inp_arr, out_arr)
o_dtype = out_arr[0][0].dtype
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [A, b], [x])

Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def core_scipy_fn(A, b):
A_val_copy, b_val_copy
)
np.testing.assert_allclose(
out, expected_out, atol=1e-5 if config.floatX == "float32" else 0
out, expected_out, atol=1e-4 if config.floatX == "float32" else 0
)

# Confirm input was destroyed
Expand Down
22 changes: 14 additions & 8 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ def test_eigvalsh_grad():
)


class TestSolveBase(utt.InferShapeTester):
class TestSolveBase:
class SolveTest(SolveBase):
def perform(self, node, inputs, outputs):
A, b = inputs
outputs[0][0] = scipy.linalg.solve(A, b)

@pytest.mark.parametrize(
"A_func, b_func, error_message",
[
Expand All @@ -191,16 +196,16 @@ def test_make_node(self, A_func, b_func, error_message):
with pytest.raises(ValueError, match=error_message):
A = A_func()
b = b_func()
SolveBase(b_ndim=2)(A, b)
self.SolveTest(b_ndim=2)(A, b)

def test__repr__(self):
np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = SolveBase(b_ndim=2)(A, b)
y = self.SolveTest(b_ndim=2)(A, b)
assert (
y.__repr__()
== "SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
== "SolveTest{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)


Expand Down Expand Up @@ -239,8 +244,9 @@ def test_correctness(self):
A_val = np.asarray(rng.random((5, 5)), dtype=config.floatX)
A_val = np.dot(A_val.transpose(), A_val)

assert np.allclose(
scipy.linalg.solve(A_val, b_val), gen_solve_func(A_val, b_val)
np.testing.assert_allclose(
scipy.linalg.solve(A_val, b_val, assume_a="gen"),
gen_solve_func(A_val, b_val),
)

A_undef = np.array(
Expand All @@ -253,7 +259,7 @@ def test_correctness(self):
],
dtype=config.floatX,
)
assert np.allclose(
np.testing.assert_allclose(
scipy.linalg.solve(A_undef, b_val), gen_solve_func(A_undef, b_val)
)

Expand Down Expand Up @@ -450,7 +456,7 @@ def test_solve_dtype(self):
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))

assert x.dtype == x_result.dtype
assert x.dtype == x_result.dtype, (A_dtype, b_dtype)


def test_cho_solve():
Expand Down

0 comments on commit 581f65a

Please sign in to comment.