From c5d4b96a4aadbc86c6b63052e900adbc88037d38 Mon Sep 17 00:00:00 2001 From: Dave Liddell Date: Fri, 14 Jun 2024 10:53:31 -0600 Subject: [PATCH] Allow flags to be set with greater flexibility (#17659) Changes to the python binding to allow iree.runtime.flags.parse_flags to take effect at times other than before the first time a driver is created. Also includes fixes for bugs exposed during the development of this feature. - Added "internal" API functions `create_hal_driver()` and `clear_hal_driver_cache()` to create a driver object independent of the cache, and to clear the cache, respectively - Added `HalDriver` class implementation functions for the above new API functions. Refactored class to share as much common code as possible. - Factored out driver URI processing into its own nested class for easier handling of URI components - Fixed dangling pointer bug. In the C layer flags are being kept by reference as string views, requiring the caller to keep the original flag strings (argc, argv) around for as long as the flags are being used. However, the python binding was using a local variable for those strings, letting them go out of scope and causing garbage values later on. The fix is to move the strings to a file scope variable. Flag handling does not appear to be getting used in a multi-threaded environment, as other aspects of flag handling use static variables with no mutex guarding that I could find. - Fixed runtime assert in Windows debug build for the improper use of std::vector<>::front() on an empty vector. The code never used the value of front(), as it was guarded by a check for the vector's size, but the assert prevents the debug build from running. --------- Signed-off-by: Dave Liddell Signed-off-by: daveliddell --- runtime/bindings/python/hal.cc | 48 +++++++++++++------ runtime/bindings/python/hal.h | 20 ++++++++ runtime/bindings/python/initialize_module.cc | 27 ++++++++++- .../python/iree/runtime/system_setup.py | 7 ++- .../python/tests/system_setup_test.py | 24 ++++++++++ 5 files changed, 109 insertions(+), 17 deletions(-) diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index 18771af9e0be..8799e4884fbd 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -905,29 +905,40 @@ std::vector HalDriver::Query() { return driver_names; } -py::object HalDriver::Create(const std::string& device_uri, - py::dict& driver_cache) { - iree_string_view_t driver_name, device_path, params_str; +HalDriver::DeviceUri::DeviceUri(const std::string& device_uri) { iree_string_view_t device_uri_sv{ device_uri.data(), static_cast(device_uri.size())}; iree_uri_split(device_uri_sv, &driver_name, &device_path, ¶ms_str); +} - // Check cache. - py::str cache_key(driver_name.data, driver_name.size); - py::object cached = driver_cache.attr("get")(cache_key); - if (!cached.is_none()) { - return cached; - } - - // Create. +py::object HalDriver::Create(const DeviceUri& device_uri) { iree_hal_driver_t* driver; CheckApiStatus(iree_hal_driver_registry_try_create( - iree_hal_driver_registry_default(), driver_name, + iree_hal_driver_registry_default(), device_uri.driver_name, iree_allocator_system(), &driver), "Error creating driver"); - // Cache. py::object driver_obj = py::cast(HalDriver::StealFromRawPtr(driver)); + return driver_obj; +} + +py::object HalDriver::Create(const std::string& device_uri) { + DeviceUri parsed_uri(device_uri); + return HalDriver::Create(parsed_uri); +} + +py::object HalDriver::Create(const std::string& device_uri, + py::dict& driver_cache) { + // Look up the driver by driver name in the cache, and return it if found. + DeviceUri parsed_uri(device_uri); + py::str cache_key(parsed_uri.driver_name.data, parsed_uri.driver_name.size); + py::object cached = driver_cache.attr("get")(cache_key); + if (!cached.is_none()) { + return cached; + } + + // Create a new driver and put it in the cache. + py::object driver_obj = HalDriver::Create(parsed_uri); driver_cache[cache_key] = driver_obj; return driver_obj; } @@ -1026,7 +1037,8 @@ HalDevice HalDriver::CreateDevice(iree_hal_device_id_t device_id, std::vector params; iree_hal_device_t* device; CheckApiStatus(iree_hal_driver_create_device_by_id( - raw_ptr(), device_id, params.size(), ¶ms.front(), + raw_ptr(), device_id, params.size(), + (params.empty() ? nullptr : ¶ms.front()), iree_allocator_system(), &device), "Error creating default device"); CheckApiStatus(ConfigureDevice(device, allocators), @@ -1289,6 +1301,14 @@ void SetupHalBindings(nanobind::module_ m) { }, py::arg("device_uri")); + m.def( + "create_hal_driver", + [](std::string device_uri) { return HalDriver::Create(device_uri); }, + py::arg("device_uri")); + + m.def("clear_hal_driver_cache", + [driver_cache]() { const_cast(driver_cache).clear(); }); + py::class_(m, "HalAllocator") .def("trim", [](HalAllocator& self) { diff --git a/runtime/bindings/python/hal.h b/runtime/bindings/python/hal.h index 29d02334e959..7dbc108917c3 100644 --- a/runtime/bindings/python/hal.h +++ b/runtime/bindings/python/hal.h @@ -12,6 +12,7 @@ #include "./binding.h" #include "./status_utils.h" #include "./vm.h" +#include "iree/base/string_view.h" #include "iree/hal/api.h" namespace iree { @@ -142,8 +143,27 @@ class HalDevice : public ApiRefCounted { }; class HalDriver : public ApiRefCounted { + // Object that holds the components of a device URI string. + struct DeviceUri { + iree_string_view_t driver_name; + iree_string_view_t device_path; + iree_string_view_t params_str; + + DeviceUri(const std::string& device_uri); + }; + + // Create a stand-alone driver (not residing in a cache) given the name, + // path, and params components of a device URI. + static py::object Create(const DeviceUri& device_uri); + public: static std::vector Query(); + + // Create a stand-alone driver (not residing in a cache) given a device URI. + static py::object Create(const std::string& device_uri); + + // Returns a driver from the given cache, creating it and placing it in + // the cache if not already found there. static py::object Create(const std::string& device_uri, py::dict& driver_cache); diff --git a/runtime/bindings/python/initialize_module.cc b/runtime/bindings/python/initialize_module.cc index 7eb9254cf3b0..c79da46353fc 100644 --- a/runtime/bindings/python/initialize_module.cc +++ b/runtime/bindings/python/initialize_module.cc @@ -4,6 +4,8 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include + #include "./binding.h" #include "./hal.h" #include "./invoke.h" @@ -16,6 +18,14 @@ #include "iree/base/internal/flags.h" #include "iree/hal/drivers/init.h" +namespace { +// Stable storage for flag processing. Flag handling uses string views, +// expecting the caller to keep the original strings around for as long +// as the flags are in use. This object holds one set of flag strings +// for each invocation of parse_flags. +std::vector>> alloced_flag_cache; +} // namespace + namespace iree { namespace python { @@ -34,19 +44,32 @@ NB_MODULE(_runtime, m) { SetupPyModuleBindings(m); SetupVmBindings(m); + // Adds the given set of strings to the global flags. These new flags + // take effect upon the next creation of a driver. They do not affect + // drivers already created. m.def("parse_flags", [](py::args py_flags) { - std::vector alloced_flags; + // Make a new set of strings at the back of the cache + alloced_flag_cache.emplace_back( + std::make_unique>(std::vector())); + auto &alloced_flags = *alloced_flag_cache.back(); + + // Add the given python strings to the std::string set. alloced_flags.push_back("python"); for (py::handle py_flag : py_flags) { alloced_flags.push_back(py::cast(py_flag)); } - // Must build pointer vector after filling so pointers are stable. + // As the flags-processing mechanism of the C API requires long-lived + // char * strings, create a set of char * strings from the std::strings, + // with the std::strings responsible for maintaining the storage. + // Must build pointer vector after filling std::strings so pointers are + // stable. std::vector flag_ptrs; for (auto &alloced_flag : alloced_flags) { flag_ptrs.push_back(const_cast(alloced_flag.c_str())); } + // Send the flags to the C API char **argv = &flag_ptrs[0]; int argc = flag_ptrs.size(); CheckApiStatus(iree_flags_parse(IREE_FLAGS_PARSE_MODE_CONTINUE_AFTER_HELP, diff --git a/runtime/bindings/python/iree/runtime/system_setup.py b/runtime/bindings/python/iree/runtime/system_setup.py index 0560003d5f1d..8cd117d56acc 100644 --- a/runtime/bindings/python/iree/runtime/system_setup.py +++ b/runtime/bindings/python/iree/runtime/system_setup.py @@ -26,7 +26,12 @@ def query_available_drivers() -> Collection[str]: def get_driver(device_uri: str) -> HalDriver: - """Returns a HAL driver by device_uri (or driver name).""" + """Returns a HAL driver by device_uri (or driver name). + + Args: + device_uri: The URI of the device, either just a driver name for the + default or a fully qualified "driver://path?params". + """ return get_cached_hal_driver(device_uri) diff --git a/runtime/bindings/python/tests/system_setup_test.py b/runtime/bindings/python/tests/system_setup_test.py index 2d0ddf9ea2cb..c55dc466598f 100644 --- a/runtime/bindings/python/tests/system_setup_test.py +++ b/runtime/bindings/python/tests/system_setup_test.py @@ -8,6 +8,7 @@ import unittest from iree.runtime import system_setup as ss +from iree.runtime._binding import create_hal_driver, clear_hal_driver_cache class DeviceSetupTest(unittest.TestCase): @@ -65,6 +66,29 @@ def testCreateDeviceWithAllocators(self): infos[0]["device_id"], allocators=["caching", "debug"] ) + def testDriverCacheInternals(self): + # Two drivers created with the same URI using the caching get_driver + # should return the same driver + driver1 = ss.get_driver("local-sync") + driver2 = ss.get_driver("local-sync") + self.assertIs(driver1, driver2) + + # A driver created using the non-caching create_hal_driver should be + # unique from cached drivers of the same URI + driver3 = create_hal_driver("local-sync") + self.assertIsNot(driver3, driver1) + + # Drivers created with create_hal_driver should all be unique from + # one another + driver4 = create_hal_driver("local-sync") + self.assertIsNot(driver4, driver3) + + # Clearing the cache should make any new driver unique from previously + # cached ones + clear_hal_driver_cache() + driver5 = ss.get_driver("local-sync") + self.assertIsNot(driver5, driver1) + if __name__ == "__main__": logging.basicConfig(level=logging.INFO)