Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_devices method to dpctl.SyclPlatform class #1992

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
cdef DPCTLPlatformVectorRef DPCTLPlatform_GetPlatforms()
cdef DPCTLSyclContextRef DPCTLPlatform_GetDefaultContext(
const DPCTLSyclPlatformRef)
cdef DPCTLDeviceVectorRef DPCTLPlatform_GetDevices(
const DPCTLSyclPlatformRef PRef, _device_type DTy)


cdef extern from "syclinterface/dpctl_sycl_context_interface.h":
Expand Down
4 changes: 2 additions & 2 deletions dpctl/_sycl_device_factory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ cpdef list get_devices(backend=backend_type.all, device_type=device_type_t.all):
device_type (optional):
A :class:`dpctl.device_type` enum value or a string that
specifies a SYCL device type. Currently, accepted values are:
"gpu", "cpu", "accelerator", "host", or "all".
"gpu", "cpu", "accelerator", or "all".
Default: ``dpctl.device_type.all``.
Returns:
list:
Expand Down Expand Up @@ -218,7 +218,7 @@ cpdef int get_num_devices(
device_type (optional):
A :class:`dpctl.device_type` enum value or a string that
specifies a SYCL device type. Currently, accepted values are:
"gpu", "cpu", "accelerator", "host", or "all".
"gpu", "cpu", "accelerator", or "all".
Default: ``dpctl.device_type.all``.
Returns:
int:
Expand Down
81 changes: 81 additions & 0 deletions dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ from libcpp cimport bool
from ._backend cimport ( # noqa: E211
DPCTLCString_Delete,
DPCTLDeviceSelector_Delete,
DPCTLDeviceVector_Delete,
DPCTLDeviceVector_GetAt,
DPCTLDeviceVector_Size,
DPCTLDeviceVectorRef,
DPCTLFilterSelector_Create,
DPCTLPlatform_AreEq,
DPCTLPlatform_Copy,
Expand All @@ -34,6 +38,7 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatform_Delete,
DPCTLPlatform_GetBackend,
DPCTLPlatform_GetDefaultContext,
DPCTLPlatform_GetDevices,
DPCTLPlatform_GetName,
DPCTLPlatform_GetPlatforms,
DPCTLPlatform_GetVendor,
Expand All @@ -46,17 +51,21 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatformVector_Size,
DPCTLPlatformVectorRef,
DPCTLSyclContextRef,
DPCTLSyclDeviceRef,
DPCTLSyclDeviceSelectorRef,
DPCTLSyclPlatformRef,
_backend_type,
_device_type,
)

import warnings

from ._sycl_context import SyclContextCreationError
from .enum_types import backend_type
from .enum_types import device_type as device_type_t

from ._sycl_context cimport SyclContext
from ._sycl_device cimport SyclDevice

__all__ = [
"get_platforms",
Expand Down Expand Up @@ -366,6 +375,78 @@ cdef class SyclPlatform(_SyclPlatform):
"""
return DPCTLPlatform_Hash(self._platform_ref)

def get_devices(self, device_type=device_type_t.all):
"""
Returns the list of :class:`dpctl.SyclDevice` objects associated with
:class:`dpctl.SyclPlatform` instance selected based on
the given :class:`dpctl.device_type`.

Args:
device_type (optional):
A :class:`dpctl.device_type` enum value or a string that
specifies a SYCL device type. Currently, accepted values are:
"gpu", "cpu", "accelerator", or "all".
Default: ``dpctl.device_type.all``.

Returns:
list:
A :obj:`list` of :class:`dpctl.SyclDevice` objects
that belong to this platform.

Raises:
TypeError:
If `device_type` is not a str or :class:`dpctl.device_type`
enum.
ValueError:
If the ``DPCTLPlatform_GetDevices`` call returned
``NULL`` instead of a ``DPCTLDeviceVectorRef`` object.
"""
cdef _device_type DTy = _device_type._ALL_DEVICES
cdef DPCTLDeviceVectorRef DVRef = NULL
cdef size_t num_devs
cdef size_t i
cdef DPCTLSyclDeviceRef DRef

if isinstance(device_type, str):
dty_str = device_type.strip().lower()
if dty_str == "accelerator":
DTy = _device_type._ACCELERATOR
elif dty_str == "all":
DTy = _device_type._ALL_DEVICES
elif dty_str == "cpu":
DTy = _device_type._CPU
elif dty_str == "gpu":
DTy = _device_type._GPU
else:
DTy = _device_type._UNKNOWN_DEVICE
elif isinstance(device_type, device_type_t):
if device_type == device_type_t.all:
DTy = _device_type._ALL_DEVICES
elif device_type == device_type_t.accelerator:
DTy = _device_type._ACCELERATOR
elif device_type == device_type_t.cpu:
DTy = _device_type._CPU
elif device_type == device_type_t.gpu:
DTy = _device_type._GPU
else:
DTy = _device_type._UNKNOWN_DEVICE
else:
raise TypeError(
"device type should be specified as a str or an "
"``enum_types.device_type``."
)
DVRef = DPCTLPlatform_GetDevices(self.get_platform_ref(), DTy)
if (DVRef is NULL):
raise ValueError("Internal error: NULL device vector encountered")
num_devs = DPCTLDeviceVector_Size(DVRef)
devices = []
for i in range(num_devs):
DRef = DPCTLDeviceVector_GetAt(DVRef, i)
devices.append(SyclDevice._create(DRef))
DPCTLDeviceVector_Delete(DVRef)

return devices


def lsplatform(verbosity=0):
"""
Expand Down
4 changes: 0 additions & 4 deletions dpctl/tests/test_sycl_device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def string_to_device_type(dty_str):
return dty.accelerator
elif dty_str == "cpu":
return dty.cpu
elif dty_str == "host":
return dty.host
elif dty_str == "gpu":
return dty.gpu

Expand All @@ -62,8 +60,6 @@ def string_to_backend_type(bty_str):
return bty.cuda
elif bty_str == "hip":
return bty.hip
elif bty_str == "host":
return bty.host
elif bty_str == "level_zero":
return bty.level_zero
elif bty_str == "opencl":
Expand Down
47 changes: 47 additions & 0 deletions dpctl/tests/test_sycl_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import dpctl
from dpctl import device_type

from .helper import has_sycl_platforms

Expand Down Expand Up @@ -212,3 +213,49 @@ def test_get_platforms():
assert has_sycl_platforms()
except Exception:
pytest.fail("Encountered an exception inside get_platforms().")


def test_platform_get_devices():
platforms = dpctl.get_platforms()
if platforms:
for p in platforms:
assert len(p.get_devices())
else:
pytest.skip("No platforms available")


def _str_device_type_to_enum(dty):
if dty == "accelerator":
return device_type.accelerator
elif dty == "cpu":
return device_type.cpu
elif dty == "gpu":
return device_type.gpu


def test_platform_get_devices_str_device_type():
platforms = dpctl.get_platforms()
dtys = ["accelerator", "all", "cpu", "gpu"]
if platforms:
for p in platforms:
for dty in dtys:
devices = p.get_devices(device_type=dty)
if len(devices):
dty_enum = _str_device_type_to_enum(dty)
assert (d.device_type == dty_enum for d in devices)


def test_platform_get_devices_enum_device_type():
platforms = dpctl.get_platforms()
dtys = [
device_type.accelerator,
device_type.all,
device_type.cpu,
device_type.gpu,
]
if platforms:
for p in platforms:
for dty in dtys:
devices = p.get_devices(device_type=dty)
if len(devices):
assert (d.device_type == dty for d in devices)
3 changes: 0 additions & 3 deletions libsyclinterface/helper/source/dpctl_utils_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ info::device_type DPCTL_StrToDeviceType(const std::string &devTyStr)
else if (devTyStr == "custom") {
devTy = info::device_type::custom;
}
else if (devTyStr == "host") {
devTy = info::device_type::host;
}
else {
// \todo handle the error
throw std::runtime_error("Unknown device type.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ DPCTL_API
size_t DPCTLDeviceMgr_GetNumDevices(int device_identifier);

/*!
* @brief Prints out the info::deivice attributes for the device that are
* @brief Prints out the info::device attributes for the device that are
* currently supported by dpctl.
*
* @param DRef A #DPCTLSyclDeviceRef opaque pointer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "Support/ExternC.h"
#include "Support/MemOwnershipAttrs.h"
#include "dpctl_data_types.h"
#include "dpctl_sycl_device_manager.h"
#include "dpctl_sycl_enum_types.h"
#include "dpctl_sycl_platform_manager.h"
#include "dpctl_sycl_types.h"
Expand Down Expand Up @@ -176,6 +177,20 @@ DPCTLPlatform_GetDefaultContext(__dpctl_keep const DPCTLSyclPlatformRef PRef);
* @ingroup PlatformInterface
*/
DPCTL_API
size_t DPCTLPlatform_Hash(__dpctl_keep DPCTLSyclPlatformRef PRef);
size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef);

/*!
* @brief Returns a vector of devices associated with sycl::platform referenced
* by DPCTLSyclPlatformRef object.
*
* @param PRef The DPCTLSyclPlatformRef pointer.
* @param DTy A DPCTLSyclDeviceType enum value.
* @return A DPCTLDeviceVectorRef with devices associated with given PRef.
* @ingroup PlatformInterface
*/
DPCTL_API
__dpctl_give DPCTLDeviceVectorRef
DPCTLPlatform_GetDevices(__dpctl_keep const DPCTLSyclPlatformRef PRef,
DPCTLSyclDeviceType DTy);

DPCTL_C_EXTERN_C_END
44 changes: 44 additions & 0 deletions libsyclinterface/source/dpctl_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "dpctl_device_selection.hpp"
#include "dpctl_error_handlers.h"
#include "dpctl_string_utils.hpp"
#include "dpctl_sycl_enum_types.h"
#include "dpctl_sycl_type_casters.hpp"
#include "dpctl_utils_helper.h"
#include <iomanip>
Expand Down Expand Up @@ -269,3 +270,46 @@ size_t DPCTLPlatform_Hash(__dpctl_keep const DPCTLSyclPlatformRef PRef)
return 0;
}
}

__dpctl_give DPCTLDeviceVectorRef
DPCTLPlatform_GetDevices(__dpctl_keep const DPCTLSyclPlatformRef PRef,
DPCTLSyclDeviceType DTy)
{
auto P = unwrap<platform>(PRef);
if (!P) {
error_handler("Cannot retrieve devices from DPCTLSyclPlatformRef as "
"input is a nullptr.",
__FILE__, __func__, __LINE__);
return nullptr;
}

using vecTy = std::vector<DPCTLSyclDeviceRef>;
vecTy *DevicesVectorPtr = nullptr;
try {
DevicesVectorPtr = new vecTy();
} catch (std::exception const &e) {
delete DevicesVectorPtr;
error_handler(e, __FILE__, __func__, __LINE__);
return nullptr;
}

// handle unknown device
if (DTy == DPCTLSyclDeviceType::DPCTL_UNKNOWN_DEVICE) {
return wrap<vecTy>(DevicesVectorPtr);
}

try {
auto SyclDTy = DPCTL_DPCTLDeviceTypeToSyclDeviceType(DTy);
auto Devices = P->get_devices(SyclDTy);
DevicesVectorPtr->reserve(Devices.size());
for (const auto &Dev : Devices) {
DevicesVectorPtr->emplace_back(
wrap<device>(new device(std::move(Dev))));
}
return wrap<vecTy>(DevicesVectorPtr);
} catch (std::exception const &e) {
delete DevicesVectorPtr;
error_handler(e, __FILE__, __func__, __LINE__);
return nullptr;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ INSTANTIATE_TEST_SUITE_P(FilterSelectorCreation,
"gpu:0",
"gpu:1",
"1",
"0",
"host"));
"0"));

