From 0c0e487fa3e0b0c0dde67f5cd02b284276af4b33 Mon Sep 17 00:00:00 2001 From: Brice Videau Date: Mon, 8 Jul 2024 15:06:16 -0500 Subject: [PATCH] Use explicit ref-counting for user defined object data. --- bindings/python/cconfigspace/tree_space.py | 38 ++++++++------ bindings/python/cconfigspace/tuner.py | 52 ++++++++++++-------- bindings/ruby/cconfigspace.gemspec | 2 +- bindings/ruby/lib/cconfigspace/tree_space.rb | 19 ++++--- bindings/ruby/lib/cconfigspace/tuner.rb | 23 +++++---- 5 files changed, 82 insertions(+), 52 deletions(-) diff --git a/bindings/python/cconfigspace/tree_space.py b/bindings/python/cconfigspace/tree_space.py index 1b43042f..193ad78a 100644 --- a/bindings/python/cconfigspace/tree_space.py +++ b/bindings/python/cconfigspace/tree_space.py @@ -181,8 +181,12 @@ def _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize): def delete_wrapper(ts): try: ts = ct.cast(ts, ccs_tree_space) + o = Object.from_handle(ts) + tsdata = o.tree_space_data if delete is not None: - delete(Object.from_handle(ts)) + delete(o) + if tsdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) _unregister_vector(ts) return Result.SUCCESS except Exception as e: @@ -252,14 +256,15 @@ def __init__(self, handle = None, retain = False, auto_release = True, if get_child is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, + wrappers = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) + (_, + _, + _, + _, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) + deserialize_wrapper_func) = wrappers handle = ccs_tree_space() vec = DynamicTreeSpaceVector() vec.delete = delete_wrapper_func @@ -277,7 +282,9 @@ def __init__(self, handle = None, retain = False, auto_release = True, res = ccs_create_dynamic_tree_space(str.encode(name), tree.handle, feature_space, rng, ct.byref(vec), c_tree_space_data, ct.byref(handle)) Error.check(res) super().__init__(handle = handle, retain = False) - _register_vector(handle, [delete_wrapper, get_child_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tree_space_data]) + _register_vector(handle, wrappers) + if c_tree_space_data is not None: + ct.pythonapi.Py_IncRef(c_tree_space_data) else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) @@ -285,22 +292,25 @@ def __init__(self, handle = None, retain = False, auto_release = True, def deserialize(cls, delete, get_child, serialize = None, deserialize = None, tree_space_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): if get_child is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, + + wrappers = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) + (_, + _, + _, + _, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) - handle = ccs_tree_space() + deserialize_wrapper_func) = wrappers vector = DynamicTreeSpaceVector() vector.delete = delete_wrapper_func vector.get_child = get_child_wrapper_func vector.serialize = serialize_wrapper_func vector.deserialize = deserialize_wrapper_func res = super().deserialize(format = format, handle_map = handle_map, vector = vector, data = tree_space_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - _register_vector(res.handle, [delete_wrapper, get_child_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tree_space_data]) + _register_vector(res.handle, wrappers) + if tree_space_data is not None: + ct.pythonapi.Py_IncRef(ct.py_object(tree_space_data)) return res @property diff --git a/bindings/python/cconfigspace/tuner.py b/bindings/python/cconfigspace/tuner.py index f4df3bdb..64e049b0 100644 --- a/bindings/python/cconfigspace/tuner.py +++ b/bindings/python/cconfigspace/tuner.py @@ -187,8 +187,12 @@ def _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_histor def delete_wrapper(tun): try: tun = ct.cast(tun, ccs_tuner) + o = Object.from_handle(tun) + tdata = o.tuner_data if delete is not None: - delete(Object.from_handle(tun)) + delete(o) + if tdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tdata)) _unregister_vector(tun) return Result.SUCCESS except Exception as e: @@ -353,14 +357,15 @@ def __init__(self, handle = None, retain = False, auto_release = True, if ask is None or tell is None or get_optima is None or get_history is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, + wrappers = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + (_, + _, + _, + _, + _, + _, + _, + _, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, @@ -368,7 +373,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + deserialize_wrapper_func) = wrappers handle = ccs_tuner() vec = UserDefinedTunerVector() vec.delete = delete_wrapper_func @@ -386,7 +391,9 @@ def __init__(self, handle = None, retain = False, auto_release = True, res = ccs_create_user_defined_tuner(str.encode(name), objective_space.handle, ct.byref(vec), c_tuner_data, ct.byref(handle)) Error.check(res) super().__init__(handle = handle, retain = False) - _register_vector(handle, [delete_wrapper, ask_wrapper, tell_wrapper, get_optima_wrapper, get_history_wrapper, suggest_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optima_wrapper_func, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tuner_data]) + _register_vector(handle, wrappers) + if c_tuner_data is not None: + ct.pythonapi.Py_IncRef(c_tuner_data) else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) @@ -394,14 +401,15 @@ def __init__(self, handle = None, retain = False, auto_release = True, def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, serialize = None, deserialize = None, tuner_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): if ask is None or tell is None or get_optima is None or get_history is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, + wrappers = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + (_, + _, + _, + _, + _, + _, + _, + _, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, @@ -409,7 +417,7 @@ def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + deserialize_wrapper_func) = wrappers vector = UserDefinedTunerVector() vector.delete = delete_wrapper_func vector.ask = ask_wrapper_func @@ -420,7 +428,9 @@ def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, vector.serialize = serialize_wrapper_func vector.deserialize = deserialize_wrapper_func res = super().deserialize(format = format, handle_map = handle_map, vector = vector, data = tuner_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - _register_vector(res.handle, [delete_wrapper, ask_wrapper, tell_wrapper, get_optima_wrapper, get_history_wrapper, suggest_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optima_wrapper_func, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tuner_data]) + _register_vector(res.handle, wrappers) + if tuner_data is not None: + ct.pythonapi.Py_IncRef(ct.py_object(tuner_data)) return res @property diff --git a/bindings/ruby/cconfigspace.gemspec b/bindings/ruby/cconfigspace.gemspec index 2ed6bcc2..e880107c 100644 --- a/bindings/ruby/cconfigspace.gemspec +++ b/bindings/ruby/cconfigspace.gemspec @@ -10,6 +10,6 @@ Gem::Specification.new do |s| s.license = 'BSD-3-Clause' s.required_ruby_version = '>= 2.3.0' s.add_dependency 'ffi', '~> 1.13', '>=1.13.0' - s.add_dependency 'ffi-value', '~> 0.1', '>=0.1.1' + s.add_dependency 'ffi-value', '~> 0.1', '>=0.1.3' s.add_dependency 'whittle', '~> 0.0', '>=0.0.8' end diff --git a/bindings/ruby/lib/cconfigspace/tree_space.rb b/bindings/ruby/lib/cconfigspace/tree_space.rb index 75f7c919..76e70a12 100644 --- a/bindings/ruby/lib/cconfigspace/tree_space.rb +++ b/bindings/ruby/lib/cconfigspace/tree_space.rb @@ -133,7 +133,10 @@ class DynamicTreeSpaceVector < FFI::Struct def self.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) delwrapper = lambda { |ts| begin - del.call(CCS::Object.from_handle(ts)) if del + o = CCS::Object.from_handle(ts) + tsdata = o.tree_space_data + del.call(o) if del + FFI.dec_ref(tsdata) unless tsdata.nil? CCS.unregister_vector(ts) CCSError.to_native(:CCS_RESULT_SUCCESS) rescue => e @@ -196,8 +199,8 @@ def initialize(handle = nil, retain: false, auto_release: true, super(handle, retain: retain, auto_release: auto_release) else raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = - CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) + wrappers = CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) + delwrapper, get_childwrapper, serializewrapper, deserializewrapper = wrappers vector = DynamicTreeSpaceVector::new vector[:del] = delwrapper vector[:get_child] = get_childwrapper @@ -207,21 +210,23 @@ def initialize(handle = nil, retain: false, auto_release: true, CCS.error_check CCS.ccs_create_dynamic_tree_space(name, tree, feature_space, rng, vector, tree_space_data, ptr) h = ptr.read_ccs_tree_space_t super(h, retain: false) - CCS.register_vector(h, [delwrapper, get_childwrapper, serializewrapper, deserializewrapper, tree_space_data]) + CCS.register_vector(h, wrappers) + FFI.inc_ref(tree_space_data) unless tree_space_data.nil? end end def self.deserialize(del: nil, get_child: nil, serialize: nil, deserialize: nil, tree_space_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = - CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) + wrappers = CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) + delwrapper, get_childwrapper, serializewrapper, deserializewrapper = wrappers vector = DynamicTreeSpaceVector::new vector[:del] = delwrapper vector[:get_child] = get_childwrapper vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tree_space_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, [delwrapper, get_childwrapper, serializewrapper, deserializewrapper, tree_space_data]) + CCS.register_vector(res.handle, wrappers) + FFI.inc_ref(tree_space_data) unless tree_space_data.nil? res end diff --git a/bindings/ruby/lib/cconfigspace/tuner.rb b/bindings/ruby/lib/cconfigspace/tuner.rb index c9dd8983..48b2f41d 100644 --- a/bindings/ruby/lib/cconfigspace/tuner.rb +++ b/bindings/ruby/lib/cconfigspace/tuner.rb @@ -32,7 +32,7 @@ def self.from_handle(handle, retain: true, auto_release: true) CCS.error_check CCS.ccs_tuner_get_type(handle, ptr) case ptr.read_ccs_tuner_type_t when :CCS_TUNER_TYPE_RANDOM - RandomTuner + RandomTuner when :CCS_TUNER_TYPE_USER_DEFINED UserDefinedTuner else @@ -139,7 +139,10 @@ class UserDefinedTunerVector < FFI::Struct def self.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) delwrapper = lambda { |tun| begin - del.call(CCS::Object.from_handle(tun)) if del + o = CCS::Object.from_handle(tun) + tdata = o.tuner_data + del.call(o) if del + FFI.dec_ref(tdata) unless tdata.nil? CCS.unregister_vector(tun) CCSError.to_native(:CCS_RESULT_SUCCESS) rescue => e @@ -269,9 +272,9 @@ def initialize(handle = nil, retain: false, auto_release: true, if handle super(handle, retain: retain, auto_release: auto_release) else - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = - CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? + wrappers = CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = wrappers vector = UserDefinedTunerVector::new vector[:del] = delwrapper vector[:ask] = askwrapper @@ -285,14 +288,15 @@ def initialize(handle = nil, retain: false, auto_release: true, CCS.error_check CCS.ccs_create_user_defined_tuner(name, objective_space, vector, tuner_data, ptr) handle = ptr.read_ccs_tuner_t super(handle, retain: false) - CCS.register_vector(handle, [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper, tuner_data]) + CCS.register_vector(handle, wrappers) + FFI.inc_ref(tuner_data) unless tuner_data.nil? end end def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history: nil, suggest: nil, serialize: nil, deserialize: nil, tuner_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = - CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + wrappers = CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = wrappers vector = UserDefinedTunerVector::new vector[:del] = delwrapper vector[:ask] = askwrapper @@ -303,7 +307,8 @@ def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tuner_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper, tuner_data]) + CCS.register_vector(res.handle, wrappers) + FFI.inc_ref(tuner_data) unless tuner_data.nil? res end end