Skip to content

Commit

Permalink
[MPS] fix float32 error on mps, in linalg.matrix_rank and linalg.pinv (
Browse files Browse the repository at this point in the history
…pytorch#114771)

Fixes pytorch#114285

(However, still have NotImplementedError
```NotImplementedError: The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch#77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```)

Pull Request resolved: pytorch#114771
Approved by: https://github.com/lezcano
  • Loading branch information
watarungurunnn authored and pytorchmergebot committed Feb 5, 2024
1 parent a72190f commit d444a3b
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ env
.circleci/scripts/COMMIT_MSG
scripts/release_notes/*.json
sccache-stats*.json
lint.json

# These files get copied over on invoking setup.py
torchgen/packaged/*
Expand Down
23 changes: 19 additions & 4 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,12 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
const optional<Tensor>& atol_opt,
const optional<Tensor>& rtol_opt,
const c10::string_view function_name) {
auto options = input.options().dtype(ScalarType::Double);
auto options = input.options();
if (input.device().type() == kMetal || input.device().type() == kMPS) {
options = options.dtype(ScalarType::Float);
} else {
options = options.dtype(ScalarType::Double);
}
auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
checkNotComplexTolerance(atol, function_name, "atol");
Tensor rtol;
Expand All @@ -465,7 +470,7 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
const Tensor& input,
optional<double> atol_opt,
optional<double> rtol_opt) {
double atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
c10::SymFloat rtol;
if (rtol_opt.has_value()) {
rtol = rtol_opt.value();
Expand All @@ -476,7 +481,12 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
? 0.0
: default_rtol;
}
auto options = input.options().dtype(ScalarType::Double);
auto options = input.options();
if (input.device().type() == kMetal || input.device().type() == kMPS) {
options = options.dtype(ScalarType::Float);
} else {
options = options.dtype(ScalarType::Double);
}
auto atol_tensor = at::full({}, atol, options);
auto rtol_tensor = at::full({}, rtol, options);
return std::make_tuple(atol_tensor, rtol_tensor);
Expand Down Expand Up @@ -545,7 +555,12 @@ Tensor linalg_pinv(const Tensor& input, optional<double> atol, optional<double>
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
// For NumPy compatibility the rcond argument is used as relative tolerance
checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
auto options = input.options().dtype(ScalarType::Double);
auto options = input.options();
if (input.device().type() == kMetal || input.device().type() == kMPS) {
options = options.dtype(ScalarType::Float);
} else {
options = options.dtype(ScalarType::Double);
}
return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
}

Expand Down
145 changes: 141 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def mps_ops_grad_modifier(ops):
'msort': [torch.float16],
}

ON_MPS_XFAILLIST = {
# Failures due to lack of implementation of downstream functions on MPS backend
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
'linalg.matrix_rank': None,
}

def addDecorator(op, d) -> None:
op.decorators = list(op.decorators) if op.decorators is not None else []
op.decorators.append(d)
Expand All @@ -205,6 +211,11 @@ def addDecorator(op, d) -> None:
unittest.skip,
dtypes=SKIPLIST_GRAD[key]))

if key in ON_MPS_XFAILLIST:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
dtypes=ON_MPS_XFAILLIST[key]))

if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
Expand Down Expand Up @@ -722,7 +733,6 @@ def mps_ops_modifier(ops):
'nn.functional.norm': None,
'ormqr': None,
'pca_lowrank': None,
'pinverse': None,
'qr': None,
'quantile': None,
'rsub': None,
Expand Down Expand Up @@ -792,9 +802,7 @@ def mps_ops_modifier(ops):
'softmaxwith_dtype': None,
'float_power': None,
'full_like': None,
'linalg.matrix_rank': None,
'linalg.matrix_rankhermitian': None,
'linalg.pinv': None,
'linalg.pinvhermitian': None,
'nonzero_static': None,

Expand Down Expand Up @@ -918,6 +926,12 @@ def mps_ops_modifier(ops):
'logit': [torch.float16],
}

ON_MPS_XFAILLIST = {
# Failures due to lack of implementation of downstream functions on MPS backend
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
'linalg.matrix_rank': None,
}

EMPTY_OPS_SKIPLIST = {
# Fill tensors with uninitialized data, causing mismatch with CPU.
# They occasionally match, thus skipping them.
Expand Down Expand Up @@ -954,7 +968,7 @@ def addDecorator(op, d) -> None:
dtypes=EMPTY_OPS_SKIPLIST[key]))
if key in SKIPLIST:
addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]:
for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
if key in xfaillist:
addDecorator(op, DecorateInfo(
unittest.expectedFailure,
Expand Down Expand Up @@ -8729,6 +8743,129 @@ def test_addr(self, device="mps", dtype=torch.float32):
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)

def test_matrix_rank(self, device="mps", dtype=torch.float32):
matrix_rank = torch.linalg.matrix_rank

def run_test(shape0, shape1, batch):
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
rank_a = matrix_rank(a)

self.assertEqual(rank_a, matrix_rank(a.mH))
aaH = torch.matmul(a, a.mH)
rank_aaH = matrix_rank(aaH)
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
self.assertEqual(rank_aaH, rank_aaH_hermitian)
aHa = torch.matmul(a.mH, a)
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))

# check against NumPy
self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))

self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))

# hermitian flag for NumPy was added in 1.14.0
if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
self.assertEqual(rank_aaH_hermitian,
np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
self.assertEqual(matrix_rank(aaH, 0.01, True),
np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))

# check out= variant
out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
ans = matrix_rank(a, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, rank_a)

shapes = (3, 13)
batches = ((), (0, ), (4, ), (3, 5, ))
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test(shape0, shape1, batch)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
raise e

def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
from torch.testing._internal.common_utils import random_hermitian_pd_matrix

def run_test_main(A, hermitian):
# Testing against definition for pseudo-inverses
A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
np_A = A.cpu().numpy()
np_A_pinv = A_pinv.cpu().numpy()
if A.numel() > 0:
self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
else:
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))

# Check out= variant
out = torch.empty_like(A_pinv)
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, A_pinv)

def run_test_numpy(A, hermitian):
# Check against NumPy output
# Test float rcond, and specific value for each matrix
rconds = [float(torch.rand(1)), ]
# Test different types of rcond tensor
for rcond_type in MPS_DTYPES:
rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
# Test broadcasting of rcond
if A.ndim > 2:
rconds.append(torch.rand(A.shape[-3], device=device))
for rcond in rconds:
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
self.assertEqual(actual, expected, atol=precision, rtol=precision)

for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
(2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
A = torch.randn(*sizes, dtype=dtype, device=device)
hermitian = False
run_test_main(A, hermitian)
run_test_numpy(A, hermitian)

# Check hermitian = True
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
(0, 0), (3, 0, 0), ]: # zero numel square matrices
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
hermitian = True
# escape only when NotImplementedError of downstream function is raised
# TODO: remove this once the required function is implemented
try:
run_test_main(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
try:
run_test_numpy(A, hermitian)
except NotImplementedError as e:
with self.assertRaisesRegex(
NotImplementedError,
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e





class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):
# Slicing with step
Expand Down

0 comments on commit d444a3b

Please sign in to comment.