INSTANTIATE_TEST_SUITE_P(NegativeFilterSelectorCreation,
TestUnsupportedFilters,
Expand Down
26 changes: 26 additions & 0 deletions libsyclinterface/tests/test_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
//===----------------------------------------------------------------------===//

#include "dpctl_sycl_context_interface.h"
#include "dpctl_sycl_device_interface.h"
#include "dpctl_sycl_device_selector_interface.h"
#include "dpctl_sycl_platform_interface.h"
#include "dpctl_sycl_platform_manager.h"
Expand Down Expand Up @@ -92,6 +93,26 @@ void check_platform_default_context(
EXPECT_NO_FATAL_FAILURE(DPCTLContext_Delete(CRef));
}

void check_platform_get_devices(__dpctl_keep const DPCTLSyclPlatformRef PRef)
{
DPCTLDeviceVectorRef DVRef = nullptr;
size_t nDevices = 0;

DPCTLSyclDeviceType defDTy = DPCTLSyclDeviceType::DPCTL_ALL;
EXPECT_NO_FATAL_FAILURE(DVRef = DPCTLPlatform_GetDevices(PRef, defDTy));
EXPECT_TRUE(DVRef != nullptr);
EXPECT_NO_FATAL_FAILURE(nDevices = DPCTLDeviceVector_Size(DVRef));
for (auto i = 0ul; i < nDevices; ++i) {
DPCTLSyclDeviceRef DRef = nullptr;
EXPECT_NO_FATAL_FAILURE(DRef = DPCTLDeviceVector_GetAt(DVRef, i));
ASSERT_TRUE(DRef != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLDevice_Delete(DRef));
}

EXPECT_NO_FATAL_FAILURE(DPCTLDeviceVector_Clear(DVRef));
EXPECT_NO_FATAL_FAILURE(DPCTLDeviceVector_Delete(DVRef));
}

} // namespace

struct TestDPCTLSyclPlatformInterface
Expand Down Expand Up @@ -282,6 +303,11 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkAreEqNullArg)
ASSERT_TRUE(DPCTLPlatform_Hash(Null_PRef) == 0);
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkGetDevices)
{
check_platform_get_devices(PRef);
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetName) { check_platform_name(PRef); }

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetVendor)
Expand Down
Loading