diff --git a/bindings/python/Makefile.am b/bindings/python/Makefile.am index 574feb22..11d69202 100644 --- a/bindings/python/Makefile.am +++ b/bindings/python/Makefile.am @@ -48,6 +48,6 @@ LIBEXT = .so endif check: FORCE - PYTHONPATH=$(srcdir)/ LIBCCONFIGSPACE_SO_=$(top_builddir)/src/.libs/libcconfigspace$(LIBEXT) python3 -m unittest discover -s $(srcdir)/test/ + PYTHONPATH=$(srcdir)/ LIBCCONFIGSPACE_SO_=$(abs_top_builddir)/src/.libs/libcconfigspace$(LIBEXT) python3 -m unittest discover -s $(srcdir)/test/ FORCE: diff --git a/bindings/python/cconfigspace/base.py b/bindings/python/cconfigspace/base.py index ad559566..dcbafcc0 100644 --- a/bindings/python/cconfigspace/base.py +++ b/bindings/python/cconfigspace/base.py @@ -1,5 +1,5 @@ import ctypes as ct -import json +import pickle import sys import traceback from . import libcconfigspace @@ -106,8 +106,8 @@ def __repr__(self): class CEnumeration(ct.c_int, metaclass=CEnumerationType): _members_ = {} - def __init__(self, value): - ct.c_int.__init__(self, value) + def __init__(*args): + ct.c_int.__init__(*args) @property def name(self): @@ -395,10 +395,9 @@ class DeserializeOption(CEnumeration): _members_ = [ ('END', 0), 'HANDLE_MAP', - 'VECTOR', - 'DATA', + 'VECTOR_CALLBACK', 'NON_BLOCKING', - 'CALLBACK' + 'DATA_CALLBACK' ] def _ccs_get_function(method, argtypes = [], restype = Result): @@ -422,7 +421,8 @@ def _ccs_get_function(method, argtypes = [], restype = Result): ccs_object_get_user_data = _ccs_get_function("ccs_object_get_user_data", [ccs_object, ct.POINTER(ct.c_void_p)]) ccs_object_serialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t), ct.c_void_p) ccs_object_set_serialize_callback = _ccs_get_function("ccs_object_set_serialize_callback", [ccs_object, ccs_object_serialize_callback_type, ct.c_void_p]) -ccs_object_deserialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p) +ccs_object_deserialize_data_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.py_object)) +ccs_object_deserialize_vector_callback_type = ct.CFUNCTYPE(Result, ObjectType, ct.c_char_p, ct.c_void_p, ct.POINTER(ct.c_void_p), ct.POINTER(ct.py_object)) # Variadic methods ccs_object_serialize = getattr(libcconfigspace, "ccs_object_serialize") ccs_object_serialize.argtypes = ccs_object, SerializeFormat, SerializeOperation, @@ -520,10 +520,10 @@ def from_handle(cls, h, retain = True): auto_release = True return cls._from_handle(h, retain, auto_release) - def set_destroy_callback(self, callback, user_data = None): - _set_destroy_callback(self.handle, callback, user_data = user_data) + def set_destroy_callback(self, callback): + _set_destroy_callback(self.handle, callback) - def serialize(self, format = 'binary', path = None, file_descriptor = None, callback = None, callback_data = None): + def serialize(self, format = 'binary', path = None, file_descriptor = None, callback = None): if format != 'binary': raise Error(Result(Result.ERROR_INVALID_VALUE)) if path and file_descriptor: @@ -532,7 +532,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call if callback: cb_wrapper = _get_serialize_callback_wrapper(callback) cb_wrapper_func = ccs_object_serialize_callback_type(cb_wrapper) - options = [SerializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options + options = [SerializeOption.CALLBACK, cb_wrapper_func, ct.py_object()] + options elif _default_user_data_serializer: options = [SerializeOption.CALLBACK, _default_user_data_serializer, ct.py_object()] + options if path: @@ -556,7 +556,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call return v @classmethod - def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): + def deserialize(cls, format = 'binary', handle_map = None, vector_callback = None, path = None, buffer = None, file_descriptor = None, callback = None): if format != 'binary': raise Error(Result(Result.ERROR_INVALID_VALUE)) mode_count = 0; @@ -572,16 +572,16 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = options = [DeserializeOption.END] if handle_map: options = [DeserializeOption.HANDLE_MAP, handle_map.handle] + options - if vector: - options = [DeserializeOption.VECTOR, ct.byref(vector)] + options - if data: - options = [DeserializeOption.DATA, ct.py_object(data)] + options + if vector_callback: + vector_cb_wrapper = _get_deserialize_vector_callback_wrapper(vector_callback) + vector_cb_wrapper_func = ccs_object_deserialize_vector_callback_type(vector_cb_wrapper) + options = [DeserializeOption.VECTOR_CALLBACK, vector_cb_wrapper_func, ct.py_object()] + options if callback: - cb_wrapper = _get_deserialize_callback_wrapper(callback) - cb_wrapper_func = ccs_object_deserialize_callback_type(cb_wrapper) - options = [DeserializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options + cb_wrapper = _get_deserialize_data_callback_wrapper(callback) + cb_wrapper_func = ccs_object_deserialize_data_callback_type(cb_wrapper) + options = [DeserializeOption.DATA_CALLBACK, cb_wrapper_func, ct.py_object()] + options elif _default_user_data_deserializer: - options = [DeserializeOption.CALLBACK, _default_user_data_deserializer, ct.py_object()] + options + options = [DeserializeOption.DATA_CALLBACK, _default_user_data_deserializer, ct.py_object()] + options if buffer: s = ct.c_size_t(ct.sizeof(buffer)) res = ccs_object_deserialize(ct.byref(o), SerializeFormat.BINARY, SerializeOperation.MEMORY, s, buffer, *options) @@ -602,73 +602,58 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = def _get_serialize_callback_wrapper(callback): def serialize_callback_wrapper(obj, serialize_data_size, serialize_data, serialize_data_size_ret, cb_data): try: - p_sd = ct.cast(serialize_data, ct.c_void_p) - p_sdsz = ct.cast(serialize_data_size_ret, ct.POINTER(ct.c_size_t)) - cb_data = ct.cast(cb_data, ct.c_void_p) - if cb_data: - cb_data = ct.cast(cb_data, ct.py_object).value - else: - cb_data = None - serialized = callback(Object.from_handle(ccs_object(obj)), cb_data, True if serialize_data_size == 0 else False) - if p_sd and serialize_data_size < ct.sizeof(serialized): + serialized = callback(Object.from_handle(obj).user_data) + state = ct.create_string_buffer(serialized, len(serialized)) + if serialize_data and serialize_data_size < ct.sizeof(state): raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_sd: - ct.memmove(p_sd, ct.byref(serialized), ct.sizeof(serialized)) - if p_sdsz: - p_sdsz[0] = ct.sizeof(serialized) + if serialize_data: + ct.memmove(serialize_data, ct.byref(state), ct.sizeof(state)) + if serialize_data_size_ret: + serialize_data_size_ret[0] = ct.sizeof(state) return Result.SUCCESS except Exception as e: return Error.set_error(e) return serialize_callback_wrapper -def _get_deserialize_callback_wrapper(callback): - def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data): +def _get_deserialize_data_callback_wrapper(callback): + def deserialize_data_callback_wrapper(obj, serialize_data_size, p_serialize_data, cb_data): try: - p_sd = ct.cast(serialize_data, ct.c_void_p) - cb_data = ct.cast(cb_data, ct.c_void_p) - if cb_data: - cb_data = ct.cast(cb_data, ct.py_object).value - else: - cb_data = None - if p_sd: - serialized = ct.cast(p_sd, ct.POINTER(ct.c_byte * serialize_data_size)) - else: - serialized = None - callback(Object.from_handle(ccs_object(obj)), serialized, cb_data) + user_data = callback(ct.string_at(p_serialize_data, serialize_data_size)) + Object.from_handle(ccs_object(obj)).user_data = user_data return Result.SUCCESS except Exception as e: return Error.set_error(e) - return deserialize_callback_wrapper + return deserialize_data_callback_wrapper -def _json_user_data_serializer(obj, data, size): - string = json.dumps(obj.user_data).encode("ascii") - return ct.create_string_buffer(string) - -def _json_user_data_deserializer(obj, serialized, data): - serialized = ct.cast(serialized, ct.c_char_p) - obj.user_data = json.loads(serialized.value) +def _get_deserialize_vector_callback_wrapper(callback): + def deserialize_vector_callback_wrapper(obj_type, name, callback_user_data, vector_ret, data_ret): + try: + (vector, data) = callback(obj_type.value, name.decode()) + c_vector = ct.py_object(vector) + c_data = ct.py_object(data) + vector_ret[0] = ct.cast(ct.byref(vector), ct.c_void_p) + data_ret[0] = c_data + ct.pythonapi.Py_IncRef(c_vector) + ct.pythonapi.Py_IncRef(c_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + return deserialize_vector_callback_wrapper -_json_user_data_serializer_wrap = _get_serialize_callback_wrapper(_json_user_data_serializer) -_json_user_data_serializer_func = ccs_object_serialize_callback_type(_json_user_data_serializer_wrap) +def _pickle_user_data_serializer(user_data): + return pickle.dumps(user_data) -_json_user_data_deserializer_wrap = _get_deserialize_callback_wrapper(_json_user_data_deserializer) -_json_user_data_deserializer_func = ccs_object_deserialize_callback_type(_json_user_data_deserializer_wrap) +def _pickle_user_data_deserializer(serialized): + return pickle.loads(serialized) -_default_user_data_serializer = _json_user_data_serializer_func -_default_user_data_deserializer = _json_user_data_deserializer_func +_pickle_user_data_serializer_wrap = _get_serialize_callback_wrapper(_pickle_user_data_serializer) +_pickle_user_data_serializer_func = ccs_object_serialize_callback_type(_pickle_user_data_serializer_wrap) -# Delete wrappers are responsible for deregistering the object data_store -def _register_vector(handle, vector_data): - value = handle.value - if value in _data_store: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - _data_store[value] = dict.fromkeys(['callbacks', 'user_data', 'serialize_calback', 'strings']) - _data_store[value]['callbacks'] = vector_data - _data_store[value]['strings'] = [] +_pickle_user_data_deserializer_wrap = _get_deserialize_data_callback_wrapper(_pickle_user_data_deserializer) +_pickle_user_data_deserializer_func = ccs_object_deserialize_data_callback_type(_pickle_user_data_deserializer_wrap) -def _unregister_vector(handle): - value = handle.value - del _data_store[value] +_default_user_data_serializer = _pickle_user_data_serializer_func +_default_user_data_deserializer = _pickle_user_data_deserializer_func # If objects don't have a user-defined del operation, then the first time a # data needs to be registered a destruction callback is attached. @@ -707,20 +692,20 @@ def _register_serialize_callback(handle, callback_data): _register_destroy_callback(handle) _data_store[value]['serialize_calback'] = callback_data -def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, callback = None, callback_data = None): - return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, callback = callback, callback_data = callback_data) +def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, vector_callback = None, vector_callback_data = None, callback = None, callback_data = None): + return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, file_descriptor = file_descriptor, vector_callback = vector_callback, callback = callback) -def _set_destroy_callback(handle, callback, user_data = None): +def _set_destroy_callback(handle, callback): if callback is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) def cb_wrapper(obj, data): callback(Object.from_handle(obj), data) cb_wrapper_func = ccs_object_destroy_callback_type(cb_wrapper) - res = ccs_object_set_destroy_callback(handle, cb_wrapper_func, user_data) + res = ccs_object_set_destroy_callback(handle, cb_wrapper_func, None) Error.check(res) - _register_callback(handle, [callback, cb_wrapper, cb_wrapper_func, user_data]) + _register_callback(handle, [callback, cb_wrapper, cb_wrapper_func]) -def _set_serialize_callback(handle, callback, user_data = None): +def _set_serialize_callback(handle, callback): if callback is None: res = ccs_object_set_serialize_callback(handle, None, None) Error.check(res) @@ -728,9 +713,9 @@ def _set_serialize_callback(handle, callback, user_data = None): else: cb_wrapper = _get_serialize_callback_wrapper(callback) cb_wrapper_func = ccs_object_serialize_callback_type(cb_wrapper) - res = ccs_object_set_serialize_callback(handle, cb_wrapper_func, user_data) + res = ccs_object_set_serialize_callback(handle, cb_wrapper_func, None) Error.check(res) - _register_serialize_callback(handle, [callback, cb_wrapper, cb_wrapper_func, user_data]) + _register_serialize_callback(handle, [callback, cb_wrapper, cb_wrapper_func]) _ccs_id = 0 def _ccs_get_id(): diff --git a/bindings/python/cconfigspace/configuration_space.py b/bindings/python/cconfigspace/configuration_space.py index 2de0bbb6..1f4822a0 100644 --- a/bindings/python/cconfigspace/configuration_space.py +++ b/bindings/python/cconfigspace/configuration_space.py @@ -22,7 +22,7 @@ class ConfigurationSpace(Context): def __init__(self, handle = None, retain = False, auto_release = True, - name = "", parameters = None, conditions = None, forbidden_clauses = None, feature_space = None, rng = None): + name = "", parameters = None, conditions = None, forbidden_clauses = None, feature_space = None, rng = None, binding = None): if handle is None: count = len(parameters) @@ -30,11 +30,12 @@ def __init__(self, handle = None, retain = False, auto_release = True, if feature_space is not None: ctx_params = ctx_params + list(feature_space.parameters) ctx = dict(zip([x.name for x in ctx_params], ctx_params)) + extra = (ctx, binding) if forbidden_clauses is not None: numfc = len(forbidden_clauses) if numfc > 0: - forbidden_clauses = [ parser.parse(fc, extra = ctx) if isinstance(fc, str) else fc for fc in forbidden_clauses ] + forbidden_clauses = [ parser.parse(fc, extra = extra) if isinstance(fc, str) else fc for fc in forbidden_clauses ] fcv = (ccs_expression * numfc)(*[x.handle.value for x in forbidden_clauses]) else: fcv = None @@ -45,7 +46,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if conditions is not None: indexdict = dict(reversed(ele) for ele in enumerate(parameters)) cv = (ccs_expression * count)() - conditions = dict( (k, parser.parse(v, extra = ctx) if isinstance(v, str) else v) for (k, v) in conditions.items() ) + conditions = dict( (k, parser.parse(v, extra = extra) if isinstance(v, str) else v) for (k, v) in conditions.items() ) for (k, v) in conditions.items(): if isinstance(k, Parameter): cv[indexdict[k]] = v.handle.value diff --git a/bindings/python/cconfigspace/expression.py b/bindings/python/cconfigspace/expression.py index b47f31a5..9911399d 100644 --- a/bindings/python/cconfigspace/expression.py +++ b/bindings/python/cconfigspace/expression.py @@ -24,7 +24,8 @@ class ExpressionType(CEnumeration): 'IN', 'LIST', 'LITERAL', - 'VARIABLE' ] + 'VARIABLE', + 'USER_DEFINED' ] class AssociativityType(CEnumeration): _members_ = [ @@ -171,6 +172,13 @@ def __str__(self): else: return "{} {} {}".format(nds[0], symbol, nds[1]) + @classmethod + def get_function_vector_data(cls, name, binding = {}): + proc = binding[name] + def evaluate(expr, *args): + return proc(*args) + return (Expression.UserDefined.get_vector(evaluate = evaluate), None) + class ExpressionOr(Expression): def __init__(self, handle = None, retain = False, auto_release = True, @@ -476,6 +484,156 @@ def __str__(self): Expression.List = ExpressionList +ccs_user_defined_expression_del_type = ct.CFUNCTYPE(Result, ccs_expression) +ccs_user_defined_expression_eval_type = ct.CFUNCTYPE(Result, ccs_expression, ct.c_size_t, ct.POINTER(Datum), ct.POINTER(Datum)) +ccs_user_defined_expression_serialize_type = ct.CFUNCTYPE(Result, ccs_expression, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t)) +ccs_user_defined_expression_deserialize_type = ct.CFUNCTYPE(Result, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.py_object)) + +class UserDefinedExpressionVector(ct.Structure): + _fields_ = [ + ('delete', ccs_user_defined_expression_del_type), + ('evaluate', ccs_user_defined_expression_eval_type), + ('serialize', ccs_user_defined_expression_serialize_type), + ('deserialize', ccs_user_defined_expression_deserialize_type) ] + +ccs_create_user_defined_expression = _ccs_get_function("ccs_create_user_defined_expression", [ct.c_char_p, ct.c_size_t, ct.POINTER(Datum), ct.POINTER(UserDefinedExpressionVector), ct.py_object, ct.POINTER(ccs_expression)]) +ccs_user_defined_expression_get_name = _ccs_get_function("ccs_user_defined_expression_get_name", [ccs_expression, ct.POINTER(ct.c_char_p)]) +ccs_user_defined_expression_get_expression_data = _ccs_get_function("ccs_user_defined_expression_get_expression_data", [ccs_expression, ct.POINTER(ct.c_void_p)]) + +class ExpressionUserDefined(Expression): + def __init__(self, handle = None, retain = False, auto_release = True, + name = "", nodes = [], delete = None, evaluate = None, serialize = None, deserialize = None, expression_data = None): + if handle is None: + if evaluate is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + + vec = self.get_vector(delete, evaluate, serialize, deserialize) + c_expression_data = None + if expression_data is not None: + c_expression_data = ct.py_object(expression_data) + handle = ccs_expression() + sz = len(nodes) + v = (Datum*sz)() + ss = [] + for i in range(sz): + v[i].set_value(nodes[i], string_store = ss) + res = ccs_create_user_defined_expression(str.encode(name), sz, v, ct.byref(vec), c_expression_data, ct.byref(handle)) + Error.check(res) + super().__init__(handle = handle, retain = False) + ct.pythonapi.Py_IncRef(ct.py_object(vec)) + if c_expression_data is not None: + ct.pythonapi.Py_IncRef(c_expression_data) + else: + super().__init__(handle = handle, retain = retain, auto_release = auto_release) + + @property + def name(self): + if hasattr(self, "_name"): + return self._name + v = ct.c_char_p() + res = ccs_user_defined_expression_get_name(self.handle, ct.byref(v)) + Error.check(res) + self._name = v.value.decode() + return self._name + + @property + def expression_data(self): + if hasattr(self, "_expression_data"): + return self._expression_data + v = ct.c_void_p() + res = ccs_user_defined_expression_get_expression_data(self.handle, ct.byref(v)) + Error.check(res) + if v: + self._expression_data = ct.cast(v, ct.py_object).value + else: + self._expression_data = None + return self._expression_data + + def __str__(self): + return "{}({})".format(self.name, ", ".join(map(str, self.nodes))) + + @classmethod + def get_vector(self, delete = None, evaluate = None, serialize = None, deserialize = None): + vec = UserDefinedExpressionVector() + setattr(vec, '_string_store', list()) + setattr(vec, '_object_store', list()) + def delete_wrapper(expr): + try: + o = Object.from_handle(expr) + edata = o.expression_data + if delete is not None: + delete(o) + if edata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(edata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def evaluate_wrapper(expr, num_values, p_values, p_value_ret): + try: + if num_values == 0: + value_ret = evaluate(Expression.from_handle(expr)) + else: + values = tuple(p_values[i].value for i in range(num_values)) + value_ret = evaluate(Expression.from_handle(expr), *values) + p_value_ret[0].set_value(value_ret, string_store = getattr(vec, '_string_store'), object_store = getattr(vec, '_object_store')) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if serialize is not None: + def serialize_wrapper(expr, state_size, p_state, p_state_size): + try: + serialized = serialize(Expression.from_handle(expr)) + state = ct.create_string_buffer(serialized, len(serialized)) + if p_state and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_state: + ct.memmove(p_state, ct.byref(state), ct.sizeof(state)) + if p_state_size: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(state_size, p_state, p_expression_data): + try: + expression_data = deserialize(ct.string_at(p_state, state_size)) + c_expression_data = ct.py_object(expression_data) + p_expression_data[0] = c_expression_data + ct.pythonapi.Py_IncRef(c_expression_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_user_defined_expression_del_type(delete_wrapper) + evaluate_wrapper_func = ccs_user_defined_expression_eval_type(evaluate_wrapper) + serialize_wrapper_func = ccs_user_defined_expression_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_user_defined_expression_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.evaluate = evaluate_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + evaluate_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + evaluate_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + +Expression.UserDefined = ExpressionUserDefined + setattr(Expression, 'EXPRESSION_MAP', { ExpressionType.OR: ExpressionOr, ExpressionType.AND: ExpressionAnd, @@ -497,4 +655,5 @@ def __str__(self): ExpressionType.LIST: ExpressionList, ExpressionType.LITERAL: ExpressionLiteral, ExpressionType.VARIABLE: ExpressionVariable, + ExpressionType.USER_DEFINED: ExpressionUserDefined, }) diff --git a/bindings/python/cconfigspace/expression_parser.py b/bindings/python/cconfigspace/expression_parser.py index 1fdd866f..dfbaf603 100644 --- a/bindings/python/cconfigspace/expression_parser.py +++ b/bindings/python/cconfigspace/expression_parser.py @@ -18,16 +18,21 @@ | value; list: '[' list_item ']' | '[' ']'; +user_defined: identifier '(' list_item ')' + | identifier '(' ')'; list_item: list_item ',' value | value; value: none | true | false | string - | identifier + | user_defined + | variable | integer | float; +variable: identifier; + terminals none: /%s/ {%d}; true: /%s/ {%d}; @@ -43,9 +48,9 @@ ccs_terminal_regexp[TerminalType.STRING], ccs_terminal_precedence[TerminalType.STRING], ccs_terminal_regexp[TerminalType.INTEGER], ccs_terminal_precedence[TerminalType.INTEGER], ccs_terminal_regexp[TerminalType.FLOAT], ccs_terminal_precedence[TerminalType.FLOAT]) - _actions = {} _expr_actions = [ lambda _, n: n[1] ] + for i in range(ExpressionType.OR, ExpressionType.LIST): if ccs_expression_arity[i] == 1: _expr_actions.append((lambda s: lambda _, n: Expression.EXPRESSION_MAP[s](node = n[1]))(i)) @@ -57,14 +62,27 @@ lambda _, n: Expression.List(values = n[1]), lambda _, n: Expression.List(values = []) ] + +def wrap_user_function(proc): + def wrap(expr, *args): + return proc(*args) + return wrap + +_actions["user_defined"] = [ + lambda p, n: Expression.UserDefined(name = n[0], nodes = n[2], evaluate = wrap_user_function(p.extra[1][n[0]])), + lambda p, n: Expression.UserDefined(name = n[0], evaluate = wrap_user_function(p.extra[1][n[0]])) +] _actions["list_item"] = [ lambda _, n: n[0] + [n[2]], lambda _, n: [n[0]] ] +_actions["variable"] = [ + lambda p, n: Expression.Variable(parameter = p.extra[0][n[0]] if isinstance(p.extra[0], dict) else p.extra[0].parameter_by_name(n[0])) +] _actions["none"] = lambda _, value: Expression.Literal(value = None) _actions["true"] = lambda _, value: Expression.Literal(value = True) _actions["false"] = lambda _, value: Expression.Literal(value = False) -_actions["identifier"] = lambda p, value: Expression.Variable(parameter = p.extra[value] if isinstance(p.extra, dict) else p.extra.parameter_by_name(value)) +_actions["identifier"] = lambda _, value: value _actions["string"] = lambda _, value: Expression.Literal(value = eval(value)) _actions["float"] = lambda _, value: Expression.Literal(value = float(value)) _actions["integer"] = lambda _, value: Expression.Literal(value = int(value)) @@ -73,5 +91,5 @@ parser = Parser(_g, actions=_actions) -def parse(expr): - return parser.parse(expr) +def parse(expr, context = {}, binding = {}): + return parser.parse(expr, extra = (context, binding)) diff --git a/bindings/python/cconfigspace/objective_space.py b/bindings/python/cconfigspace/objective_space.py index a97f88ab..ea8ecc65 100644 --- a/bindings/python/cconfigspace/objective_space.py +++ b/bindings/python/cconfigspace/objective_space.py @@ -20,7 +20,7 @@ class ObjectiveType(CEnumeration): class ObjectiveSpace(Context): def __init__(self, handle = None, retain = False, auto_release = True, - name = "", search_space = None, parameters = [], objectives = [], types = None): + name = "", search_space = None, parameters = [], objectives = [], types = None, binding = None): if handle is None: count = len(parameters) @@ -34,7 +34,8 @@ def __init__(self, handle = None, retain = False, auto_release = True, if isinstance(objectives, dict): types = objectives.values() objectives = objectives.keys() - objectives = [ parser.parse(objective, extra = ctx) if isinstance(objective, str) else objective for objective in objectives ] + extra = (ctx, binding) + objectives = [ parser.parse(objective, extra = extra) if isinstance(objective, str) else objective for objective in objectives ] sz = len(objectives) if types: if len(types) != sz: diff --git a/bindings/python/cconfigspace/tree_space.py b/bindings/python/cconfigspace/tree_space.py index 1b43042f..9e993bcc 100644 --- a/bindings/python/cconfigspace/tree_space.py +++ b/bindings/python/cconfigspace/tree_space.py @@ -1,8 +1,9 @@ import ctypes as ct from . import libcconfigspace -from .base import Object, Error, Result, ccs_rng, ccs_tree, ccs_tree_space, ccs_feature_space, ccs_features, ccs_tree_configuration, Datum, ccs_bool, _ccs_get_function, CEnumeration, _register_vector, _unregister_vector, ccs_retain_object +from .base import Object, Error, Result, ccs_rng, ccs_tree, ccs_tree_space, ccs_feature_space, ccs_features, ccs_tree_configuration, Datum, ccs_bool, _ccs_get_function, CEnumeration, ccs_retain_object from .rng import Rng from .tree import Tree +from .feature_space import FeatureSpace class TreeSpaceType(CEnumeration): _members_ = [ @@ -165,7 +166,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, ccs_dynamic_tree_space_del_type = ct.CFUNCTYPE(Result, ccs_tree_space) ccs_dynamic_tree_space_get_child_type = ct.CFUNCTYPE(Result, ccs_tree_space, ccs_tree, ct.c_size_t, ct.POINTER(ccs_tree)) ccs_dynamic_tree_space_serialize_type = ct.CFUNCTYPE(Result, ccs_tree_space, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t)) -ccs_dynamic_tree_space_deserialize_type = ct.CFUNCTYPE(Result, ccs_tree_space, ct.c_size_t, ct.c_void_p) +ccs_dynamic_tree_space_deserialize_type = ct.CFUNCTYPE(Result, ccs_tree, ccs_feature_space, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.py_object)) class DynamicTreeSpaceVector(ct.Structure): _fields_ = [ @@ -177,73 +178,6 @@ class DynamicTreeSpaceVector(ct.Structure): ccs_create_dynamic_tree_space = _ccs_get_function("ccs_create_dynamic_tree_space", [ct.c_char_p, ccs_tree, ccs_feature_space, ccs_rng, ct.POINTER(DynamicTreeSpaceVector), ct.py_object, ct.POINTER(ccs_tree_space)]) ccs_dynamic_tree_space_get_tree_space_data = _ccs_get_function("ccs_dynamic_tree_space_get_tree_space_data", [ccs_tree_space, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize): - def delete_wrapper(ts): - try: - ts = ct.cast(ts, ccs_tree_space) - if delete is not None: - delete(Object.from_handle(ts)) - _unregister_vector(ts) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_child_wrapper(ts, parent, index, p_child): - try: - ts = ct.cast(ts, ccs_tree_space) - parent = ct.cast(parent, ccs_tree) - child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) - res = ccs_retain_object(child.handle) - Error.check(res) - p_child[0] = child.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if serialize is not None: - def serialize_wrapper(ts, state_size, p_state, p_state_size): - try: - ts = ct.cast(ts, ccs_tree_space) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(TreeSpace.from_handle(ts), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(ts, state_size, p_state): - try: - ts = ct.cast(ts, ccs_tree_space) - p_s = ct.cast(p_state, ct.c_void_p) - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - deserialize(TreeSpace.from_handle(ts), state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - return (delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, - ccs_dynamic_tree_space_del_type(delete_wrapper), - ccs_dynamic_tree_space_get_child_type(get_child_wrapper), - ccs_dynamic_tree_space_serialize_type(serialize_wrapper), - ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper)) - class DynamicTreeSpace(TreeSpace): def __init__(self, handle = None, retain = False, auto_release = True, @@ -252,20 +186,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, if get_child is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - get_child_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) - handle = ccs_tree_space() - vec = DynamicTreeSpaceVector() - vec.delete = delete_wrapper_func - vec.get_child = get_child_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func + vec = self.get_vector(delete, get_child, serialize, deserialize) if tree_space_data is not None: c_tree_space_data = ct.py_object(tree_space_data) else: @@ -274,35 +195,16 @@ def __init__(self, handle = None, retain = False, auto_release = True, rng = rng.handle if feature_space is not None: feature_space = feature_space.handle + handle = ccs_tree_space() res = ccs_create_dynamic_tree_space(str.encode(name), tree.handle, feature_space, rng, ct.byref(vec), c_tree_space_data, ct.byref(handle)) Error.check(res) super().__init__(handle = handle, retain = False) - _register_vector(handle, [delete_wrapper, get_child_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tree_space_data]) + ct.pythonapi.Py_IncRef(ct.py_object(vec)) + if c_tree_space_data is not None: + ct.pythonapi.Py_IncRef(c_tree_space_data) else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, get_child, serialize = None, deserialize = None, tree_space_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if get_child is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - get_child_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - get_child_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_callbacks(delete, get_child, serialize, deserialize) - handle = ccs_tree_space() - vector = DynamicTreeSpaceVector() - vector.delete = delete_wrapper_func - vector.get_child = get_child_wrapper_func - vector.serialize = serialize_wrapper_func - vector.deserialize = deserialize_wrapper_func - res = super().deserialize(format = format, handle_map = handle_map, vector = vector, data = tree_space_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - _register_vector(res.handle, [delete_wrapper, get_child_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, get_child_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tree_space_data]) - return res - @property def tree_space_data(self): if hasattr(self, "_tree_space_data"): @@ -316,6 +218,87 @@ def tree_space_data(self): self._tree_space_data = None return self._tree_space_data + @classmethod + def get_vector(self, delete = None, get_child = None, serialize = None, deserialize = None): + if get_child is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + vec = DynamicTreeSpaceVector() + def delete_wrapper(ts): + try: + ts = ct.cast(ts, ccs_tree_space) + o = Object.from_handle(ts) + tsdata = o.tree_space_data + if delete is not None: + delete(o) + if tsdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tsdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_child_wrapper(ts, parent, index, p_child): + try: + ts = ct.cast(ts, ccs_tree_space) + parent = ct.cast(parent, ccs_tree) + child = get_child(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) + res = ccs_retain_object(child.handle) + Error.check(res) + p_child[0] = child.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if serialize is not None: + def serialize_wrapper(ts, state_size, p_state, p_state_size): + try: + serialized = serialize(TreeSpace.from_handle(ts)) + state = ct.create_string_buffer(serialized, len(serialized)) + if p_state and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_state: + ct.memmove(p_state, ct.byref(state), ct.sizeof(state)) + if p_state_size: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(tree, feature_space, state_size, p_state, p_tree_space_data): + try: + tree_space_data = deserialize(Tree.from_handle(tree), FeatureSpace.from_handle(feature_space) if feature_space else None, ct.string_at(p_state, state_size)) + c_tree_space_data = ct.py_object(tree_space_data) + p_tree_space_data[0] = c_tree_space_data + ct.pythonapi.Py_IncRef(c_tree_space_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_dynamic_tree_space_del_type(delete_wrapper) + get_child_wrapper_func = ccs_dynamic_tree_space_get_child_type(get_child_wrapper) + serialize_wrapper_func = ccs_dynamic_tree_space_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_dynamic_tree_space_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.get_child = get_child_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + get_child_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + get_child_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + TreeSpace.Dynamic = DynamicTreeSpace from .tree_configuration import TreeConfiguration diff --git a/bindings/python/cconfigspace/tuner.py b/bindings/python/cconfigspace/tuner.py index f4df3bdb..67b8c5b8 100644 --- a/bindings/python/cconfigspace/tuner.py +++ b/bindings/python/cconfigspace/tuner.py @@ -1,5 +1,6 @@ import ctypes as ct -from .base import Object, Error, CEnumeration, Result, _ccs_get_function, ccs_context, ccs_parameter, ccs_search_space, ccs_search_configuration, ccs_feature_space, ccs_features, Datum, ccs_objective_space, ccs_evaluation, ccs_tuner, ccs_retain_object, _register_vector, _unregister_vector +import sys +from .base import Object, Error, CEnumeration, Result, _ccs_get_function, ccs_context, ccs_parameter, ccs_search_space, ccs_search_configuration, ccs_feature_space, ccs_features, Datum, ccs_objective_space, ccs_evaluation, ccs_tuner, ccs_retain_object from .context import Context from .parameter import Parameter from .features import Features @@ -167,7 +168,7 @@ def __init__(self, handle = None, retain = False, auto_release = True, ccs_user_defined_tuner_get_history_type = ct.CFUNCTYPE(Result, ccs_tuner, ccs_features, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.POINTER(ct.c_size_t)) ccs_user_defined_tuner_suggest_type = ct.CFUNCTYPE(Result, ccs_tuner, ccs_features, ct.POINTER(ccs_search_configuration)) ccs_user_defined_tuner_serialize_type = ct.CFUNCTYPE(Result, ccs_tuner, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t)) -ccs_user_defined_tuner_deserialize_type = ct.CFUNCTYPE(Result, ccs_tuner, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.c_void_p) +ccs_user_defined_tuner_deserialize_type = ct.CFUNCTYPE(Result, ccs_objective_space, ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.POINTER(ccs_evaluation), ct.c_size_t, ct.c_void_p, ct.POINTER(ct.py_object)) class UserDefinedTunerVector(ct.Structure): _fields_ = [ @@ -183,169 +184,6 @@ class UserDefinedTunerVector(ct.Structure): ccs_create_user_defined_tuner = _ccs_get_function("ccs_create_user_defined_tuner", [ct.c_char_p, ccs_objective_space, ct.POINTER(UserDefinedTunerVector), ct.py_object, ct.POINTER(ccs_tuner)]) ccs_user_defined_tuner_get_tuner_data = _ccs_get_function("ccs_user_defined_tuner_get_tuner_data", [ccs_tuner, ct.POINTER(ct.c_void_p)]) -def _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize): - def delete_wrapper(tun): - try: - tun = ct.cast(tun, ccs_tuner) - if delete is not None: - delete(Object.from_handle(tun)) - _unregister_vector(tun) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def ask_wrapper(tun, features, count, p_configurations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_confs = ct.cast(p_configurations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) - if p_confs.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_confs.value is not None: - for i in range(len(configurations)): - res = ccs_retain_object(configurations[i].handle) - Error.check(res) - p_configurations[i] = configurations[i].handle.value - for i in range(len(configurations), count): - p_configurations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def tell_wrapper(tun, count, p_evaluations): - try: - tun = ct.cast(tun, ccs_tuner) - if count == 0: - return Result.SUCCESS - p_evals = ct.cast(p_evaluations, ct.c_void_p) - if p_evals.value is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] - tell(Tuner.from_handle(tun), evals) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_optima_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = len(optima) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = optima[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - def get_history_wrapper(tun, features, count, p_evaluations, p_count): - try: - tun = ct.cast(tun, ccs_tuner) - p_evals = ct.cast(p_evaluations, ct.c_void_p) - p_c = ct.cast(p_count, ct.c_void_p) - history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - count_ret = (len(history) if history else 0) - if p_evals.value is not None and count < count_ret: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_evals.value is not None: - for i in range(count_ret): - p_evaluations[i] = history[i].handle.value - for i in range(count_ret, count): - p_evaluations[i] = None - if p_c.value is not None: - p_count[0] = count_ret - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - - if suggest is not None: - def suggest_wrapper(tun, features, p_configuration): - try: - tun = ct.cast(tun, ccs_tuner) - configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) - res = ccs_retain_object(configuration.handle) - Error.check(res) - p_configuration[0] = configuration.handle.value - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - suggest_wrapper = 0 - - if serialize is not None: - def serialize_wrapper(tun, state_size, p_state, p_state_size): - try: - tun = ct.cast(tun, ccs_tuner) - p_s = ct.cast(p_state, ct.c_void_p) - p_sz = ct.cast(p_state_size, ct.c_void_p) - state = serialize(Tuner.from_handle(tun), True if state_size == 0 else False) - if p_s.value is not None and state_size < ct.sizeof(state): - raise Error(Result(Result.ERROR_INVALID_VALUE)) - if p_s.value is not None: - ct.memmove(p_s, ct.byref(state), ct.sizeof(state)) - if p_sz.value is not None: - p_state_size[0] = ct.sizeof(state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - serialize_wrapper = 0 - - if deserialize is not None: - def deserialize_wrapper(tun, size_history, p_history, num_optima, p_optima, state_size, p_state): - try: - tun = ct.cast(tun, ccs_tuner) - p_h = ct.cast(p_history, ct.c_void_p) - p_o = ct.cast(p_optima, ct.c_void_p) - p_s = ct.cast(p_state, ct.c_void_p) - if p_h.value is None: - history = [] - else: - history = [Evaluation.from_handle(ccs_evaluation(p_h[i])) for i in range(size_history)] - if p_o.value is None: - optima = [] - else: - optima = [Evaluation.from_handle(ccs_evaluation(p_o[i])) for i in range(num_optima)] - if p_s.value is None: - state = None - else: - state = ct.cast(p_s, POINTER(c_byte * state_size)) - deserialize(Tuner.from_handle(tun), history, optima, state) - return Result.SUCCESS - except Exception as e: - return Error.set_error(e) - else: - deserialize_wrapper = 0 - - return (delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, - ccs_user_defined_tuner_del_type(delete_wrapper), - ccs_user_defined_tuner_ask_type(ask_wrapper), - ccs_user_defined_tuner_tell_type(tell_wrapper), - ccs_user_defined_tuner_get_optima_type(get_optima_wrapper), - ccs_user_defined_tuner_get_history_type(get_history_wrapper), - ccs_user_defined_tuner_suggest_type(suggest_wrapper), - ccs_user_defined_tuner_serialize_type(serialize_wrapper), - ccs_user_defined_tuner_deserialize_type(deserialize_wrapper)) - - class UserDefinedTuner(Tuner): def __init__(self, handle = None, retain = False, auto_release = True, name = "", objective_space = None, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None, tuner_data = None ): @@ -353,76 +191,21 @@ def __init__(self, handle = None, retain = False, auto_release = True, if ask is None or tell is None or get_optima is None or get_history is None: raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - ask_wrapper_func, - tell_wrapper_func, - get_optima_wrapper_func, - get_history_wrapper_func, - suggest_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - handle = ccs_tuner() - vec = UserDefinedTunerVector() - vec.delete = delete_wrapper_func - vec.ask = ask_wrapper_func - vec.tell = tell_wrapper_func - vec.get_optima = get_optima_wrapper_func - vec.get_history = get_history_wrapper_func - vec.suggest = suggest_wrapper_func - vec.serialize = serialize_wrapper_func - vec.deserialize = deserialize_wrapper_func + vec = self.get_vector(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) if tuner_data is not None: c_tuner_data = ct.py_object(tuner_data) else: c_tuner_data = None + handle = ccs_tuner() res = ccs_create_user_defined_tuner(str.encode(name), objective_space.handle, ct.byref(vec), c_tuner_data, ct.byref(handle)) Error.check(res) super().__init__(handle = handle, retain = False) - _register_vector(handle, [delete_wrapper, ask_wrapper, tell_wrapper, get_optima_wrapper, get_history_wrapper, suggest_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optima_wrapper_func, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tuner_data]) + ct.pythonapi.Py_IncRef(ct.py_object(vec)) + if c_tuner_data is not None: + ct.pythonapi.Py_IncRef(c_tuner_data) else: super().__init__(handle = handle, retain = retain, auto_release = auto_release) - @classmethod - def deserialize(cls, delete, ask, tell, get_optima, get_history, suggest = None, serialize = None, deserialize = None, tuner_data = None, format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None): - if ask is None or tell is None or get_optima is None or get_history is None: - raise Error(Result(Result.ERROR_INVALID_VALUE)) - (delete_wrapper, - ask_wrapper, - tell_wrapper, - get_optima_wrapper, - get_history_wrapper, - suggest_wrapper, - serialize_wrapper, - deserialize_wrapper, - delete_wrapper_func, - ask_wrapper_func, - tell_wrapper_func, - get_optima_wrapper_func, - get_history_wrapper_func, - suggest_wrapper_func, - serialize_wrapper_func, - deserialize_wrapper_func) = _wrap_user_defined_tuner_callbacks(delete, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - vector = UserDefinedTunerVector() - vector.delete = delete_wrapper_func - vector.ask = ask_wrapper_func - vector.tell = tell_wrapper_func - vector.get_optima = get_optima_wrapper_func - vector.get_history = get_history_wrapper_func - vector.suggest = suggest_wrapper_func - vector.serialize = serialize_wrapper_func - vector.deserialize = deserialize_wrapper_func - res = super().deserialize(format = format, handle_map = handle_map, vector = vector, data = tuner_data, path = path, buffer = buffer, file_descriptor = file_descriptor, callback = callback, callback_data = callback_data) - _register_vector(res.handle, [delete_wrapper, ask_wrapper, tell_wrapper, get_optima_wrapper, get_history_wrapper, suggest_wrapper, serialize_wrapper, deserialize_wrapper, delete_wrapper_func, ask_wrapper_func, tell_wrapper_func, get_optima_wrapper_func, get_history_wrapper_func, suggest_wrapper_func, serialize_wrapper_func, deserialize_wrapper_func, tuner_data]) - return res - @property def tuner_data(self): if hasattr(self, "_tuner_data"): @@ -436,4 +219,184 @@ def tuner_data(self): self._tuner_data = None return self._tuner_data + @classmethod + def get_vector(self, delete = None, ask = None, tell = None, get_optima = None, get_history = None, suggest = None, serialize = None, deserialize = None): + vec = UserDefinedTunerVector() + def delete_wrapper(tun): + try: + tun = ct.cast(tun, ccs_tuner) + o = Object.from_handle(tun) + tdata = o.tuner_data + if delete is not None: + delete(o) + if tdata is not None: + ct.pythonapi.Py_DecRef(ct.py_object(tdata)) + ct.pythonapi.Py_DecRef(ct.py_object(vec)) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def ask_wrapper(tun, features, count, p_configurations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_confs = ct.cast(p_configurations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + (configurations, count_ret) = ask(Tuner.from_handle(tun), Features.from_handle(features) if features else None, count if p_confs.value else None) + if p_confs.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_confs.value is not None: + for i in range(len(configurations)): + res = ccs_retain_object(configurations[i].handle) + Error.check(res) + p_configurations[i] = configurations[i].handle.value + for i in range(len(configurations), count): + p_configurations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def tell_wrapper(tun, count, p_evaluations): + try: + tun = ct.cast(tun, ccs_tuner) + if count == 0: + return Result.SUCCESS + p_evals = ct.cast(p_evaluations, ct.c_void_p) + if p_evals.value is None: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + evals = [Evaluation.from_handle(ccs_evaluation(p_evaluations[i])) for i in range(count)] + tell(Tuner.from_handle(tun), evals) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_optima_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + optima = get_optima(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = len(optima) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = optima[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + def get_history_wrapper(tun, features, count, p_evaluations, p_count): + try: + tun = ct.cast(tun, ccs_tuner) + p_evals = ct.cast(p_evaluations, ct.c_void_p) + p_c = ct.cast(p_count, ct.c_void_p) + history = get_history(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + count_ret = (len(history) if history else 0) + if p_evals.value is not None and count < count_ret: + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_evals.value is not None: + for i in range(count_ret): + p_evaluations[i] = history[i].handle.value + for i in range(count_ret, count): + p_evaluations[i] = None + if p_c.value is not None: + p_count[0] = count_ret + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + + if suggest is not None: + def suggest_wrapper(tun, features, p_configuration): + try: + tun = ct.cast(tun, ccs_tuner) + configuration = suggest(Tuner.from_handle(tun), Features.from_handle(features) if features else None) + res = ccs_retain_object(configuration.handle) + Error.check(res) + p_configuration[0] = configuration.handle.value + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + suggest_wrapper = 0 + + if serialize is not None: + def serialize_wrapper(tun, state_size, p_state, p_state_size): + try: + serialized = serialize(Tuner.from_handle(tun)) + state = ct.create_string_buffer(serialized, len(serialized)) + if p_state and state_size < ct.sizeof(state): + raise Error(Result(Result.ERROR_INVALID_VALUE)) + if p_state: + ct.memmove(p_state, ct.byref(state), ct.sizeof(state)) + if p_state_size: + p_state_size[0] = ct.sizeof(state) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + serialize_wrapper = 0 + + if deserialize is not None: + def deserialize_wrapper(o_space, size_history, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data): + try: + if p_history: + history = [Evaluation.from_handle(p_history[i]) for i in range(size_history)] + else: + history = [] + if p_optima: + optima = [Evaluation.from_handle(p_optima[i]) for i in range(num_optima)] + else: + optima = [] + tuner_data = deserialize(ObjectiveSpace.from_handle(o_space), history, optima, ct.string_at(p_state, state_size)) + c_tuner_data = ct.py_object(tuner_data) + p_tuner_data[0] = c_tuner_data + ct.pythonapi.Py_IncRef(c_tuner_data) + return Result.SUCCESS + except Exception as e: + return Error.set_error(e) + else: + deserialize_wrapper = 0 + + delete_wrapper_func = ccs_user_defined_tuner_del_type(delete_wrapper) + ask_wrapper_func = ccs_user_defined_tuner_ask_type(ask_wrapper) + tell_wrapper_func = ccs_user_defined_tuner_tell_type(tell_wrapper) + get_optima_wrapper_func = ccs_user_defined_tuner_get_optima_type(get_optima_wrapper) + get_history_wrapper_func = ccs_user_defined_tuner_get_history_type(get_history_wrapper) + suggest_wrapper_func = ccs_user_defined_tuner_suggest_type(suggest_wrapper) + serialize_wrapper_func = ccs_user_defined_tuner_serialize_type(serialize_wrapper) + deserialize_wrapper_func = ccs_user_defined_tuner_deserialize_type(deserialize_wrapper) + vec.delete = delete_wrapper_func + vec.ask = ask_wrapper_func + vec.tell = tell_wrapper_func + vec.get_optima = get_optima_wrapper_func + vec.get_history = get_history_wrapper_func + vec.suggest = suggest_wrapper_func + vec.serialize = serialize_wrapper_func + vec.deserialize = deserialize_wrapper_func + + setattr(vec, '_wrappers', ( + delete_wrapper, + ask_wrapper, + tell_wrapper, + get_optima_wrapper, + get_history_wrapper, + suggest_wrapper, + serialize_wrapper, + deserialize_wrapper, + delete_wrapper_func, + ask_wrapper_func, + tell_wrapper_func, + get_optima_wrapper_func, + get_history_wrapper_func, + suggest_wrapper_func, + serialize_wrapper_func, + deserialize_wrapper_func)) + return vec + Tuner.UserDefined = UserDefinedTuner diff --git a/bindings/python/test/test_expression.py b/bindings/python/test/test_expression.py index 4a62ad91..4da6ce14 100644 --- a/bindings/python/test/test_expression.py +++ b/bindings/python/test/test_expression.py @@ -1,5 +1,7 @@ import unittest import sys +import random +import pickle sys.path.insert(1, '.') sys.path.insert(1, '..') import cconfigspace as ccs @@ -59,5 +61,31 @@ def test_binary(self): self.assertEqual( "true || false", str(e) ) self.assertTrue( e.eval() ) + def test_user_defined(self): + def my_rand(expr, limit): + return expr.expression_data.randrange(limit) + + def my_serialize(expr): + return pickle.dumps(expr.expression_data) + + def my_deserialize(state): + return pickle.loads(state) + + def get_vector_data(otype, name): + self.assertEqual( ccs.ObjectType.EXPRESSION, otype ) + self.assertEqual( "rand", name) + return (ccs.Expression.UserDefined.get_vector(evaluate = my_rand, serialize = my_serialize, deserialize = my_deserialize), None) + + limit = 10 + e = ccs.Expression.UserDefined(name = "rand", nodes = [limit], expression_data = random.Random(), evaluate = my_rand, serialize = my_serialize, deserialize = my_deserialize) + self.assertEqual( "rand(10)", str(e) ) + evals = [ e.eval() for i in range(100) ] + self.assertTrue( all(i >= 0 and i < limit for i in evals) ) + + buff = e.serialize() + e_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) + + self.assertTrue( all(e.eval() == e_copy.eval() for i in range(100)) ) + if __name__ == '__main__': unittest.main() diff --git a/bindings/python/test/test_expression_parser.py b/bindings/python/test/test_expression_parser.py index 0a24efd0..9a9e5f72 100644 --- a/bindings/python/test/test_expression_parser.py +++ b/bindings/python/test/test_expression_parser.py @@ -57,6 +57,25 @@ def test_none(self): self.assertIsNone( res.eval() ) self.assertEqual( "none", res.__str__() ) + def test_function(self): + def func(a, b): + return a * b + l = locals() + + exp = "func(3, 4)" + res = ccs.parse(exp, binding = l) + self.assertIsInstance( res, ccs.Expression.UserDefined ) + self.assertEqual( 12, res.eval() ) + self.assertEqual( "func(3, 4)", res.__str__() ) + + def get_vector_data(otype, name): + self.assertEqual( ccs.ObjectType.EXPRESSION, otype ) + return ccs.Expression.get_function_vector_data(name, binding = l) + + buff = res.serialize() + res_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) + self.assertEqual( "func(3, 4)", res_copy.__str__() ) + self.assertEqual( 12, res_copy.eval() ) if __name__ == '__main__': unittest.main() diff --git a/bindings/python/test/test_features_tuner.py b/bindings/python/test/test_features_tuner.py index cd6ce770..b6bdef9b 100644 --- a/bindings/python/test/test_features_tuner.py +++ b/bindings/python/test/test_features_tuner.py @@ -118,6 +118,9 @@ def suggest(tuner, features): else: return choice(optis).configuration + def get_vector_data(otype, name): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + (fs, os) = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -155,7 +158,7 @@ def suggest(tuner, features): self.assertTrue(t.suggest(features_off) in [x.configuration for x in optims]) # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tree_space.py b/bindings/python/test/test_tree_space.py index 39f1ef4e..d6862dbb 100644 --- a/bindings/python/test/test_tree_space.py +++ b/bindings/python/test/test_tree_space.py @@ -54,6 +54,9 @@ def get_child(tree_space, parent, child_index): arity = 0 if arity < 0 else arity return ccs.Tree(arity = arity, value = (4 - child_depth)*100 + child_index) + def get_vector_data(otype, name): + return (ccs.DynamicTreeSpace.get_vector(delete = delete, get_child = get_child), None) + tree = ccs.Tree(arity = 4, value = 400) ts = ccs.DynamicTreeSpace(name = 'space', tree = tree, delete = delete, get_child = get_child) self.assertEqual( ccs.ObjectType.TREE_SPACE, ts.object_type ) @@ -75,7 +78,7 @@ def get_child(tree_space, parent, child_index): self.assertTrue( ts.check_configuration(tc) ) buff = ts.serialize() - ts2 = ccs.DynamicTreeSpace.deserialize(buffer = buff, delete = delete, get_child = get_child) + ts2 = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) self.assertEqual( [400, 301, 201], ts2.get_values_at_position([1, 1]) ) def test_tree_configuration(self): diff --git a/bindings/python/test/test_tree_tuner.py b/bindings/python/test/test_tree_tuner.py index 9afe2a5a..9461b0ee 100644 --- a/bindings/python/test/test_tree_tuner.py +++ b/bindings/python/test/test_tree_tuner.py @@ -107,6 +107,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -125,7 +128,7 @@ def suggest(tuner, features): self.assertTrue(all(best >= x.objective_values[0] for x in hist)) self.assertTrue(t.suggest() in [x.configuration for x in optims]) buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() diff --git a/bindings/python/test/test_tuner.py b/bindings/python/test/test_tuner.py index 577ba5d6..f65ae2e1 100644 --- a/bindings/python/test/test_tuner.py +++ b/bindings/python/test/test_tuner.py @@ -101,6 +101,9 @@ def suggest(tuner, features): else: return choice(tuner.tuner_data.optima).configuration + def get_vector_data(otype, name): + return (ccs.UserDefinedTuner.get_vector(delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest), TunerData()) + os = self.create_tuning_problem() t = ccs.UserDefinedTuner(name = "tuner", objective_space = os, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) t2 = ccs.Object.from_handle(t.handle) @@ -126,7 +129,7 @@ def suggest(tuner, features): # test serialization buff = t.serialize() - t_copy = ccs.UserDefinedTuner.deserialize(buffer = buff, delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(buffer = buff, vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -137,7 +140,7 @@ def suggest(tuner, features): self.assertTrue(t_copy.suggest() in [x.configuration for x in optims_2]) t.serialize(path = 'tuner.ccs') - t_copy = ccs.UserDefinedTuner.deserialize(path = 'tuner.ccs', delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(path = 'tuner.ccs', vector_callback = get_vector_data) hist = t_copy.history() self.assertEqual(200, len(hist)) optims_2 = t_copy.optima() @@ -152,7 +155,7 @@ def suggest(tuner, features): t.serialize(file_descriptor = file.fileno()) file.close() file = open( 'tuner.ccs', "rb") - t_copy = ccs.UserDefinedTuner.deserialize(file_descriptor = file.fileno(), delete = delete, ask = ask, tell = tell, get_optima = get_optima, get_history = get_history, suggest = suggest, tuner_data = TunerData()) + t_copy = ccs.deserialize(file_descriptor = file.fileno(), vector_callback = get_vector_data) file.close() hist = t_copy.history() self.assertEqual(200, len(hist)) diff --git a/bindings/ruby/Makefile.am b/bindings/ruby/Makefile.am index 73971155..79586395 100644 --- a/bindings/ruby/Makefile.am +++ b/bindings/ruby/Makefile.am @@ -51,6 +51,6 @@ LIBEXT = .so endif check: FORCE - LIBCCONFIGSPACE_SO=$(top_builddir)/src/.libs/libcconfigspace$(LIBEXT) SRC_DIR=$(srcdir) rake -f $(srcdir)/rakefile test + LIBCCONFIGSPACE_SO=$(abs_top_builddir)/src/.libs/libcconfigspace$(LIBEXT) SRC_DIR=$(srcdir) rake -f $(srcdir)/rakefile test FORCE: diff --git a/bindings/ruby/cconfigspace.gemspec b/bindings/ruby/cconfigspace.gemspec index 2ed6bcc2..e880107c 100644 --- a/bindings/ruby/cconfigspace.gemspec +++ b/bindings/ruby/cconfigspace.gemspec @@ -10,6 +10,6 @@ Gem::Specification.new do |s| s.license = 'BSD-3-Clause' s.required_ruby_version = '>= 2.3.0' s.add_dependency 'ffi', '~> 1.13', '>=1.13.0' - s.add_dependency 'ffi-value', '~> 0.1', '>=0.1.1' + s.add_dependency 'ffi-value', '~> 0.1', '>=0.1.3' s.add_dependency 'whittle', '~> 0.0', '>=0.0.8' end diff --git a/bindings/ruby/lib/cconfigspace/base.rb b/bindings/ruby/lib/cconfigspace/base.rb index 0a2bd2e5..5931efa6 100644 --- a/bindings/ruby/lib/cconfigspace/base.rb +++ b/bindings/ruby/lib/cconfigspace/base.rb @@ -1,5 +1,4 @@ require 'singleton' -require 'yaml' module CCS extend FFI::Library @@ -9,7 +8,59 @@ module CCS ffi_lib "cconfigspace" end + module MemoryAccessor + def self.included(mod) + mod.class_eval { + alias read_ccs_float_t read_double + alias get_ccs_float_t get_double + alias read_array_of_ccs_float_t read_array_of_double + alias write_array_of_ccs_float_t write_array_of_double + alias read_ccs_int_t read_int64 + alias get_ccs_int_t get_int64 + alias read_array_of_ccs_int_t read_array_of_int64 + alias write_array_of_ccs_int_t write_array_of_int64 + alias read_ccs_bool_t read_int32 + alias read_ccs_evaluation_result_t read_int32 + alias read_ccs_hash_t read_uint32 + if FFI.find_type(:size_t).size == 8 + alias read_size_t read_uint64 + alias write_size_t write_uint64 + alias write_array_of_size_t write_array_of_uint64 + alias read_array_of_size_t read_array_of_uint64 + else + alias read_size_t read_uint32 + alias write_size_t write_uint32 + alias write_array_of_size_t write_array_of_uint32 + alias read_array_of_size_t read_array_of_uint32 + end + alias read_ccs_object_t read_pointer + alias read_ccs_rng_t read_ccs_object_t + alias read_ccs_distribution_t read_ccs_object_t + alias read_ccs_parameter_t read_ccs_object_t + alias read_ccs_expression_t read_ccs_object_t + alias read_ccs_context_t read_ccs_object_t + alias read_ccs_distribution_space_t read_ccs_object_t + alias read_ccs_search_space_t read_ccs_object_t + alias read_ccs_configuration_space_t read_ccs_object_t + alias read_ccs_binding_t read_ccs_object_t + alias read_ccs_search_configuration_t read_ccs_object_t + alias read_ccs_configuration_t read_ccs_object_t + alias read_ccs_feature_space_t read_ccs_object_t + alias read_ccs_features_t read_ccs_object_t + alias read_ccs_objective_space_t read_ccs_object_t + alias read_ccs_evaluation_t read_ccs_object_t + alias read_ccs_tuner_t read_ccs_object_t + alias read_ccs_map_t read_ccs_object_t + alias read_ccs_error_stack_t read_ccs_object_t + alias read_ccs_tree_t read_ccs_object_t + alias read_ccs_tree_space_t read_ccs_object_t + alias read_ccs_tree_configuration_t read_ccs_object_t + } + end + end + class Pointer < FFI::Pointer + include MemoryAccessor def initialize(*args) if args.length == 2 then super(CCS::find_type(args[0]), args[1]) @@ -17,29 +68,10 @@ def initialize(*args) super(*args) end end - alias read_ccs_float_t read_double - alias get_ccs_float_t get_double - alias read_array_of_ccs_float_t read_array_of_double - alias write_array_of_ccs_float_t write_array_of_double - alias read_ccs_int_t read_int64 - alias get_ccs_int_t get_int64 - alias read_array_of_ccs_int_t read_array_of_int64 - alias write_array_of_ccs_int_t write_array_of_int64 - alias read_ccs_bool_t read_int32 - alias read_ccs_evaluation_result_t read_int32 - alias read_ccs_hash_t read_uint32 - if FFI.find_type(:size_t).size == 8 - alias read_size_t read_uint64 - alias write_size_t write_uint64 - alias write_array_of_size_t write_array_of_uint64 - else - alias read_size_t read_uint32 - alias write_size_t write_uint32 - alias write_array_of_size_t write_array_of_uint32 - end end class MemoryPointer < FFI::MemoryPointer + include MemoryAccessor def initialize(size, count = 1, clear = true) if size.is_a?(Symbol) size = CCS::find_type(size) @@ -54,31 +86,6 @@ def initialize(size, count = 1, clear = true) typedef :int32, :ccs_evaluation_result_t typedef :uint32, :ccs_hash_t - class MemoryPointer - alias read_ccs_float_t read_double - alias get_ccs_float_t get_double - alias read_array_of_ccs_float_t read_array_of_double - alias write_array_of_ccs_float_t write_array_of_double - alias read_ccs_int_t read_int64 - alias get_ccs_int_t get_int64 - alias read_array_of_ccs_int_t read_array_of_int64 - alias write_array_of_ccs_int_t write_array_of_int64 - alias read_ccs_bool_t read_int32 - alias read_ccs_evaluation_result_t read_int32 - alias read_ccs_hash_t read_uint32 - if FFI.find_type(:size_t).size == 8 - alias read_size_t read_uint64 - alias write_size_t write_uint64 - alias write_array_of_size_t write_array_of_uint64 - alias read_array_of_size_t read_array_of_uint64 - else - alias read_size_t read_uint32 - alias write_size_t write_uint32 - alias write_array_of_size_t write_array_of_uint32 - alias read_array_of_size_t read_array_of_uint32 - end - end - class Version < FFI::Struct layout :revision, :uint16, :patch, :uint16, @@ -112,30 +119,6 @@ class Version < FFI::Struct typedef :ccs_object_t, :ccs_tree_t typedef :ccs_object_t, :ccs_tree_space_t typedef :ccs_object_t, :ccs_tree_configuration_t - class MemoryPointer - alias read_ccs_object_t read_pointer - alias read_ccs_rng_t read_ccs_object_t - alias read_ccs_distribution_t read_ccs_object_t - alias read_ccs_parameter_t read_ccs_object_t - alias read_ccs_expression_t read_ccs_object_t - alias read_ccs_context_t read_ccs_object_t - alias read_ccs_distribution_space_t read_ccs_object_t - alias read_ccs_search_space_t read_ccs_object_t - alias read_ccs_configuration_space_t read_ccs_object_t - alias read_ccs_binding_t read_ccs_object_t - alias read_ccs_search_configuration_t read_ccs_object_t - alias read_ccs_configuration_t read_ccs_object_t - alias read_ccs_feature_space_t read_ccs_object_t - alias read_ccs_features_t read_ccs_object_t - alias read_ccs_objective_space_t read_ccs_object_t - alias read_ccs_evaluation_t read_ccs_object_t - alias read_ccs_tuner_t read_ccs_object_t - alias read_ccs_map_t read_ccs_object_t - alias read_ccs_error_stack_t read_ccs_object_t - alias read_ccs_tree_t read_ccs_object_t - alias read_ccs_tree_space_t read_ccs_object_t - alias read_ccs_tree_configuration_t read_ccs_object_t - end ObjectType = enum FFI::Type::INT32, :ccs_object_type_t, [ :CCS_OBJECT_TYPE_RNG, @@ -189,7 +172,7 @@ class MemoryPointer :CCS_RESULT_ERROR_INVALID_TREE_SPACE, -28, :CCS_RESULT_ERROR_INVALID_DISTRIBUTION_SPACE, -29 ] - class MemoryPointer + module MemoryAccessor def read_ccs_object_type_t ObjectType.from_native(read_int32, nil) end @@ -221,7 +204,7 @@ def read_ccs_result_t :CCS_NUMERIC_TYPE_INT, DataType.to_native(:CCS_DATA_TYPE_INT, nil), :CCS_NUMERIC_TYPE_FLOAT, DataType.to_native(:CCS_DATA_TYPE_FLOAT, nil) ] - class MemoryPointer + module MemoryAccessor def read_ccs_numeric_type_t NumericType.from_native(read_int32, nil) end @@ -244,10 +227,9 @@ def read_ccs_numeric_type_t DeserializeOptions = enum FFI::Type::INT32, :ccs_deserialize_option_t, [ :CCS_DESERIALIZE_OPTION_END, 0, :CCS_DESERIALIZE_OPTION_HANDLE_MAP, - :CCS_DESERIALIZE_OPTION_VECTOR, - :CCS_DESERIALIZE_OPTION_DATA, + :CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, :CCS_DESERIALIZE_OPTION_NON_BLOCKING, - :CCS_DESERIALIZE_OPTION_CALLBACK ] + :CCS_DESERIALIZE_OPTION_DATA_CALLBACK ] class Numeric < FFI::Union layout :f, :ccs_float_t, @@ -476,13 +458,13 @@ def self.from_value(v) end end typedef Datum.by_value, :ccs_datum_t - class MemoryPointer + module MemoryAccessor def read_ccs_datum_t Datum::new(self).value end def read_array_of_ccs_datum_t(length) - length.times.collect { |i| Datum::new(self[i]).value } + length.times.collect { |i| Datum::new(self + i * Datum.size).value } end end @@ -501,7 +483,8 @@ def read_array_of_ccs_datum_t(length) attach_function :ccs_object_get_user_data, [:ccs_object_t, :pointer], :ccs_result_t callback :ccs_object_serialize_callback, [:ccs_object_t, :size_t, :pointer, :pointer, :value], :ccs_result_t attach_function :ccs_object_set_serialize_callback, [:ccs_object_t, :ccs_object_serialize_callback, :value], :ccs_result_t - callback :ccs_object_deserialize_callback, [:ccs_object_t, :size_t, :pointer, :value], :ccs_result_t + callback :ccs_object_deserialize_data_callback, [:ccs_object_t, :size_t, :pointer, :value], :ccs_result_t + callback :ccs_object_deserialize_vector_callback, [:ccs_object_type_t, :string, :value, :pointer, :pointer], :ccs_result_t attach_function :ccs_object_serialize, [:ccs_object_t, :ccs_serialize_format_t, :ccs_serialize_operation_t, :varargs], :ccs_result_t attach_function :ccs_object_deserialize, [:ccs_object_t, :ccs_serialize_format_t, :ccs_serialize_operation_t, :varargs], :ccs_result_t @@ -707,23 +690,23 @@ def to_ptr @handle end - def set_destroy_callback(user_data: nil, &block) - CCS.set_destroy_callback(@handle, user_data: user_data, &block) + def set_destroy_callback(&block) + CCS.set_destroy_callback(@handle, &block) self end - def set_serialize_callback(user_data: nil, &block) - CCS.set_serialize_callback(@handle, user_data: user_data, &block) + def set_serialize_callback(&block) + CCS.set_serialize_callback(@handle, &block) self end - def serialize(format: :binary, path: nil, file_descriptor: nil, callback: nil, callback_data: nil) + def serialize(format: :binary, path: nil, file_descriptor: nil, callback: nil) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if format != :binary raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if path && file_descriptor options = [] if callback cb_wrapper = CCS.get_serialize_wrapper(&callback) - options.concat [:ccs_serialize_option_t, :CCS_SERIALIZE_OPTION_CALLBACK, :ccs_object_serialize_callback, cb_wrapper, :value, callback_data] + options.concat [:ccs_serialize_option_t, :CCS_SERIALIZE_OPTION_CALLBACK, :ccs_object_serialize_callback, cb_wrapper, :value, nil] elsif CCS.default_user_data_serializer options.concat [:ccs_serialize_option_t, :CCS_SERIALIZE_OPTION_CALLBACK, :ccs_object_serialize_callback, CCS.default_user_data_serializer, :value, nil] end @@ -750,7 +733,7 @@ def serialize(format: :binary, path: nil, file_descriptor: nil, callback: nil, c return result end - def self.deserialize(format: :binary, handle_map: nil, vector: nil, data: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) + def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, vector_callback: nil, vector_callback_data: nil, callback: nil) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if format != :binary format = :CCS_SERIALIZE_FORMAT_BINARY mode_count = 0 @@ -761,13 +744,15 @@ def self.deserialize(format: :binary, handle_map: nil, vector: nil, data: nil, p ptr = MemoryPointer::new(:ccs_object_t) options = [] options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_HANDLE_MAP, :ccs_map_t, handle_map.handle] if handle_map - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_VECTOR, :pointer, vector] if vector - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA, :value, data] if data + if vector_callback + cb_wrapper = CCS.get_deserialize_vector_callback_wrapper(&vector_callback) + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, :ccs_object_deserialize_vector_callback, cb_wrapper, :value, vector_callback_data] + end if callback - cb_wrapper = CCS.get_deserialize_wrapper(&callback) - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_CALLBACK, :ccs_object_deserialize_callback, cb_wrapper, :value, callback_data] + cb_wrapper = CCS.get_deserialize_data_callback_wrapper(&callback) + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA_CALLBACK, :ccs_object_deserialize_data_callback, cb_wrapper, :value, nil] elsif CCS.default_user_data_deserializer - options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_CALLBACK, :ccs_object_deserialize_callback, CCS.default_user_data_deserializer, :value, nil] + options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_DATA_CALLBACK, :ccs_object_deserialize_data_callback, CCS.default_user_data_deserializer, :value, nil] end options.concat [:ccs_deserialize_option_t, :CCS_DESERIALIZE_OPTION_END] if buffer @@ -850,22 +835,36 @@ def self.register_serialize_callback(handle, callback_data) @@data_store[value][:serialize_calback] = callback_data end - def self.set_destroy_callback(handle, user_data: nil, &block) + def self.set_destroy_callback(handle, &block) raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !block cb_wrapper = lambda { |object, data| block.call(Object.from_handle(object), data) } - CCS.error_check CCS.ccs_object_set_destroy_callback(handle, cb_wrapper, user_data) - register_callback(handle, [cb_wrapper, user_data]) + CCS.error_check CCS.ccs_object_set_destroy_callback(handle, cb_wrapper, nil) + register_callback(handle, [cb_wrapper]) end def self.get_serialize_wrapper(&block) - lambda { |object, serialize_data_size, serialize_data, serialize_data_size_ret, cb_data| + lambda { |obj, serialize_data_size, serialize_data, serialize_data_size_ret, _| + begin + serialized = block.call(Object.from_handle(obj).user_data) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !serialized.kind_of?(String) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !serialize_data.null? && serialize_data_size < serialized.bytesize + serialize_data.write_bytes(serialized, 0, serialized.bytesize) unless serialize_data.null? + Pointer.new(serialize_data_size_ret).write_size_t(serialized.bytesize) unless serialize_data_size_ret.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + end + + def self.get_deserialize_data_callback_wrapper(&block) + lambda { |obj, serialize_data_size, serialize_data, _| begin - serialized = block.call(Object.from_handle(object), cb_data, serialize_data_size == 0 ? true : false) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !serialize_data.null? && serialize_data_size < serialized.size - serialize_data.write_bytes(serialized.read_bytes(serialized.size)) unless serialize_data.null? - Pointer.new(serialize_data_size_ret).write_size_t(serialized.size) unless serialize_data_size_ret.null? + serialized = serialize_data.null? ? nil : serialize_data.read_bytes(serialize_data_size) + user_data = block.call(serialized) + Object.from_handle(obj).user_data = user_data CCSError.to_native(:CCS_RESULT_SUCCESS) rescue => e CCS.set_error(e) @@ -873,11 +872,14 @@ def self.get_serialize_wrapper(&block) } end - def self.get_deserialize_wrapper(&block) - lambda { |obj, serialize_data_size, serialize_data, cb_data| + def self.get_deserialize_vector_callback_wrapper(&block) + lambda { |obj_type, name, _, vector_ret, data_ret| begin - serialized = serialize_data.null? ? nil : serialize_data.slice(0, serialize_data_size) - block.call(Object.from_handle(obj), serialized, cb_data) + vector, data = block.call(obj_type, name) + vector_ret.write_pointer(vector.pointer) + data_ret.write_value(data) + FFI.inc_ref(vector) + FFI.inc_ref(data) CCSError.to_native(:CCS_RESULT_SUCCESS) rescue => e CCS.set_error(e) @@ -885,35 +887,34 @@ def self.get_deserialize_wrapper(&block) } end - @yaml_user_data_serializer = get_serialize_wrapper { |obj, _, size| - FFI::MemoryPointer.from_string(YAML.dump(obj.user_data)) + @marshal_user_data_serializer = get_serialize_wrapper { |user_data| + Marshal.dump(user_data) } - @yaml_user_data_deserializer = get_deserialize_wrapper { |obj, serialized, _| - obj.user_data = YAML.load(serialized.read_string) + @marshal_user_data_deserializer = get_deserialize_data_callback_wrapper { |serialized| + Marshal.load(serialized) } class << self attr_accessor :default_user_data_serializer, :default_user_data_deserializer end - self.default_user_data_serializer = @yaml_user_data_serializer - self.default_user_data_deserializer = @yaml_user_data_deserializer + self.default_user_data_serializer = @marshal_user_data_serializer + self.default_user_data_deserializer = @marshal_user_data_deserializer - def self.set_serialize_callback(handle, user_data: nil, &block) + def self.set_serialize_callback(handle, &block) if block cb_wrapper = get_serialize_wrapper(&block) - cb_data = [cb_wrapper, user_data] + cb_data = [cb_wrapper] else cb_wrapper = nil - user_data = nil cb_data = nil end - CCS.error_check CCS.ccs_object_set_serialize_callback(handle, cb_wrapper, user_data) + CCS.error_check CCS.ccs_object_set_serialize_callback(handle, cb_wrapper, nil) register_serialize_callback(handle, cb_data) end - def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, callback: nil, callback_data: nil) - return CCS::Object.deserialize(format: format, handle_map: handle_map, path: path, buffer: buffer, callback: callback, callback_data: callback_data) + def self.deserialize(format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, vector_callback: nil, callback: nil) + return CCS::Object.deserialize(format: format, handle_map: handle_map, path: path, buffer: buffer, file_descriptor: file_descriptor, vector_callback: vector_callback, callback: callback) end end diff --git a/bindings/ruby/lib/cconfigspace/configuration_space.rb b/bindings/ruby/lib/cconfigspace/configuration_space.rb index 58812ab4..abc69dd4 100644 --- a/bindings/ruby/lib/cconfigspace/configuration_space.rb +++ b/bindings/ruby/lib/cconfigspace/configuration_space.rb @@ -18,7 +18,7 @@ class ConfigurationSpace < Context add_handle_array_property :forbidden_clauses, :ccs_expression_t, :ccs_configuration_space_get_forbidden_clauses, memoize: true def initialize(handle = nil, retain: false, auto_release: true, - name: "", parameters: nil, conditions: nil, forbidden_clauses: nil, feature_space: nil, rng: nil) + name: "", parameters: nil, conditions: nil, forbidden_clauses: nil, feature_space: nil, rng: nil, binding: nil) if handle super(handle, retain: retain, auto_release: auto_release) else @@ -32,8 +32,7 @@ def initialize(handle = nil, retain: false, auto_release: true, ctx = ctx_params.map { |p| [p.name, p] }.to_h if forbidden_clauses - p = ExpressionParser::new(ctx) - forbidden_clauses = forbidden_clauses.collect { |e| e.kind_of?(String) ? p.parse(e) : e } + forbidden_clauses = forbidden_clauses.collect { |e| e.kind_of?(String) ? CCS.parse(e, context: ctx, binding: binding) : e } fccount = forbidden_clauses.size fcptr = MemoryPointer::new(:ccs_expression_t, fccount) fcptr.write_array_of_pointer(forbidden_clauses.collect(&:handle)) @@ -44,8 +43,7 @@ def initialize(handle = nil, retain: false, auto_release: true, if conditions indexdict = parameters.each_with_index.to_h - p = ExpressionParser::new(ctx) - conditions = conditions.transform_values { |v| v.kind_of?(String) ? p.parse(v) : v } + conditions = conditions.transform_values { |v| v.kind_of?(String) ? CCS.parse(v, context: ctx, binding: binding) : v } cond_handles = [0]*count conditions.each do |k, v| index = case k diff --git a/bindings/ruby/lib/cconfigspace/distribution.rb b/bindings/ruby/lib/cconfigspace/distribution.rb index 6a8f0afa..088bac1a 100644 --- a/bindings/ruby/lib/cconfigspace/distribution.rb +++ b/bindings/ruby/lib/cconfigspace/distribution.rb @@ -7,7 +7,7 @@ module CCS :CCS_DISTRIBUTION_TYPE_MIXTURE, :CCS_DISTRIBUTION_TYPE_MULTIVARIATE ] - class MemoryPointer + module MemoryAccessor def read_ccs_distribution_type_t DistributionType.from_native(read_int32, nil) end @@ -17,7 +17,7 @@ def read_ccs_distribution_type_t :CCS_SCALE_TYPE_LINEAR, :CCS_SCALE_TYPE_LOGARITHMIC ] - class MemoryPointer + module MemoryAccessor def read_ccs_scale_type_t ScaleType.from_native(read_int32, nil) end diff --git a/bindings/ruby/lib/cconfigspace/evaluation.rb b/bindings/ruby/lib/cconfigspace/evaluation.rb index 323a917a..f5da8428 100644 --- a/bindings/ruby/lib/cconfigspace/evaluation.rb +++ b/bindings/ruby/lib/cconfigspace/evaluation.rb @@ -5,7 +5,7 @@ module CCS :CCS_COMPARISON_WORSE, 1, :CCS_COMPARISON_NOT_COMPARABLE, 2 ] - class MemoryPointer + module MemoryAccessor def read_ccs_comparison_t Comparison.from_native(read_int32, nil) end diff --git a/bindings/ruby/lib/cconfigspace/expression.rb b/bindings/ruby/lib/cconfigspace/expression.rb index af7c980a..6081aca4 100644 --- a/bindings/ruby/lib/cconfigspace/expression.rb +++ b/bindings/ruby/lib/cconfigspace/expression.rb @@ -19,9 +19,10 @@ module CCS :CCS_EXPRESSION_TYPE_IN, :CCS_EXPRESSION_TYPE_LIST, :CCS_EXPRESSION_TYPE_LITERAL, - :CCS_EXPRESSION_TYPE_VARIABLE + :CCS_EXPRESSION_TYPE_VARIABLE, + :CCS_EXPRESSION_TYPE_USER_DEFINED, ] - class MemoryPointer + module MemoryAccessor def read_ccs_expression_type_t ExpressionType.from_native(read_int32, nil) end @@ -32,7 +33,7 @@ def read_ccs_expression_type_t :CCS_ASSOCIATIVITY_TYPE_LEFT_TO_RIGHT, :CCS_ASSOCIATIVITY_TYPE_RIGHT_TO_LEFT ] - class MemoryPointer + module MemoryAccessor def read_ccs_associativity_type_t AssociativityType.from_native(read_int32, nil) end @@ -115,7 +116,8 @@ def self.expression_map CCS_EXPRESSION_TYPE_IN: In, CCS_EXPRESSION_TYPE_LIST: List, CCS_EXPRESSION_TYPE_LITERAL: Literal, - CCS_EXPRESSION_TYPE_VARIABLE: Variable + CCS_EXPRESSION_TYPE_VARIABLE: Variable, + CCS_EXPRESSION_TYPE_USER_DEFINED: UserDefined, } end @@ -182,6 +184,13 @@ def to_s "#{nds[0]} #{symbol} #{nds[1]}" end end + + def self.get_function_vector_data(name, binding: TOPLEVEL_BINDING) + m = binding.receiver.method(name.to_sym) + evaluate = lambda { |expr, *args| m.call(*args) } + [CCS::Expression::UserDefined.get_vector(eval: evaluate), nil] + end + end class ExpressionOr < Expression @@ -554,4 +563,130 @@ def to_s Expression::List = ExpressionList + callback :ccs_user_defined_expression_del, [:ccs_expression_t], :ccs_result_t + callback :ccs_user_defined_expression_eval, [:ccs_expression_t, :size_t, :pointer, :pointer], :ccs_result_t + callback :ccs_user_defined_expression_serialize, [:ccs_expression_t, :size_t, :pointer, :pointer], :ccs_result_t + callback :ccs_user_defined_expression_deserialize, [:size_t, :pointer, :pointer], :ccs_result_t + + class UserDefinedExpressionVector < FFI::Struct + attr_accessor :wrappers + attr_accessor :string_store + attr_accessor :object_store + layout :del, :ccs_user_defined_expression_del, + :eval, :ccs_user_defined_expression_eval, + :serialize, :ccs_user_defined_expression_serialize, + :deserialize, :ccs_user_defined_expression_deserialize + end + typedef UserDefinedExpressionVector.by_value, :ccs_user_defined_expression_vector_t + + attach_function :ccs_create_user_defined_expression, [:string, :size_t, :pointer, UserDefinedExpressionVector.by_ref, :value, :pointer], :ccs_result_t + attach_function :ccs_user_defined_expression_get_expression_data, [:ccs_expression_t, :pointer], :ccs_result_t + attach_function :ccs_user_defined_expression_get_name, [:ccs_expression_t, :pointer], :ccs_result_t + class ExpressionUserDefined < Expression + add_property :expression_data, :value, :ccs_user_defined_expression_get_expression_data, memoize: true + + def initialize(handle = nil, retain: false, auto_release: true, + name: nil, nodes: [], del: nil, eval: nil, serialize: nil, deserialize: nil, expression_data: nil) + if handle + super(handle, retain: retain, auto_release: auto_release) + else + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if name.nil? || eval.nil? + vector = ExpressionUserDefined.get_vector(del: del, eval: eval, serialize: serialize, deserialize: deserialize) + count = nodes.size + p_values = MemoryPointer::new(:ccs_datum_t, count) + ss = [] + os = [] + ptr = MemoryPointer::new(:ccs_expression_t) + nodes.each_with_index { |n, i| Datum::new(p_values[i]).set_value(n, string_store: ss, object_store: os) } + CCS.error_check CCS.ccs_create_user_defined_expression(name, count, p_values, vector, expression_data, ptr) + handle = ptr.read_ccs_expression_t + super(handle, retain: false) + FFI.inc_ref(vector) + FFI.inc_ref(expression_data) unless expression_data.nil? + end + end + + def name + @name ||= begin + ptr = MemoryPointer::new(:pointer) + CCS.error_check CCS.ccs_user_defined_expression_get_name(@handle, ptr) + ptr.read_pointer.read_string + end + end + + def self.get_vector(del: nil, eval: nil, serialize: nil, deserialize: nil) + vector = UserDefinedExpressionVector::new + vector.string_store = [] + vector.object_store = [] + delwrapper = lambda { |expr| + begin + o = CCS::Object.from_handle(expr) + edata = o.expression_data + del.call(o) if del + FFI.dec_ref(edata) unless edata.nil? + FFI.dec_ref(vector) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + evalwrapper = lambda { |expr, num_values, p_values, p_value_ret| + begin + values = Pointer.new(p_values).read_array_of_ccs_datum_t(num_values) + value_ret = eval.call(Expression.from_handle(expr), *values) + Datum::new(p_value_ret).set_value(value_ret, string_store: vector.string_store, object_store: vector.object_store) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + serializewrapper = + if serialize + lambda { |tun, state_size, p_state, p_state_size| + begin + state = serialize.call(Expression.from_handle(tun)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !state.kind_of?(String) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.bytesize + p_state.write_bytes(state, 0, state.bytesize) unless p_state.null? + Pointer.new(p_state_size).write_size_t(state.bytesize) unless p_state_size.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + deserializewrapper = + if deserialize + lambda { |state_size, p_state, p_expression_data| + begin + state = p_state.null? ? nil : p_state.read_bytes(state_size) + expression_data = deserialize.call(state) + p_expression_data.write_value(expression_data) + FFI.inc_ref(expression_data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + vector[:del] = delwrapper + vector[:eval] = evalwrapper + vector[:serialize] = serializewrapper + vector[:deserialize] = deserializewrapper + vector.wrappers = [delwrapper, evalwrapper, serializewrapper, deserializewrapper] + vector + end + + def to_s + "#{name}(#{nodes.collect(&:to_s).join(", ")})" + end + + end + + Expression::UserDefined = ExpressionUserDefined + end diff --git a/bindings/ruby/lib/cconfigspace/expression_parser.rb b/bindings/ruby/lib/cconfigspace/expression_parser.rb index bb66d965..f0030f29 100644 --- a/bindings/ruby/lib/cconfigspace/expression_parser.rb +++ b/bindings/ruby/lib/cconfigspace/expression_parser.rb @@ -19,13 +19,22 @@ module CCS class ExpressionParser < Whittle::Parser class << self attr_accessor :context + attr_accessor :bind + attr_accessor :mutex + attr_reader :parser end - def initialize(context = nil) - @context = context - end - def parse(*args) - self.class.context = @context - super + + @mutex = Mutex.new + + def parse(input, context: nil, binding: TOPLEVEL_BINDING, **options) + self.class.mutex.synchronize { + self.class.context = context + self.class.bind = binding + expr = super(input, options) + self.class.context = nil + self.class.bind = nil + expr + } end ExpressionSymbols.reverse_each { |k, v| @@ -53,17 +62,21 @@ def parse(*args) Expression::Literal::new(value: Float(num)) } rule(:integer => Regexp.new(TerminalRegexp[:CCS_TERMINAL_TYPE_INTEGER])).as { |num| Expression::Literal::new(value: Integer(num)) } - rule(:identifier => /[:a-zA-Z_][a-zA-Z_0-9]*/).as { |identifier| - Expression::Variable::new(parameter: context.kind_of?(Context) ? context.parameter_by_name(identifier) : context[identifier]) } + rule(:identifier => /[:a-zA-Z_][a-zA-Z_0-9]*/) + rule(:variable) do |r| + r[:identifier].as { |identifier| Expression::Variable::new(parameter: context.kind_of?(Context) ? context.parameter_by_name(identifier) : context[identifier]) } + end rule(:string => Regexp.new(TerminalRegexp[:CCS_TERMINAL_TYPE_STRING])).as { |str| Expression::Literal::new(value: eval(str)) } + rule(:value) do |r| r[:none] r[:true] r[:false] r[:string] - r[:identifier] + r[:user_defined] + r[:variable] r[:float] r[:integer] end @@ -78,6 +91,19 @@ def parse(*args) r["[", "]"].as { |_, _| Expression::List::new(values: []) } end + rule(:user_defined) do |r| + r[:identifier, "(", :list_item, ")"].as do |e, _, l, _| + m = bind.receiver.method(e.to_sym) + evaluate = lambda { |expr, *args| m.call(*args) } + Expression::UserDefined.new(name: e, nodes: l, eval: evaluate) + end + r[:identifier, "(", ")"].as do |e, _, _| + m = bind.receiver.method(e.to_sym) + evaluate = lambda { |expr, *args| m.call } + Expression::UserDefined.new(name: e, eval: evaluate) + end + end + rule(:expr) do |r| r["(", :expr, ")"].as { |_, exp, _| exp } ExpressionSymbols.reverse_each { |k, v| @@ -95,6 +121,13 @@ def parse(*args) end start(:expr) + + end + + @parser = ExpressionParser.new + + def self.parse(input, context: nil, binding: TOPLEVEL_BINDING, **params) + @parser.parse(input, context: context, binding: binding, **params) end end diff --git a/bindings/ruby/lib/cconfigspace/objective_space.rb b/bindings/ruby/lib/cconfigspace/objective_space.rb index bacba6da..a7acf265 100644 --- a/bindings/ruby/lib/cconfigspace/objective_space.rb +++ b/bindings/ruby/lib/cconfigspace/objective_space.rb @@ -3,7 +3,7 @@ module CCS :CCS_OBJECTIVE_TYPE_MINIMIZE, :CCS_OBJECTIVE_TYPE_MAXIMIZE ] - class MemoryPointer + module MemoryAccessor def read_ccs_objective_type_t ObjectiveType.from_native(read_int32, nil) end @@ -28,7 +28,7 @@ class ObjectiveSpace < Context add_handle_property :search_space, :ccs_search_space_t, :ccs_objective_space_get_search_space, memoize: true def initialize(handle = nil, retain: false, auto_release: true, - name: "", search_space: nil, parameters: [], objectives: [], types: nil) + name: "", search_space: nil, parameters: [], objectives: [], types: nil, binding: nil) if handle super(handle, retain: retain, auto_release: auto_release) else @@ -41,14 +41,7 @@ def initialize(handle = nil, retain: false, auto_release: true, types = objectives.values objectives = objectives.keys end - p = ExpressionParser::new(ctx) - objectives = objectives.collect { |e| - if e.kind_of? String - e = p.parse(e) - else - e - end - } + objectives = objectives.collect { |e| e.kind_of?(String) ? e = CCS.parse(e, context: ctx, binding: binding) : e } ocount = objectives.length if types raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if types.size != ocount diff --git a/bindings/ruby/lib/cconfigspace/parameter.rb b/bindings/ruby/lib/cconfigspace/parameter.rb index fb12ee12..9522b83f 100644 --- a/bindings/ruby/lib/cconfigspace/parameter.rb +++ b/bindings/ruby/lib/cconfigspace/parameter.rb @@ -14,7 +14,7 @@ def self.get_id :CCS_PARAMETER_TYPE_DISCRETE, :CCS_PARAMETER_TYPE_STRING ] - class MemoryPointer + module MemoryAccessor def read_ccs_parameter_type_t ParameterType.from_native(read_int32, nil) end diff --git a/bindings/ruby/lib/cconfigspace/tree_space.rb b/bindings/ruby/lib/cconfigspace/tree_space.rb index 75f7c919..786618ca 100644 --- a/bindings/ruby/lib/cconfigspace/tree_space.rb +++ b/bindings/ruby/lib/cconfigspace/tree_space.rb @@ -4,7 +4,7 @@ module CCS :CCS_TREE_SPACE_TYPE_STATIC, :CCS_TREE_SPACE_TYPE_DYNAMIC ] - class MemoryPointer + module MemoryAccessor def read_ccs_tree_space_type_t TreeSpaceType.from_native(read_int32, nil) end @@ -120,9 +120,10 @@ def initialize(handle = nil, retain: false, auto_release: true, callback :ccs_dynamic_tree_space_del, [:ccs_tree_space_t], :ccs_result_t callback :ccs_dynamic_tree_space_get_child, [:ccs_tree_space_t, :ccs_tree_t, :size_t, :pointer], :ccs_result_t callback :ccs_dynamic_tree_space_serialize, [:ccs_tree_space_t, :size_t, :pointer, :pointer], :ccs_result_t - callback :ccs_dynamic_tree_space_deserialize, [:ccs_tree_space_t, :size_t, :pointer], :ccs_result_t + callback :ccs_dynamic_tree_space_deserialize, [:ccs_tree_t, :ccs_feature_space_t, :size_t, :pointer, :pointer], :ccs_result_t class DynamicTreeSpaceVector < FFI::Struct + attr_accessor :wrappers layout :del, :ccs_dynamic_tree_space_del, :get_child, :ccs_dynamic_tree_space_get_child, :serialize, :ccs_dynamic_tree_space_serialize, @@ -130,59 +131,6 @@ class DynamicTreeSpaceVector < FFI::Struct end typedef DynamicTreeSpaceVector.by_value, :ccs_dynamic_tree_space_vector_t - def self.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) - delwrapper = lambda { |ts| - begin - del.call(CCS::Object.from_handle(ts)) if del - CCS.unregister_vector(ts) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_childwrapper = lambda { |ts, parent, index, p_child| - begin - child = get_child.call(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) - CCS.error_check CCS.ccs_retain_object(child.handle) - Pointer.new(p_child).write_pointer(child.handle) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - serializewrapper = - if serialize - lambda { |ts, state_size, p_state, p_state_size| - begin - state = serialize(TreeSpace.from_handle(ts), state_size == 0 ? true : false) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size - p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? - Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - deserializewrapper = - if deserialize - lambda { |ts, state_size, p_state| - begin - state = p_state.null? ? nil : p_state.slice(0, state_size) - deserialize(TreeSpace.from_handle(ts), state) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - return [delwrapper, get_childwrapper, serializewrapper, deserializewrapper] - end - attach_function :ccs_create_dynamic_tree_space, [:string, :ccs_tree_t, :ccs_feature_space_t, :ccs_rng_t, DynamicTreeSpaceVector.by_ref, :value, :pointer], :ccs_result_t attach_function :ccs_dynamic_tree_space_get_tree_space_data, [:ccs_tree_space_t, :pointer], :ccs_result_t @@ -196,33 +144,79 @@ def initialize(handle = nil, retain: false, auto_release: true, super(handle, retain: retain, auto_release: auto_release) else raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = - CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) - vector = DynamicTreeSpaceVector::new - vector[:del] = delwrapper - vector[:get_child] = get_childwrapper - vector[:serialize] = serializewrapper - vector[:deserialize] = deserializewrapper + vector = DynamicTreeSpace.get_vector(del: del, get_child: get_child, serialize: serialize, deserialize: deserialize) ptr = MemoryPointer::new(:ccs_tree_space_t) CCS.error_check CCS.ccs_create_dynamic_tree_space(name, tree, feature_space, rng, vector, tree_space_data, ptr) h = ptr.read_ccs_tree_space_t super(h, retain: false) - CCS.register_vector(h, [delwrapper, get_childwrapper, serializewrapper, deserializewrapper, tree_space_data]) + FFI.inc_ref(vector) + FFI.inc_ref(tree_space_data) unless tree_space_data.nil? end end - def self.deserialize(del: nil, get_child: nil, serialize: nil, deserialize: nil, tree_space_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if get_child.nil? - delwrapper, get_childwrapper, serializewrapper, deserializewrapper = - CCS.wrap_dynamic_tree_space_callbacks(del, get_child, serialize, deserialize) + def self.get_vector(del: nil, get_child: nil, serialize: nil, deserialize: nil) vector = DynamicTreeSpaceVector::new + delwrapper = lambda { |ts| + begin + o = CCS::Object.from_handle(ts) + tsdata = o.tree_space_data + del.call(o) if del + FFI.dec_ref(tsdata) unless tsdata.nil? + FFI.dec_ref(vector) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_childwrapper = lambda { |ts, parent, index, p_child| + begin + child = get_child.call(TreeSpace.from_handle(ts), Tree.from_handle(parent), index) + CCS.error_check CCS.ccs_retain_object(child.handle) + Pointer.new(p_child).write_pointer(child.handle) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + serializewrapper = + if serialize + lambda { |ts, state_size, p_state, p_state_size| + begin + state = serialize.call(TreeSpace.from_handle(ts)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !state.kind_of?(String) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.bytesize + p_state.write_bytes(state, 0, state.bytesize) unless p_state.null? + Pointer.new(p_state_size).write_size_t(state.bytesize) unless p_state_size.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + deserializewrapper = + if deserialize + lambda { |t, feature_space, state_size, p_state, p_tree_space_data| + begin + state = p_state.null? ? nil : p_state.read_bytes(state_size) + tree_space_data = deserialize.call(Tree.from_handle(t), feature_space.null? ? nil : FeatureSpace.from_handle(feature_space), state) + p_tree_space_data.write_value(tree_space_data) + FFI.inc_ref(tree_space_data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end vector[:del] = delwrapper vector[:get_child] = get_childwrapper vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper - res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tree_space_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, [delwrapper, get_childwrapper, serializewrapper, deserializewrapper, tree_space_data]) - res + vector.wrappers = [delwrapper, get_childwrapper, serializewrapper, deserializewrapper] + vector end end diff --git a/bindings/ruby/lib/cconfigspace/tuner.rb b/bindings/ruby/lib/cconfigspace/tuner.rb index c9dd8983..2bfc976c 100644 --- a/bindings/ruby/lib/cconfigspace/tuner.rb +++ b/bindings/ruby/lib/cconfigspace/tuner.rb @@ -4,7 +4,7 @@ module CCS :CCS_TUNER_TYPE_RANDOM, :CCS_TUNER_TYPE_USER_DEFINED ] - class MemoryPointer + module MemoryAccessor def read_ccs_tuner_type_t TunerType.from_native(read_int32, nil) end @@ -32,7 +32,7 @@ def self.from_handle(handle, retain: true, auto_release: true) CCS.error_check CCS.ccs_tuner_get_type(handle, ptr) case ptr.read_ccs_tuner_type_t when :CCS_TUNER_TYPE_RANDOM - RandomTuner + RandomTuner when :CCS_TUNER_TYPE_USER_DEFINED UserDefinedTuner else @@ -122,9 +122,10 @@ def initialize(handle = nil, retain: false, auto_release: true, callback :ccs_user_defined_tuner_get_history, [:ccs_tuner_t, :ccs_features_t, :size_t, :pointer, :pointer], :ccs_result_t callback :ccs_user_defined_tuner_suggest, [:ccs_tuner_t, :ccs_features_t, :pointer], :ccs_result_t callback :ccs_user_defined_tuner_serialize, [:ccs_tuner_t, :size_t, :pointer, :pointer], :ccs_result_t - callback :ccs_user_defined_tuner_deserialize, [:ccs_tuner_t, :size_t, :pointer, :size_t, :pointer, :size_t, :pointer], :ccs_result_t + callback :ccs_user_defined_tuner_deserialize, [:ccs_objective_space_t, :size_t, :pointer, :size_t, :pointer, :size_t, :pointer, :pointer], :ccs_result_t class UserDefinedTunerVector < FFI::Struct + attr_accessor :wrappers layout :del, :ccs_user_defined_tuner_del, :ask, :ccs_user_defined_tuner_ask, :tell, :ccs_user_defined_tuner_tell, @@ -136,128 +137,6 @@ class UserDefinedTunerVector < FFI::Struct end typedef UserDefinedTunerVector.by_value, :ccs_user_defined_tuner_vector_t - def self.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - delwrapper = lambda { |tun| - begin - del.call(CCS::Object.from_handle(tun)) if del - CCS.unregister_vector(tun) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - askwrapper = lambda { |tun, features, count, p_configurations, p_count| - begin - configurations, count_ret = ask.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features), p_configurations.null? ? nil : count) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_configurations.null? && count < count_ret - if !p_configurations.null? - configurations.each_with_index { |c, i| - err = CCS.ccs_retain_object(c.handle) - CCS.error_check(err) - p_configurations.put_pointer(i*8, c.handle) - } - (count_ret...count).each { |i| p_configurations[i].put_pointer(i*8, 0) } - end - Pointer.new(p_count).write_size_t(count_ret) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - tellwrapper = lambda { |tun, count, p_evaluations| - begin - if count > 0 - evals = count.times.collect { |i| Evaluation::from_handle(p_evaluations.get_pointer(i*8)) } - tell.call(Tuner.from_handle(tun), evals) - end - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_optimawrapper = lambda { |tun, features, count, p_evaluations, p_count| - begin - optima = get_optima.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < optima.size - unless p_evaluations.null? - optima.each_with_index { |o, i| - p_evaluations.put_pointer(8*i, o.handle) - } - ((optima.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } - end - Pointer.new(p_count).write_size_t(optima.size) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - get_historywrapper = lambda { |tun, features, count, p_evaluations, p_count| - begin - history = get_history.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < history.size - unless p_evaluations.null? - history.each_with_index { |e, i| - p_evaluations.put_pointer(8*i, e.handle) - } - ((history.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } - end - Pointer.new(p_count).write_size_t(history.size) unless p_count.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - suggestwrapper = - if suggest - lambda { |tun, features, p_configuration| - begin - configuration = suggest.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) - err = CCS.ccs_retain_object(configuration.handle) - CCS.error_check(err) - p_configuration.write_pointer(configuration.handle) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - serializewrapper = - if serialize - lambda { |tun, state_size, p_state, p_state_size| - begin - state = serialize(Tuner.from_handle(tun), state_size == 0 ? true : false) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.size - p_state.write_bytes(state.read_bytes(state.size)) unless p_state.null? - Pointer.new(p_state_size).write_size_t(state.size) unless p_state_size.null? - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - deserializewrapper = - if deserialize - lambda { |tun, history_size, p_history, num_optima, p_optima, state_size, p_state| - begin - history = p_history.null? ? [] : history_size.times.collect { |i| Evaluation::from_handle(p_p_history.get_pointer(i*8)) } - optima = p_optima.null? ? [] : num_optima.times.collect { |i| Evaluation::from_handle(p_optima.get_pointer(i*8)) } - state = p_state.null? ? nil : p_state.slice(0, state_size) - deserialize(Tuner.from_handle(tun), history, optima, state) - CCSError.to_native(:CCS_RESULT_SUCCESS) - rescue => e - CCS.set_error(e) - end - } - else - nil - end - return [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper] - end - attach_function :ccs_create_user_defined_tuner, [:string, :ccs_objective_space_t, UserDefinedTunerVector.by_ref, :value, :pointer], :ccs_result_t attach_function :ccs_user_defined_tuner_get_tuner_data, [:ccs_tuner_t, :pointer], :ccs_result_t class UserDefinedTuner < Tuner @@ -269,31 +148,143 @@ def initialize(handle = nil, retain: false, auto_release: true, if handle super(handle, retain: retain, auto_release: auto_release) else - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = - CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) - vector = UserDefinedTunerVector::new - vector[:del] = delwrapper - vector[:ask] = askwrapper - vector[:tell] = tellwrapper - vector[:get_optima] = get_optimawrapper - vector[:get_history] = get_historywrapper - vector[:suggest] = suggestwrapper - vector[:serialize] = serializewrapper - vector[:deserialize] = deserializewrapper + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? + vector = UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, serialize: serialize, deserialize: deserialize) ptr = MemoryPointer::new(:ccs_tuner_t) CCS.error_check CCS.ccs_create_user_defined_tuner(name, objective_space, vector, tuner_data, ptr) handle = ptr.read_ccs_tuner_t super(handle, retain: false) - CCS.register_vector(handle, [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper, tuner_data]) + FFI.inc_ref(vector) + FFI.inc_ref(tuner_data) unless tuner_data.nil? end end - def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history: nil, suggest: nil, serialize: nil, deserialize: nil, tuner_data: nil, format: :binary, handle_map: nil, path: nil, buffer: nil, file_descriptor: nil, callback: nil, callback_data: nil) - raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if ask.nil? || tell.nil? || get_optima.nil? || get_history.nil? - delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper = - CCS.wrap_user_defined_tuner_callbacks(del, ask, tell, get_optima, get_history, suggest, serialize, deserialize) + def self.get_vector(del: nil, ask: nil, tell: nil, get_optima: nil, get_history: nil, suggest: nil, serialize: nil, deserialize: nil) vector = UserDefinedTunerVector::new + delwrapper = lambda { |tun| + begin + o = CCS::Object.from_handle(tun) + tdata = o.tuner_data + del.call(o) if del + FFI.dec_ref(tdata) unless tdata.nil? + FFI.dec_ref(vector) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + askwrapper = lambda { |tun, features, count, p_configurations, p_count| + begin + configurations, count_ret = ask.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features), p_configurations.null? ? nil : count) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_configurations.null? && count < count_ret + if !p_configurations.null? + configurations.each_with_index { |c, i| + err = CCS.ccs_retain_object(c.handle) + CCS.error_check(err) + p_configurations.put_pointer(i*8, c.handle) + } + (count_ret...count).each { |i| p_configurations[i].put_pointer(i*8, 0) } + end + Pointer.new(p_count).write_size_t(count_ret) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + tellwrapper = lambda { |tun, count, p_evaluations| + begin + if count > 0 + evals = count.times.collect { |i| Evaluation::from_handle(p_evaluations.get_pointer(i*8)) } + tell.call(Tuner.from_handle(tun), evals) + end + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_optimawrapper = lambda { |tun, features, count, p_evaluations, p_count| + begin + optima = get_optima.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < optima.size + unless p_evaluations.null? + optima.each_with_index { |o, i| + p_evaluations.put_pointer(8*i, o.handle) + } + ((optima.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } + end + Pointer.new(p_count).write_size_t(optima.size) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + get_historywrapper = lambda { |tun, features, count, p_evaluations, p_count| + begin + history = get_history.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_evaluations.null? && count < history.size + unless p_evaluations.null? + history.each_with_index { |e, i| + p_evaluations.put_pointer(8*i, e.handle) + } + ((history.size)...count).each { |i| p_evaluations.put_pointer(8*i, 0) } + end + Pointer.new(p_count).write_size_t(history.size) unless p_count.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + suggestwrapper = + if suggest + lambda { |tun, features, p_configuration| + begin + configuration = suggest.call(Tuner.from_handle(tun), features.null? ? nil : Features.from_handle(features)) + err = CCS.ccs_retain_object(configuration.handle) + CCS.error_check(err) + p_configuration.write_pointer(configuration.handle) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + serializewrapper = + if serialize + lambda { |tun, state_size, p_state, p_state_size| + begin + state = serialize.call(Tuner.from_handle(tun)) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !state.kind_of?(String) + raise CCSError, :CCS_RESULT_ERROR_INVALID_VALUE if !p_state.null? && state_size < state.bytesize + p_state.write_bytes(state, 0, state.bytesize) unless p_state.null? + Pointer.new(p_state_size).write_size_t(state.bytesize) unless p_state_size.null? + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end + deserializewrapper = + if deserialize + lambda { |o_space, history_size, p_history, num_optima, p_optima, state_size, p_state, p_tuner_data| + begin + history = p_history.null? ? [] : history_size.times.collect { |i| Evaluation::from_handle(p_p_history.get_pointer(i*8)) } + optima = p_optima.null? ? [] : num_optima.times.collect { |i| Evaluation::from_handle(p_optima.get_pointer(i*8)) } + state = p_state.null? ? nil : p_state.read_bytes(state_size) + tuner_data = deserialize.call(ObjectiveSpace.from_handle(o_space), history, optima, state) + p_tuner_data.write_value(tuner_data) + FFI.inc_ref(tuner_data) + CCSError.to_native(:CCS_RESULT_SUCCESS) + rescue => e + CCS.set_error(e) + end + } + else + nil + end vector[:del] = delwrapper vector[:ask] = askwrapper vector[:tell] = tellwrapper @@ -302,10 +293,10 @@ def self.deserialize(del: nil, ask: nil, tell: nil, get_optima: nil, get_history vector[:suggest] = suggestwrapper vector[:serialize] = serializewrapper vector[:deserialize] = deserializewrapper - res = super(format: format, handle_map: handle_map, vector: vector.to_ptr, data: tuner_data, path: path, buffer: buffer, file_descriptor: file_descriptor, callback: callback, callback_data: callback_data) - CCS.register_vector(res.handle, [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper, tuner_data]) - res + vector.wrappers = [delwrapper, askwrapper, tellwrapper, get_optimawrapper, get_historywrapper, suggestwrapper, serializewrapper, deserializewrapper] + vector end + end Tuner::UserDefined = UserDefinedTuner diff --git a/bindings/ruby/test/test_expression.rb b/bindings/ruby/test/test_expression.rb index a4dd1c82..f1b4eb1e 100644 --- a/bindings/ruby/test/test_expression.rb +++ b/bindings/ruby/test/test_expression.rb @@ -64,4 +64,41 @@ def test_binary assert_equal( "true || false", e.to_s) assert_equal( true, e.eval ) end + + def test_user_defined + my_rand = lambda { |expr, limit| + expr.expression_data.rand(limit) + } + + my_serialize = lambda { |expr| + Marshal.dump(expr.expression_data) + } + + my_deserialize = lambda { |state| + Marshal.load(state) + } + + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_EXPRESSION, otype) + assert_equal("rand", name) + [CCS::Expression::UserDefined.get_vector(eval: my_rand, serialize: my_serialize, deserialize: my_deserialize), nil] + } + + limit = 10 + e = CCS::Expression::UserDefined.new(name: 'rand', nodes: [limit], expression_data: Random.new, eval: my_rand, serialize: my_serialize, deserialize: my_deserialize) + assert_equal( "rand(10)", e.to_s ) + + 100.times { + i = e.eval + assert(i >= 0 && i < limit) + } + + buff = e.serialize + + e_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) + + 100.times { + assert( e.eval == e_copy.eval ) + } + end end diff --git a/bindings/ruby/test/test_expression_parser.rb b/bindings/ruby/test/test_expression_parser.rb index 8e7f045f..6620a832 100644 --- a/bindings/ruby/test/test_expression_parser.rb +++ b/bindings/ruby/test/test_expression_parser.rb @@ -7,67 +7,82 @@ def setup end def test_parse - m = CCS::ExpressionParser.new.method(:parse) exp = "1.0 + 1 == 2 || +1 == 3e0 && \"y\\nes\" == 'no' " - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression ) assert_equal( "1.0 + 1 == 2 || +1 == 3.0 && \"y\\nes\" == \"no\"", res.to_s ) end def test_parse_priority - m = CCS::ExpressionParser.new.method(:parse) exp = "(1 + 3) * 2" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Multiply ) assert_equal( exp, res.to_s ) end def test_associativity - m = CCS::ExpressionParser.new.method(:parse) exp = "5 - 2 - 1" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Substract ) assert_equal( 2, res.eval ) exp = "5 - +(+2 - 1)" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Substract ) assert_equal( 4, res.eval ) end def test_in - m = CCS::ExpressionParser.new.method(:parse) exp = "5 # [3.0, 5]" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::In ) assert_equal( true, res.eval ) exp = "5 # [3.0, 4]" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::In ) assert_equal( false, res.eval ) end def test_boolean - m = CCS::ExpressionParser.new.method(:parse) exp = "true" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Literal ) assert_equal( true, res.eval ) assert_equal( "true", res.to_s ) exp = "false" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Literal ) assert_equal( false, res.eval ) assert_equal( "false", res.to_s ) end def test_none - m = CCS::ExpressionParser.new.method(:parse) exp = "none" - res = m[exp] + res = CCS.parse(exp) assert( res.kind_of? CCS::Expression::Literal ) assert_nil( res.eval ) assert_equal( "none", res.to_s ) end + def test_function + def func(a, b) + a * b + end + exp = "func(3, 4)" + res = CCS.parse(exp, binding: binding) + assert( res.kind_of? CCS::Expression::UserDefined ) + assert_equal( "func(3, 4)", res.to_s ) + assert_equal( 12, res.eval ) + + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_EXPRESSION, otype) + CCS::Expression::get_function_vector_data(name, binding: binding) + } + + buff = res.serialize + res_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) + assert_equal( "func(3, 4)", res_copy.to_s ) + assert_equal( 12, res_copy.eval ) + end + end diff --git a/bindings/ruby/test/test_features_tuner.rb b/bindings/ruby/test/test_features_tuner.rb index 83fb45c5..fd2d20ba 100644 --- a/bindings/ruby/test/test_features_tuner.rb +++ b/bindings/ruby/test/test_features_tuner.rb @@ -125,6 +125,12 @@ def test_user_defined optis.sample.configuration end } + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TunerData.new] + } + fs, os = create_tuning_problem t = CCS::UserDefinedTuner::new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) t2 = CCS::Object::from_handle(t) @@ -160,7 +166,7 @@ def test_user_defined } buff = t.serialize - t_copy = CCS::UserDefinedTuner::deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) diff --git a/bindings/ruby/test/test_tree_space.rb b/bindings/ruby/test/test_tree_space.rb index bdd0c714..ae8eaa03 100644 --- a/bindings/ruby/test/test_tree_space.rb +++ b/bindings/ruby/test/test_tree_space.rb @@ -54,6 +54,11 @@ def test_dynamic_tree_space arity = 0 if arity < 0 CCS::Tree.new(arity: arity, value: (4 - child_depth)*100 + child_index) } + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_TREE_SPACE, otype) + assert_equal('space', name) + [CCS::DynamicTreeSpace.get_vector(del: del, get_child: get_child), nil] + } tree = CCS::Tree.new(arity: 4, value: 400) ts = CCS::DynamicTreeSpace.new(name: 'space', tree: tree, del: del, get_child: get_child) @@ -74,7 +79,7 @@ def test_dynamic_tree_space } buff = ts.serialize - ts2 = CCS::DynamicTreeSpace.deserialize(buffer: buff, del: del, get_child: get_child) + ts2 = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) assert_equal( [400, 301, 201], ts2.get_values_at_position([1, 1]) ) end diff --git a/bindings/ruby/test/test_tree_tuner.rb b/bindings/ruby/test/test_tree_tuner.rb index bbd8a498..1e46e1be 100644 --- a/bindings/ruby/test/test_tree_tuner.rb +++ b/bindings/ruby/test/test_tree_tuner.rb @@ -115,6 +115,12 @@ def test_user_defined tuner.tuner_data.optima.sample.configuration end } + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TreeTunerData.new] + } + os = create_tuning_problem t = CCS::UserDefinedTuner.new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TreeTunerData.new) t2 = CCS::Object::from_handle(t) @@ -139,7 +145,7 @@ def test_user_defined assert_equal(hist.map { |e| e.objective_values.first }.max, best) assert(optims.map(&:configuration).include?(t.suggest)) buff = t.serialize - t_copy = CCS::UserDefinedTuner.deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TreeTunerData.new) + t_copy = CCS.deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) optims_2 = t_copy.optima diff --git a/bindings/ruby/test/test_tuner.rb b/bindings/ruby/test/test_tuner.rb index b4d9082a..961cfbb3 100644 --- a/bindings/ruby/test/test_tuner.rb +++ b/bindings/ruby/test/test_tuner.rb @@ -106,6 +106,12 @@ def test_user_defined tuner.tuner_data.optima.sample.configuration end } + get_vector_data = lambda { |otype, name| + assert_equal(:CCS_OBJECT_TYPE_TUNER, otype) + assert_equal("tuner", name) + [CCS::UserDefinedTuner.get_vector(del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest), TunerData.new] + } + os = create_tuning_problem t = CCS::UserDefinedTuner::new(name: "tuner", objective_space: os, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) t2 = CCS::Object::from_handle(t) @@ -132,7 +138,7 @@ def test_user_defined assert( t.optima.collect(&:configuration).include?(t.suggest) ) buff = t.serialize - t_copy = CCS::UserDefinedTuner::deserialize(buffer: buff, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(buffer: buff, vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) @@ -141,7 +147,7 @@ def test_user_defined assert( t_copy.optima.collect(&:configuration).include?(t_copy.suggest) ) t.serialize(path: 'tuner.ccs') - t_copy = CCS::UserDefinedTuner::deserialize(path: 'tuner.ccs', del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(path: 'tuner.ccs', vector_callback: get_vector_data) hist = t_copy.history assert_equal(200, hist.size) assert_equal(t.num_optima, t_copy.num_optima) @@ -154,7 +160,7 @@ def test_user_defined t.serialize(file_descriptor: f.fileno) f.close f = File.open('tuner.ccs', "rb") - t_copy = CCS::UserDefinedTuner::deserialize(file_descriptor: f.fileno, del: del, ask: ask, tell: tell, get_optima: get_optima, get_history: get_history, suggest: suggest, tuner_data: TunerData.new) + t_copy = CCS::deserialize(file_descriptor: f.fileno, vector_callback: get_vector_data) f.close hist = t_copy.history assert_equal(200, hist.size) diff --git a/include/cconfigspace/base.h b/include/cconfigspace/base.h index 8580feab..98e60d07 100644 --- a/include/cconfigspace/base.h +++ b/include/cconfigspace/base.h @@ -965,9 +965,9 @@ enum ccs_serialize_option_e { typedef enum ccs_serialize_option_e ccs_serialize_option_t; /** - * The type of CCS object deserialization callbacks. - * This callback is used to deserialize object information that were created by - * the serialization callback. + * The type of CCS object data deserialization callbacks. This callback is + * used to deserialize object data information that were created by the + * serialization callback. * @param[in, out] object a CCS object * @param[in] serialize_data_size the size of the memory pointed to by * \p serialize_data @@ -981,12 +981,41 @@ typedef enum ccs_serialize_option_e ccs_serialize_option_t; * @remarks * This function must be thread-safe for serialization to be thread safe. */ -typedef ccs_result_t (*ccs_object_deserialize_callback_t)( +typedef ccs_result_t (*ccs_object_deserialize_data_callback_t)( ccs_object_t object, size_t serialize_data_size, const char *serialize_data, void *callback_user_data); +/** + * The type of CCS user object vector deserialization callbacks. + * This callback is used to obtain the vector to use for a user defined object. + * @param[in] type the type of CCS object for which the vector is required + * @param[in] name the name of the object as it was orignally provided at + * creation + * @param[in] callback_user_data the pointer provided when the callback + * was attached. + * @param[out] vector_ret a pointer that will hold the returned pointer to the + * vector structure + * @param[out] data_ret a pointer that will optionally hold the user data value + * to use when initializing the user object. + * If the provided vector holds the relevant + * deserialization callback, this value will be + * ignored. + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_TYPE if the object type is not expected by + * the callback + * @return #CCS_RESULT_ERROR_INVALID_NAME if the object name is not recognized + * @remarks + * This function must be thread-safe for deserialization to be thread safe. + */ +typedef ccs_result_t (*ccs_object_deserialize_vector_callback_t)( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret); + /** * The different deserialization options. */ @@ -1001,15 +1030,11 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_HANDLE_MAP, /** - * The next parameter is a pointer to a ccs object vector struct, for - * user defined tuners - */ - CCS_DESERIALIZE_OPTION_VECTOR, - /** - * The next parameter is a pointer to a ccs object internal data, for - * user defined tuners + * The next parameter is a pointer to a callback of type + * ccs_object_deserialize_vector_callback_t and its user data, that + * will return the ccs object vector struct, for user defined objects */ - CCS_DESERIALIZE_OPTION_DATA, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, /** * The file descriptor operation is non-blocking. The next parameter is * a pointer to a void * variable (initialized to NULL) that will hold @@ -1019,12 +1044,12 @@ enum ccs_deserialize_option_e { */ CCS_DESERIALIZE_OPTION_NON_BLOCKING, /** - * The next parameters are a deserialization callback and it's + * The next parameters are a deserialization callback and its * user_data. This callback will be called for all objects that had * their user_data serialized. If no such callback is provided the * object's user_data value will not be set. */ - CCS_DESERIALIZE_OPTION_CALLBACK, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, /** Guard */ CCS_DESERIALIZE_OPTION_MAX, /** Try forcing 32 bits value for bindings */ diff --git a/include/cconfigspace/expression.h b/include/cconfigspace/expression.h index 5ae4af4a..1c189b9b 100644 --- a/include/cconfigspace/expression.h +++ b/include/cconfigspace/expression.h @@ -57,6 +57,8 @@ enum ccs_expression_type_e { CCS_EXPRESSION_TYPE_LITERAL, /** Variable */ CCS_EXPRESSION_TYPE_VARIABLE, + /** User defined */ + CCS_EXPRESSION_TYPE_USER_DEFINED, /** Guard */ CCS_EXPRESSION_TYPE_MAX, /** Try forcing 32 bits value for bindings */ @@ -79,7 +81,7 @@ typedef enum ccs_expression_type_e ccs_expression_type_t; * - 6 : POSITIVE, NEGATIVE, NOT * - 7 : IN * - max - 1: LIST - * - max : LITERAL, VARIABLE + * - max : LITERAL, VARIABLE, USER_DEFINED * * Those are similar to C's precedence */ @@ -117,7 +119,7 @@ typedef enum ccs_associativity_type_e ccs_associativity_type_t; * - right: POSITIVE, NEGATIVE, NOT * - left: IN * - left: LIST - * - none: LITERAL, VARIABLE + * - none: LITERAL, VARIABLE, USER_DEFINED */ extern const ccs_associativity_type_t ccs_expression_associativity[]; @@ -143,6 +145,7 @@ extern const ccs_associativity_type_t ccs_expression_associativity[]; * - LIST: NULL * - LITERAL: NULL * - VARIABLE: NULL + * - USER_DEFINED: NULL */ extern const char *ccs_expression_symbols[]; @@ -158,6 +161,7 @@ extern const char *ccs_expression_symbols[]; * - 2: IN * - -1: LIST * - 0: LITERAL, VARIABLE + * - -1: USER_DEFINED */ extern const int ccs_expression_arity[]; @@ -349,6 +353,86 @@ ccs_create_literal(ccs_datum_t value, ccs_expression_t *expression_ret); extern ccs_result_t ccs_create_variable(ccs_parameter_t parameter, ccs_expression_t *expression_ret); +/** + * A structure that define the callbacks the user must provide to create a user + * defined expression. + */ +struct ccs_user_defined_expression_vector_s { + /** + * The deletion callback that will be called once the reference count + * of the expression reaches 0. + */ + ccs_result_t (*del)(ccs_expression_t expression); + + /** + * The expression evaluation interface. + */ + ccs_result_t (*eval)( + ccs_expression_t expression, + size_t num_values, + ccs_datum_t *values, + ccs_datum_t *value_ret); + + /** + * The expression serialization interface, can be NULL. + */ + ccs_result_t (*serialize_user_state)( + ccs_expression_t expression, + size_t sate_size, + void *state, + size_t *state_size_ret); + + /** + * The expression deserialization interface, can be NULL. + */ + ccs_result_t (*deserialize_state)( + size_t state_size, + const void *state, + void **expression_data_ret); +}; + +/** + * a commodity type to represent a user defined expression callback vector. + */ +typedef struct ccs_user_defined_expression_vector_s + ccs_user_defined_expression_vector_t; + +/** + * Create a new user defined expression. + * @param[in] name the name of the expression + * @param[in] num_nodes the number of the expression children nodes. Must be + * compatible with the arity of the expression + * @param[in] nodes an array of \p num_nodes expressions + * @param[in] vector the vector of callbacks implementing the expression + * interface + * @param[in] expression_data a pointer to the expression internal data + * structures. Can be NULL + * @param[out] expression_ret a pointer to the variable that will hold the newly + * created expression + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_OBJECT if one the nodes given is of type + * #CCS_DATA_TYPE_OBJECT but the object is not a valid CCS object + * @return #CCS_RESULT_ERROR_INVALID_VALUE if \p name is NULL; or if one the + * nodes given is of type #CCS_DATA_TYPE_OBJECT but is neither a + * #CCS_OBJECT_TYPE_PARAMETER nor a #CCS_OBJECT_TYPE_EXPRESSION; or if one the + * nodes given node is not a type #CCS_DATA_TYPE_OBJECT, #CCS_DATA_TYPE_NONE, + * #CCS_DATA_TYPE_INT, #CCS_DATA_TYPE_FLOAT, #CCS_DATA_TYPE_BOOL, or + * #CCS_DATA_TYPE_STRING; or if \p expression_ret is NULL; or if \p vector is + * NULL; or if any non optional interface pointer is NULL + * @return #CCS_RESULT_ERROR_OUT_OF_MEMORY if there was not enough memory to + * allocate the new expression instance + * @remarks + * This function is thread-safe + */ +extern ccs_result_t +ccs_create_user_defined_expression( + const char *name, + size_t num_nodes, + ccs_datum_t *nodes, + ccs_user_defined_expression_vector_t *vector, + void *expression_data, + ccs_expression_t *expression_ret); + /** * Get the type of an expression. * @param[in] expression @@ -427,6 +511,43 @@ ccs_variable_get_parameter( ccs_expression_t expression, ccs_parameter_t *parameter_ret); +/** + * Get the name of a user defined expression. + * @param[in] expression + * @param[out] name_ret a pointer to the variable that will contain a pointer to + * the name of the expression + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_VALUE if \p name_ret is NULL + * @return #CCS_RESULT_ERROR_INVALID_OBJECT if \p expression is not a valid CCS + * expression + * @return #CCS_RESULT_ERROR_INVALID_EXPRESSION if \p expression is not a user + * defined expression + * @remarks + * This function is thread-safe + */ +extern ccs_result_t +ccs_user_defined_expression_get_name( + ccs_expression_t expression, + const char **name_ret); + +/** + * Get the user defined expression internal data pointer. + * @param[in] expression + * @param[out] expression_data_ret + * @return #CCS_RESULT_SUCCESS on success + * @return #CCS_RESULT_ERROR_INVALID_OBJECT if \p expression is not a valid CCS + * expression + * @return #CCS_RESULT_ERROR_INVALID_EXPRESSION if \p expression is not a user + * defined expression + * @return #CCS_RESULT_ERROR_INVALID_VALUE if \p expression_data_ret is NULL + * @remarks + * This function is thread-safe + */ +extern ccs_result_t +ccs_user_defined_expression_get_expression_data( + ccs_expression_t expression, + void **expression_data_ret); + /** * Get the value of an expression, in a given list of bindings. * @param[in] expression diff --git a/include/cconfigspace/tree_space.h b/include/cconfigspace/tree_space.h index 64bb3808..58510d7b 100644 --- a/include/cconfigspace/tree_space.h +++ b/include/cconfigspace/tree_space.h @@ -107,12 +107,15 @@ struct ccs_dynamic_tree_space_vector_s { /** * The tree space deserialization interface, can be NULL. In this case, - * only the tree is deserialized. + * only the tree is deserialized. Must return the tree space data + * to use at initialization */ ccs_result_t (*deserialize_state)( - ccs_tree_space_t tree_space, - size_t state_size, - const void *state); + ccs_tree_t tree, + ccs_feature_space_t feature_space, + size_t state_size, + const void *state, + void **tree_space_data_ret); }; /** diff --git a/include/cconfigspace/tuner.h b/include/cconfigspace/tuner.h index 154ebe38..c586db43 100644 --- a/include/cconfigspace/tuner.h +++ b/include/cconfigspace/tuner.h @@ -345,17 +345,19 @@ struct ccs_user_defined_tuner_vector_s { size_t *state_size_ret); /** - * The tuner deserialization interface, can be NULL, in which case, - * the history will be set through the tell interface + * The tuner deserialization interface, can be NULL, in which + * case, the history will be set through the tell interface. Must + * return the tuner data to use at initialization */ ccs_result_t (*deserialize_state)( - ccs_tuner_t tuner, - size_t size_history, - ccs_evaluation_t *history, - size_t num_optima, - ccs_evaluation_t *optima, - size_t state_size, - const void *state); + ccs_objective_space_t objective_space, + size_t size_history, + ccs_evaluation_t *history, + size_t num_optima, + ccs_evaluation_t *optima, + size_t state_size, + const void *state, + void **tuner_data_ret); }; /** diff --git a/samples/Makefile.am b/samples/Makefile.am index eca46bda..8248ab53 100644 --- a/samples/Makefile.am +++ b/samples/Makefile.am @@ -1,6 +1,6 @@ AM_COLOR_TESTS = yes -test_ruby_CFLAGS = -I$(top_srcdir)/include -Wall -Wextra $(GSL_CFLAGS) $(RUBY_CFLAGS) -Wno-deprecated-declarations +test_ruby_CFLAGS = -I$(top_srcdir)/include -Wall -Wextra $(GSL_CFLAGS) $(RUBY_CFLAGS) -Wno-deprecated-declarations -Wno-unused-parameter if !ISMACOS if STRICT diff --git a/src/cconfigspace_deserialize.h b/src/cconfigspace_deserialize.h index ff71bdbb..aeada3d0 100644 --- a/src/cconfigspace_deserialize.h +++ b/src/cconfigspace_deserialize.h @@ -54,12 +54,12 @@ _ccs_object_deserialize_options( CCS_CHECK_OBJ(opts->handle_map, CCS_OBJECT_TYPE_MAP); opts->map_values = CCS_TRUE; break; - case CCS_DESERIALIZE_OPTION_VECTOR: - opts->vector = va_arg(args, void *); - CCS_CHECK_PTR(opts->vector); - break; - case CCS_DESERIALIZE_OPTION_DATA: - opts->data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK: + opts->deserialize_vector_callback = va_arg( + args, ccs_object_deserialize_vector_callback_t); + CCS_CHECK_PTR(opts->deserialize_vector_callback); + opts->deserialize_vector_user_data = + va_arg(args, void *); break; case CCS_DESERIALIZE_OPTION_NON_BLOCKING: CCS_REFUTE( @@ -70,11 +70,11 @@ _ccs_object_deserialize_options( va_arg(args, _ccs_file_descriptor_state_t **); CCS_CHECK_PTR(opts->ppfd_state); break; - case CCS_DESERIALIZE_OPTION_CALLBACK: - opts->deserialize_callback = - va_arg(args, ccs_object_deserialize_callback_t); - CCS_CHECK_PTR(opts->deserialize_callback); - opts->deserialize_user_data = va_arg(args, void *); + case CCS_DESERIALIZE_OPTION_DATA_CALLBACK: + opts->deserialize_data_callback = va_arg( + args, ccs_object_deserialize_data_callback_t); + CCS_CHECK_PTR(opts->deserialize_data_callback); + opts->deserialize_data_user_data = va_arg(args, void *); break; default: CCS_RAISE( diff --git a/src/cconfigspace_internal.h b/src/cconfigspace_internal.h index e503162d..78e2ae85 100644 --- a/src/cconfigspace_internal.h +++ b/src/cconfigspace_internal.h @@ -1296,13 +1296,13 @@ _ccs_peek_bin_ccs_object_internal( } struct _ccs_object_deserialize_options_s { - ccs_map_t handle_map; - ccs_bool_t map_values; - _ccs_file_descriptor_state_t **ppfd_state; - void *vector; - void *data; - ccs_object_deserialize_callback_t deserialize_callback; - void *deserialize_user_data; + ccs_map_t handle_map; + ccs_bool_t map_values; + _ccs_file_descriptor_state_t **ppfd_state; + ccs_object_deserialize_vector_callback_t deserialize_vector_callback; + void *deserialize_vector_user_data; + ccs_object_deserialize_data_callback_t deserialize_data_callback; + void *deserialize_data_user_data; }; typedef struct _ccs_object_deserialize_options_s _ccs_object_deserialize_options_t; @@ -1449,10 +1449,10 @@ _ccs_object_deserialize_user_data( CCS_VALIDATE(_ccs_deserialize_bin_size( &serialize_data_size, buffer_size, buffer)); if (serialize_data_size) { - if (opts->deserialize_callback) - CCS_VALIDATE(opts->deserialize_callback( + if (opts->deserialize_data_callback) + CCS_VALIDATE(opts->deserialize_data_callback( object, serialize_data_size, *buffer, - opts->deserialize_user_data)); + opts->deserialize_data_user_data)); *buffer_size -= serialize_data_size; *buffer += serialize_data_size; } diff --git a/src/expression.c b/src/expression.c index ced48729..6efec332 100644 --- a/src/expression.c +++ b/src/expression.c @@ -16,7 +16,8 @@ const int ccs_expression_precedence[] = { 6, 6, 6, 7, 8, - 9, 9 + 9, 9, + 9 }; const ccs_associativity_type_t ccs_expression_associativity[] = { @@ -29,7 +30,8 @@ const ccs_associativity_type_t ccs_expression_associativity[] = { CCS_ASSOCIATIVITY_TYPE_RIGHT_TO_LEFT, CCS_ASSOCIATIVITY_TYPE_RIGHT_TO_LEFT, CCS_ASSOCIATIVITY_TYPE_RIGHT_TO_LEFT, CCS_ASSOCIATIVITY_TYPE_LEFT_TO_RIGHT, CCS_ASSOCIATIVITY_TYPE_LEFT_TO_RIGHT, - CCS_ASSOCIATIVITY_TYPE_NONE, CCS_ASSOCIATIVITY_TYPE_NONE + CCS_ASSOCIATIVITY_TYPE_NONE, CCS_ASSOCIATIVITY_TYPE_NONE, + CCS_ASSOCIATIVITY_TYPE_NONE, }; const char *ccs_expression_symbols[] = { @@ -42,7 +44,8 @@ const char *ccs_expression_symbols[] = { "+", "-", "!", "#", NULL, - NULL, NULL + NULL, NULL, + NULL, }; const int ccs_expression_arity[] = { @@ -55,7 +58,8 @@ const int ccs_expression_arity[] = { 1, 1, 1, 2, -1, - 0, 0 + 0, 0, + -1 }; const int ccs_terminal_precedence[] = { @@ -93,8 +97,8 @@ ccs_expression_get_ops(ccs_expression_t expression) static ccs_result_t _ccs_expression_del(ccs_object_t o) { - ccs_expression_t d = (ccs_expression_t)o; - _ccs_expression_data_t *data = d->data; + ccs_expression_t e = (ccs_expression_t)o; + _ccs_expression_data_t *data = e->data; for (size_t i = 0; i < data->num_nodes; i++) ccs_release_object(data->nodes[i]); return CCS_RESULT_SUCCESS; @@ -205,7 +209,7 @@ _ccs_expression_serialize( do { \ _ccs_expression_ops_t *ops = \ ccs_expression_get_ops(expression); \ - CCS_VALIDATE(ops->eval(expression->data, expr_ctx, result)); \ + CCS_VALIDATE(ops->eval(expression, expr_ctx, result)); \ } while (0) static inline ccs_result_t @@ -248,20 +252,14 @@ _ccs_expr_node_eval( static ccs_result_t _ccs_expr_or_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) -{ - ccs_datum_t left; - ccs_datum_t right; - // avoid inactive branch suppressing a parameter parameter - // if the other branch is valid. - // TODO: use EVAL_LEFT_RIGHT(data, num_bindings, bindings, left, right, - // NULL, NULL); - CCS_VALIDATE( - _ccs_expr_node_eval(data->nodes[0], expr_ctx, &left, NULL)); - CCS_VALIDATE( - _ccs_expr_node_eval(data->nodes[1], expr_ctx, &right, NULL)); + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) +{ + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); CCS_REFUTE( left.type != CCS_DATA_TYPE_BOOL && left.type != CCS_DATA_TYPE_INACTIVE, @@ -291,12 +289,13 @@ static _ccs_expression_ops_t _ccs_expr_or_ops = { static ccs_result_t _ccs_expr_and_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -432,14 +431,14 @@ _ccs_datum_cmp_generic(ccs_datum_t *a, ccs_datum_t *b, ccs_int_t *cmp) static ccs_result_t _ccs_expr_equal_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); RETURN_IF_INACTIVE(left, result); @@ -466,14 +465,14 @@ static _ccs_expression_ops_t _ccs_expr_equal_ops = { static ccs_result_t _ccs_expr_not_equal_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); RETURN_IF_INACTIVE(left, result); @@ -500,14 +499,14 @@ static _ccs_expression_ops_t _ccs_expr_not_equal_ops = { static ccs_result_t _ccs_expr_less_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); CCS_REFUTE( @@ -549,14 +548,14 @@ static _ccs_expression_ops_t _ccs_expr_less_ops = { static ccs_result_t _ccs_expr_greater_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); CCS_REFUTE( @@ -598,14 +597,14 @@ static _ccs_expression_ops_t _ccs_expr_greater_ops = { static ccs_result_t _ccs_expr_less_or_equal_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); CCS_REFUTE( @@ -647,14 +646,14 @@ static _ccs_expression_ops_t _ccs_expr_less_or_equal_ops = { static ccs_result_t _ccs_expr_greater_or_equal_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; - ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; - ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + ccs_parameter_type_t htl = CCS_PARAMETER_TYPE_MAX; + ccs_parameter_type_t htr = CCS_PARAMETER_TYPE_MAX; EVAL_LEFT_RIGHT(data, expr_ctx, left, right, &htl, &htr); CCS_REFUTE( @@ -703,11 +702,12 @@ _ccs_expression_list_eval_node( static ccs_result_t _ccs_expr_in_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_expression_type_t etype = CCS_EXPR_TYPE(data->nodes[1]); + _ccs_expression_data_t *data = e->data; + ccs_expression_type_t etype = CCS_EXPR_TYPE(data->nodes[1]); CCS_REFUTE( etype != CCS_EXPRESSION_TYPE_LIST, CCS_RESULT_ERROR_INVALID_VALUE); @@ -746,12 +746,13 @@ static _ccs_expression_ops_t _ccs_expr_in_ops = { static ccs_result_t _ccs_expr_add_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -784,12 +785,13 @@ static _ccs_expression_ops_t _ccs_expr_add_ops = { static ccs_result_t _ccs_expr_substract_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -822,12 +824,13 @@ static _ccs_expression_ops_t _ccs_expr_substract_ops = { static ccs_result_t _ccs_expr_multiply_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -860,12 +863,13 @@ static _ccs_expression_ops_t _ccs_expr_multiply_ops = { static ccs_result_t _ccs_expr_divide_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -910,12 +914,13 @@ static _ccs_expression_ops_t _ccs_expr_divide_ops = { static ccs_result_t _ccs_expr_modulo_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t left; - ccs_datum_t right; + _ccs_expression_data_t *data = e->data; + ccs_datum_t left, right; + EVAL_LEFT_RIGHT(data, expr_ctx, left, right, NULL, NULL); RETURN_IF_INACTIVE(left, result); RETURN_IF_INACTIVE(right, result); @@ -960,11 +965,13 @@ static _ccs_expression_ops_t _ccs_expr_modulo_ops = { static ccs_result_t _ccs_expr_positive_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t node; + _ccs_expression_data_t *data = e->data; + ccs_datum_t node; + EVAL_NODE(data, expr_ctx, node, NULL); RETURN_IF_INACTIVE(node, result); CCS_REFUTE( @@ -982,11 +989,13 @@ static _ccs_expression_ops_t _ccs_expr_positive_ops = { static ccs_result_t _ccs_expr_negative_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t node; + _ccs_expression_data_t *data = e->data; + ccs_datum_t node; + EVAL_NODE(data, expr_ctx, node, NULL); RETURN_IF_INACTIVE(node, result); CCS_REFUTE( @@ -1008,11 +1017,13 @@ static _ccs_expression_ops_t _ccs_expr_negative_ops = { static ccs_result_t _ccs_expr_not_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - ccs_datum_t node; + _ccs_expression_data_t *data = e->data; + ccs_datum_t node; + EVAL_NODE(data, expr_ctx, node, NULL); RETURN_IF_INACTIVE(node, result); CCS_REFUTE( @@ -1029,16 +1040,16 @@ static _ccs_expression_ops_t _ccs_expr_not_ops = { static ccs_result_t _ccs_expr_list_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { - (void)data; + (void)e; (void)expr_ctx; (void)result; CCS_RAISE( CCS_RESULT_ERROR_UNSUPPORTED_OPERATION, - "Lists cannot be avaluated"); + "Lists cannot be evaluated"); } static _ccs_expression_ops_t _ccs_expr_list_ops = { @@ -1142,13 +1153,15 @@ _ccs_expression_literal_serialize( static ccs_result_t _ccs_expr_literal_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { (void)expr_ctx; + _ccs_expression_data_t *data = e->data; _ccs_expression_literal_data_t *d = (_ccs_expression_literal_data_t *)data; + *result = d->value; return CCS_RESULT_SUCCESS; } @@ -1263,15 +1276,17 @@ _ccs_expression_variable_serialize( static ccs_result_t _ccs_expr_variable_eval( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result) + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) { + _ccs_expression_data_t *data = e->data; _ccs_expression_variable_data_t *d = (_ccs_expression_variable_data_t *)data; size_t num_bindings = expr_ctx->num_bindings; ccs_binding_t *bindings = expr_ctx->bindings; ccs_bool_t found = CCS_FALSE; + CCS_REFUTE(!num_bindings, CCS_RESULT_ERROR_INVALID_OBJECT); for (size_t i = 0; i < num_bindings; i++) { CCS_VALIDATE(ccs_binding_get_value_by_parameter( @@ -1289,6 +1304,141 @@ static _ccs_expression_ops_t _ccs_expr_variable_ops = { &_ccs_expression_variable_serialize}, &_ccs_expr_variable_eval}; +static ccs_result_t +_ccs_expr_user_defined_del(ccs_object_t o) +{ + _ccs_expression_user_defined_data_t *d = + (_ccs_expression_user_defined_data_t *)((ccs_expression_t)o) + ->data; + + ccs_result_t err; + err = d->vector.del((ccs_expression_t)o); + for (size_t i = 0; i < d->expr.num_nodes; i++) + ccs_release_object(d->expr.nodes[i]); + return err; +} + +static inline ccs_result_t +_ccs_serialize_bin_size_ccs_expression_user_defined( + ccs_expression_t expression, + size_t *cum_size, + _ccs_object_serialize_options_t *opts) +{ + _ccs_expression_user_defined_data_t *data = + (_ccs_expression_user_defined_data_t *)(expression->data); + CCS_VALIDATE(_ccs_serialize_bin_size_ccs_expression_data( + &data->expr, cum_size, opts)); + *cum_size += _ccs_serialize_bin_size_string(data->name); + size_t state_size = 0; + if (data->vector.serialize_user_state) + CCS_VALIDATE(data->vector.serialize_user_state( + expression, 0, NULL, &state_size)); + *cum_size += _ccs_serialize_bin_size_size(state_size); + *cum_size += state_size; + return CCS_RESULT_SUCCESS; +} + +static inline ccs_result_t +_ccs_serialize_bin_ccs_expression_user_defined( + ccs_expression_t expression, + size_t *buffer_size, + char **buffer, + _ccs_object_serialize_options_t *opts) +{ + _ccs_expression_user_defined_data_t *data = + (_ccs_expression_user_defined_data_t *)(expression->data); + CCS_VALIDATE(_ccs_serialize_bin_ccs_expression_data( + &data->expr, buffer_size, buffer, opts)); + CCS_VALIDATE( + _ccs_serialize_bin_string(data->name, buffer_size, buffer)); + size_t state_size = 0; + if (data->vector.serialize_user_state) + CCS_VALIDATE(data->vector.serialize_user_state( + expression, 0, NULL, &state_size)); + CCS_VALIDATE(_ccs_serialize_bin_size(state_size, buffer_size, buffer)); + if (state_size) { + CCS_REFUTE( + *buffer_size < state_size, + CCS_RESULT_ERROR_NOT_ENOUGH_DATA); + CCS_VALIDATE(data->vector.serialize_user_state( + expression, state_size, *buffer, NULL)); + *buffer_size -= state_size; + *buffer += state_size; + } + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +_ccs_expression_user_defined_serialize_size( + ccs_object_t object, + ccs_serialize_format_t format, + size_t *cum_size, + _ccs_object_serialize_options_t *opts) +{ + switch (format) { + case CCS_SERIALIZE_FORMAT_BINARY: + CCS_VALIDATE( + _ccs_serialize_bin_size_ccs_expression_user_defined( + (ccs_expression_t)object, cum_size, opts)); + break; + default: + CCS_RAISE( + CCS_RESULT_ERROR_INVALID_VALUE, + "Unsupported serialization format: %d", format); + } + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +_ccs_expression_user_defined_serialize( + ccs_object_t object, + ccs_serialize_format_t format, + size_t *buffer_size, + char **buffer, + _ccs_object_serialize_options_t *opts) +{ + switch (format) { + case CCS_SERIALIZE_FORMAT_BINARY: + CCS_VALIDATE(_ccs_serialize_bin_ccs_expression_user_defined( + (ccs_expression_t)object, buffer_size, buffer, opts)); + break; + default: + CCS_RAISE( + CCS_RESULT_ERROR_INVALID_VALUE, + "Unsupported serialization format: %d", format); + } + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +_ccs_expr_user_defined_eval( + ccs_expression_t e, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result) +{ + _ccs_expression_data_t *data = e->data; + _ccs_expression_user_defined_data_t *d = + (_ccs_expression_user_defined_data_t *)data; + ccs_datum_t *values = NULL; + + if (data->num_nodes) + values = (ccs_datum_t *)alloca( + sizeof(ccs_datum_t) * data->num_nodes); + for (size_t i = 0; i < data->num_nodes; i++) { + CCS_VALIDATE(_ccs_expr_node_eval( + data->nodes[i], expr_ctx, values + i, NULL)); + RETURN_IF_INACTIVE(values[i], result); + } + CCS_VALIDATE(d->vector.eval(e, data->num_nodes, values, result)); + return CCS_RESULT_SUCCESS; +} + +static _ccs_expression_ops_t _ccs_expr_user_defined_ops = { + {&_ccs_expr_user_defined_del, + &_ccs_expression_user_defined_serialize_size, + &_ccs_expression_user_defined_serialize}, + &_ccs_expr_user_defined_eval}; + static inline _ccs_expression_ops_t * _ccs_expression_ops_broker(ccs_expression_type_t expression_type) { @@ -1353,6 +1503,9 @@ _ccs_expression_ops_broker(ccs_expression_type_t expression_type) case CCS_EXPRESSION_TYPE_VARIABLE: return &_ccs_expr_variable_ops; break; + case CCS_EXPRESSION_TYPE_USER_DEFINED: + return &_ccs_expr_user_defined_ops; + break; default: return NULL; } @@ -1432,6 +1585,54 @@ ccs_create_variable(ccs_parameter_t parameter, ccs_expression_t *expression_ret) return err; } +static inline ccs_result_t +_ccs_validate_nodes(size_t num_nodes, ccs_datum_t *nodes) +{ + for (size_t i = 0; i < num_nodes; i++) { + if (nodes[i].type == CCS_DATA_TYPE_OBJECT) { + ccs_object_type_t object_type = + CCS_OBJ_TYPE(nodes[i].value.o); + CCS_REFUTE( + object_type != CCS_OBJECT_TYPE_PARAMETER && + object_type != + CCS_OBJECT_TYPE_EXPRESSION, + CCS_RESULT_ERROR_INVALID_VALUE); + } else + CCS_REFUTE( + nodes[i].type < CCS_DATA_TYPE_NONE || + nodes[i].type > CCS_DATA_TYPE_STRING, + CCS_RESULT_ERROR_INVALID_VALUE); + } + return CCS_RESULT_SUCCESS; +} + +static inline ccs_result_t +_ccs_create_nodes( + size_t num_nodes, + ccs_datum_t *nodes, + ccs_expression_t *nodes_ret) +{ + for (size_t i = 0; i < num_nodes; i++) { + if (nodes[i].type == CCS_DATA_TYPE_OBJECT) { + ccs_object_type_t t = CCS_OBJ_TYPE(nodes[i].value.o); + if (t == CCS_OBJECT_TYPE_EXPRESSION) { + CCS_VALIDATE( + ccs_retain_object(nodes[i].value.o)); + nodes_ret[i] = + (ccs_expression_t)nodes[i].value.o; + } else { + CCS_VALIDATE(ccs_create_variable( + (ccs_parameter_t)nodes[i].value.o, + nodes_ret + i)); + } + } else { + CCS_VALIDATE( + ccs_create_literal(nodes[i], nodes_ret + i)); + } + } + return CCS_RESULT_SUCCESS; +} + ccs_result_t ccs_create_expression( ccs_expression_type_t type, @@ -1449,27 +1650,13 @@ ccs_create_expression( CCS_REFUTE( arity >= 0 && num_nodes != (size_t)arity, CCS_RESULT_ERROR_INVALID_VALUE); - ccs_result_t err; - for (size_t i = 0; i < num_nodes; i++) { - if (nodes[i].type == CCS_DATA_TYPE_OBJECT) { - ccs_object_type_t object_type = - CCS_OBJ_TYPE(nodes[i].value.o); - CCS_REFUTE( - object_type != CCS_OBJECT_TYPE_PARAMETER && - object_type != - CCS_OBJECT_TYPE_EXPRESSION, - CCS_RESULT_ERROR_INVALID_VALUE); - } else - CCS_REFUTE( - nodes[i].type < CCS_DATA_TYPE_NONE || - nodes[i].type > CCS_DATA_TYPE_STRING, - CCS_RESULT_ERROR_INVALID_VALUE); - } + CCS_VALIDATE(_ccs_validate_nodes(num_nodes, nodes)); - uintptr_t mem = (uintptr_t)calloc( - 1, sizeof(struct _ccs_expression_s) + - sizeof(struct _ccs_expression_data_s) + - num_nodes * sizeof(ccs_expression_t)); + ccs_result_t err; + uintptr_t mem = (uintptr_t)calloc( + 1, sizeof(struct _ccs_expression_s) + + sizeof(struct _ccs_expression_data_s) + + num_nodes * sizeof(ccs_expression_t)); CCS_REFUTE(!mem, CCS_RESULT_ERROR_OUT_OF_MEMORY); ccs_expression_t expression = (ccs_expression_t)mem; @@ -1484,33 +1671,10 @@ ccs_create_expression( expression_data->nodes = (ccs_expression_t *)(mem + sizeof(struct _ccs_expression_s) + sizeof(struct _ccs_expression_data_s)); - for (size_t i = 0; i < num_nodes; i++) { - if (nodes[i].type == CCS_DATA_TYPE_OBJECT) { - ccs_object_type_t t = CCS_OBJ_TYPE(nodes[i].value.o); - if (t == CCS_OBJECT_TYPE_EXPRESSION) { - CCS_VALIDATE_ERR_GOTO( - err, - ccs_retain_object(nodes[i].value.o), - cleanup); - expression_data->nodes[i] = - (ccs_expression_t)nodes[i].value.o; - } else { - CCS_VALIDATE_ERR_GOTO( - err, - ccs_create_variable( - (ccs_parameter_t)nodes[i] - .value.o, - expression_data->nodes + i), - cleanup); - } - } else { - CCS_VALIDATE_ERR_GOTO( - err, - ccs_create_literal( - nodes[i], expression_data->nodes + i), - cleanup); - } - } + CCS_VALIDATE_ERR_GOTO( + err, + _ccs_create_nodes(num_nodes, nodes, expression_data->nodes), + cleanup); expression->data = expression_data; *expression_ret = expression; return CCS_RESULT_SUCCESS; @@ -1548,6 +1712,67 @@ ccs_create_unary_expression( return CCS_RESULT_SUCCESS; } +ccs_result_t +ccs_create_user_defined_expression( + const char *name, + size_t num_nodes, + ccs_datum_t *nodes, + ccs_user_defined_expression_vector_t *vector, + void *expr_data, + ccs_expression_t *expression_ret) +{ + CCS_CHECK_ARY(num_nodes, nodes); + CCS_CHECK_PTR(expression_ret); + CCS_CHECK_PTR(vector); + CCS_CHECK_PTR(vector->del); + CCS_CHECK_PTR(vector->eval); + CCS_VALIDATE(_ccs_validate_nodes(num_nodes, nodes)); + + ccs_result_t err; + uintptr_t mem = (uintptr_t)calloc( + 1, sizeof(struct _ccs_expression_s) + + sizeof(struct _ccs_expression_user_defined_data_s) + + num_nodes * sizeof(ccs_expression_t) + strlen(name) + + 1); + CCS_REFUTE(!mem, CCS_RESULT_ERROR_OUT_OF_MEMORY); + + ccs_expression_t expression; + expression = (ccs_expression_t)mem; + _ccs_object_init( + &(expression->obj), CCS_OBJECT_TYPE_EXPRESSION, + (_ccs_object_ops_t *)_ccs_expression_ops_broker( + CCS_EXPRESSION_TYPE_USER_DEFINED)); + _ccs_expression_user_defined_data_t *expression_data; + expression_data = (_ccs_expression_user_defined_data_t + *)(mem + sizeof(struct _ccs_expression_s)); + expression_data->expr.type = CCS_EXPRESSION_TYPE_USER_DEFINED; + expression_data->expr.num_nodes = num_nodes; + expression_data->expr.nodes = + (ccs_expression_t + *)(mem + sizeof(struct _ccs_expression_s) + sizeof(struct _ccs_expression_user_defined_data_s)); + expression_data->name = + (const char + *)(mem + sizeof(struct _ccs_expression_s) + sizeof(struct _ccs_expression_user_defined_data_s) + sizeof(ccs_expression_t) * num_nodes); + expression_data->vector = *vector; + expression_data->expression_data = expr_data; + strcpy((char *)expression_data->name, name); + CCS_VALIDATE_ERR_GOTO( + err, + _ccs_create_nodes(num_nodes, nodes, expression_data->expr.nodes), + cleanup); + expression->data = (_ccs_expression_data_t *)expression_data; + *expression_ret = expression; + return CCS_RESULT_SUCCESS; +cleanup: + for (size_t i = 0; i < num_nodes; i++) { + if (expression_data->expr.nodes[i]) + ccs_release_object(expression_data->expr.nodes[i]); + } + _ccs_object_deinit(&(expression->obj)); + free((void *)mem); + return err; +} + ccs_result_t ccs_expression_eval( ccs_expression_t expression, @@ -1666,6 +1891,38 @@ ccs_variable_get_parameter( return CCS_RESULT_SUCCESS; } +ccs_result_t +ccs_user_defined_expression_get_expression_data( + ccs_expression_t expression, + void **expression_data_ret) +{ + CCS_CHECK_OBJ(expression, CCS_OBJECT_TYPE_EXPRESSION); + CCS_CHECK_PTR(expression_data_ret); + _ccs_expression_user_defined_data_t *d = + (_ccs_expression_user_defined_data_t *)expression->data; + CCS_REFUTE( + d->expr.type != CCS_EXPRESSION_TYPE_USER_DEFINED, + CCS_RESULT_ERROR_INVALID_EXPRESSION); + *expression_data_ret = d->expression_data; + return CCS_RESULT_SUCCESS; +} + +ccs_result_t +ccs_user_defined_expression_get_name( + ccs_expression_t expression, + const char **name_ret) +{ + CCS_CHECK_OBJ(expression, CCS_OBJECT_TYPE_EXPRESSION); + CCS_CHECK_PTR(name_ret); + _ccs_expression_user_defined_data_t *d = + (_ccs_expression_user_defined_data_t *)expression->data; + CCS_REFUTE( + d->expr.type != CCS_EXPRESSION_TYPE_USER_DEFINED, + CCS_RESULT_ERROR_INVALID_EXPRESSION); + *name_ret = d->name; + return CCS_RESULT_SUCCESS; +} + #undef utarray_oom #define utarray_oom() \ { \ diff --git a/src/expression_deserialize.h b/src/expression_deserialize.h index 8ac6e7b3..27c5cb5e 100644 --- a/src/expression_deserialize.h +++ b/src/expression_deserialize.h @@ -151,6 +151,79 @@ _ccs_deserialize_bin_expression_general( return res; } +struct _ccs_expression_user_defined_data_mock_s { + _ccs_expression_data_mock_t expr; + const char *name; + _ccs_blob_t blob; +}; +typedef struct _ccs_expression_user_defined_data_mock_s + _ccs_expression_user_defined_data_mock_t; + +static inline ccs_result_t +_ccs_deserialize_bin_ccs_expression_user_defined_data( + _ccs_expression_user_defined_data_mock_t *data, + uint32_t version, + size_t *buffer_size, + const char **buffer, + _ccs_object_deserialize_options_t *opts) +{ + CCS_VALIDATE(_ccs_deserialize_bin_ccs_expression_data( + &data->expr, version, buffer_size, buffer, opts)); + CCS_VALIDATE( + _ccs_deserialize_bin_string(&data->name, buffer_size, buffer)); + CCS_VALIDATE(_ccs_deserialize_bin_ccs_blob( + &data->blob, buffer_size, buffer)); + return CCS_RESULT_SUCCESS; +} + +static inline ccs_result_t +_ccs_deserialize_bin_expression_user_defined( + ccs_expression_t *expression_ret, + uint32_t version, + size_t *buffer_size, + const char **buffer, + _ccs_object_deserialize_options_t *opts) +{ + _ccs_expression_user_defined_data_mock_t data; + ccs_user_defined_expression_vector_t *vector = NULL; + void *expression_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + + CCS_VALIDATE(_ccs_deserialize_bin_ccs_expression_user_defined_data( + &data, version, buffer_size, buffer, opts)); + + CCS_VALIDATE_ERR_GOTO( + res, + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_EXPRESSION, data.name, + opts->deserialize_vector_user_data, (void **)&vector, + &expression_data), + end); + + if (vector->deserialize_state) + CCS_VALIDATE_ERR_GOTO( + res, + vector->deserialize_state( + data.blob.sz, data.blob.blob, &expression_data), + end); + + CCS_VALIDATE_ERR_GOTO( + res, + ccs_create_user_defined_expression( + data.name, data.expr.num_nodes, data.expr.nodes, vector, + expression_data, expression_ret), + end); + +end: + if (data.expr.nodes) { + for (size_t i = 0; i < data.expr.num_nodes; i++) + if (data.expr.nodes[i].type == CCS_DATA_TYPE_OBJECT) + ccs_release_object(data.expr.nodes[i].value.o); + free(data.expr.nodes); + } + return res; +} + static inline ccs_result_t _ccs_deserialize_bin_expression( ccs_expression_t *expression_ret, @@ -175,6 +248,11 @@ _ccs_deserialize_bin_expression( expression_ret, version, buffer_size, buffer, &new_opts)); break; + case CCS_EXPRESSION_TYPE_USER_DEFINED: + CCS_VALIDATE(_ccs_deserialize_bin_expression_user_defined( + expression_ret, version, buffer_size, buffer, + &new_opts)); + break; default: CCS_REFUTE( dtype < CCS_EXPRESSION_TYPE_OR || diff --git a/src/expression_internal.h b/src/expression_internal.h index 89a905bd..640d6e76 100644 --- a/src/expression_internal.h +++ b/src/expression_internal.h @@ -15,9 +15,9 @@ struct _ccs_expression_ops_s { _ccs_object_ops_t obj_ops; ccs_result_t (*eval)( - _ccs_expression_data_t *data, - _ccs_expr_ctx_t *expr_ctx, - ccs_datum_t *result); + ccs_expression_t expression, + _ccs_expr_ctx_t *expr_ctx, + ccs_datum_t *result); }; typedef struct _ccs_expression_ops_s _ccs_expression_ops_t; @@ -44,4 +44,13 @@ struct _ccs_expression_variable_data_s { }; typedef struct _ccs_expression_variable_data_s _ccs_expression_variable_data_t; +struct _ccs_expression_user_defined_data_s { + _ccs_expression_data_t expr; + ccs_user_defined_expression_vector_t vector; + void *expression_data; + const char *name; +}; +typedef struct _ccs_expression_user_defined_data_s + _ccs_expression_user_defined_data_t; + #endif //_EXPRESSION_INTERNAL_H diff --git a/src/tree_space_deserialize.h b/src/tree_space_deserialize.h index 94cb977a..9a2d8aa3 100644 --- a/src/tree_space_deserialize.h +++ b/src/tree_space_deserialize.h @@ -88,32 +88,42 @@ _ccs_deserialize_bin_tree_space_dynamic( _ccs_object_deserialize_options_t *opts, _ccs_tree_space_common_data_mock_t *data) { - _ccs_blob_t blob = {0, NULL}; - ccs_dynamic_tree_space_vector_t *vector = - (ccs_dynamic_tree_space_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + _ccs_blob_t blob = {0, NULL}; + ccs_dynamic_tree_space_vector_t *vector = NULL; + void *tree_space_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_tree_space_common_data( data, version, buffer_size, buffer, opts), end); - CCS_VALIDATE(_ccs_deserialize_bin_ccs_blob(&blob, buffer_size, buffer)); + CCS_VALIDATE_ERR_GOTO( + res, _ccs_deserialize_bin_ccs_blob(&blob, buffer_size, buffer), + end); + CCS_VALIDATE_ERR_GOTO( res, - ccs_create_dynamic_tree_space( - data->name, data->tree, data->feature_space, data->rng, - vector, opts->data, tree_space_ret), + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TREE_SPACE, data->name, + opts->deserialize_vector_user_data, (void **)&vector, + &tree_space_data), end); + if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, vector->deserialize_state( - *tree_space_ret, blob.sz, blob.blob), - tree_space); - goto end; -tree_space: - ccs_release_object(*tree_space_ret); - *tree_space_ret = NULL; + data->tree, data->feature_space, blob.sz, + blob.blob, &tree_space_data), + end); + + CCS_VALIDATE_ERR_GOTO( + res, + ccs_create_dynamic_tree_space( + data->name, data->tree, data->feature_space, data->rng, + vector, tree_space_data, tree_space_ret), + end); end: if (data->feature_space) ccs_release_object(data->feature_space); @@ -137,6 +147,8 @@ _ccs_deserialize_bin_tree_space( ccs_tree_space_type_t stype; CCS_VALIDATE( _ccs_peek_bin_ccs_tree_space_type(&stype, buffer_size, buffer)); + if (stype == CCS_TREE_SPACE_TYPE_DYNAMIC) + CCS_CHECK_PTR(opts->deserialize_vector_callback); _ccs_tree_space_common_data_mock_t data = { CCS_TREE_SPACE_TYPE_STATIC, NULL, NULL, NULL, NULL, NULL}; diff --git a/src/tuner_deserialize.h b/src/tuner_deserialize.h index 220aeb9e..9dec92e2 100644 --- a/src/tuner_deserialize.h +++ b/src/tuner_deserialize.h @@ -179,32 +179,44 @@ _ccs_deserialize_bin_user_defined_tuner( NULL, NULL}, {0, NULL}}; - ccs_user_defined_tuner_vector_t *vector = - (ccs_user_defined_tuner_vector_t *)opts->vector; - ccs_result_t res = CCS_RESULT_SUCCESS; + ccs_user_defined_tuner_vector_t *vector = NULL; + void *tuner_data = NULL; + ccs_result_t res = CCS_RESULT_SUCCESS; + CCS_VALIDATE_ERR_GOTO( res, _ccs_deserialize_bin_ccs_user_defined_tuner_data( &data, version, buffer_size, buffer, opts), end); + CCS_VALIDATE_ERR_GOTO( res, - ccs_create_user_defined_tuner( - data.base_data.common_data.name, - data.base_data.common_data.objective_space, vector, - opts->data, tuner_ret), + opts->deserialize_vector_callback( + CCS_OBJECT_TYPE_TUNER, data.base_data.common_data.name, + opts->deserialize_vector_user_data, (void **)&vector, + &tuner_data), end); + if (vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, vector->deserialize_state( - *tuner_ret, data.base_data.size_history, + data.base_data.common_data.objective_space, + data.base_data.size_history, data.base_data.history, data.base_data.size_optima, data.base_data.optima, data.blob.sz, - data.blob.blob), - tuner); - else + data.blob.blob, &tuner_data), + end); + + CCS_VALIDATE_ERR_GOTO( + res, + ccs_create_user_defined_tuner( + data.base_data.common_data.name, + data.base_data.common_data.objective_space, vector, + tuner_data, tuner_ret), + end); + if (!vector->deserialize_state) CCS_VALIDATE_ERR_GOTO( res, vector->tell( @@ -241,7 +253,7 @@ _ccs_deserialize_bin_tuner( ccs_tuner_type_t ttype; CCS_VALIDATE(_ccs_peek_bin_ccs_tuner_type(&ttype, buffer_size, buffer)); if (ttype == CCS_TUNER_TYPE_USER_DEFINED) - CCS_CHECK_PTR(opts->vector); + CCS_CHECK_PTR(opts->deserialize_vector_callback); new_opts.map_values = CCS_TRUE; CCS_VALIDATE(ccs_create_map(&new_opts.handle_map)); diff --git a/tests/test_categorical_parameter.c b/tests/test_categorical_parameter.c index 2104fd21..8f11658c 100644 --- a/tests/test_categorical_parameter.c +++ b/tests/test_categorical_parameter.c @@ -179,7 +179,7 @@ test_create(void) err = ccs_object_deserialize( (ccs_object_t *)¶meter, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_CALLBACK, &deserialize_callback, + CCS_DESERIALIZE_OPTION_DATA_CALLBACK, &deserialize_callback, (void *)0xbeefdead, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_dynamic_tree_space.c b/tests/test_dynamic_tree_space.c index adf0cb01..da794558 100644 --- a/tests/test_dynamic_tree_space.c +++ b/tests/test_dynamic_tree_space.c @@ -35,24 +35,46 @@ my_tree_get_child( return CCS_RESULT_SUCCESS; } +ccs_dynamic_tree_space_vector_t vector = { + &my_tree_del, &my_tree_get_child, NULL, NULL}; + +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TREE_SPACE: + *vector_ret = (void *)&vector; + *data_ret = NULL; + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test_dynamic_tree_space(void) { - ccs_result_t err; - ccs_bool_t is_valid; - ccs_tree_t root, tree; - ccs_tree_space_t tree_space; - ccs_tree_space_type_t tree_type; - ccs_rng_t rng, rng2; - size_t position_size, *position, depths[5]; - ccs_datum_t value, *values; - ccs_float_t areas[5] = {1.0, 4.0, 2.0, 1.0, 0.0}; - ccs_float_t inv_sum; - const char *name; - ccs_tree_configuration_t config, configs[NUM_SAMPLES]; - - ccs_dynamic_tree_space_vector_t vector = { - &my_tree_del, &my_tree_get_child, NULL, NULL}; + ccs_result_t err; + ccs_bool_t is_valid; + ccs_tree_t root, tree; + ccs_tree_space_t tree_space; + ccs_tree_space_type_t tree_type; + ccs_rng_t rng, rng2; + size_t position_size, *position, depths[5]; + ccs_datum_t value, *values; + ccs_float_t areas[5] = {1.0, 4.0, 2.0, 1.0, 0.0}; + ccs_float_t inv_sum; + const char *name; + ccs_tree_configuration_t config, configs[NUM_SAMPLES]; + err = ccs_create_tree(4, ccs_int(4 * 100), &root); assert(err == CCS_RESULT_SUCCESS); err = ccs_create_rng(&rng); @@ -176,7 +198,8 @@ test_dynamic_tree_space(void) err = ccs_object_deserialize( (ccs_object_t *)&tree_space, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, - CCS_DESERIALIZE_OPTION_VECTOR, &vector, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void *)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); free(buff); diff --git a/tests/test_expression.c b/tests/test_expression.c index e5dca951..9e420f04 100644 --- a/tests/test_expression.c +++ b/tests/test_expression.c @@ -3,6 +3,7 @@ #include #include #include +#include double d = -2.0; @@ -1237,6 +1238,161 @@ test_deserialize(void) free(buff); } +static ccs_result_t +my_del(ccs_expression_t expression) +{ + ccs_result_t res; + void *data; + res = ccs_user_defined_expression_get_expression_data( + expression, &data); + assert(res == CCS_RESULT_SUCCESS); + gsl_rng_free((gsl_rng *)data); + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +my_eval(ccs_expression_t expression, + size_t num_values, + ccs_datum_t *values, + ccs_datum_t *value_ret) +{ + ccs_result_t res; + void *data; + assert(num_values == 1); + assert(values); + assert(value_ret); + res = ccs_user_defined_expression_get_expression_data( + expression, &data); + assert(res == CCS_RESULT_SUCCESS); + *value_ret = ccs_int( + gsl_rng_uniform_int((gsl_rng *)data, values[0].value.i)); + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +my_serialize_user_state( + ccs_expression_t expression, + size_t sate_size, + void *state, + size_t *state_size_ret) +{ + ccs_result_t res; + void *data; + assert(state || state_size_ret); + res = ccs_user_defined_expression_get_expression_data( + expression, &data); + assert(res == CCS_RESULT_SUCCESS); + size_t sz = gsl_rng_size((gsl_rng *)data); + void *pstate = gsl_rng_state((gsl_rng *)data); + if (state_size_ret) + *state_size_ret = sz; + if (state) { + assert(sate_size >= sz); + memcpy(state, pstate, sz); + } + return CCS_RESULT_SUCCESS; +} + +static ccs_result_t +my_deserialize_state( + size_t state_size, + const void *state, + void **expression_data_ret) +{ + assert(state); + assert(expression_data_ret); + gsl_rng *grng = gsl_rng_alloc(gsl_rng_mt19937); + assert(grng); + memcpy(gsl_rng_state(grng), state, state_size); + *expression_data_ret = (void *)grng; + return CCS_RESULT_SUCCESS; +} + +static ccs_user_defined_expression_vector_t my_vector = { + &my_del, &my_eval, &my_serialize_user_state, &my_deserialize_state}; + +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + assert(callback_user_data == NULL); + assert(type == CCS_OBJECT_TYPE_EXPRESSION); + assert(!strcmp(name, "my_rand")); + assert(data_ret); + assert(vector_ret); + *vector_ret = (void *)&my_vector; + *data_ret = NULL; + return CCS_RESULT_SUCCESS; +} + +void +test_user_defined(void) +{ + ccs_datum_t limit = ccs_int(10); + ccs_datum_t result, result_copy; + ccs_result_t err; + ccs_expression_t expression, expression_copy; + gsl_rng *grng; + char *buff; + size_t buff_size; + + grng = gsl_rng_alloc(gsl_rng_mt19937); + assert(grng); + + err = ccs_create_user_defined_expression( + "my_rand", 1, &limit, &my_vector, (void *)grng, &expression); + assert(err == CCS_RESULT_SUCCESS); + for (size_t i = 0; i < 100; i++) { + err = ccs_expression_eval(expression, 0, NULL, &result); + assert(err == CCS_RESULT_SUCCESS); + assert(result.type == CCS_DATA_TYPE_INT); + assert(result.value.i >= 0); + assert(result.value.i < limit.value.i); + } + + err = ccs_object_serialize( + expression, CCS_SERIALIZE_FORMAT_BINARY, + CCS_SERIALIZE_OPERATION_SIZE, &buff_size, + CCS_SERIALIZE_OPTION_END); + assert(err == CCS_RESULT_SUCCESS); + buff = (char *)malloc(buff_size); + assert(buff); + + err = ccs_object_serialize( + expression, CCS_SERIALIZE_FORMAT_BINARY, + CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, + CCS_SERIALIZE_OPTION_END); + assert(err == CCS_RESULT_SUCCESS); + + err = ccs_object_deserialize( + (ccs_object_t *)&expression_copy, CCS_SERIALIZE_FORMAT_BINARY, + CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void *)NULL, + CCS_DESERIALIZE_OPTION_END); + assert(err == CCS_RESULT_SUCCESS); + + for (size_t i = 0; i < 100; i++) { + err = ccs_expression_eval(expression, 0, NULL, &result); + assert(err == CCS_RESULT_SUCCESS); + err = ccs_expression_eval( + expression_copy, 0, NULL, &result_copy); + assert(err == CCS_RESULT_SUCCESS); + assert(result.type == result_copy.type); + assert(result.value.i == result_copy.value.i); + } + + free(buff); + err = ccs_release_object(expression_copy); + assert(err == CCS_RESULT_SUCCESS); + err = ccs_release_object(expression); + assert(err == CCS_RESULT_SUCCESS); +} + int main(void) { @@ -1259,6 +1415,7 @@ main(void) test_arithmetic_greater(); test_arithmetic_less_or_equal(); test_arithmetic_greater_or_equal(); + test_user_defined(); test_compound(); test_in(); test_get_parameters(); diff --git a/tests/test_user_defined_features_tuner.c b/tests/test_user_defined_features_tuner.c index d319c55b..4f3a90a5 100644 --- a/tests/test_user_defined_features_tuner.c +++ b/tests/test_user_defined_features_tuner.c @@ -147,6 +147,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -263,15 +285,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void *)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tree_tuner.c b/tests/test_user_defined_tree_tuner.c index 8bfa2977..258b34b7 100644 --- a/tests/test_user_defined_tree_tuner.c +++ b/tests/test_user_defined_tree_tuner.c @@ -193,6 +193,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -274,15 +296,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void *)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); diff --git a/tests/test_user_defined_tuner.c b/tests/test_user_defined_tuner.c index e56e9d1c..ad857e4b 100644 --- a/tests/test_user_defined_tuner.c +++ b/tests/test_user_defined_tuner.c @@ -140,6 +140,28 @@ ccs_user_defined_tuner_vector_t tuner_last_vector = { NULL, NULL}; +ccs_result_t +deserialize_vector_callback( + ccs_object_type_t type, + const char *name, + void *callback_user_data, + void **vector_ret, + void **data_ret) +{ + (void)name; + (void)callback_user_data; + switch (type) { + case CCS_OBJECT_TYPE_TUNER: + *vector_ret = (void *)&tuner_last_vector; + *data_ret = calloc(1, sizeof(tuner_last_t)); + assert(*data_ret); + break; + default: + return CCS_RESULT_ERROR_INVALID_TYPE; + } + return CCS_RESULT_SUCCESS; +} + void test(void) { @@ -217,15 +239,12 @@ test(void) CCS_SERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS); - tuner_data = (tuner_last_t *)calloc(1, sizeof(tuner_last_t)); - assert(tuner_data); - err = ccs_object_deserialize( (ccs_object_t *)&tuner_copy, CCS_SERIALIZE_FORMAT_BINARY, CCS_SERIALIZE_OPERATION_MEMORY, buff_size, buff, CCS_DESERIALIZE_OPTION_HANDLE_MAP, map, - CCS_DESERIALIZE_OPTION_VECTOR, &tuner_last_vector, - CCS_DESERIALIZE_OPTION_DATA, tuner_data, + CCS_DESERIALIZE_OPTION_VECTOR_CALLBACK, + &deserialize_vector_callback, (void *)NULL, CCS_DESERIALIZE_OPTION_END); assert(err == CCS_RESULT_SUCCESS);