Skip to content

Commit

Permalink
Refactor ref-counting code and fix ref counted releasing before aquiring
Browse files Browse the repository at this point in the history
  • Loading branch information
rune-scape committed Sep 21, 2024
1 parent 40b378e commit 324ae43
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 187 deletions.
102 changes: 36 additions & 66 deletions core/object/ref_counted.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,30 @@ template <typename T>
class Ref {
T *reference = nullptr;

void ref(const Ref &p_from) {
if (p_from.reference == reference) {
_FORCE_INLINE_ void ref(const Ref &p_from) {
ref_pointer<false>(p_from.reference);
}

template <bool Init>
_FORCE_INLINE_ void ref_pointer(T *p_refcounted) {
if (p_refcounted == reference) {
return;
}

unref();

reference = p_from.reference;
// This will go out of scope and get unref'd.
Ref cleanup_ref;
cleanup_ref.reference = reference;
reference = p_refcounted;
if (reference) {
reference->reference();
}
}

void ref_pointer(T *p_ref) {
ERR_FAIL_NULL(p_ref);

if (p_ref->init_ref()) {
reference = p_ref;
if constexpr (Init) {
if (!reference->init_ref()) {
reference = nullptr;
}
} else {
if (!reference->reference()) {
reference = nullptr;
}
}
}
}

Expand Down Expand Up @@ -119,15 +125,11 @@ class Ref {

template <typename T_Other>
void operator=(const Ref<T_Other> &p_from) {
RefCounted *refb = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_from.ptr()));
if (!refb) {
unref();
return;
}
Ref r;
r.reference = Object::cast_to<T>(refb);
ref(r);
r.reference = nullptr;
ref_pointer<false>(Object::cast_to<T>(p_from.ptr()));
}

void operator=(T *p_from) {
ref_pointer<true>(p_from);
}

void operator=(const Variant &p_variant) {
Expand All @@ -137,65 +139,33 @@ class Ref {
return;
}

unref();

if (!object) {
return;
}

T *r = Object::cast_to<T>(object);
if (r && r->reference()) {
reference = r;
}
ref_pointer<false>(Object::cast_to<T>(object));
}

template <typename T_Other>
void reference_ptr(T_Other *p_ptr) {
if (reference == p_ptr) {
return;
}
unref();

T *r = Object::cast_to<T>(p_ptr);
if (r) {
ref_pointer(r);
}
ref_pointer<true>(Object::cast_to<T>(p_ptr));
}

Ref(const Ref &p_from) {
ref(p_from);
this->operator=(p_from);
}

template <typename T_Other>
Ref(const Ref<T_Other> &p_from) {
RefCounted *refb = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_from.ptr()));
if (!refb) {
unref();
return;
}
Ref r;
r.reference = Object::cast_to<T>(refb);
ref(r);
r.reference = nullptr;
this->operator=(p_from);
}

