From f8d6a00acffbacef081e012b321385cf4fd10b79 Mon Sep 17 00:00:00 2001 From: Joshua Oreman Date: Tue, 5 Nov 2024 11:58:26 -0700 Subject: [PATCH] Explain example_policy; add richer example of callback tracking --- docs/api_core.rst | 4 ++ tests/CMakeLists.txt | 2 + tests/test_callbacks.cpp | 126 +++++++++++++++++++++++++++++++++++++++ tests/test_callbacks.py | 58 ++++++++++++++++++ tests/test_functions.cpp | 11 ++++ 5 files changed, 201 insertions(+) create mode 100644 tests/test_callbacks.cpp create mode 100644 tests/test_callbacks.py diff --git a/docs/api_core.rst b/docs/api_core.rst index a21db805..83de639e 100644 --- a/docs/api_core.rst +++ b/docs/api_core.rst @@ -2002,6 +2002,10 @@ parameter of :cpp:func:`module_::def`, :cpp:func:`class_::def`, } }; + For a more complex example (binding an object that uses trivially-copyable + callbacks), see ``tests/test_callbacks.cpp`` in the nanobind source + distribution. + .. _class_binding_annotations: Class binding annotations diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9e8fcb4e..c5cfef42 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -62,6 +62,7 @@ endif() set(TEST_NAMES functions + callbacks classes holders stl @@ -137,6 +138,7 @@ target_link_libraries(test_inter_module_2_ext PRIVATE inter_module) set(TEST_FILES common.py + test_callbacks.py test_classes.py test_eigen.py test_enum.py diff --git a/tests/test_callbacks.cpp b/tests/test_callbacks.cpp new file mode 100644 index 00000000..b2541218 --- /dev/null +++ b/tests/test_callbacks.cpp @@ -0,0 +1,126 @@ +// This is an example of using nb::call_policy to support binding an +// object that takes non-owning callbacks. Since the callbacks can't +// directly keep a Python object alive (they're trivially copyable), we +// maintain a sideband structure to manage the lifetimes. + +#include +#include +#include + +#include +#include + +namespace nb = nanobind; + +// The callback type accepted by the object, which we assume we can't change. +// It's trivially copyable, so it can't directly keep a Python object alive. +struct callback { + void *context; + void (*func)(void *context, int arg); + + void operator()(int arg) const { (*func)(context, arg); } + bool operator==(const callback& other) const { + return context == other.context && func == other.func; + } +}; + +// An object that uses these callbacks, which we want to write bindings for +class publisher { + public: + void subscribe(callback cb) { cbs.push_back(cb); } + void unsubscribe(callback cb) { + cbs.erase(std::remove(cbs.begin(), cbs.end(), cb), cbs.end()); + } + void emit(int arg) const { for (auto cb : cbs) cb(arg); } + private: + std::vector cbs; +}; + +template <> struct nanobind::detail::type_caster { + static void wrap_call(void *context, int arg) { + borrow((PyObject *) context)(arg); + } + bool from_python(handle src, uint8_t, cleanup_list*) noexcept { + if (!isinstance(src)) return false; + value = {(void *) src.ptr(), &wrap_call}; + return true; + } + static handle from_cpp(callback cb, rv_policy policy, cleanup_list*) noexcept { + if (cb.func == &wrap_call) + return handle((PyObject *) cb.context).inc_ref(); + if (policy == rv_policy::none) + return handle(); + return cpp_function(cb, policy).release(); + } + NB_TYPE_CASTER(callback, const_name("Callable[[int], None]")) +}; + +nb::dict cb_registry() { + return nb::cast( + nb::module_::import_("test_callbacks_ext").attr("registry")); +} + +struct callback_data { + struct py_hash { + size_t operator()(const nb::object& obj) const { return nb::hash(obj); } + }; + struct py_eq { + bool operator()(const nb::object& a, const nb::object& b) const { + return a.equal(b); + } + }; + std::unordered_set subscribers; +}; + +callback_data& callbacks_for(nb::handle publisher) { + auto registry = cb_registry(); + nb::weakref key(publisher, registry.attr("__delitem__")); + if (nb::handle value = PyDict_GetItem(registry.ptr(), key.ptr())) { + return nb::cast(value); + } + nb::object new_data = nb::cast(callback_data{}); + registry[key] = new_data; + return nb::cast(new_data); +} + +// to check at compile time that the subscribe/unsubscribe functions take +// two arguments: self, callback +using TwoArgs = std::integral_constant; + +struct subscribe_policy { + static void precall(PyObject **, TwoArgs, nb::detail::cleanup_list *) {} + static void postcall(PyObject **args, TwoArgs, nb::handle) { + nb::handle self = args[0], cb = args[1]; + callbacks_for(self).subscribers.insert(nb::borrow(cb)); + } +}; + +struct unsubscribe_policy { + static void precall(PyObject **args, TwoArgs, nb::detail::cleanup_list *) { + nb::handle self = args[0], cb = args[1]; + auto& cbs = callbacks_for(self); + auto it = cbs.subscribers.find(nb::borrow(cb)); + if (it != cbs.subscribers.end() && !it->is(cb)) { + // No callback identical to this one is subscribed. Substitute + // one that is Python-equal. + args[1] = it->ptr(); + } + } + static void postcall(PyObject **args, TwoArgs, nb::handle) { + nb::handle self = args[0], cb = args[1]; + callbacks_for(self).subscribers.erase(nb::borrow(cb)); + } +}; + +NB_MODULE(test_callbacks_ext, m) { + m.attr("registry") = nb::dict(); + nb::class_(m, "callback_data") + .def_ro("subscribers", &callback_data::subscribers); + nb::class_(m, "publisher", nb::is_weak_referenceable()) + .def(nb::init<>()) + .def("subscribe", &publisher::subscribe, + nb::call_policy()) + .def("unsubscribe", &publisher::unsubscribe, + nb::call_policy()) + .def("emit", &publisher::emit); +} diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 00000000..5f6796c3 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,58 @@ +import test_callbacks_ext as t +import gc + + +def test_callbacks(): + pub1 = t.publisher() + pub2 = t.publisher() + record = [] + + def sub1(x): + record.append(x + 10) + + def sub2(x): + record.append(x + 20) + + pub1.subscribe(sub1) + pub2.subscribe(sub2) + for pub in (pub1, pub2): + pub.subscribe(record.append) + + pub1.emit(1) + assert record == [11, 1] + del record[:] + + pub2.emit(2) + assert record == [22, 2] + del record[:] + + pub1_w, pub2_w = t.registry.keys() # weakrefs to pub1, pub2 + assert pub1_w() is pub1 + assert pub2_w() is pub2 + assert t.registry[pub1_w].subscribers == {sub1, record.append} + assert t.registry[pub2_w].subscribers == {sub2, record.append} + + # NB: this `record.append` is a different object than the one we subscribed + # above, so we're testing the normalization logic in unsubscribe_policy + pub1.unsubscribe(record.append) + assert t.registry[pub1_w].subscribers == {sub1} + pub1.emit(3) + assert record == [13] + del record[:] + + del pub, pub1 + gc.collect() + gc.collect() + assert pub1_w() is None + assert pub2_w() is pub2 + assert t.registry.keys() == {pub2_w} + + pub2.emit(4) + assert record == [24, 4] + del record[:] + + del pub2 + gc.collect() + gc.collect() + assert pub2_w() is None + assert not t.registry diff --git a/tests/test_functions.cpp b/tests/test_functions.cpp index 34e584e3..0bc645d1 100644 --- a/tests/test_functions.cpp +++ b/tests/test_functions.cpp @@ -15,6 +15,17 @@ struct my_call_guard { ~my_call_guard() { call_guard_value = 2; } }; +// Example call policy for use with nb::call_policy<>. Each call will add +// an entry to `calls` containing the arguments tuple and return value. +// The return value will be recorded as "" if the function +// did not return (still executing or threw an exception) and as +// "" if the function returned something that we +// couldn't convert to a Python object. +// Additional features to test particular interactions: +// - the precall hook will throw if any arguments are not strings +// - any argument equal to "swapfrom" will be replaced by a temporary +// string object equal to "swapto", which will be destroyed at end of call +// - the postcall hook will throw if any argument equals "postthrow" struct example_policy { static inline std::vector> calls; static void precall(PyObject **args, size_t nargs,