diff --git a/bindings/python/cconfigspace/base.py b/bindings/python/cconfigspace/base.py index d25593f7..da1707cf 100644 --- a/bindings/python/cconfigspace/base.py +++ b/bindings/python/cconfigspace/base.py @@ -395,10 +395,9 @@ class DeserializeOption(CEnumeration): _members_ = [ ('END', 0), 'HANDLE_MAP', - 'VECTOR', - 'DATA', + 'VECTOR_CALLBACK', 'NON_BLOCKING', - 'CALLBACK' + 'DATA_CALLBACK' ] def _ccs_get_function(method, argtypes = [], restype = Result): @@ -422,7 +421,8 @@ def _ccs_get_function(method, argtypes = [], restype = Result): ccs_object_get_user_data = _ccs_get_function("ccs_object_get_user_data", [ccs_object, ct.POINTER(ct.c_void_p)]) ccs_object_serialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t), ct.c_void_p) ccs_object_set_serialize_callback = _ccs_get_function("ccs_object_set_serialize_callback", [ccs_object, ccs_object_serialize_callback_type, ct.c_void_p]) -ccs_object_deserialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p) +ccs_object_deserialize_data_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p) +ccs_object_deserialize_vector_callback_type = ct.CFUNCTYPE(Result, ct.c_int, ct.c_char_p, ct.c_void_p, ct.POINTER(ct.c_void_p), ct.POINTER(ct.py_object)) # Variadic methods ccs_object_serialize = getattr(libcconfigspace, "ccs_object_serialize") ccs_object_serialize.argtypes = ccs_object, SerializeFormat, SerializeOperation, @@ -556,7 +556,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call return v @classmethod - def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): + def deserialize(cls, format = 'binary', handle_map = None, vector_callback = None, vector_callback_data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): if format != 'binary': raise Error(Result(Result.ERROR_INVALID_VALUE)) mode_count = 0; @@ -572,16 +572,16 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = options = [DeserializeOption.END] if handle_map: options = [DeserializeOption.HANDLE_MAP, handle_map.handle] + options - if vector: - options = [DeserializeOption.VECTOR, ct.byref(vector)] + options - if data: - options = [DeserializeOption.DATA, ct.py_object(data)] + options + if vector_callback: + vector_cb_wrapper = _get_deserialize_vector_callback_wrapper(vector_callback) + vector_cb_wrapper_func = ccs_object_deserialize_vector_callback_type(vector_cb_wrapper) + options = [DeserializeOption.VECTOR_CALLBACK, vector_cb_wrapper_func, ct.py_object(vector_callback_data)] + options if callback: - cb_wrapper = _get_deserialize_callback_wrapper(callback) - cb_wrapper_func = ccs_object_deserialize_callback_type(cb_wrapper) - options = [DeserializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options + cb_wrapper = _get_deserialize_data_callback_wrapper(callback) + cb_wrapper_func = ccs_object_deserialize_data_callback_type(cb_wrapper) + options = [DeserializeOption.DATA_CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options elif _default_user_data_deserializer: - options = [DeserializeOption.CALLBACK, _default_user_data_deserializer, ct.py_object()] + options + options = [DeserializeOption.DATA_CALLBACK, _default_user_data_deserializer, ct.py_object()] + options if buffer: s = ct.c_size_t(ct.sizeof(buffer)) res = ccs_object_deserialize(ct.byref(o), SerializeFormat.BINARY, SerializeOperation.MEMORY, s, buffer, *options) @@ -621,8 +621,8 @@ def serialize_callback_wrapper(obj, serialize_data_size, serialize_data, seriali return Error.set_error(e) return serialize_callback_wrapper -def _get_deserialize_callback_wrapper(callback): - def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data): +def _get_deserialize_data_callback_wrapper(callback): + def deserialize_data_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data): try: p_sd = ct.cast(serialize_data, ct.c_void_p) cb_data = ct.cast(cb_data, ct.c_void_p) @@ -638,7 +638,24 @@ def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_da return Result.SUCCESS except Exception as e: return Error.set_error(e) - return deserialize_callback_wrapper + return deserialize_data_callback_wrapper + +def _get_deserialize_vector_callback_wrapper(callback): + def deserialize_vector_callback_wrapper(obj_type, name, callback_user_data, vector_ret, data_ret): + try: + cb_data = ct.cast(callback_user_data, ct.py_object).value if callback_user_data else None + o_type = ObjectType(obj_type) + (vector, data) = callback(o_type, name, cb_data) + c_vector = ct.py_object(vector) + c_data = ct.py_object(data) + vector_ret[0] = ct.cast(ct.byref(vector), ct.c_void_p) + data_ret[0] = c_data + ct.pythonapi.Py_IncRef(c_vector) + ct.pythonapi.Py_IncRef(c_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + return deserialize_vector_callback_wrapper def _json_user_data_serializer(obj, data, size): string = json.dumps(obj.user_data).encode("ascii") @@ -651,8 +668,8 @@ def _json_user_data_deserializer(obj, serialized, data): _json_user_data_serializer_wrap = _get_serialize_callback_wrapper(_json_user_data_serializer) _json_user_data_serializer_func = ccs_object_serialize_callback_type(_json_user_data_serializer_wrap) -_json_user_data_deserializer_wrap = _get_deserialize_callback_wrapper(_json_user_data_deserializer) -_json_user_data_deserializer_func = ccs_object_deserialize_callback_type(_json_user_data_deserializer_wrap) +_json_user_data_deserializer_wrap = _get_deserialize_data_callback_wrapper(_json_user_data_deserializer) +_json_user_data_deserializer_func = ccs_object_deserialize_data_callback_type(_json_user_data_deserializer_wrap) _default_user_data_serializer = _json_user_data_serializer_func _default_user_data_deserializer = _json_user_data_deserializer_func @@ -694,8 +711,8 @@ def _register_serialize_callback(handle, callback_data): _register_destroy_callback(handle) _data_store[value]['serialize_calback'] = callback_data -def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, callback = None, callback_data = None): - return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, callback = callback, callback_data = callback_data) +def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, vector_callback = None, vector_callback_data = None, callback = None, callback_data = None): + return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, file_descriptor = file_descriptor, vector_callback = vector_callback, vector_callback_data = vector_callback_data, callback = callback, callback_data = callback_data) def _set_destroy_callback(handle, callback, user_data = None): if callback is None: diff --git a/bindings/python/cconfigspace/tree_space.py b/bindings/python/cconfigspace/tree_space.py index 9d737d34..43297068 100644 --- a/bindings/python/cconfigspace/tree_space.py +++ b/bindings/python/cconfigspace/tree_space.py @@ -178,93 +178,6 @@ class DynamicTreeSpaceVector(ct.Structure): ccs_create_dynamic_tree_space = _ccs_get_function("ccs_create_dynamic_tree_space", [ct.c_char_p, ccs_tree, ccs_feature_space, ccs_rng, ct.POINTER(DynamicTreeSpaceVector), ct.py_object, ct.POINTER(ccs_tree_space)]) ccs_dynamic_tree_space_get_tree_space_data = _ccs_get_function("ccs_dynamic_tree_space_get_tree_space_data", [ccs_tree_space, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize): - vec = DynamicTreeSpaceVector() - def delete_wrapper(ts): - try: - ts = ct.cast(ts, ccs_tree_space) - o = Object.from_handle(ts) - tsdata = o.tree_space_data - if delete is not None: - delete(o) - if tsdata is not None: - ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) - ct.pythonapi.Py_DecRef(ct.py_object(vec)) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_child_wrapper(ts, parent, index, p_child): - try: - ts = ct.cast(ts, ccs_tree_space) - parent = ct.cast(parent, ccs_tree) - child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) - res = ccs_retain_object(child.handle) - Error.check(res) - p_child[0] = child.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if serialize is not None: - def serialize_wrapper(ts, state_size, p_state, p_state_size): - try: - ts = ct.cast(ts, ccs_tree_space) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(TreeSpace.from_handle(ts), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data): - try: - t = ct.cast(tree, ccs_tree) - p_s = ct.cast(p_state, ct.c_void_p) - p_t = ct.cast(p_tree_space_data, ct.c_void_p) - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - tree_space_data = deserialize(Tree.from_handle(t), FeatureSpace.from_handle(feature_space) if feature_space else None, state) - c_tree_space_data = ct.py_object(tree_space_data) - p_t[0] = c_tree_space_data - ct.pythonapi.Py_IncRef(c_tree_space_data) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - delete_wrapper_func = ccs_dynamic_tree_space_del_type(delete_wrapper) - get_child_wrapper_func = ccs_dynamic_tree_space_get_child_type(get_child_wrapper) - serialize_wrapper_func = ccs_dynamic_tree_space_serialize_type(serialize_wrapper) - deserialize_wrapper_func = ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper) - vec.delete = delete_wrapper_func - vec.get_child = get_child_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func - - setattr(vec, '_wrappers', ( - delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - get_child_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func)) - return vec - class DynamicTreeSpace(TreeSpace): def __init__(self, handle = None, retain = False, auto_release = True, @@ -273,7 +186,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if get_child is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) + vec = self.get_vector(delete, get_child, serialize, deserialize) if tree_space_data is not None: c_tree_space_data = ct.py_object(tree_space_data) else: @@ -292,18 +205,6 @@ def __init__(self, handle = None, retain = False, auto_release = True, else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, get_child, serialize = None, deserialize = None, tree_space_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if get_child is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - - vec = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) - res = super().deserialize(format = format, handle_map = handle_map, vector = vec, data = tree_space_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - ct.pythonapi.Py_IncRef(ct.py_object(vec)) - if tree_space_data is not None: - ct.pythonapi.Py_IncRef(ct.py_object(tree_space_data)) - return res - @property def tree_space_data(self): if hasattr(self, "_tree_space_data"): @@ -317,6 +218,96 @@ def tree_space_data(self): self._tree_space_data = None return self._tree_space_data + @classmethod + def get_vector(self, delete = None, get_child = None, serialize = None, deserialize = None): + if get_child is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + vec = DynamicTreeSpaceVector() + def delete_wrapper(ts): + try: + ts = ct.cast(ts, ccs_tree_space) + o = Object.from_handle(ts) + tsdata = o.tree_space_data + if delete is not None: + delete(o) + if tsdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_child_wrapper(ts, parent, index, p_child): + try: + ts = ct.cast(ts, ccs_tree_space) + parent = ct.cast(parent, ccs_tree) + child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) + res = ccs_retain_object(child.handle) + Error.check(res) + p_child[0] = child.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if serialize is not None: + def serialize_wrapper(ts, state_size, p_state, p_state_size): + try: + ts = ct.cast(ts, ccs_tree_space) + p_s = ct.cast(p_state, ct.c_void_p) + p_sz = ct.cast(p_state_size, ct.c_void_p) + state = serialize(TreeSpace.from_handle(ts), True if state_size == 0 else False) + if p_s.value is not None and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_s.value is not None: + ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) + if p_sz.value is not None: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data): + try: + t = ct.cast(tree, ccs_tree) + p_s = ct.cast(p_state, ct.c_void_p) + p_t = ct.cast(p_tree_space_data, ct.c_void_p) + if p_s.value is None: + state = None + else: + state = ct.cast(p_s, POINTER(c_byte * state_size)) + tree_space_data = deserialize(Tree.from_handle(t), FeatureSpace.from_handle(feature_space) if feature_space else None, state) + c_tree_space_data = ct.py_object(tree_space_data) + p_t[0] = c_tree_space_data + ct.pythonapi.Py_IncRef(c_tree_space_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_dynamic_tree_space_del_type(delete_wrapper) + get_child_wrapper_func = ccs_dynamic_tree_space_get_child_type(get_child_wrapper) + serialize_wrapper_func = ccs_dynamic_tree_space_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.get_child = get_child_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + get_child_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + get_child_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + TreeSpace.Dynamic = DynamicTreeSpace from .tree_configuration import TreeConfiguration diff --git a/bindings/python/cconfigspace/tuner.py b/bindings/python/cconfigspace/tuner.py index 55881bb4..fe12930a 100644 --- a/bindings/python/cconfigspace/tuner.py +++ b/bindings/python/cconfigspace/tuner.py @@ -184,197 +184,6 @@ class UserDefinedTunerVector(ct.Structure): ccs_create_user_defined_tuner = _ccs_get_function("ccs_create_user_defined_tuner", [ct.c_char_p, ccs_objective_space, ct.POINTER(UserDefinedTunerVector), ct.py_object, ct.POINTER(ccs_tuner)]) ccs_user_defined_tuner_get_tuner_data = _ccs_get_function("ccs_user_defined_tuner_get_tuner_data", [ccs_tuner, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize): - vec = UserDefinedTunerVector() - def delete_wrapper(tun): - try: - tun = ct.cast(tun, ccs_tuner) - o = Object.from_handle(tun) - tdata = o.tuner_data - if delete is not None: - delete(o) - if tdata is not None: - ct.pythonapi.Py_DecRef(ct.py_object(tdata)) - ct.pythonapi.Py_DecRef(ct.py_object(vec)) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def ask_wrapper(tun, features, count, p_configurations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_confs = ct.cast(p_configurations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) - if p_confs.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_confs.value is not None: - for i in range(len(configurations)): - res = ccs_retain_object(configurations[i].handle) - Error.check(res) - p_configurations[i] = configurations[i].handle.value - for i in range(len(configurations), count): - p_configurations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def tell_wrapper(tun, count, p_evaluations): - try: - tun = ct.cast(tun, ccs_tuner) - if count == 0: - return Result.SUCCESS - p_evals = ct.cast(p_evaluations, ct.c_void_p) - if p_evals.value is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] - tell(Tuner.from_handle(tun), evals) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_optima_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = len(optima) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = optima[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_history_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = (len(history) if history else 0) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = history[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if suggest is not None: - def suggest_wrapper(tun, features, p_configuration): - try: - tun = ct.cast(tun, ccs_tuner) - configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - res = ccs_retain_object(configuration.handle) - Error.check(res) - p_configuration[0] = configuration.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - suggest_wrapper = 0 - - if serialize is not None: - def serialize_wrapper(tun, state_size, p_state, p_state_size): - try: - tun = ct.cast(tun, ccs_tuner) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(Tuner.from_handle(tun), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data): - try: - o_space = ct.cast(o_space, ccs_objective_space) - p_h = ct.cast(p_history, ct.c_void_p) - p_o = ct.cast(p_optima, ct.c_void_p) - p_s = ct.cast(p_state, ct.c_void_p) - p_t = ct.cast(p_tuner_data, ct.c_void_p) - if p_h.value is None: - history = [] - else: - history = [Evaluation.from_handle(ccs_evaluation(p_h[i])) for i in range(size_history)] - if p_o.value is None: - optima = [] - else: - optima = [Evaluation.from_handle(ccs_evaluation(p_o[i])) for i in range(num_optima)] - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) - c_tuner_data = ct.py_object(tuner_data) - p_t[0] = c_tuner_data - ct.pythonapi.Py_IncRef(c_tuner_data) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper) - ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper) - tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper) - get_optima_wrapper_func = ccs_user_defined_tuner_get_optima_type(get_optima_wrapper) - get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper) - suggest_wrapper_func = ccs_user_defined_tuner_suggest_type(suggest_wrapper) - serialize_wrapper_func = ccs_user_defined_tuner_serialize_type(serialize_wrapper) - deserialize_wrapper_func = ccs_user_defined_tuner_deserialize_type(deserialize_wrapper) - vec.delete = delete_wrapper_func - vec.ask = ask_wrapper_func - vec.tell = tell_wrapper_func - vec.get_optima = get_optima_wrapper_func - vec.get_history = get_history_wrapper_func - vec.suggest = suggest_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func - - setattr(vec, '_wrappers', ( - delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - ask_wrapper_func, - tell_wrapper_func, - get_optima_wrapper_func, - get_history_wrapper_func, - suggest_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func)) - return vec - - class UserDefinedTuner(Tuner): def __init__(self, handle = None, retain = False, auto_release = True, name = "", objective_space = None, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None, tuner_data = None ): @@ -382,7 +191,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if ask is None or tell is None or get_optima is None or get_history is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + vec = self.get_vector(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) if tuner_data is not None: c_tuner_data = ct.py_object(tuner_data) else: @@ -397,17 +206,6 @@ def __init__(self, handle = None, retain = False, auto_release = True, else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, serialize = None, deserialize = None, tuner_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if ask is None or tell is None or get_optima is None or get_history is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - vec = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - res = super().deserialize(format = format, handle_map = handle_map, vector = vec, data = tuner_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - ct.pythonapi.Py_IncRef(ct.py_object(vec)) - if tuner_data is not None: - ct.pythonapi.Py_IncRef(ct.py_object(tuner_data)) - return res - @property def tuner_data(self): if hasattr(self, "_tuner_data"): @@ -421,4 +219,195 @@ def tuner_data(self): self._tuner_data = None return self._tuner_data + @classmethod + def get_vector(self, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None): + vec = UserDefinedTunerVector() + def delete_wrapper(tun): + try: + tun = ct.cast(tun, ccs_tuner) + o = Object.from_handle(tun) + tdata = o.tuner_data + if delete is not None: + delete(o) + if tdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def ask_wrapper(tun, features, count, p_configurations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_confs = ct.cast(p_configurations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) + if p_confs.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_confs.value is not None: + for i in range(len(configurations)): + res = ccs_retain_object(configurations[i].handle) + Error.check(res) + p_configurations[i] = configurations[i].handle.value + for i in range(len(configurations), count): + p_configurations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def tell_wrapper(tun, count, p_evaluations): + try: + tun = ct.cast(tun, ccs_tuner) + if count == 0: + return Result.SUCCESS + p_evals = ct.cast(p_evaluations, ct.c_void_p) + if p_evals.value is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] + tell(Tuner.from_handle(tun), evals) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_optima_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = len(optima) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = optima[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_history_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = (len(history) if history else 0) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = history[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if suggest is not None: + def suggest_wrapper(tun, features, p_configuration): + try: + tun = ct.cast(tun, ccs_tuner) + configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + res = ccs_retain_object(configuration.handle) + Error.check(res) + p_configuration[0] = configuration.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + suggest_wrapper = 0 + + if serialize is not None: + def serialize_wrapper(tun, state_size, p_state, p_state_size): + try: + tun = ct.cast(tun, ccs_tuner) + p_s = ct.cast(p_state, ct.c_void_p) + p_sz = ct.cast(p_state_size, ct.c_void_p) + state = serialize(Tuner.from_handle(tun), True if state_size == 0 else False) + if p_s.value is not None and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_s.value is not None: + ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) + if p_sz.value is not None: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data): + try: + o_space = ct.cast(o_space, ccs_objective_space) + p_h = ct.cast(p_history, ct.c_void_p) + p_o = ct.cast(p_optima, ct.c_void_p) + p_s = ct.cast(p_state, ct.c_void_p) + p_t = ct.cast(p_tuner_data, ct.c_void_p) + if p_h.value is None: + history = [] + else: + history = [Evaluation.from_handle(ccs_evaluation(p_h[i])) for i in range(size_history)] + if p_o.value is None: + optima = [] + else: + optima = [Evaluation.from_handle(ccs_evaluation(p_o[i])) for i in range(num_optima)] + if p_s.value is None: + state = None + else: + state = ct.cast(p_s, POINTER(c_byte * state_size)) + tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) + c_tuner_data = ct.py_object(tuner_data) + p_t[0] = c_tuner_data + ct.pythonapi.Py_IncRef(c_tuner_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper) + ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper) + tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper) + get_optima_wrapper_func = ccs_user_defined_tuner_get_optima_type(get_optima_wrapper) + get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper) + suggest_wrapper_func = ccs_user_defined_tuner_suggest_type(suggest_wrapper) + serialize_wrapper_func = ccs_user_defined_tuner_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_user_defined_tuner_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.ask = ask_wrapper_func + vec.tell = tell_wrapper_func + vec.get_optima = get_optima_wrapper_func + vec.get_history = get_history_wrapper_func + vec.suggest = suggest_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + ask_wrapper, + tell_wrapper, + get_optima_wrapper, + get_history_wrapper, + suggest_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + ask_wrapper_func, + tell_wrapper_func, + get_optima_wrapper_func, + get_history_wrapper_func, + suggest_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + Tuner.UserDefined = UserDefinedTuner diff --git a/bindings/python/test/test_features_tuner.py b/bindings/python/test/test_features_tuner.py index cd6ce770..759d25d9 100644 --- a/bindings/python/test/test_features_tuner.py +++ b/bindings/python/test/test_features_tuner.py @@ -118,6 +118,9 @@ def suggest(tuner, features): else: return choice(optis).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + (fs, os) = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -155,7 +158,7 @@ def suggest(tuner, features): self.assertTrue(t.suggest(features_off) in [x.configuration for x in optims]) # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tree_space.py b/bindings/python/test/test_tree_space.py index 39f1ef4e..d65f1908 100644 --- a/bindings/python/test/test_tree_space.py +++ b/bindings/python/test/test_tree_space.py @@ -54,6 +54,9 @@ def get_child(tree_space, parent, child_index): arity = 0 if arity < 0 else arity return ccs.Tree(arity = arity, value = (4 - child_depth)*100 + child_index) + def get_vector_data(otype, name, cb_data): + return (ccs.DynamicTreeSpace.get_vector(delete = delete, get_child = get_child), None) + tree = ccs.Tree(arity = 4, value = 400) ts = ccs.DynamicTreeSpace(name = 'space', tree = tree, delete = delete, get_child = get_child) self.assertEqual( ccs.ObjectType.TREE_SPACE, ts.object_type ) @@ -75,7 +78,7 @@ def get_child(tree_space, parent, child_index): self.assertTrue( ts.check_configuration(tc) ) buff = ts.serialize() - ts2 = ccs.DynamicTreeSpace.deserialize(buffer = buff, delete = delete, get_child = get_child) + ts2 = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) self.assertEqual( [400, 301, 201], ts2.get_values_at_position([1, 1]) ) def test_tree_configuration(self): diff --git a/bindings/python/test/test_tree_tuner.py b/bindings/python/test/test_tree_tuner.py index 9afe2a5a..08d2176f 100644 --- a/bindings/python/test/test_tree_tuner.py +++ b/bindings/python/test/test_tree_tuner.py @@ -107,6 +107,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -125,7 +128,7 @@ def suggest(tuner, features): self.assertTrue(all(best >= x.objective_values[0] for x in hist)) self.assertTrue(t.suggest() in [x.configuration for x in optims]) buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tuner.py b/bindings/python/test/test_tuner.py index 577ba5d6..0a2ca54b 100644 --- a/bindings/python/test/test_tuner.py +++ b/bindings/python/test/test_tuner.py @@ -101,6 +101,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name, cb_data): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -126,7 +129,7 @@ def suggest(tuner, features): # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -137,7 +140,7 @@ def suggest(tuner, features): self.assertTrue(t_copy.suggest() in [x.configuration for x in optims_2]) t.serialize(path = 'tuner.ccs') - t_copy = ccs.UserDefinedTuner.deserialize(path = 'tuner.ccs', delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(path = 'tuner.ccs', vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -152,7 +155,7 @@ def suggest(tuner, features): t.serialize(file_descriptor = file.fileno()) file.close() file = open( 'tuner.ccs', "rb") - t_copy = ccs.UserDefinedTuner.deserialize(file_descriptor = file.fileno(), delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(file_descriptor = file.fileno(), vector_callback = get_vector_data) file.close() hist = t_copy.history() self.assertEqual(200, len(hist)) diff --git a/bindings/ruby/lib/cconfigspace/base.rb b/bindings/ruby/lib/cconfigspace/base.rb index 0a2bd2e5..2daa29d5 100644 --- a/bindings/ruby/lib/cconfigspace/base.rb +++ b/bindings/ruby/lib/cconfigspace/base.rb @@ -244,10 +244,9 @@ def read_ccs_numeric_type_t DeserializeOptions = enum FFI::Type::INT32, :ccs_deserialize_option_t, [ :CCS_DESERIALIZE_OPTION_END, 0, :CCS_DESERIALIZE_OPTION_HANDLE_MAP, - :CCS_DESERIALIZE_OPTION_VECTOR, - :CCS_DESERIALIZE_OPTION_DATA, + :CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, :CCS_DESERIALIZE_OPTION_NON_BLOCKING, - :CCS_DESERIALIZE_OPTION_CALLBACK ] + :CCS_DESERIALIZE_OPTION_DATA_CALLBACK ] class Numeric < FFI::Union layout :f, :ccs_float_t, @@ -501,7 +500,8 @@ def read_array_of_ccs_datum_t(length) attach_function :ccs_object_get_user_data, [:ccs_object_t, :pointer], :ccs_result_t callback :ccs_object_serialize_callback, [:ccs_object_t, :size_t, :pointer, :pointer, :value], :ccs_result_t attach_function :ccs_object_set_serialize_callback, [:ccs_object_t, :ccs_object_serialize_callback, :value], :ccs_result_t - callback :ccs_object_deserialize_callback, [:ccs_object_t, :size_t, :pointer, :value], :ccs_result_t + callback :ccs_object_deserialize_data_callback, [:ccs_object_t, :size_t, :pointer, :value], :ccs_result_t + callback :ccs_object_deserialize_vector_callback, [:ccs_object_type_t, :string, :value, :pointer, :pointer], :ccs_result_t attach_function :ccs_object_serialize, [:ccs_object_t, :ccs_serialize_format_t, :ccs_serialize_operation_t, :varargs], :ccs_result_t attach_function :ccs_object_deserialize, [:ccs_object_t, :ccs_serialize_format_t, :ccs_serialize_operation_t, :varargs], :ccs_result_t @@ -750,7 +750,7 @@ def serialize(format: :binary, path: nil, file_descriptor: nil, callback: nil, c return result end - def self.deserialize(format: :binary, handle_map: nil, vector: nil, data: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) + def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, vector_callback: nil, vector_callback_data: nil, callback: nil, callback_data: nil) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if format != :binary format = :CCS_SERIALIZE_FORMAT_BINARY mode_count = 0 @@ -761,13 +761,15 @@ def self.deserialize(format: :binary, handle_map: nil, vector: nil, data: nil, p ptr = MemoryPointer::new(:ccs_object_t) options = [] options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_HANDLE_MAP, :ccs_map_t, handle_map.handle] if handle_map - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_VECTOR, :pointer, vector] if vector - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA, :value, data] if data + if vector_callback + cb_wrapper = CCS.get_deserialize_vector_callback_wrapper(&vector_callback) + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, :ccs_object_deserialize_vector_callback, cb_wrapper, :value, vector_callback_data] + end if callback - cb_wrapper = CCS.get_deserialize_wrapper(&callback) - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_CALLBACK, :ccs_object_deserialize_callback, cb_wrapper, :value, callback_data] + cb_wrapper = CCS.get_deserialize_data_callback_wrapper(&callback) + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA_CALLBACK, :ccs_object_deserialize_data_callback, cb_wrapper, :value, callback_data] elsif CCS.default_user_data_deserializer - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_CALLBACK, :ccs_object_deserialize_callback, CCS.default_user_data_deserializer, :value, nil] + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA_CALLBACK, :ccs_object_deserialize_data_callback, CCS.default_user_data_deserializer, :value, nil] end options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_END] if buffer @@ -873,7 +875,7 @@ def self.get_serialize_wrapper(&block) } end - def self.get_deserialize_wrapper(&block) + def self.get_deserialize_data_callback_wrapper(&block) lambda { |obj, serialize_data_size, serialize_data, cb_data| begin serialized = serialize_data.null? ? nil : serialize_data.slice(0, serialize_data_size) @@ -885,11 +887,26 @@ def self.get_deserialize_wrapper(&block) } end + def self.get_deserialize_vector_callback_wrapper(&block) + lambda { |obj_type, name, callback_user_data, vector_ret, data_ret| + begin + vector, data = block.call(obj_type, name, callback_user_data) + vector_ret.write_pointer(vector.pointer) + data_ret.write_value(data) + FFI.inc_ref(vector) + FFI.inc_ref(data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + end + @yaml_user_data_serializer = get_serialize_wrapper { |obj, _, size| FFI::MemoryPointer.from_string(YAML.dump(obj.user_data)) } - @yaml_user_data_deserializer = get_deserialize_wrapper { |obj, serialized, _| + @yaml_user_data_deserializer = get_deserialize_data_callback_wrapper { |obj, serialized, _| obj.user_data = YAML.load(serialized.read_string) } @@ -912,8 +929,8 @@ def self.set_serialize_callback(handle, user_data: nil, &block) register_serialize_callback(handle, cb_data) end - def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, callback: nil, callback_data: nil) - return CCS::Object.deserialize(format: format, handle_map: handle_map, path: path, buffer: buffer, callback: callback, callback_data: callback_data) + def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, vector_callback: nil, vector_callback_data: nil, callback: nil, callback_data: nil) + return CCS::Object.deserialize(format: format, handle_map: handle_map, path: path, buffer: buffer, file_descriptor: file_descriptor, vector_callback: vector_callback, vector_callback_data: vector_callback_data, callback: callback, callback_data: callback_data) end end diff --git a/bindings/ruby/lib/cconfigspace/tree_space.rb b/bindings/ruby/lib/cconfigspace/tree_space.rb index 8667189b..a3e557d4 100644 --- a/bindings/ruby/lib/cconfigspace/tree_space.rb +++ b/bindings/ruby/lib/cconfigspace/tree_space.rb @@ -123,6 +123,7 @@ def initialize(handle = nil, retain: false, auto_release: true, callback :ccs_dynamic_tree_space_deserialize, [:ccs_tree_t, :ccs_feature_space_t, :size_t, :pointer, :pointer], :ccs_result_t class DynamicTreeSpaceVector < FFI::Struct + attr_accessor :wrappers layout :del, :ccs_dynamic_tree_space_del, :get_child, :ccs_dynamic_tree_space_get_child, :serialize, :ccs_dynamic_tree_space_serialize, @@ -130,64 +131,6 @@ class DynamicTreeSpaceVector < FFI::Struct end typedef DynamicTreeSpaceVector.by_value, :ccs_dynamic_tree_space_vector_t - def self.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) - delwrapper = lambda { |ts| - begin - o = CCS::Object.from_handle(ts) - tsdata = o.tree_space_data - del.call(o) if del - FFI.dec_ref(tsdata) unless tsdata.nil? - CCS.unregister_vector(ts) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_childwrapper = lambda { |ts, parent, index, p_child| - begin - child = get_child.call(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) - CCS.error_check CCS.ccs_retain_object(child.handle) - Pointer.new(p_child).write_pointer(child.handle) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - serializewrapper = - if serialize - lambda { |ts, state_size, p_state, p_state_size| - begin - state = serialize(TreeSpace.from_handle(ts), state_size == 0 ? true : false) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size - p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? - Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - deserializewrapper = - if deserialize - lambda { |t, feature_space, state_size, p_state, p_tree_space_data| - begin - state = p_state.null? ? nil : p_state.slice(0, state_size) - tree_space_data = deserialize(Tree.from_handle(t), feature_space.null? ? nil : FeatureSpace.from_handle(feature_space), state) - p_tree_space_data.write_value(tree_space_data) - FFI.inc_ref(tree_space_data) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - return [delwrapper, get_childwrapper, serializewrapper, deserializewrapper] - end - attach_function :ccs_create_dynamic_tree_space, [:string, :ccs_tree_t, :ccs_feature_space_t, :ccs_rng_t, DynamicTreeSpaceVector.by_ref, :value, :pointer], :ccs_result_t attach_function :ccs_dynamic_tree_space_get_tree_space_data, [:ccs_tree_space_t, :pointer], :ccs_result_t @@ -201,35 +144,78 @@ def initialize(handle = nil, retain: false, auto_release: true, super(handle, retain: retain, auto_release: auto_release) else raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - wrappers = CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = wrappers - vector = DynamicTreeSpaceVector::new - vector[:del] = delwrapper - vector[:get_child] = get_childwrapper - vector[:serialize] = serializewrapper - vector[:deserialize] = deserializewrapper + vector = DynamicTreeSpace.get_vector(del: del, get_child: get_child, serialize: serialize, deserialize: deserialize) ptr = MemoryPointer::new(:ccs_tree_space_t) CCS.error_check CCS.ccs_create_dynamic_tree_space(name, tree, feature_space, rng, vector, tree_space_data, ptr) h = ptr.read_ccs_tree_space_t super(h, retain: false) - CCS.register_vector(h, wrappers) + FFI.inc_ref(vector) FFI.inc_ref(tree_space_data) unless tree_space_data.nil? end end - def self.deserialize(del: nil, get_child: nil, serialize: nil, deserialize: nil, tree_space_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - wrappers = CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = wrappers + def self.get_vector(del: nil, get_child: nil, serialize: nil, deserialize: nil) vector = DynamicTreeSpaceVector::new + delwrapper = lambda { |ts| + begin + o = CCS::Object.from_handle(ts) + tsdata = o.tree_space_data + del.call(o) if del + FFI.dec_ref(tsdata) unless tsdata.nil? + FFI.dec_ref(vector) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_childwrapper = lambda { |ts, parent, index, p_child| + begin + child = get_child.call(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) + CCS.error_check CCS.ccs_retain_object(child.handle) + Pointer.new(p_child).write_pointer(child.handle) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + serializewrapper = + if serialize + lambda { |ts, state_size, p_state, p_state_size| + begin + state = serialize(TreeSpace.from_handle(ts), state_size == 0 ? true : false) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size + p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? + Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + deserializewrapper = + if deserialize + lambda { |t, feature_space, state_size, p_state, p_tree_space_data| + begin + state = p_state.null? ? nil : p_state.slice(0, state_size) + tree_space_data = deserialize(Tree.from_handle(t), feature_space.null? ? nil : FeatureSpace.from_handle(feature_space), state) + p_tree_space_data.write_value(tree_space_data) + FFI.inc_ref(tree_space_data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end vector[:del] = delwrapper vector[:get_child] = get_childwrapper vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper - res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tree_space_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, wrappers) - FFI.inc_ref(tree_space_data) unless tree_space_data.nil? - res + vector.wrappers = [delwrapper, get_childwrapper, serializewrapper, deserializewrapper] + vector end end diff --git a/bindings/ruby/lib/cconfigspace/tuner.rb b/bindings/ruby/lib/cconfigspace/tuner.rb index 3ab74feb..841cef3f 100644 --- a/bindings/ruby/lib/cconfigspace/tuner.rb +++ b/bindings/ruby/lib/cconfigspace/tuner.rb @@ -125,6 +125,7 @@ def initialize(handle = nil, retain: false, auto_release: true, callback :ccs_user_defined_tuner_deserialize, [:ccs_objective_space_t, :size_t, :pointer, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t class UserDefinedTunerVector < FFI::Struct + attr_accessor :wrappers layout :del, :ccs_user_defined_tuner_del, :ask, :ccs_user_defined_tuner_ask, :tell, :ccs_user_defined_tuner_tell, @@ -136,133 +137,6 @@ class UserDefinedTunerVector < FFI::Struct end typedef UserDefinedTunerVector.by_value, :ccs_user_defined_tuner_vector_t - def self.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - delwrapper = lambda { |tun| - begin - o = CCS::Object.from_handle(tun) - tdata = o.tuner_data - del.call(o) if del - FFI.dec_ref(tdata) unless tdata.nil? - CCS.unregister_vector(tun) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - askwrapper = lambda { |tun, features, count, p_configurations, p_count| - begin - configurations, count_ret = ask.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features), p_configurations.null? ? nil : count) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_configurations.null? && count < count_ret - if !p_configurations.null? - configurations.each_with_index { |c, i| - err = CCS.ccs_retain_object(c.handle) - CCS.error_check(err) - p_configurations.put_pointer(i*8, c.handle) - } - (count_ret...count).each { |i| p_configurations[i].put_pointer(i*8, 0) } - end - Pointer.new(p_count).write_size_t(count_ret) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - tellwrapper = lambda { |tun, count, p_evaluations| - begin - if count > 0 - evals = count.times.collect { |i| Evaluation::from_handle(p_evaluations.get_pointer(i*8)) } - tell.call(Tuner.from_handle(tun), evals) - end - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_optimawrapper = lambda { |tun, features, count, p_evaluations, p_count| - begin - optima = get_optima.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < optima.size - unless p_evaluations.null? - optima.each_with_index { |o, i| - p_evaluations.put_pointer(8*i, o.handle) - } - ((optima.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } - end - Pointer.new(p_count).write_size_t(optima.size) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_historywrapper = lambda { |tun, features, count, p_evaluations, p_count| - begin - history = get_history.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < history.size - unless p_evaluations.null? - history.each_with_index { |e, i| - p_evaluations.put_pointer(8*i, e.handle) - } - ((history.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } - end - Pointer.new(p_count).write_size_t(history.size) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - suggestwrapper = - if suggest - lambda { |tun, features, p_configuration| - begin - configuration = suggest.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - err = CCS.ccs_retain_object(configuration.handle) - CCS.error_check(err) - p_configuration.write_pointer(configuration.handle) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - serializewrapper = - if serialize - lambda { |tun, state_size, p_state, p_state_size| - begin - state = serialize(Tuner.from_handle(tun), state_size == 0 ? true : false) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size - p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? - Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - deserializewrapper = - if deserialize - lambda { |o_space, history_size, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data| - begin - history = p_history.null? ? [] : history_size.times.collect { |i| Evaluation::from_handle(p_p_history.get_pointer(i*8)) } - optima = p_optima.null? ? [] : num_optima.times.collect { |i| Evaluation::from_handle(p_optima.get_pointer(i*8)) } - state = p_state.null? ? nil : p_state.slice(0, state_size) - tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) - p_tuner_data.write_value(tuner_data) - FFI.inc_ref(tuner_data) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - return [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper] - end - attach_function :ccs_create_user_defined_tuner, [:string, :ccs_objective_space_t, UserDefinedTunerVector.by_ref, :value, :pointer], :ccs_result_t attach_function :ccs_user_defined_tuner_get_tuner_data, [:ccs_tuner_t, :pointer], :ccs_result_t class UserDefinedTuner < Tuner @@ -275,31 +149,141 @@ def initialize(handle = nil, retain: false, auto_release: true, super(handle, retain: retain, auto_release: auto_release) else raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - wrappers = CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = wrappers - vector = UserDefinedTunerVector::new - vector[:del] = delwrapper - vector[:ask] = askwrapper - vector[:tell] = tellwrapper - vector[:get_optima] = get_optimawrapper - vector[:get_history] = get_historywrapper - vector[:suggest] = suggestwrapper - vector[:serialize] = serializewrapper - vector[:deserialize] = deserializewrapper + vector = UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, serialize: serialize, deserialize: deserialize) ptr = MemoryPointer::new(:ccs_tuner_t) CCS.error_check CCS.ccs_create_user_defined_tuner(name, objective_space, vector, tuner_data, ptr) handle = ptr.read_ccs_tuner_t super(handle, retain: false) - CCS.register_vector(handle, wrappers) + FFI.inc_ref(vector) FFI.inc_ref(tuner_data) unless tuner_data.nil? end end - def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history: nil, suggest: nil, serialize: nil, deserialize: nil, tuner_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - wrappers = CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = wrappers + def self.get_vector(del: nil, ask: nil, tell: nil, get_optima: nil, get_history: nil, suggest: nil, serialize: nil, deserialize: nil) vector = UserDefinedTunerVector::new + delwrapper = lambda { |tun| + begin + o = CCS::Object.from_handle(tun) + tdata = o.tuner_data + del.call(o) if del + FFI.dec_ref(tdata) unless tdata.nil? + FFI.dec_ref(vector) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + askwrapper = lambda { |tun, features, count, p_configurations, p_count| + begin + configurations, count_ret = ask.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features), p_configurations.null? ? nil : count) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_configurations.null? && count < count_ret + if !p_configurations.null? + configurations.each_with_index { |c, i| + err = CCS.ccs_retain_object(c.handle) + CCS.error_check(err) + p_configurations.put_pointer(i*8, c.handle) + } + (count_ret...count).each { |i| p_configurations[i].put_pointer(i*8, 0) } + end + Pointer.new(p_count).write_size_t(count_ret) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + tellwrapper = lambda { |tun, count, p_evaluations| + begin + if count > 0 + evals = count.times.collect { |i| Evaluation::from_handle(p_evaluations.get_pointer(i*8)) } + tell.call(Tuner.from_handle(tun), evals) + end + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_optimawrapper = lambda { |tun, features, count, p_evaluations, p_count| + begin + optima = get_optima.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < optima.size + unless p_evaluations.null? + optima.each_with_index { |o, i| + p_evaluations.put_pointer(8*i, o.handle) + } + ((optima.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } + end + Pointer.new(p_count).write_size_t(optima.size) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_historywrapper = lambda { |tun, features, count, p_evaluations, p_count| + begin + history = get_history.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < history.size + unless p_evaluations.null? + history.each_with_index { |e, i| + p_evaluations.put_pointer(8*i, e.handle) + } + ((history.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } + end + Pointer.new(p_count).write_size_t(history.size) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + suggestwrapper = + if suggest + lambda { |tun, features, p_configuration| + begin + configuration = suggest.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + err = CCS.ccs_retain_object(configuration.handle) + CCS.error_check(err) + p_configuration.write_pointer(configuration.handle) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + serializewrapper = + if serialize + lambda { |tun, state_size, p_state, p_state_size| + begin + state = serialize(Tuner.from_handle(tun), state_size == 0 ? true : false) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size + p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? + Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + deserializewrapper = + if deserialize + lambda { |o_space, history_size, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data| + begin + history = p_history.null? ? [] : history_size.times.collect { |i| Evaluation::from_handle(p_p_history.get_pointer(i*8)) } + optima = p_optima.null? ? [] : num_optima.times.collect { |i| Evaluation::from_handle(p_optima.get_pointer(i*8)) } + state = p_state.null? ? nil : p_state.slice(0, state_size) + tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, state) + p_tuner_data.write_value(tuner_data) + FFI.inc_ref(tuner_data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end vector[:del] = delwrapper vector[:ask] = askwrapper vector[:tell] = tellwrapper @@ -308,11 +292,10 @@ def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history vector[:suggest] = suggestwrapper vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper - res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tuner_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, wrappers) - FFI.inc_ref(tuner_data) unless tuner_data.nil? - res + vector.wrappers = [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper] + vector end + end Tuner::UserDefined = UserDefinedTuner diff --git a/bindings/ruby/test/test_features_tuner.rb b/bindings/ruby/test/test_features_tuner.rb index 83fb45c5..ec1a9bfb 100644 --- a/bindings/ruby/test/test_features_tuner.rb +++ b/bindings/ruby/test/test_features_tuner.rb @@ -125,6 +125,13 @@ def test_user_defined optis.sample.configuration end } + get_vector_data = lambda { |otype, name, cb_data| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + assert_nil(cb_data) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TunerData.new] + } + fs, os = create_tuning_problem t = CCS::UserDefinedTuner::new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) t2 = CCS::Object::from_handle(t) @@ -160,7 +167,7 @@ def test_user_defined } buff = t.serialize - t_copy = CCS::UserDefinedTuner::deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) diff --git a/bindings/ruby/test/test_tree_space.rb b/bindings/ruby/test/test_tree_space.rb index bdd0c714..d846adf4 100644 --- a/bindings/ruby/test/test_tree_space.rb +++ b/bindings/ruby/test/test_tree_space.rb @@ -54,6 +54,12 @@ def test_dynamic_tree_space arity = 0 if arity < 0 CCS::Tree.new(arity: arity, value: (4 - child_depth)*100 + child_index) } + get_vector_data = lambda { |otype, name, cb_data| + assert_equal(:CCS_OBJECT_TYPE_TREE_SPACE, otype) + assert_equal('space', name) + assert_nil(cb_data) + [CCS::DynamicTreeSpace.get_vector(del: del, get_child: get_child), nil] + } tree = CCS::Tree.new(arity: 4, value: 400) ts = CCS::DynamicTreeSpace.new(name: 'space', tree: tree, del: del, get_child: get_child) @@ -74,7 +80,7 @@ def test_dynamic_tree_space } buff = ts.serialize - ts2 = CCS::DynamicTreeSpace.deserialize(buffer: buff, del: del, get_child: get_child) + ts2 = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) assert_equal( [400, 301, 201], ts2.get_values_at_position([1, 1]) ) end diff --git a/bindings/ruby/test/test_tree_tuner.rb b/bindings/ruby/test/test_tree_tuner.rb index bbd8a498..e79a8d9d 100644 --- a/bindings/ruby/test/test_tree_tuner.rb +++ b/bindings/ruby/test/test_tree_tuner.rb @@ -115,6 +115,13 @@ def test_user_defined tuner.tuner_data.optima.sample.configuration end } + get_vector_data = lambda { |otype, name, cb_data| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + assert_nil(cb_data) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TreeTunerData.new] + } + os = create_tuning_problem t = CCS::UserDefinedTuner.new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TreeTunerData.new) t2 = CCS::Object::from_handle(t) @@ -139,7 +146,7 @@ def test_user_defined assert_equal(hist.map { |e| e.objective_values.first }.max, best) assert(optims.map(&:configuration).include?(t.suggest)) buff = t.serialize - t_copy = CCS::UserDefinedTuner.deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TreeTunerData.new) + t_copy = CCS.deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) optims_2 = t_copy.optima diff --git a/bindings/ruby/test/test_tuner.rb b/bindings/ruby/test/test_tuner.rb index b4d9082a..5d6b4323 100644 --- a/bindings/ruby/test/test_tuner.rb +++ b/bindings/ruby/test/test_tuner.rb @@ -106,6 +106,13 @@ def test_user_defined tuner.tuner_data.optima.sample.configuration end } + get_vector_data = lambda { |otype, name, cb_data| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + assert_nil(cb_data) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TunerData.new] + } + os = create_tuning_problem t = CCS::UserDefinedTuner::new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) t2 = CCS::Object::from_handle(t) @@ -132,7 +139,7 @@ def test_user_defined assert( t.optima.collect(&:configuration).include?(t.suggest) ) buff = t.serialize - t_copy = CCS::UserDefinedTuner::deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) @@ -141,7 +148,7 @@ def test_user_defined assert( t_copy.optima.collect(&:configuration).include?(t_copy.suggest) ) t.serialize(path: 'tuner.ccs') - t_copy = CCS::UserDefinedTuner::deserialize(path: 'tuner.ccs', del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(path: 'tuner.ccs', vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) @@ -154,7 +161,7 @@ def test_user_defined t.serialize(file_descriptor: f.fileno) f.close f = File.open('tuner.ccs', "rb") - t_copy = CCS::UserDefinedTuner::deserialize(file_descriptor: f.fileno, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(file_descriptor: f.fileno, vector_callback: get_vector_data) f.close hist = t_copy.history assert_equal(200, hist.size) diff --git a/include/cconfigspace/base.h b/include/cconfigspace/base.h index 8580feab..98e60d07 100644 --- a/include/cconfigspace/base.h +++ b/include/cconfigspace/base.h @@ -965,9 +965,9 @@ enum ccs_serialize_option_e { typedef enum ccs_serialize_option_e ccs_serialize_option_t; /** - * The type of CCS object deserialization callbacks. - * This callback is used to deserialize object information that were created by - * the serialization callback. + * The type of CCS object data deserialization callbacks. This callback is + * used to deserialize object data information that were created by the + * serialization callback. * @param[in, out] object a CCS object * @param[in] serialize_data_size the size of the memory pointed to by * \p serialize_data @@ -981,12 +981,41 @@ typedef enum ccs_serialize_option_e ccs_serialize_option_t; * @remarks * This function must be thread-safe for serialization to be thread safe. */ -typedef ccs_result_t (*ccs_object_deserialize_callback_t)( +typedef ccs_result_t (*ccs_object_deserialize_data_callback_t)( ccs_object_t object, size_t serialize_data_size, const char *serialize_data, void *callback_user_data); +/** + * The type of CCS user object vector deserialization callbacks. + * This callback is used to obtain the vector to use for a user defined object. + * @param[in] type the type of CCS object for which the vector is required + * @param[in] name the name of the object as it was orignally provided at + * creation + * @param[in] callback_user_data the pointer provided when the callback + * was attached. + * @param[out] vector_ret a pointer that will hold the returned pointer to the + * vector structure + * @param[out] data_ret a pointer that will optionally hold the user data value + * to use when initializing the user object. + * If the provided vector holds the relevant + * deserialization callback, this value will be + * ignored. + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_TYPE if the object type is not expected by + * the callback + * @return #CCS_RESULT_ERROR_INVALID_NAME if the object name is not recognized + * @remarks + * This function must be thread-safe for deserialization to be thread safe. + */ +typedef ccs_result_t (*ccs_object_deserialize_vector_callback_t)( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret); + /** * The different deserialization options. */ @@ -1001,15 +1030,11 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_HANDLE_MAP, /** - * The next parameter is a pointer to a ccs object vector struct, for - * user defined tuners - */ - CCS_DESERIALIZE_OPTION_VECTOR, - /** - * The next parameter is a pointer to a ccs object internal data, for - * user defined tuners + * The next parameter is a pointer to a callback of type + * ccs_object_deserialize_vector_callback_t and its user data, that + * will return the ccs object vector struct, for user defined objects */ - CCS_DESERIALIZE_OPTION_DATA, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, /** * The file descriptor operation is non-blocking. The next parameter is * a pointer to a void * variable (initialized to NULL) that will hold @@ -1019,12 +1044,12 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_NON_BLOCKING, /** - * The next parameters are a deserialization callback and it's + * The next parameters are a deserialization callback and its * user_data. This callback will be called for all objects that had * their user_data serialized. If no such callback is provided the * object's user_data value will not be set. */ - CCS_DESERIALIZE_OPTION_CALLBACK, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, /** Guard */ CCS_DESERIALIZE_OPTION_MAX, /** Try forcing 32 bits value for bindings */ diff --git a/src/cconfigspace_deserialize.h b/src/cconfigspace_deserialize.h index ff71bdbb..6221093f 100644 --- a/src/cconfigspace_deserialize.h +++ b/src/cconfigspace_deserialize.h @@ -54,12 +54,11 @@ _ccs_object_deserialize_options( CCS_CHECK_OBJ(opts->handle_map, CCS_OBJECT_TYPE_MAP); opts->map_values = CCS_TRUE; break; - case CCS_DESERIALIZE_OPTION_VECTOR: - opts->vector = va_arg(args, void *); - CCS_CHECK_PTR(opts->vector); - break; - case CCS_DESERIALIZE_OPTION_DATA: - opts->data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK: + opts->deserialize_vector_callback = + va_arg(args, ccs_object_deserialize_vector_callback_t); + CCS_CHECK_PTR(opts->deserialize_vector_callback); + opts->deserialize_vector_user_data = va_arg(args, void *); break; case CCS_DESERIALIZE_OPTION_NON_BLOCKING: CCS_REFUTE( @@ -70,11 +69,11 @@ _ccs_object_deserialize_options( va_arg(args, _ccs_file_descriptor_state_t **); CCS_CHECK_PTR(opts->ppfd_state); break; - case CCS_DESERIALIZE_OPTION_CALLBACK: - opts->deserialize_callback = - va_arg(args, ccs_object_deserialize_callback_t); - CCS_CHECK_PTR(opts->deserialize_callback); - opts->deserialize_user_data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_DATA_CALLBACK: + opts->deserialize_data_callback = + va_arg(args, ccs_object_deserialize_data_callback_t); + CCS_CHECK_PTR(opts->deserialize_data_callback); + opts->deserialize_data_user_data = va_arg(args, void *); break; default: CCS_RAISE( diff --git a/src/cconfigspace_internal.h b/src/cconfigspace_internal.h index e503162d..4fb4c980 100644 --- a/src/cconfigspace_internal.h +++ b/src/cconfigspace_internal.h @@ -1299,10 +1299,10 @@ struct _ccs_object_deserialize_options_s { ccs_map_t handle_map; ccs_bool_t map_values; _ccs_file_descriptor_state_t **ppfd_state; - void *vector; - void *data; - ccs_object_deserialize_callback_t deserialize_callback; - void *deserialize_user_data; + ccs_object_deserialize_vector_callback_t deserialize_vector_callback; + void *deserialize_vector_user_data; + ccs_object_deserialize_data_callback_t deserialize_data_callback; + void *deserialize_data_user_data; }; typedef struct _ccs_object_deserialize_options_s _ccs_object_deserialize_options_t; @@ -1449,10 +1449,10 @@ _ccs_object_deserialize_user_data( CCS_VALIDATE(_ccs_deserialize_bin_size( &serialize_data_size, buffer_size, buffer)); if (serialize_data_size) { - if (opts->deserialize_callback) - CCS_VALIDATE(opts->deserialize_callback( + if (opts->deserialize_data_callback) + CCS_VALIDATE(opts->deserialize_data_callback( object, serialize_data_size, *buffer, - opts->deserialize_user_data)); + opts->deserialize_data_user_data)); *buffer_size -= serialize_data_size; *buffer += serialize_data_size; } diff --git a/src/tree_space_deserialize.h b/src/tree_space_deserialize.h index f9808126..3414ffb3 100644 --- a/src/tree_space_deserialize.h +++ b/src/tree_space_deserialize.h @@ -88,27 +88,38 @@ _ccs_deserialize_bin_tree_space_dynamic( _ccs_object_deserialize_options_t *opts, _ccs_tree_space_common_data_mock_t *data) { - _ccs_blob_t blob = {0, NULL}; - ccs_dynamic_tree_space_vector_t *vector = - (ccs_dynamic_tree_space_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + _ccs_blob_t blob = {0, NULL}; + ccs_dynamic_tree_space_vector_t *vector = NULL; + void *tree_space_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_tree_space_common_data( data, version, buffer_size, buffer, opts), end); - CCS_VALIDATE(_ccs_deserialize_bin_ccs_blob(&blob, buffer_size, buffer)); + CCS_VALIDATE_ERR_GOTO( + res, + _ccs_deserialize_bin_ccs_blob( + &blob, buffer_size, buffer), + end); + + CCS_VALIDATE_ERR_GOTO( + res, + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TREE_SPACE, + data->name, + opts->deserialize_vector_user_data, + (void**)&vector, &tree_space_data), + end); - void *tree_space_data; if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, vector->deserialize_state( data->tree, data->feature_space, blob.sz, blob.blob, &tree_space_data), - tree_space); - else - tree_space_data = opts->data; + end); CCS_VALIDATE_ERR_GOTO( res, @@ -117,10 +128,6 @@ _ccs_deserialize_bin_tree_space_dynamic( data->feature_space, data->rng, vector, tree_space_data, tree_space_ret), end); - goto end; -tree_space: - ccs_release_object(*tree_space_ret); - *tree_space_ret = NULL; end: if (data->feature_space) ccs_release_object(data->feature_space); @@ -144,6 +151,8 @@ _ccs_deserialize_bin_tree_space( ccs_tree_space_type_t stype; CCS_VALIDATE( _ccs_peek_bin_ccs_tree_space_type(&stype, buffer_size, buffer)); + if (stype == CCS_TREE_SPACE_TYPE_DYNAMIC) + CCS_CHECK_PTR(opts->deserialize_vector_callback); _ccs_tree_space_common_data_mock_t data = { CCS_TREE_SPACE_TYPE_STATIC, NULL, NULL, NULL, NULL, NULL}; diff --git a/src/tuner_deserialize.h b/src/tuner_deserialize.h index d3335457..40293814 100644 --- a/src/tuner_deserialize.h +++ b/src/tuner_deserialize.h @@ -179,16 +179,25 @@ _ccs_deserialize_bin_user_defined_tuner( NULL, NULL}, {0, NULL}}; - ccs_user_defined_tuner_vector_t *vector = - (ccs_user_defined_tuner_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + ccs_user_defined_tuner_vector_t *vector = NULL; + void *tuner_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_user_defined_tuner_data( &data, version, buffer_size, buffer, opts), end); - void *tuner_data; + CCS_VALIDATE_ERR_GOTO( + res, + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TUNER, + data.base_data.common_data.name, + opts->deserialize_vector_user_data, + (void **)&vector, &tuner_data), + end); + if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, @@ -200,8 +209,6 @@ _ccs_deserialize_bin_user_defined_tuner( data.base_data.optima, data.blob.sz, data.blob.blob, &tuner_data), end); - else - tuner_data = opts->data; CCS_VALIDATE_ERR_GOTO( res, @@ -247,7 +254,7 @@ _ccs_deserialize_bin_tuner( ccs_tuner_type_t ttype; CCS_VALIDATE(_ccs_peek_bin_ccs_tuner_type(&ttype, buffer_size, buffer)); if (ttype == CCS_TUNER_TYPE_USER_DEFINED) - CCS_CHECK_PTR(opts->vector); + CCS_CHECK_PTR(opts->deserialize_vector_callback); new_opts.map_values = CCS_TRUE; CCS_VALIDATE(ccs_create_map(&new_opts.handle_map)); diff --git a/tests/test_categorical_parameter.c b/tests/test_categorical_parameter.c index 2104fd21..8f11658c 100644 --- a/tests/test_categorical_parameter.c +++ b/tests/test_categorical_parameter.c @@ -179,7 +179,7 @@ test_create(void) err = ccs_object_deserialize( (ccs_object_t *)¶meter, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_CALLBACK, &deserialize_callback, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, &deserialize_callback, (void *)0xbeefdead, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_dynamic_tree_space.c b/tests/test_dynamic_tree_space.c index adf0cb01..3aa5c68f 100644 --- a/tests/test_dynamic_tree_space.c +++ b/tests/test_dynamic_tree_space.c @@ -35,6 +35,30 @@ my_tree_get_child( return CCS_RESULT_SUCCESS; } +ccs_dynamic_tree_space_vector_t vector = { + &my_tree_del, &my_tree_get_child, NULL, NULL}; + +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TREE_SPACE: + *vector_ret = (void *)&vector; + *data_ret = NULL; + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test_dynamic_tree_space(void) { @@ -51,8 +75,6 @@ test_dynamic_tree_space(void) const char *name; ccs_tree_configuration_t config, configs[NUM_SAMPLES]; - ccs_dynamic_tree_space_vector_t vector = { - &my_tree_del, &my_tree_get_child, NULL, NULL}; err = ccs_create_tree(4, ccs_int(4 * 100), &root); assert(err == CCS_RESULT_SUCCESS); err = ccs_create_rng(&rng); @@ -176,7 +198,8 @@ test_dynamic_tree_space(void) err = ccs_object_deserialize( (ccs_object_t *)&tree_space, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_VECTOR, &vector, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_user_defined_features_tuner.c b/tests/test_user_defined_features_tuner.c index d319c55b..46a86bd9 100644 --- a/tests/test_user_defined_features_tuner.c +++ b/tests/test_user_defined_features_tuner.c @@ -147,6 +147,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -263,15 +285,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tree_tuner.c b/tests/test_user_defined_tree_tuner.c index 8bfa2977..64c6fb4d 100644 --- a/tests/test_user_defined_tree_tuner.c +++ b/tests/test_user_defined_tree_tuner.c @@ -193,6 +193,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -274,15 +296,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tuner.c b/tests/test_user_defined_tuner.c index e56e9d1c..c590fddd 100644 --- a/tests/test_user_defined_tuner.c +++ b/tests/test_user_defined_tuner.c @@ -140,6 +140,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -217,15 +239,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void*)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS);