diff --git a/test/lazy/test_lazy_tensor_deprecation.py b/test/lazy/test_lazy_tensor_deprecation.py index 0b2c8fc6a..c2c69362a 100644 --- a/test/lazy/test_lazy_tensor_deprecation.py +++ b/test/lazy/test_lazy_tensor_deprecation.py @@ -70,8 +70,8 @@ def test_deprecated_methods(self): with warnings.catch_warnings(record=True) as ws: self.assertAllClose(nlt.diag(), dlo.diag()) - self.assertTrue(len(ws) == 1) - self.assertTrue(issubclass(ws[0].category, DeprecationWarning)) + self.assertTrue(len(ws) == 2) + self.assertTrue(all(issubclass(w.category, DeprecationWarning) for w in ws)) with warnings.catch_warnings(record=True) as ws: self.assertAllClose(nlt.evaluate(), dlo.to_dense()) @@ -85,8 +85,10 @@ def test_deprecated_methods(self): self.assertTrue(issubclass(ws[0].category, DeprecationWarning)) with warnings.catch_warnings(record=True) as ws: - self.assertAllClose(nlt.symeig(eigenvectors=True)[0], torch.linalg.eigh(dlo)[0]) - self.assertAllClose(nlt.symeig(eigenvectors=True)[1].to_dense(), torch.linalg.eigh(dlo)[1].to_dense()) + evals, evecs = nlt.symeig(eigenvectors=True) + evals_tfn, evecs_tfn = torch.linalg.eigh(dlo) + self.assertAllClose(evals, evals_tfn) + self.assertAllClose(evecs.to_dense(), evecs_tfn.to_dense()) self.assertTrue(len(ws) == 1) self.assertTrue(issubclass(ws[0].category, DeprecationWarning))