Ref(T *p_reference) {
if (p_reference) {
ref_pointer(p_reference);
}
Ref(T *p_from) {
this->operator=(p_from);
}

Ref(const Variant &p_variant) {
Object *object = p_variant.get_validated_object();

if (!object) {
return;
}

T *r = Object::cast_to<T>(object);
if (r && r->reference()) {
reference = r;
}
Ref(const Variant &p_from) {
this->operator=(p_from);
}

inline bool is_valid() const { return reference != nullptr; }
Expand All @@ -217,7 +187,7 @@ class Ref {
ref(memnew(T(p_params...)));
}

Ref() {}
Ref() = default;

~Ref() {
unref();
Expand Down Expand Up @@ -294,13 +264,13 @@ struct GetTypeInfo<const Ref<T> &> {
template <typename T>
struct VariantInternalAccessor<Ref<T>> {
static _FORCE_INLINE_ Ref<T> get(const Variant *v) { return Ref<T>(*VariantInternal::get_object(v)); }
static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::refcounted_object_assign(v, p_ref.ptr()); }
static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::object_assign(v, p_ref); }
};

template <typename T>
struct VariantInternalAccessor<const Ref<T> &> {
static _FORCE_INLINE_ Ref<T> get(const Variant *v) { return Ref<T>(*VariantInternal::get_object(v)); }
static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::refcounted_object_assign(v, p_ref.ptr()); }
static _FORCE_INLINE_ void set(Variant *v, const Ref<T> &p_ref) { VariantInternal::object_assign(v, p_ref); }
};

#endif // REF_COUNTED_H
19 changes: 10 additions & 9 deletions core/variant/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,31 +315,32 @@ bool Callable::operator<(const Callable &p_callable) const {
}

void Callable::operator=(const Callable &p_callable) {
CallableCustom *cleanup_ref = nullptr;
if (is_custom()) {
if (p_callable.is_custom()) {
if (custom == p_callable.custom) {
return;
}
}

if (custom->ref_count.unref()) {
memdelete(custom);
custom = nullptr;
}
cleanup_ref = custom;
custom = nullptr;
}

if (p_callable.is_custom()) {
method = StringName();
if (!p_callable.custom->ref_count.ref()) {
object = 0;
} else {
object = 0;
object = 0;
if (p_callable.custom->ref_count.ref()) {
custom = p_callable.custom;
}
} else {
method = p_callable.method;
object = p_callable.object;
}

if (cleanup_ref != nullptr && cleanup_ref->ref_count.unref()) {
memdelete(cleanup_ref);
}
cleanup_ref = nullptr;
}

Callable::operator String() const {
Expand Down
132 changes: 65 additions & 67 deletions core/variant/variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,17 +1072,69 @@ bool Variant::is_null() const {
}
}

void Variant::ObjData::ref(const ObjData &p_from) {
// Mirrors Ref::ref in refcounted.h
if (p_from.id == id) {
return;
}

ObjData cleanup_ref = *this;

*this = p_from;
if (id.is_ref_counted()) {
RefCounted *reference = static_cast<RefCounted *>(obj);
// Assuming reference is not null because id.is_ref_counted() was true.
if (!reference->reference()) {
*this = ObjData();
}
}

cleanup_ref.unref();
}

void Variant::ObjData::ref_pointer(Object *p_object) {
// Mirrors Ref::ref_pointer in refcounted.h
if (p_object == obj) {
return;
}

ObjData cleanup_ref = *this;

if (p_object) {
*this = ObjData{ p_object->get_instance_id(), p_object };
if (p_object->is_ref_counted()) {
RefCounted *reference = static_cast<RefCounted *>(p_object);
if (!reference->init_ref()) {
*this = ObjData();
}
}
} else {
*this = ObjData();
}

cleanup_ref.unref();
}

void Variant::ObjData::unref() {
// Mirrors Ref::unref in refcounted.h
if (id.is_ref_counted()) {
RefCounted *reference = static_cast<RefCounted *>(obj);
// Assuming reference is not null because id.is_ref_counted() was true.
if (reference->unreference()) {
memdelete(reference);
}
}
*this = ObjData();
}

void Variant::reference(const Variant &p_variant) {
switch (type) {
case NIL:
case BOOL:
case INT:
case FLOAT:
break;
default:
clear();
if (type == OBJECT && p_variant.type == OBJECT) {
_get_obj().ref(p_variant._get_obj());
return;
}

clear();

type = p_variant.type;

switch (p_variant.type) {
Expand Down Expand Up @@ -1165,18 +1217,7 @@ void Variant::reference(const Variant &p_variant) {
} break;
case OBJECT: {
memnew_placement(_data._mem, ObjData);

if (p_variant._get_obj().obj && p_variant._get_obj().id.is_ref_counted()) {
RefCounted *ref_counted = static_cast<RefCounted *>(p_variant._get_obj().obj);
if (!ref_counted->reference()) {
_get_obj().obj = nullptr;
_get_obj().id = ObjectID();
break;
}
}

_get_obj().obj = const_cast<Object *>(p_variant._get_obj().obj);
_get_obj().id = p_variant._get_obj().id;
_get_obj().ref(p_variant._get_obj());
} break;
case CALLABLE: {
memnew_placement(_data._mem, Callable(*reinterpret_cast<const Callable *>(p_variant._data._mem)));
Expand Down Expand Up @@ -1375,15 +1416,7 @@ void Variant::_clear_internal() {
reinterpret_cast<NodePath *>(_data._mem)->~NodePath();
} break;
case OBJECT: {
if (_get_obj().id.is_ref_counted()) {
// We are safe that there is a reference here.
RefCounted *ref_counted = static_cast<RefCounted *>(_get_obj().obj);
if (ref_counted->unreference()) {
memdelete(ref_counted);
}
}
_get_obj().obj = nullptr;
_get_obj().id = ObjectID();
_get_obj().unref();
} break;
case RID: {
// Not much need probably.
Expand Down Expand Up @@ -2589,24 +2622,8 @@ Variant::Variant(const ::RID &p_rid) :

Variant::Variant(const Object *p_object) :
type(OBJECT) {
memnew_placement(_data._mem, ObjData);

if (p_object) {
if (p_object->is_ref_counted()) {
RefCounted *ref_counted = const_cast<RefCounted *>(static_cast<const RefCounted *>(p_object));
if (!ref_counted->init_ref()) {
_get_obj().obj = nullptr;
_get_obj().id = ObjectID();
return;
}
}

_get_obj().obj = const_cast<Object *>(p_object);
_get_obj().id = p_object->get_instance_id();
} else {
_get_obj().obj = nullptr;
_get_obj().id = ObjectID();
}
_get_obj() = ObjData();
_get_obj().ref_pointer(const_cast<Object *>(p_object));
}

Variant::Variant(const Callable &p_callable) :
Expand Down Expand Up @@ -2828,26 +2845,7 @@ void Variant::operator=(const Variant &p_variant) {
*reinterpret_cast<::RID *>(_data._mem) = *reinterpret_cast<const ::RID *>(p_variant._data._mem);
} break;
case OBJECT: {
if (_get_obj().id.is_ref_counted()) {
//we are safe that there is a reference here
RefCounted *ref_counted = static_cast<RefCounted *>(_get_obj().obj);
if (ref_counted->unreference()) {
memdelete(ref_counted);
}
}

if (p_variant._get_obj().obj && p_variant._get_obj().id.is_ref_counted()) {
RefCounted *ref_counted = static_cast<RefCounted *>(p_variant._get_obj().obj);
if (!ref_counted->reference()) {
_get_obj().obj = nullptr;
_get_obj().id = ObjectID();
break;
}
}

_get_obj().obj = const_cast<Object *>(p_variant._get_obj().obj);
_get_obj().id = p_variant._get_obj().id;

_get_obj().ref(p_variant._get_obj());
} break;
case CALLABLE: {
*reinterpret_cast<Callable *>(_data._mem) = *reinterpret_cast<const Callable *>(p_variant._data._mem);
Expand Down
Loading

0 comments on commit 324ae43

Please sign in to comment.