Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: More uint types for torch #244

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
torch.int32,
torch.int64,
}
try:
# torch >=2.3
_int_dtypes |= {torch.uint16, torch.uint32, torch.uint64}
except AttributeError:
pass


_array_api_dtypes = {
torch.bool,
Expand Down
12 changes: 0 additions & 12 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,12 @@ array_api_tests/test_array_object.py::test_getitem
array_api_tests/test_array_object.py::test_setitem
# Masking doesn't suport 0 dimensions in the mask
array_api_tests/test_array_object.py::test_getitem_masking
# torch doesn't have uint dtypes other than uint8
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint16)]
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint32)]
array_api_tests/test_array_object.py::test_scalar_casting[__int__(uint64)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint16)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint32)]
array_api_tests/test_array_object.py::test_scalar_casting[__index__(uint64)]

# Overflow error from large inputs
array_api_tests/test_creation_functions.py::test_arange
# pytorch linspace bug (should be fixed in torch 2.0)
array_api_tests/test_creation_functions.py::test_linspace

# torch doesn't have higher uint dtypes
array_api_tests/test_data_type_functions.py::test_iinfo[uint16]
array_api_tests/test_data_type_functions.py::test_iinfo[uint32]
array_api_tests/test_data_type_functions.py::test_iinfo[uint64]

# We cannot wrap the tensor object
array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
Expand Down
Loading