From 292b663d4b3bd5c8a2e82f1de8f7e6a1bd99a5ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 25 Mar 2024 02:46:40 +0000 Subject: [PATCH] pre-commit: running and fixing... --- thunder/clang/__init__.py | 2 +- thunder/tests/opinfos.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 0c3cbdc9dc..2b3e76b418 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1169,6 +1169,7 @@ def mT_scalar_warning(): UserWarning, ) + @clangop(method_name="mT") def matrix_transpose(a: TensorProxy) -> TensorProxy: """Transposes the last two dimensions of a tensor. @@ -1191,7 +1192,6 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy: [3, 6]]) """ - if a.ndim == 0: mT_scalar_warning() return a diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index bab8e58b1a..0fb119d3aa 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -3956,13 +3956,12 @@ def matrix_transpose_sample_generator(op, device, dtype, requires_grad, **kwargs for shape in cases: yield SampleInput(make(shape)) + def matrix_transpose_error_generator(op, device, dtype=torch.float32, **kwargs): make = partial(make_tensor, device=device, dtype=dtype) # shape, error type, error message - cases = ( - ((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."), - ) + cases = (((3), RuntimeError, "tensor.mT is only supported on matrices or batches of matrices. Got 1-D tensor."),) for shape, err_type, err_msg in cases: yield SampleInput(make(shape)), err_type, err_msg