Skip to content

Commit

Permalink
Merge pull request #1979 from IntelPython/fix-device-keyword-in-array…
Browse files Browse the repository at this point in the history
…-api-inspection

Fix array API inspection behavior with `device` keyword
  • Loading branch information
ndgrigorian authored Jan 24, 2025
2 parents c5cbb08 + 06f266c commit 7bc3d80
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 32 deletions.
49 changes: 29 additions & 20 deletions dpctl/tensor/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,21 @@ def _isdtype_impl(dtype, kind):
elif isinstance(kind, tuple):
return any(_isdtype_impl(dtype, k) for k in kind)
else:
raise TypeError(f"Unsupported data type kind: {kind}")
raise TypeError(f"Unsupported type for dtype kind: {type(kind)}")


def _get_device_impl(d):
if d is None:
return dpctl.select_default_device()
elif isinstance(d, dpctl.SyclDevice):
return d
elif isinstance(d, (dpt.Device, dpctl.SyclQueue)):
return d.sycl_device
else:
try:
return dpctl.SyclDevice(d)
except TypeError:
raise TypeError(f"Unsupported type for device argument: {type(d)}")


__array_api_version__ = "2023.12"
Expand Down Expand Up @@ -117,13 +131,13 @@ def default_dtypes(self, *, device=None):
Returns a dictionary of default data types for ``device``.
Args:
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`]):
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
array API concept of device used in getting default data types.
``device`` can be ``None`` (in which case the default device
is used), an instance of :class:`dpctl.SyclDevice` corresponding
to a non-partitioned SYCL device, an instance of
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
is used), an instance of :class:`dpctl.SyclDevice`, an instance
of :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
a filter selector string.
Default: ``None``.
Returns:
Expand All @@ -135,10 +149,7 @@ def default_dtypes(self, *, device=None):
- ``"integral"``: dtype
- ``"indexing"``: dtype
"""
if device is None:
device = dpctl.select_default_device()
elif isinstance(device, dpt.Device):
device = device.sycl_device
device = _get_device_impl(device)
return {
"real floating": dpt.dtype(default_device_fp_type(device)),
"complex floating": dpt.dtype(default_device_complex_type(device)),
Expand All @@ -161,10 +172,10 @@ def dtypes(self, *, device=None, kind=None):
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
array API concept of device used in getting default data types.
``device`` can be ``None`` (in which case the default device is
used), an instance of :class:`dpctl.SyclDevice` corresponding
to a non-partitioned SYCL device, an instance of
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
used), an instance of :class:`dpctl.SyclDevice`, an instance of
:class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
a filter selector string.
Default: ``None``.
kind (Optional[str, Tuple[str, ...]]):
Expand Down Expand Up @@ -196,22 +207,20 @@ def dtypes(self, *, device=None, kind=None):
a dictionary of the supported data types of the specified
``kind``
"""
if device is None:
device = dpctl.select_default_device()
elif isinstance(device, dpt.Device):
device = device.sycl_device
device = _get_device_impl(device)
_fp64 = device.has_aspect_fp64
if kind is None:
return {
key: val
for key, val in self._all_dtypes.items()
if (key != "float64" or _fp64)
if _fp64 or (key != "float64" and key != "complex128")
}
else:
return {
key: val
for key, val in self._all_dtypes.items()
if (key != "float64" or _fp64) and _isdtype_impl(val, kind)
if (_fp64 or (key != "float64" and key != "complex128"))
and _isdtype_impl(val, kind)
}

def devices(self):
Expand Down
69 changes: 57 additions & 12 deletions dpctl/tests/test_tensor_array_api_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"bool": dpt.bool,
"float32": dpt.float32,
"complex64": dpt.complex64,
"complex128": dpt.complex128,
"int8": dpt.int8,
"int16": dpt.int16,
"int32": dpt.int32,
Expand All @@ -41,12 +40,6 @@
}


class MockDevice:
def __init__(self, fp16: bool, fp64: bool):
self.has_aspect_fp16 = fp16
self.has_aspect_fp64 = fp64


def test_array_api_inspection_methods():
info = dpt.__array_namespace_info__()
assert info.capabilities()
Expand Down Expand Up @@ -125,17 +118,21 @@ def test_array_api_inspection_default_device_dtypes():
dtypes = _dtypes_no_fp16_fp64.copy()
if dev.has_aspect_fp64:
dtypes["float64"] = dpt.float64
dtypes["complex128"] = dpt.complex128

assert dtypes == dpt.__array_namespace_info__().dtypes()


@pytest.mark.parametrize("fp16", [True, False])
@pytest.mark.parametrize("fp64", [True, False])
def test_array_api_inspection_device_dtypes(fp16, fp64):
dev = MockDevice(fp16, fp64)
def test_array_api_inspection_device_dtypes():
info = dpt.__array_namespace_info__()
try:
dev = info.default_device()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")
dtypes = _dtypes_no_fp16_fp64.copy()
if fp64:
if dev.has_aspect_fp64:
dtypes["float64"] = dpt.float64
dtypes["complex128"] = dpt.complex128

assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)

Expand Down Expand Up @@ -179,3 +176,51 @@ def test_array_api_inspection_dtype_kind():
)
== info.dtypes()
)
assert info.dtypes(
kind=("integral", "real floating", "complex floating")
) == info.dtypes(kind="numeric")


def test_array_api_inspection_dtype_kind_errors():
info = dpt.__array_namespace_info__()
try:
info.default_device()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

with pytest.raises(ValueError):
info.dtypes(kind="error")

with pytest.raises(TypeError):
info.dtypes(kind={0: "real floating"})


def test_array_api_inspection_device_types():
info = dpt.__array_namespace_info__()
try:
dev = info.default_device()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

q = dpctl.SyclQueue(dev)
assert info.default_dtypes(device=q)
assert info.dtypes(device=q)

dev_dpt = dpt.Device.create_device(dev)
assert info.default_dtypes(device=dev_dpt)
assert info.dtypes(device=dev_dpt)

filter = dev.get_filter_string()
assert info.default_dtypes(device=filter)
assert info.dtypes(device=filter)


def test_array_api_inspection_device_errors():
info = dpt.__array_namespace_info__()

bad_dev = dict()
with pytest.raises(TypeError):
info.dtypes(device=bad_dev)

with pytest.raises(TypeError):
info.default_dtypes(device=bad_dev)

0 comments on commit 7bc3d80

Please sign in to comment.