diff --git a/src/nb_type.cpp b/src/nb_type.cpp index 1fc5c152..9ded2bf3 100644 --- a/src/nb_type.cpp +++ b/src/nb_type.cpp @@ -99,6 +99,8 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */, self->clear_keep_alive = 0; self->intrusive = intrusive; self->unused = 0; + + // Make the object compatible with nb_try_inc_ref (free-threaded builds only) nb_enable_try_inc_ref((PyObject *) self); // Update hash table that maps from C++ to Python instance @@ -165,8 +167,15 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) { self->clear_keep_alive = 0; self->intrusive = intrusive; self->unused = 0; + + // Make the object compatible with nb_try_inc_ref (free-threaded builds only) nb_enable_try_inc_ref((PyObject *) self); + return (PyObject *) self; +} + +/// Register the object constructed by 'inst_new_ext()' in the internal data structures +static void inst_ext_register(PyObject *self, void *value) { nb_shard &shard = internals->shard(value); lock_shard guard(shard); @@ -188,8 +197,7 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) { nb_inst_seq *seq = nb_get_seq(entry); while (true) { - check((nb_inst *) seq->inst != self, - "nanobind::detail::inst_new_ext(): duplicate instance!"); + check(seq->inst != self, "nanobind::detail::inst_new_ext(): duplicate instance!"); if (!seq->next) break; seq = seq->next; @@ -203,10 +211,9 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) { next->next = nullptr; seq->next = next; } - - return (PyObject *) self; } + static void inst_dealloc(PyObject *self) { PyTypeObject *tp = Py_TYPE(self); const type_data *t = nb_type_data(tp); @@ -1737,6 +1744,9 @@ static PyObject *nb_type_put_common(void *value, type_data *t, rv_policy rvp, if (intrusive) t->set_self_py(new_value, (PyObject *) inst); + if (!create_new) + inst_ext_register((PyObject *) inst, value); + return (PyObject *) inst; } @@ -1763,7 +1773,7 @@ PyObject *nb_type_put(const std::type_info *cpp_type, return true; }; - if (rvp != rv_policy::copy) { + if (rvp != rv_policy::copy && rvp != rv_policy::move) { nb_shard &shard = internals_->shard(value); lock_shard guard(shard); @@ -1847,7 +1857,7 @@ PyObject *nb_type_put_p(const std::type_info *cpp_type, return true; }; - if (rvp != rv_policy::copy) { + if (rvp != rv_policy::copy && rvp != rv_policy::move) { nb_shard &shard = internals_->shard(value); lock_shard guard(shard); @@ -2082,6 +2092,7 @@ PyObject *nb_inst_reference(PyTypeObject *t, void *ptr, PyObject *parent) { nbi->state = nb_inst::state_ready; if (parent) keep_alive(result, parent); + inst_ext_register(result, ptr); return result; } @@ -2092,6 +2103,7 @@ PyObject *nb_inst_take_ownership(PyTypeObject *t, void *ptr) { nb_inst *nbi = (nb_inst *) result; nbi->destruct = nbi->cpp_delete = true; nbi->state = nb_inst::state_ready; + inst_ext_register(result, ptr); return result; } diff --git a/tests/test_thread.cpp b/tests/test_thread.cpp index 54181b7b..97e82960 100644 --- a/tests/test_thread.cpp +++ b/tests/test_thread.cpp @@ -16,6 +16,23 @@ struct GlobalData {} global_data; nb::ft_mutex mutex; +struct ClassWithProperty { +public: + ClassWithProperty(int value): value_(value) {} + int get_prop() const { return value_; } +private: + int value_; +}; + +class ClassWithClassProperty { +public: + ClassWithClassProperty(ClassWithProperty value) : value_(std::move(value)) {}; + const ClassWithProperty& get_prop() const { return value_; } +private: + ClassWithProperty value_; +}; + + NB_MODULE(test_thread_ext, m) { nb::class_(m, "Counter") .def(nb::init<>()) @@ -39,4 +56,16 @@ NB_MODULE(test_thread_ext, m) { nb::class_(m, "GlobalData") .def_static("get", [] { return &global_data; }, nb::rv_policy::reference); + + nb::class_(m, "ClassWithProperty") + .def(nb::init(), nb::arg("value")) + .def_prop_ro("prop2", &ClassWithProperty::get_prop); + + nb::class_(m, "ClassWithClassProperty") + .def( + "__init__", + [](ClassWithClassProperty* self, ClassWithProperty value) { + new (self) ClassWithClassProperty(std::move(value)); + }, nb::arg("value")) + .def_prop_ro("prop1", &ClassWithClassProperty::get_prop); } diff --git a/tests/test_thread.py b/tests/test_thread.py index 1dd05af9..2a179862 100644 --- a/tests/test_thread.py +++ b/tests/test_thread.py @@ -1,5 +1,5 @@ import test_thread_ext as t -from test_thread_ext import Counter, GlobalData +from test_thread_ext import Counter, GlobalData, ClassWithProperty, ClassWithClassProperty from common import parallelize def test01_object_creation(n_threads=8): @@ -88,3 +88,15 @@ def f(): GlobalData.get() parallelize(f, n_threads=n_threads) + + +def test07_access_attributes(n_threads=8): + n = 1000 + c1 = ClassWithProperty(123) + c2 = ClassWithClassProperty(c1) + + def f(): + for i in range(n): + _ = c2.prop1.prop2 + + parallelize(f, n_threads=n_threads)