diff --git a/dpctl/sycl_core.pyx b/dpctl/sycl_core.pyx index 347a25d29d..7a10b9d6d7 100644 --- a/dpctl/sycl_core.pyx +++ b/dpctl/sycl_core.pyx @@ -111,6 +111,16 @@ cdef class SyclDevice: ''' return self._device_name.decode() + def get_device_type (self): + ''' Returns the type of the device as a `device_type` enum + ''' + if DPPLDevice_IsGPU(self._device_ref): + return device_type.gpu + elif DPPLDevice_IsCPU(self._device_ref): + return device_type.cpu + else: + raise ValueError("Unknown device type.") + def get_vendor_name (self): ''' Returns the device vendor name as a string ''' @@ -515,6 +525,14 @@ cdef class _SyclQueueManager: else: return False + def get_current_device_type (self): + ''' Returns current device type as `device_type` enum + ''' + if self.is_in_device_context(): + return self.get_current_queue().get_sycl_device().get_device_type() + else: + return None + # This private instance of the _SyclQueueManager should not be directly # accessed outside the module. @@ -523,6 +541,7 @@ _qmgr = _SyclQueueManager() # Global bound functions dump = _qmgr.dump get_current_queue = _qmgr.get_current_queue +get_current_device_type = _qmgr.get_current_device_type get_num_platforms = _qmgr.get_num_platforms get_num_activated_queues = _qmgr.get_num_activated_queues has_cpu_queues = _qmgr.has_cpu_queues diff --git a/dpctl/tests/test_sycl_queue_manager.py b/dpctl/tests/test_sycl_queue_manager.py index f8e39042c5..1eddef7fcb 100644 --- a/dpctl/tests/test_sycl_queue_manager.py +++ b/dpctl/tests/test_sycl_queue_manager.py @@ -25,6 +25,7 @@ import dpctl import unittest + class TestGetNumPlatforms (unittest.TestCase): @unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available") @@ -32,6 +33,7 @@ def test_dpctl_get_num_platforms (self): if(dpctl.has_sycl_platforms): self.assertGreaterEqual(dpctl.get_num_platforms(), 1) + @unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available") class TestDumpMethods (unittest.TestCase): def test_dpctl_dump (self): @@ -47,6 +49,7 @@ def test_dpctl_dump_device_info (self): except Exception: self.fail("Encountered an exception inside dump_device_info().") + @unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available") class TestIsInDeviceContext (unittest.TestCase): @@ -65,6 +68,35 @@ def test_is_in_device_context_inside_nested_device_ctxt (self): self.assertTrue(dpctl.is_in_device_context()) self.assertFalse(dpctl.is_in_device_context()) + +@unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available") +class TestIsInDeviceContext (unittest.TestCase): + + def test_get_current_device_type_outside_device_ctxt (self): + self.assertEqual(dpctl.get_current_device_type(), None) + + def test_get_current_device_type_inside_device_ctxt (self): + self.assertEqual(dpctl.get_current_device_type(), None) + + with dpctl.device_context(dpctl.device_type.gpu): + self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.gpu) + + self.assertEqual(dpctl.get_current_device_type(), None) + + @unittest.skipIf(not dpctl.has_cpu_queues(), "No CPU platforms available") + def test_get_current_device_type_inside_nested_device_ctxt (self): + self.assertEqual(dpctl.get_current_device_type(), None) + + with dpctl.device_context(dpctl.device_type.cpu): + self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.cpu) + + with dpctl.device_context(dpctl.device_type.gpu): + self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.gpu) + self.assertEqual(dpctl.get_current_device_type(), dpctl.device_type.cpu) + + self.assertEqual(dpctl.get_current_device_type(), None) + + @unittest.skipIf(not dpctl.has_sycl_platforms(), "No SYCL platforms available") class TestGetCurrentQueueInMultipleThreads (unittest.TestCase): @@ -96,5 +128,6 @@ def SessionThread (self): Session1.start() Session2.start() + if __name__ == '__main__': unittest.main()