From adfa73d3f58e1f1036c0351de67bdcdf150cca7b Mon Sep 17 00:00:00 2001 From: Brice Videau Date: Fri, 12 Jul 2024 18:52:29 -0500 Subject: [PATCH] WIP: use vector dispatcher callback for user defined objects. Working python bindings. --- bindings/python/cconfigspace/base.py | 53 ++- bindings/python/cconfigspace/tree_space.py | 191 +++++----- bindings/python/cconfigspace/tuner.py | 395 ++++++++++---------- bindings/python/test/test_features_tuner.py | 5 +- bindings/python/test/test_tree_space.py | 5 +- bindings/python/test/test_tree_tuner.py | 5 +- bindings/python/test/test_tuner.py | 9 +- include/cconfigspace/base.h | 53 ++- src/cconfigspace_deserialize.h | 21 +- src/cconfigspace_internal.h | 14 +- src/tree_space_deserialize.h | 35 +- src/tuner_deserialize.h | 21 +- tests/test_categorical_parameter.c | 2 +- tests/test_dynamic_tree_space.c | 29 +- tests/test_user_defined_features_tuner.c | 29 +- tests/test_user_defined_tree_tuner.c | 29 +- tests/test_user_defined_tuner.c | 29 +- 17 files changed, 527 insertions(+), 398 deletions(-) diff --git a/bindings/python/cconfigspace/base.py b/bindings/python/cconfigspace/base.py index d25593f7..b95e8423 100644 --- a/bindings/python/cconfigspace/base.py +++ b/bindings/python/cconfigspace/base.py @@ -395,10 +395,9 @@ class DeserializeOption(CEnumeration): _members_ = [ ('END', 0), 'HANDLE_MAP', - 'VECTOR', - 'DATA', + 'VECTOR_CALLBACK', 'NON_BLOCKING', - 'CALLBACK' + 'DATA_CALLBACK' ] def _ccs_get_function(method, argtypes = [], restype = Result): @@ -422,7 +421,8 @@ def _ccs_get_function(method, argtypes = [], restype = Result): ccs_object_get_user_data = _ccs_get_function("ccs_object_get_user_data", [ccs_object, ct.POINTER(ct.c_void_p)]) ccs_object_serialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t), ct.c_void_p) ccs_object_set_serialize_callback = _ccs_get_function("ccs_object_set_serialize_callback", [ccs_object, ccs_object_serialize_callback_type, ct.c_void_p]) -ccs_object_deserialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p) +ccs_object_deserialize_data_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p) +ccs_object_deserialize_vector_callback_type = ct.CFUNCTYPE(Result, ct.c_int, ct.c_char_p, ct.c_void_p, ct.POINTER(ct.c_void_p), ct.POINTER(ct.py_object)) # Variadic methods ccs_object_serialize = getattr(libcconfigspace, "ccs_object_serialize") ccs_object_serialize.argtypes = ccs_object, SerializeFormat, SerializeOperation, @@ -556,7 +556,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call return v @classmethod - def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): + def deserialize(cls, format = 'binary', handle_map = None, vector_callback = None, vector_callback_data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): if format != 'binary': raise Error(Result(Result.ERROR_INVALID_VALUE)) mode_count = 0; @@ -572,16 +572,16 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = options = [DeserializeOption.END] if handle_map: options = [DeserializeOption.HANDLE_MAP, handle_map.handle] + options - if vector: - options = [DeserializeOption.VECTOR, ct.byref(vector)] + options - if data: - options = [DeserializeOption.DATA, ct.py_object(data)] + options + if vector_callback: + vector_cb_wrapper = _get_deserialize_vector_callback_wrapper(vector_callback) + vector_cb_wrapper_func = ccs_object_deserialize_vector_callback_type(vector_cb_wrapper) + options = [DeserializeOption.VECTOR_CALLBACK, vector_cb_wrapper_func, ct.py_object(vector_callback_data)] + options if callback: - cb_wrapper = _get_deserialize_callback_wrapper(callback) - cb_wrapper_func = ccs_object_deserialize_callback_type(cb_wrapper) - options = [DeserializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options + cb_wrapper = _get_deserialize_data_callback_wrapper(callback) + cb_wrapper_func = ccs_object_deserialize_data_callback_type(cb_wrapper) + options = [DeserializeOption.DATA_CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options elif _default_user_data_deserializer: - options = [DeserializeOption.CALLBACK, _default_user_data_deserializer, ct.py_object()] + options + options = [DeserializeOption.DATA_CALLBACK, _default_user_data_deserializer, ct.py_object()] + options if buffer: s = ct.c_size_t(ct.sizeof(buffer)) res = ccs_object_deserialize(ct.byref(o), SerializeFormat.BINARY, SerializeOperation.MEMORY, s, buffer, *options) @@ -621,8 +621,8 @@ def serialize_callback_wrapper(obj, serialize_data_size, serialize_data, seriali return Error.set_error(e) return serialize_callback_wrapper -def _get_deserialize_callback_wrapper(callback): - def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data): +def _get_deserialize_data_callback_wrapper(callback): + def deserialize_data_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data): try: p_sd = ct.cast(serialize_data, ct.c_void_p) cb_data = ct.cast(cb_data, ct.c_void_p) @@ -638,7 +638,24 @@ def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_da return Result.SUCCESS except Exception as e: return Error.set_error(e) - return deserialize_callback_wrapper + return deserialize_data_callback_wrapper + +def _get_deserialize_vector_callback_wrapper(callback): + def deserialize_vector_callback_wrapper(obj_type, name, callback_user_data, vector_ret, data_ret): + try: + cb_data = ct.cast(callback_user_data, ct.py_object).value if callback_user_data else None + o_type = ObjectType(obj_type) + (vector, data) = callback(o_type, name, cb_data) + c_vector = ct.py_object(vector) + c_data = ct.py_object(data) + vector_ret[0] = ct.cast(ct.byref(vector), ct.c_void_p) + data_ret[0] = c_data + ct.pythonapi.Py_IncRef(c_vector) + ct.pythonapi.Py_IncRef(c_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + return deserialize_vector_callback_wrapper def _json_user_data_serializer(obj, data, size): string = json.dumps(obj.user_data).encode("ascii") @@ -651,8 +668,8 @@ def _json_user_data_deserializer(obj, serialized, data): _json_user_data_serializer_wrap = _get_serialize_callback_wrapper(_json_user_data_serializer) _json_user_data_serializer_func = ccs_object_serialize_callback_type(_json_user_data_serializer_wrap) -_json_user_data_deserializer_wrap = _get_deserialize_callback_wrapper(_json_user_data_deserializer) -_json_user_data_deserializer_func = ccs_object_deserialize_callback_type(_json_user_data_deserializer_wrap) +_json_user_data_deserializer_wrap = _get_deserialize_data_callback_wrapper(_json_user_data_deserializer) +_json_user_data_deserializer_func = ccs_object_deserialize_data_callback_type(_json_user_data_deserializer_wrap) _default_user_data_serializer = _json_user_data_serializer_func _default_user_data_deserializer = _json_user_data_deserializer_func diff --git a/bindings/python/cconfigspace/tree_space.py b/bindings/python/cconfigspace/tree_space.py index 9d737d34..43297068 100644 --- a/bindings/python/cconfigspace/tree_space.py +++ b/bindings/python/cconfigspace/tree_space.py @@ -178,93 +178,6 @@ class DynamicTreeSpaceVector(ct.Structure): ccs_create_dynamic_tree_space = _ccs_get_function("ccs_create_dynamic_tree_space", [ct.c_char_p, ccs_tree, ccs_feature_space, ccs_rng, ct.POINTER(DynamicTreeSpaceVector), ct.py_object, ct.POINTER(ccs_tree_space)]) ccs_dynamic_tree_space_get_tree_space_data = _ccs_get_function("ccs_dynamic_tree_space_get_tree_space_data", [ccs_tree_space, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize): - vec = DynamicTreeSpaceVector() - def delete_wrapper(ts): - try: - ts = ct.cast(ts, ccs_tree_space) - o = Object.from_handle(ts) - tsdata = o.tree_space_data - if delete is not None: - delete(o) - if tsdata is not None: - ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) - ct.pythonapi.Py_DecRef(ct.py_object(vec)) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_child_wrapper(ts, parent, index, p_child): - try: - ts = ct.cast(ts, ccs_tree_space) - parent = ct.cast(parent, ccs_tree) - child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) - res = ccs_retain_object(child.handle) - Error.check(res) - p_child[0] = child.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if serialize is not None: - def serialize_wrapper(ts, state_size, p_state, p_state_size): - try: - ts = ct.cast(ts, ccs_tree_space) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(TreeSpace.from_handle(ts), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data): - try: - t = ct.cast(tree, ccs_tree) - p_s = ct.cast(p_state, ct.c_void_p) - p_t = ct.cast(p_tree_space_data, ct.c_void_p) - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - tree_space_data = deserialize(Tree.from_handle(t), FeatureSpace.from_handle(feature_space) if feature_space else None, state) - c_tree_space_data = ct.py_object(tree_space_data) - p_t[0] = c_tree_space_data - ct.pythonapi.Py_IncRef(c_tree_space_data) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - delete_wrapper_func = ccs_dynamic_tree_space_del_type(delete_wrapper) - get_child_wrapper_func = ccs_dynamic_tree_space_get_child_type(get_child_wrapper) - serialize_wrapper_func = ccs_dynamic_tree_space_serialize_type(serialize_wrapper) - deserialize_wrapper_func = ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper) - vec.delete = delete_wrapper_func - vec.get_child = get_child_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func - - setattr(vec, '_wrappers', ( - delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - get_child_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func)) - return vec - class DynamicTreeSpace(TreeSpace): def __init__(self, handle = None, retain = False, auto_release = True, @@ -273,7 +186,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if get_child is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) + vec = self.get_vector(delete, get_child, serialize, deserialize) if tree_space_data is not None: c_tree_space_data = ct.py_object(tree_space_data) else: @@ -292,18 +205,6 @@ def __init__(self, handle = None, retain = False, auto_release = True, else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, get_child, serialize = None, deserialize = None, tree_space_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if get_child is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - - vec = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) - res = super().deserialize(format = format, handle_map = handle_map, vector = vec, data = tree_space_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - ct.pythonapi.Py_IncRef(ct.py_object(vec)) - if tree_space_data is not None: - ct.pythonapi.Py_IncRef(ct.py_object(tree_space_data)) - return res - @property def tree_space_data(self): if hasattr(self, "_tree_space_data"): @@ -317,6 +218,96 @@ def tree_space_data(self): self._tree_space_data = None return self._tree_space_data + @classmethod + def get_vector(self, delete = None, get_child = None, serialize = None, deserialize = None): + if get_child is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + vec = DynamicTreeSpaceVector() + def delete_wrapper(ts): + try: + ts = ct.cast(ts, ccs_tree_space) + o = Object.from_handle(ts) + tsdata = o.tree_space_data + if delete is not None: + delete(o) + if tsdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_child_wrapper(ts, parent, index, p_child): + try: + ts = ct.cast(ts, ccs_tree_space) + parent = ct.cast(parent, ccs_tree) + child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) + res = ccs_retain_object(child.handle) + Error.check(res) + p_child[0] = child.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if serialize is not None: + def serialize_wrapper(ts, state_size, p_state, p_state_size): + try: + ts = ct.cast(ts, ccs_tree_space) + p_s = ct.cast(p_state, ct.c_void_p) + p_sz = ct.cast(p_state_size, ct.c_void_p) + state = serialize(TreeSpace.from_handle(ts), True if state_size == 0 else False) + if p_s.value is not None and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_s.value is not None: + ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) + if p_sz.value is not None: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data): + try: + t = ct.cast(tree, ccs_tree) + p_s = ct.cast(p_state, ct.c_void_p) + p_t = ct.cast(p_tree_space_data, ct.c_void_p) + if p_s.value is None: + state = None + else: + state = ct.cast(p_s, POINTER(c_byte * state_size)) + tree_space_data = deserialize(Tree.from_handle(t), FeatureSpace.from_handle(feature_space) if feature_space else None, state) + c_tree_space_data = ct.py_object(tree_space_data) + p_t[0] = c_tree_space_data + ct.pythonapi.Py_IncRef(c_tree_space_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_dynamic_tree_space_del_type(delete_wrapper) + get_child_wrapper_func = ccs_dynamic_tree_space_get_child_type(get_child_wrapper) + serialize_wrapper_func = ccs_dynamic_tree_space_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.get_child = get_child_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + get_child_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + get_child_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + TreeSpace.Dynamic = DynamicTreeSpace from .tree_configuration import TreeConfiguration diff --git a/bindings/python/cconfigspace/tuner.py b/bindings/python/cconfigspace/tuner.py index 55881bb4..fe12930a 100644 --- a/bindings/python/cconfigspace/tuner.py +++ b/bindings/python/cconfigspace/tuner.py @@ -184,197 +184,6 @@ class UserDefinedTunerVector(ct.Structure): ccs_create_user_defined_tuner = _ccs_get_function("ccs_create_user_defined_tuner", [ct.c_char_p, ccs_objective_space, ct.POINTER(UserDefinedTunerVector), ct.py_object, ct.POINTER(ccs_tuner)]) ccs_user_defined_tuner_get_tuner_data = _ccs_get_function("ccs_user_defined_tuner_get_tuner_data", [ccs_tuner, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize): - vec = UserDefinedTunerVector() - def delete_wrapper(tun): - try: - tun = ct.cast(tun, ccs_tuner) - o = Object.from_handle(tun) - tdata = o.tuner_data - if delete is not None: - delete(o) - if tdata is not None: - ct.pythonapi.Py_DecRef(ct.py_object(tdata)) - ct.pythonapi.Py_DecRef(ct.py_object(vec)) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def ask_wrapper(tun, features, count, p_configurations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_confs = ct.cast(p_configurations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) - if p_confs.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_confs.value is not None: - for i in range(len(configurations)): - res = ccs_retain_object(configurations[i].handle) - Error.check(res) - p_configurations[i] = configurations[i].handle.value - for i in range(len(configurations), count): - p_configurations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def tell_wrapper(tun, count, p_evaluations): - try: - tun = ct.cast(tun, ccs_tuner) - if count == 0: - return Result.SUCCESS - p_evals = ct.cast(p_evaluations, ct.c_void_p) - if p_evals.value is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] - tell(Tuner.from_handle(tun), evals) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_optima_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = len(optima) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = optima[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_history_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = (len(history) if history else 0) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = history[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if suggest is not None: - def suggest_wrapper(tun, features, p_configuration): - try: - tun = ct.cast(tun, ccs_tuner) - configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - res = ccs_retain_object(configuration.handle) - Error.check(res) - p_configuration[0] = configuration.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - suggest_wrapper = 0 - - if serialize is not None: - def serialize_wrapper(tun, state_size, p_state, p_state_size): - try: - tun = ct.cast(tun, ccs_tuner) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(Tuner.from_handle(tun), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data): - try: - o_space = ct.cast(o_space, ccs_objective_space) - p_h = ct.cast(p_history, ct.c_void_p) - p_o = ct.cast(p_optima, ct.c_void_p) - p_s = ct.cast(p_state, ct.c_void_p) - p_t = ct.cast(p_tuner_data, ct.c_void_p) - if p_h.value is None: - history = [] - else: - history = [Evaluation.from_handle(ccs_evaluation(p_h[i])) for i in range(size_history)] - if p_o.value is None: - optima = [] - else: - optima = [Evaluation.from_handle(ccs_evaluation(p_o[i])) for i in range(num_optima)] - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) - c_tuner_data = ct.py_object(tuner_data) - p_t[0] = c_tuner_data - ct.pythonapi.Py_IncRef(c_tuner_data) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper) - ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper) - tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper) - get_optima_wrapper_func = ccs_user_defined_tuner_get_optima_type(get_optima_wrapper) - get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper) - suggest_wrapper_func = ccs_user_defined_tuner_suggest_type(suggest_wrapper) - serialize_wrapper_func = ccs_user_defined_tuner_serialize_type(serialize_wrapper) - deserialize_wrapper_func = ccs_user_defined_tuner_deserialize_type(deserialize_wrapper) - vec.delete = delete_wrapper_func - vec.ask = ask_wrapper_func - vec.tell = tell_wrapper_func - vec.get_optima = get_optima_wrapper_func - vec.get_history = get_history_wrapper_func - vec.suggest = suggest_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func - - setattr(vec, '_wrappers', ( - delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - ask_wrapper_func, - tell_wrapper_func, - get_optima_wrapper_func, - get_history_wrapper_func, - suggest_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func)) - return vec - - class UserDefinedTuner(Tuner): def __init__(self, handle = None, retain = False, auto_release = True, name = "", objective_space = None, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None, tuner_data = None ): @@ -382,7 +191,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if ask is None or tell is None or get_optima is None or get_history is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + vec = self.get_vector(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) if tuner_data is not None: c_tuner_data = ct.py_object(tuner_data) else: @@ -397,17 +206,6 @@ def __init__(self, handle = None, retain = False, auto_release = True, else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, serialize = None, deserialize = None, tuner_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if ask is None or tell is None or get_optima is None or get_history is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - res = super().deserialize(format = format, handle_map = handle_map, vector = vec, data = tuner_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - ct.pythonapi.Py_IncRef(ct.py_object(vec)) - if tuner_data is not None: - ct.pythonapi.Py_IncRef(ct.py_object(tuner_data)) - return res - @property def tuner_data(self): if hasattr(self, "_tuner_data"): @@ -421,4 +219,195 @@ def tuner_data(self): self._tuner_data = None return self._tuner_data + @classmethod + def get_vector(self, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None): + vec = UserDefinedTunerVector() + def delete_wrapper(tun): + try: + tun = ct.cast(tun, ccs_tuner) + o = Object.from_handle(tun) + tdata = o.tuner_data + if delete is not None: + delete(o) + if tdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def ask_wrapper(tun, features, count, p_configurations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_confs = ct.cast(p_configurations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) + if p_confs.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_confs.value is not None: + for i in range(len(configurations)): + res = ccs_retain_object(configurations[i].handle) + Error.check(res) + p_configurations[i] = configurations[i].handle.value + for i in range(len(configurations), count): + p_configurations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def tell_wrapper(tun, count, p_evaluations): + try: + tun = ct.cast(tun, ccs_tuner) + if count == 0: + return Result.SUCCESS + p_evals = ct.cast(p_evaluations, ct.c_void_p) + if p_evals.value is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] + tell(Tuner.from_handle(tun), evals) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_optima_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = len(optima) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = optima[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_history_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = (len(history) if history else 0) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = history[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if suggest is not None: + def suggest_wrapper(tun, features, p_configuration): + try: + tun = ct.cast(tun, ccs_tuner) + configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + res = ccs_retain_object(configuration.handle) + Error.check(res) + p_configuration[0] = configuration.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + suggest_wrapper = 0 + + if serialize is not None: + def serialize_wrapper(tun, state_size, p_state, p_state_size): + try: + tun = ct.cast(tun, ccs_tuner) + p_s = ct.cast(p_state, ct.c_void_p) + p_sz = ct.cast(p_state_size, ct.c_void_p) + state = serialize(Tuner.from_handle(tun), True if state_size == 0 else False) + if p_s.value is not None and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_s.value is not None: + ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) + if p_sz.value is not None: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data): + try: + o_space = ct.cast(o_space, ccs_objective_space) + p_h = ct.cast(p_history, ct.c_void_p) + p_o = ct.cast(p_optima, ct.c_void_p) + p_s = ct.cast(p_state, ct.c_void_p) + p_t = ct.cast(p_tuner_data, ct.c_void_p) + if p_h.value is None: + history = [] + else: + history = [Evaluation.from_handle(ccs_evaluation(p_h[i])) for i in range(size_history)] + if p_o.value is None: + optima = [] + else: + optima = [Evaluation.from_handle(ccs_evaluation(p_o[i])) for i in range(num_optima)] + if p_s.value is None: + state = None + else: + state = ct.cast(p_s, POINTER(c_byte * state_size)) + tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) + c_tuner_data = ct.py_object(tuner_data) + p_t[0] = c_tuner_data + ct.pythonapi.Py_IncRef(c_tuner_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper) + ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper) + tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper) + get_optima_wrapper_func = ccs_user_defined_tuner_get_optima_type(get_optima_wrapper) + get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper) + suggest_wrapper_func = ccs_user_defined_tuner_suggest_type(suggest_wrapper) + serialize_wrapper_func = ccs_user_defined_tuner_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_user_defined_tuner_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.ask = ask_wrapper_func + vec.tell = tell_wrapper_func + vec.get_optima = get_optima_wrapper_func + vec.get_history = get_history_wrapper_func + vec.suggest = suggest_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + ask_wrapper, + tell_wrapper, + get_optima_wrapper, + get_history_wrapper, + suggest_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + ask_wrapper_func, + tell_wrapper_func, + get_optima_wrapper_func, + get_history_wrapper_func, + suggest_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + Tuner.UserDefined = UserDefinedTuner diff --git a/bindings/python/test/test_features_tuner.py b/bindings/python/test/test_features_tuner.py index cd6ce770..a9c5f49c 100644 --- a/bindings/python/test/test_features_tuner.py +++ b/bindings/python/test/test_features_tuner.py @@ -118,6 +118,9 @@ def suggest(tuner, features): else: return choice(optis).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + (fs, os) = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -155,7 +158,7 @@ def suggest(tuner, features): self.assertTrue(t.suggest(features_off) in [x.configuration for x in optims]) # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tree_space.py b/bindings/python/test/test_tree_space.py index 39f1ef4e..1ea927a3 100644 --- a/bindings/python/test/test_tree_space.py +++ b/bindings/python/test/test_tree_space.py @@ -54,6 +54,9 @@ def get_child(tree_space, parent, child_index): arity = 0 if arity < 0 else arity return ccs.Tree(arity = arity, value = (4 - child_depth)*100 + child_index) + def get_vector_data(otype, name, cb_data): + return (ccs.DynamicTreeSpace.get_vector(delete = delete, get_child = get_child), None) + tree = ccs.Tree(arity = 4, value = 400) ts = ccs.DynamicTreeSpace(name = 'space', tree = tree, delete = delete, get_child = get_child) self.assertEqual( ccs.ObjectType.TREE_SPACE, ts.object_type ) @@ -75,7 +78,7 @@ def get_child(tree_space, parent, child_index): self.assertTrue( ts.check_configuration(tc) ) buff = ts.serialize() - ts2 = ccs.DynamicTreeSpace.deserialize(buffer = buff, delete = delete, get_child = get_child) + ts2 = ccs.DynamicTreeSpace.deserialize(buffer = buff, vector_callback = get_vector_data) self.assertEqual( [400, 301, 201], ts2.get_values_at_position([1, 1]) ) def test_tree_configuration(self): diff --git a/bindings/python/test/test_tree_tuner.py b/bindings/python/test/test_tree_tuner.py index 9afe2a5a..1f45f221 100644 --- a/bindings/python/test/test_tree_tuner.py +++ b/bindings/python/test/test_tree_tuner.py @@ -107,6 +107,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -125,7 +128,7 @@ def suggest(tuner, features): self.assertTrue(all(best >= x.objective_values[0] for x in hist)) self.assertTrue(t.suggest() in [x.configuration for x in optims]) buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tuner.py b/bindings/python/test/test_tuner.py index 577ba5d6..32ad30a9 100644 --- a/bindings/python/test/test_tuner.py +++ b/bindings/python/test/test_tuner.py @@ -101,6 +101,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -126,7 +129,7 @@ def suggest(tuner, features): # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -137,7 +140,7 @@ def suggest(tuner, features): self.assertTrue(t_copy.suggest() in [x.configuration for x in optims_2]) t.serialize(path = 'tuner.ccs') - t_copy = ccs.UserDefinedTuner.deserialize(path = 'tuner.ccs', delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.UserDefinedTuner.deserialize(path = 'tuner.ccs', vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -152,7 +155,7 @@ def suggest(tuner, features): t.serialize(file_descriptor = file.fileno()) file.close() file = open( 'tuner.ccs', "rb") - t_copy = ccs.UserDefinedTuner.deserialize(file_descriptor = file.fileno(), delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.UserDefinedTuner.deserialize(file_descriptor = file.fileno(), vector_callback = get_vector_data) file.close() hist = t_copy.history() self.assertEqual(200, len(hist)) diff --git a/include/cconfigspace/base.h b/include/cconfigspace/base.h index 8580feab..98e60d07 100644 --- a/include/cconfigspace/base.h +++ b/include/cconfigspace/base.h @@ -965,9 +965,9 @@ enum ccs_serialize_option_e { typedef enum ccs_serialize_option_e ccs_serialize_option_t; /** - * The type of CCS object deserialization callbacks. - * This callback is used to deserialize object information that were created by - * the serialization callback. + * The type of CCS object data deserialization callbacks. This callback is + * used to deserialize object data information that were created by the + * serialization callback. * @param[in, out] object a CCS object * @param[in] serialize_data_size the size of the memory pointed to by * \p serialize_data @@ -981,12 +981,41 @@ typedef enum ccs_serialize_option_e ccs_serialize_option_t; * @remarks * This function must be thread-safe for serialization to be thread safe. */ -typedef ccs_result_t (*ccs_object_deserialize_callback_t)( +typedef ccs_result_t (*ccs_object_deserialize_data_callback_t)( ccs_object_t object, size_t serialize_data_size, const char *serialize_data, void *callback_user_data); +/** + * The type of CCS user object vector deserialization callbacks. + * This callback is used to obtain the vector to use for a user defined object. + * @param[in] type the type of CCS object for which the vector is required + * @param[in] name the name of the object as it was orignally provided at + * creation + * @param[in] callback_user_data the pointer provided when the callback + * was attached. + * @param[out] vector_ret a pointer that will hold the returned pointer to the + * vector structure + * @param[out] data_ret a pointer that will optionally hold the user data value + * to use when initializing the user object. + * If the provided vector holds the relevant + * deserialization callback, this value will be + * ignored. + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_TYPE if the object type is not expected by + * the callback + * @return #CCS_RESULT_ERROR_INVALID_NAME if the object name is not recognized + * @remarks + * This function must be thread-safe for deserialization to be thread safe. + */ +typedef ccs_result_t (*ccs_object_deserialize_vector_callback_t)( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret); + /** * The different deserialization options. */ @@ -1001,15 +1030,11 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_HANDLE_MAP, /** - * The next parameter is a pointer to a ccs object vector struct, for - * user defined tuners - */ - CCS_DESERIALIZE_OPTION_VECTOR, - /** - * The next parameter is a pointer to a ccs object internal data, for - * user defined tuners + * The next parameter is a pointer to a callback of type + * ccs_object_deserialize_vector_callback_t and its user data, that + * will return the ccs object vector struct, for user defined objects */ - CCS_DESERIALIZE_OPTION_DATA, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, /** * The file descriptor operation is non-blocking. The next parameter is * a pointer to a void * variable (initialized to NULL) that will hold @@ -1019,12 +1044,12 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_NON_BLOCKING, /** - * The next parameters are a deserialization callback and it's + * The next parameters are a deserialization callback and its * user_data. This callback will be called for all objects that had * their user_data serialized. If no such callback is provided the * object's user_data value will not be set. */ - CCS_DESERIALIZE_OPTION_CALLBACK, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, /** Guard */ CCS_DESERIALIZE_OPTION_MAX, /** Try forcing 32 bits value for bindings */ diff --git a/src/cconfigspace_deserialize.h b/src/cconfigspace_deserialize.h index ff71bdbb..6221093f 100644 --- a/src/cconfigspace_deserialize.h +++ b/src/cconfigspace_deserialize.h @@ -54,12 +54,11 @@ _ccs_object_deserialize_options( CCS_CHECK_OBJ(opts->handle_map, CCS_OBJECT_TYPE_MAP); opts->map_values = CCS_TRUE; break; - case CCS_DESERIALIZE_OPTION_VECTOR: - opts->vector = va_arg(args, void *); - CCS_CHECK_PTR(opts->vector); - break; - case CCS_DESERIALIZE_OPTION_DATA: - opts->data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK: + opts->deserialize_vector_callback = + va_arg(args, ccs_object_deserialize_vector_callback_t); + CCS_CHECK_PTR(opts->deserialize_vector_callback); + opts->deserialize_vector_user_data = va_arg(args, void *); break; case CCS_DESERIALIZE_OPTION_NON_BLOCKING: CCS_REFUTE( @@ -70,11 +69,11 @@ _ccs_object_deserialize_options( va_arg(args, _ccs_file_descriptor_state_t **); CCS_CHECK_PTR(opts->ppfd_state); break; - case CCS_DESERIALIZE_OPTION_CALLBACK: - opts->deserialize_callback = - va_arg(args, ccs_object_deserialize_callback_t); - CCS_CHECK_PTR(opts->deserialize_callback); - opts->deserialize_user_data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_DATA_CALLBACK: + opts->deserialize_data_callback = + va_arg(args, ccs_object_deserialize_data_callback_t); + CCS_CHECK_PTR(opts->deserialize_data_callback); + opts->deserialize_data_user_data = va_arg(args, void *); break; default: CCS_RAISE( diff --git a/src/cconfigspace_internal.h b/src/cconfigspace_internal.h index e503162d..4fb4c980 100644 --- a/src/cconfigspace_internal.h +++ b/src/cconfigspace_internal.h @@ -1299,10 +1299,10 @@ struct _ccs_object_deserialize_options_s { ccs_map_t handle_map; ccs_bool_t map_values; _ccs_file_descriptor_state_t **ppfd_state; - void *vector; - void *data; - ccs_object_deserialize_callback_t deserialize_callback; - void *deserialize_user_data; + ccs_object_deserialize_vector_callback_t deserialize_vector_callback; + void *deserialize_vector_user_data; + ccs_object_deserialize_data_callback_t deserialize_data_callback; + void *deserialize_data_user_data; }; typedef struct _ccs_object_deserialize_options_s _ccs_object_deserialize_options_t; @@ -1449,10 +1449,10 @@ _ccs_object_deserialize_user_data( CCS_VALIDATE(_ccs_deserialize_bin_size( &serialize_data_size, buffer_size, buffer)); if (serialize_data_size) { - if (opts->deserialize_callback) - CCS_VALIDATE(opts->deserialize_callback( + if (opts->deserialize_data_callback) + CCS_VALIDATE(opts->deserialize_data_callback( object, serialize_data_size, *buffer, - opts->deserialize_user_data)); + opts->deserialize_data_user_data)); *buffer_size -= serialize_data_size; *buffer += serialize_data_size; } diff --git a/src/tree_space_deserialize.h b/src/tree_space_deserialize.h index f9808126..3414ffb3 100644 --- a/src/tree_space_deserialize.h +++ b/src/tree_space_deserialize.h @@ -88,27 +88,38 @@ _ccs_deserialize_bin_tree_space_dynamic( _ccs_object_deserialize_options_t *opts, _ccs_tree_space_common_data_mock_t *data) { - _ccs_blob_t blob = {0, NULL}; - ccs_dynamic_tree_space_vector_t *vector = - (ccs_dynamic_tree_space_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + _ccs_blob_t blob = {0, NULL}; + ccs_dynamic_tree_space_vector_t *vector = NULL; + void *tree_space_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_tree_space_common_data( data, version, buffer_size, buffer, opts), end); - CCS_VALIDATE(_ccs_deserialize_bin_ccs_blob(&blob, buffer_size, buffer)); + CCS_VALIDATE_ERR_GOTO( + res, + _ccs_deserialize_bin_ccs_blob( + &blob, buffer_size, buffer), + end); + + CCS_VALIDATE_ERR_GOTO( + res, + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TREE_SPACE, + data->name, + opts->deserialize_vector_user_data, + (void**)&vector, &tree_space_data), + end); - void *tree_space_data; if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, vector->deserialize_state( data->tree, data->feature_space, blob.sz, blob.blob, &tree_space_data), - tree_space); - else - tree_space_data = opts->data; + end); CCS_VALIDATE_ERR_GOTO( res, @@ -117,10 +128,6 @@ _ccs_deserialize_bin_tree_space_dynamic( data->feature_space, data->rng, vector, tree_space_data, tree_space_ret), end); - goto end; -tree_space: - ccs_release_object(*tree_space_ret); - *tree_space_ret = NULL; end: if (data->feature_space) ccs_release_object(data->feature_space); @@ -144,6 +151,8 @@ _ccs_deserialize_bin_tree_space( ccs_tree_space_type_t stype; CCS_VALIDATE( _ccs_peek_bin_ccs_tree_space_type(&stype, buffer_size, buffer)); + if (stype == CCS_TREE_SPACE_TYPE_DYNAMIC) + CCS_CHECK_PTR(opts->deserialize_vector_callback); _ccs_tree_space_common_data_mock_t data = { CCS_TREE_SPACE_TYPE_STATIC, NULL, NULL, NULL, NULL, NULL}; diff --git a/src/tuner_deserialize.h b/src/tuner_deserialize.h index d3335457..40293814 100644 --- a/src/tuner_deserialize.h +++ b/src/tuner_deserialize.h @@ -179,16 +179,25 @@ _ccs_deserialize_bin_user_defined_tuner( NULL, NULL}, {0, NULL}}; - ccs_user_defined_tuner_vector_t *vector = - (ccs_user_defined_tuner_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + ccs_user_defined_tuner_vector_t *vector = NULL; + void *tuner_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_user_defined_tuner_data( &data, version, buffer_size, buffer, opts), end); - void *tuner_data; + CCS_VALIDATE_ERR_GOTO( + res, + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TUNER, + data.base_data.common_data.name, + opts->deserialize_vector_user_data, + (void **)&vector, &tuner_data), + end); + if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, @@ -200,8 +209,6 @@ _ccs_deserialize_bin_user_defined_tuner( data.base_data.optima, data.blob.sz, data.blob.blob, &tuner_data), end); - else - tuner_data = opts->data; CCS_VALIDATE_ERR_GOTO( res, @@ -247,7 +254,7 @@ _ccs_deserialize_bin_tuner( ccs_tuner_type_t ttype; CCS_VALIDATE(_ccs_peek_bin_ccs_tuner_type(&ttype, buffer_size, buffer)); if (ttype == CCS_TUNER_TYPE_USER_DEFINED) - CCS_CHECK_PTR(opts->vector); + CCS_CHECK_PTR(opts->deserialize_vector_callback); new_opts.map_values = CCS_TRUE; CCS_VALIDATE(ccs_create_map(&new_opts.handle_map)); diff --git a/tests/test_categorical_parameter.c b/tests/test_categorical_parameter.c index 2104fd21..8f11658c 100644 --- a/tests/test_categorical_parameter.c +++ b/tests/test_categorical_parameter.c @@ -179,7 +179,7 @@ test_create(void) err = ccs_object_deserialize( (ccs_object_t *)¶meter, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_CALLBACK, &deserialize_callback, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, &deserialize_callback, (void *)0xbeefdead, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_dynamic_tree_space.c b/tests/test_dynamic_tree_space.c index adf0cb01..3aa5c68f 100644 --- a/tests/test_dynamic_tree_space.c +++ b/tests/test_dynamic_tree_space.c @@ -35,6 +35,30 @@ my_tree_get_child( return CCS_RESULT_SUCCESS; } +ccs_dynamic_tree_space_vector_t vector = { + &my_tree_del, &my_tree_get_child, NULL, NULL}; + +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TREE_SPACE: + *vector_ret = (void *)&vector; + *data_ret = NULL; + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test_dynamic_tree_space(void) { @@ -51,8 +75,6 @@ test_dynamic_tree_space(void) const char *name; ccs_tree_configuration_t config, configs[NUM_SAMPLES]; - ccs_dynamic_tree_space_vector_t vector = { - &my_tree_del, &my_tree_get_child, NULL, NULL}; err = ccs_create_tree(4, ccs_int(4 * 100), &root); assert(err == CCS_RESULT_SUCCESS); err = ccs_create_rng(&rng); @@ -176,7 +198,8 @@ test_dynamic_tree_space(void) err = ccs_object_deserialize( (ccs_object_t *)&tree_space, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_VECTOR, &vector, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_user_defined_features_tuner.c b/tests/test_user_defined_features_tuner.c index d319c55b..46a86bd9 100644 --- a/tests/test_user_defined_features_tuner.c +++ b/tests/test_user_defined_features_tuner.c @@ -147,6 +147,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -263,15 +285,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tree_tuner.c b/tests/test_user_defined_tree_tuner.c index 8bfa2977..64c6fb4d 100644 --- a/tests/test_user_defined_tree_tuner.c +++ b/tests/test_user_defined_tree_tuner.c @@ -193,6 +193,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -274,15 +296,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tuner.c b/tests/test_user_defined_tuner.c index e56e9d1c..c590fddd 100644 --- a/tests/test_user_defined_tuner.c +++ b/tests/test_user_defined_tuner.c @@ -140,6 +140,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -217,15 +239,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS);