Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added test case for issue #867 #886

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions src/nb_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */,
self->clear_keep_alive = 0;
self->intrusive = intrusive;
self->unused = 0;

// Make the object compatible with nb_try_inc_ref (free-threaded builds only)
nb_enable_try_inc_ref((PyObject *) self);

// Update hash table that maps from C++ to Python instance
Expand Down Expand Up @@ -165,8 +167,15 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
self->clear_keep_alive = 0;
self->intrusive = intrusive;
self->unused = 0;

// Make the object compatible with nb_try_inc_ref (free-threaded builds only)
nb_enable_try_inc_ref((PyObject *) self);

return (PyObject *) self;
}

/// Register the object constructed by 'inst_new_ext()' in the internal data structures
static void inst_ext_register(PyObject *self, void *value) {
nb_shard &shard = internals->shard(value);
lock_shard guard(shard);

Expand All @@ -188,8 +197,7 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {

nb_inst_seq *seq = nb_get_seq(entry);
while (true) {
check((nb_inst *) seq->inst != self,
"nanobind::detail::inst_new_ext(): duplicate instance!");
check(seq->inst != self, "nanobind::detail::inst_new_ext(): duplicate instance!");
if (!seq->next)
break;
seq = seq->next;
Expand All @@ -203,10 +211,9 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
next->next = nullptr;
seq->next = next;
}

return (PyObject *) self;
}


static void inst_dealloc(PyObject *self) {
PyTypeObject *tp = Py_TYPE(self);
const type_data *t = nb_type_data(tp);
Expand Down Expand Up @@ -1737,6 +1744,9 @@ static PyObject *nb_type_put_common(void *value, type_data *t, rv_policy rvp,
if (intrusive)
t->set_self_py(new_value, (PyObject *) inst);

if (!create_new)
inst_ext_register((PyObject *) inst, value);

return (PyObject *) inst;
}

Expand All @@ -1763,7 +1773,7 @@ PyObject *nb_type_put(const std::type_info *cpp_type,
return true;
};

if (rvp != rv_policy::copy) {
if (rvp != rv_policy::copy && rvp != rv_policy::move) {
nb_shard &shard = internals_->shard(value);
lock_shard guard(shard);

Expand Down Expand Up @@ -1847,7 +1857,7 @@ PyObject *nb_type_put_p(const std::type_info *cpp_type,
return true;
};

if (rvp != rv_policy::copy) {
if (rvp != rv_policy::copy && rvp != rv_policy::move) {
nb_shard &shard = internals_->shard(value);
lock_shard guard(shard);

Expand Down Expand Up @@ -2082,6 +2092,7 @@ PyObject *nb_inst_reference(PyTypeObject *t, void *ptr, PyObject *parent) {
nbi->state = nb_inst::state_ready;
if (parent)
keep_alive(result, parent);
inst_ext_register(result, ptr);
return result;
}

Expand All @@ -2092,6 +2103,7 @@ PyObject *nb_inst_take_ownership(PyTypeObject *t, void *ptr) {
nb_inst *nbi = (nb_inst *) result;
nbi->destruct = nbi->cpp_delete = true;
nbi->state = nb_inst::state_ready;
inst_ext_register(result, ptr);
return result;
}

Expand Down
29 changes: 29 additions & 0 deletions tests/test_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ struct GlobalData {} global_data;

nb::ft_mutex mutex;

struct ClassWithProperty {
public:
ClassWithProperty(int value): value_(value) {}
int get_prop() const { return value_; }
private:
int value_;
};

class ClassWithClassProperty {
public:
ClassWithClassProperty(ClassWithProperty value) : value_(std::move(value)) {};
const ClassWithProperty& get_prop() const { return value_; }
private:
ClassWithProperty value_;
};


NB_MODULE(test_thread_ext, m) {
nb::class_<Counter>(m, "Counter")
.def(nb::init<>())
Expand All @@ -39,4 +56,16 @@ NB_MODULE(test_thread_ext, m) {

nb::class_<GlobalData>(m, "GlobalData")
.def_static("get", [] { return &global_data; }, nb::rv_policy::reference);

nb::class_<ClassWithProperty>(m, "ClassWithProperty")
.def(nb::init<int>(), nb::arg("value"))
.def_prop_ro("prop2", &ClassWithProperty::get_prop);

nb::class_<ClassWithClassProperty>(m, "ClassWithClassProperty")
.def(
"__init__",
[](ClassWithClassProperty* self, ClassWithProperty value) {
new (self) ClassWithClassProperty(std::move(value));
}, nb::arg("value"))
.def_prop_ro("prop1", &ClassWithClassProperty::get_prop);
}
14 changes: 13 additions & 1 deletion tests/test_thread.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import test_thread_ext as t
from test_thread_ext import Counter, GlobalData
from test_thread_ext import Counter, GlobalData, ClassWithProperty, ClassWithClassProperty
from common import parallelize

def test01_object_creation(n_threads=8):
Expand Down Expand Up @@ -88,3 +88,15 @@ def f():
GlobalData.get()

parallelize(f, n_threads=n_threads)


def test07_access_attributes(n_threads=8):
n = 1000
c1 = ClassWithProperty(123)
c2 = ClassWithClassProperty(c1)

def f():
for i in range(n):
_ = c2.prop1.prop2

parallelize(f, n_threads=n_threads)
Loading