Skip to content

Commit

Permalink
WIP: User defined expression.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerilk committed Aug 26, 2024
1 parent 04e867e commit 3de3b8b
Show file tree
Hide file tree
Showing 15 changed files with 1,205 additions and 269 deletions.
164 changes: 163 additions & 1 deletion bindings/python/cconfigspace/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class ExpressionType(CEnumeration):
'IN',
'LIST',
'LITERAL',
'VARIABLE' ]
'VARIABLE',
'USER_DEFINED' ]

class AssociativityType(CEnumeration):
_members_ = [
Expand Down Expand Up @@ -476,6 +477,166 @@ def __str__(self):

Expression.List = ExpressionList

ccs_user_defined_expression_del_type = ct.CFUNCTYPE(Result, ccs_expression)
ccs_user_defined_expression_eval_type = ct.CFUNCTYPE(Result, ccs_expression, ct.c_size_t, ct.POINTER(Datum), ct.POINTER(Datum))
ccs_user_defined_expression_serialize_type = ct.CFUNCTYPE(Result, ccs_expression, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t))
ccs_user_defined_expression_deserialize_type = ct.CFUNCTYPE(Result, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_void_p))

class UserDefinedExpressionVector(ct.Structure):
_fields_ = [
('delete', ccs_user_defined_expression_del_type),
('evaluate', ccs_user_defined_expression_eval_type),
('serialize', ccs_user_defined_expression_serialize_type),
('deserialize', ccs_user_defined_expression_deserialize_type) ]

ccs_create_user_defined_expression = _ccs_get_function("ccs_create_user_defined_expression", [ct.c_char_p, ct.c_size_t, ct.POINTER(Datum), ct.POINTER(UserDefinedExpressionVector), ct.py_object, ct.POINTER(ccs_expression)])
ccs_user_defined_expression_get_name = _ccs_get_function("ccs_user_defined_expression_get_name", [ccs_expression, ct.POINTER(ct.c_char_p)])
ccs_user_defined_expresion_get_expression_data = _ccs_get_function("ccs_user_defined_expression_get_expression_data", [ccs_expression, ct.POINTER(ct.c_void_p)])

class ExpressionUserDefined(Expression):
def __init__(self, handle = None, retain = False, auto_release = True,
name = "", nodes = [], delete = None, evaluate = None, serialize = None, deserialize = None, expression_data = None):
if handle is None:
if evaluate is None:
raise Error(Result(Result.ERROR_INVALID_VALUE))

vec = self.get_vector(delete, evaluate, serialize, deserialize)
c_expression_data = None
if expression_data is not None:
c_expression_data = ct.py_object(expression_data)
handle = ccs_expression()
sz = len(nodes)
v = (Datum*sz)()
ss = []
for i in range(sz):
v[i].set_value(values[i], string_store = ss)
res = ccs_create_user_defined_expression(str.encode(name), sz, v, ct.byref(vec), c_expression_data, ct.byref(handle))
Error.check(res)
super().__init__(handle = handle, retain = False)
ct.pythonapi.Py_IncRef(ct.py_object(vec))
if c_expression_data is not None:
ct.pythonapi.Py_IncRef(c_expression_data)
else:
super().__init__(handle = handle, retain = retain, auto_release = auto_release)

@property
def name(self):
if hasattr(self, "_name"):
return self._name
v = ct.c_char_p()
res = ccs_expression_get_name(self.handle, ct.byref(v))
Error.check(res)
self._name = v.value.decode()
return self._name

@property
def expression_data(self):
if hasattr(self, "_expression_data"):
return self._expression_data
v = ct.c_void_p()
res = ccs_user_defined_expression_get_expression_data(self.handle, ct.byref(v))
Error.check(res)
if v:
self._expression_data = ct.cast(v, ct.py_object).value
else:
self._expression_data = None
return self._expression_data

def __str__(self):
return "{}({})".format(self.name, ", ".join(map(str, self.nodes)))

@classmethod
def get_vector(self, delete = None, evaluate = None, serialize = None, deserialize = None):
vec = UserDefinedExpressionVector()
setattr(vec, '_string_store', list())
setattr(vec, '_object_store', list())
def delete_wrapper(expr):
try:
expr = ct.cast(expr, ccs_expression)
o = Object.from_handle(expr)
edata = o.expression_data
if delete is not None:
delete(o)
if edata is not None:
ct.pythonapi.Py_DecRef(ct.py_object(edata))
ct.pythonapi.Py_DecRef(ct.py_object(vec))
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)

def evaluate_wrapper(expr, num_values, p_values, p_value_ret):
try:
expr = ct.cast(expr, ccs_expression)
if num_values == 0:
value_ret = evaluate(Expression.from_handle(expr))
else:
values = tuple(p_values[i].value for i in range(num_values))
value_ret = evaluate(Expression.from_handle(expr), *values)
p_value_ret[0].set_value(value_ret, string_store = getattr(vec, '_string_store'), object_store = getattr(vec, '_object_store'))
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)

if serialize is not None:
def serialize_wrapper(expr, state_size, p_state, p_state_size):
try:
expr = ct.cast(expr, ccs_expression)
p_s = ct.cast(p_state, ct.c_void_p)
p_sz = ct.cast(p_state_size, ct.c_void_p)
state = serialize(Expression.from_handle(expr), True if state_size == 0 else False)
if p_s.value is not None and state_size < ct.sizeof(state):
raise Error(Result(Result.ERROR_INVALID_VALUE))
if p_s.value is not None:
ct.memmove(p_s, ct.byref(state), ct.sizeof(state))
if p_sz.value is not None:
p_state_size[0] = ct.sizeof(state)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
else:
serialize_wrapper = 0

if deserialize is not None:
def deserialize_wrapper(state_size, p_state, p_expression_data):
try:
p_s = ct.cast(p_state, ct.c_void_p)
p_e = ct.cast(p_expression_data, ct.c_void_p)
if p_s.value is None:
state = None
else:
state = ct.cast(p_s, POINTER(c_byte * state_size))
expression_data = deserialize(state)
c_expression_data = ct.py_object(expression_data)
p_e[0] = c_expression_data
ct.pythonapi.Py_IncRef(c_expression_data)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
else:
deserialize_wrapper = 0

delete_wrapper_func = ccs_user_defined_expression_del_type(delete_wrapper)
evaluate_wrapper_func = ccs_user_defined_expression_eval_type(evaluate_wrapper)
serialize_wrapper_func = ccs_user_defined_expression_serialize_type(serialize_wrapper)
deserialize_wrapper_func = ccs_user_defined_expression_deserialize_type(deserialize_wrapper)
vec.delete = delete_wrapper_func
vec.evaluate = evaluate_wrapper_func
vec.serialize = serialize_wrapper_func
vec.deserialize = deserialize_wrapper_func

setattr(vec, '_wrappers', (
delete_wrapper,
evaluate_wrapper,
serialize_wrapper,
deserialize_wrapper,
delete_wrapper_func,
evaluate_wrapper_func,
serialize_wrapper_func,
deserialize_wrapper_func))
return vec

Expression.UserDefined = ExpressionUserDefined

setattr(Expression, 'EXPRESSION_MAP', {
ExpressionType.OR: ExpressionOr,
ExpressionType.AND: ExpressionAnd,
Expand All @@ -497,4 +658,5 @@ def __str__(self):
ExpressionType.LIST: ExpressionList,
ExpressionType.LITERAL: ExpressionLiteral,
ExpressionType.VARIABLE: ExpressionVariable,
ExpressionType.USER_DEFINED: ExpressionUserDefined,
})
130 changes: 57 additions & 73 deletions bindings/ruby/lib/cconfigspace/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,70 @@ module CCS
ffi_lib "cconfigspace"
end

module MemoryAccessor
def self.included(mod)
mod.class_eval {
alias read_ccs_float_t read_double
alias get_ccs_float_t get_double
alias read_array_of_ccs_float_t read_array_of_double
alias write_array_of_ccs_float_t write_array_of_double
alias read_ccs_int_t read_int64
alias get_ccs_int_t get_int64
alias read_array_of_ccs_int_t read_array_of_int64
alias write_array_of_ccs_int_t write_array_of_int64
alias read_ccs_bool_t read_int32
alias read_ccs_evaluation_result_t read_int32
alias read_ccs_hash_t read_uint32
if FFI.find_type(:size_t).size == 8
alias read_size_t read_uint64
alias write_size_t write_uint64
alias write_array_of_size_t write_array_of_uint64
alias read_array_of_size_t read_array_of_uint64
else
alias read_size_t read_uint32
alias write_size_t write_uint32
alias write_array_of_size_t write_array_of_uint32
alias read_array_of_size_t read_array_of_uint32
end
alias read_ccs_object_t read_pointer
alias read_ccs_rng_t read_ccs_object_t
alias read_ccs_distribution_t read_ccs_object_t
alias read_ccs_parameter_t read_ccs_object_t
alias read_ccs_expression_t read_ccs_object_t
alias read_ccs_context_t read_ccs_object_t
alias read_ccs_distribution_space_t read_ccs_object_t
alias read_ccs_search_space_t read_ccs_object_t
alias read_ccs_configuration_space_t read_ccs_object_t
alias read_ccs_binding_t read_ccs_object_t
alias read_ccs_search_configuration_t read_ccs_object_t
alias read_ccs_configuration_t read_ccs_object_t
alias read_ccs_feature_space_t read_ccs_object_t
alias read_ccs_features_t read_ccs_object_t
alias read_ccs_objective_space_t read_ccs_object_t
alias read_ccs_evaluation_t read_ccs_object_t
alias read_ccs_tuner_t read_ccs_object_t
alias read_ccs_map_t read_ccs_object_t
alias read_ccs_error_stack_t read_ccs_object_t
alias read_ccs_tree_t read_ccs_object_t
alias read_ccs_tree_space_t read_ccs_object_t
alias read_ccs_tree_configuration_t read_ccs_object_t
}
end
end

class Pointer < FFI::Pointer
include MemoryAccessor
def initialize(*args)
if args.length == 2 then
super(CCS::find_type(args[0]), args[1])
else
super(*args)
end
end
alias read_ccs_float_t read_double
alias get_ccs_float_t get_double
alias read_array_of_ccs_float_t read_array_of_double
alias write_array_of_ccs_float_t write_array_of_double
alias read_ccs_int_t read_int64
alias get_ccs_int_t get_int64
alias read_array_of_ccs_int_t read_array_of_int64
alias write_array_of_ccs_int_t write_array_of_int64
alias read_ccs_bool_t read_int32
alias read_ccs_evaluation_result_t read_int32
alias read_ccs_hash_t read_uint32
if FFI.find_type(:size_t).size == 8
alias read_size_t read_uint64
alias write_size_t write_uint64
alias write_array_of_size_t write_array_of_uint64
else
alias read_size_t read_uint32
alias write_size_t write_uint32
alias write_array_of_size_t write_array_of_uint32
end
end

class MemoryPointer < FFI::MemoryPointer
include MemoryAccessor
def initialize(size, count = 1, clear = true)
if size.is_a?(Symbol)
size = CCS::find_type(size)
Expand All @@ -54,31 +87,6 @@ def initialize(size, count = 1, clear = true)
typedef :int32, :ccs_evaluation_result_t
typedef :uint32, :ccs_hash_t

class MemoryPointer
alias read_ccs_float_t read_double
alias get_ccs_float_t get_double
alias read_array_of_ccs_float_t read_array_of_double
alias write_array_of_ccs_float_t write_array_of_double
alias read_ccs_int_t read_int64
alias get_ccs_int_t get_int64
alias read_array_of_ccs_int_t read_array_of_int64
alias write_array_of_ccs_int_t write_array_of_int64
alias read_ccs_bool_t read_int32
alias read_ccs_evaluation_result_t read_int32
alias read_ccs_hash_t read_uint32
if FFI.find_type(:size_t).size == 8
alias read_size_t read_uint64
alias write_size_t write_uint64
alias write_array_of_size_t write_array_of_uint64
alias read_array_of_size_t read_array_of_uint64
else
alias read_size_t read_uint32
alias write_size_t write_uint32
alias write_array_of_size_t write_array_of_uint32
alias read_array_of_size_t read_array_of_uint32
end
end

class Version < FFI::Struct
layout :revision, :uint16,
:patch, :uint16,
Expand Down Expand Up @@ -112,30 +120,6 @@ class Version < FFI::Struct
typedef :ccs_object_t, :ccs_tree_t
typedef :ccs_object_t, :ccs_tree_space_t
typedef :ccs_object_t, :ccs_tree_configuration_t
class MemoryPointer
alias read_ccs_object_t read_pointer
alias read_ccs_rng_t read_ccs_object_t
alias read_ccs_distribution_t read_ccs_object_t
alias read_ccs_parameter_t read_ccs_object_t
alias read_ccs_expression_t read_ccs_object_t
alias read_ccs_context_t read_ccs_object_t
alias read_ccs_distribution_space_t read_ccs_object_t
alias read_ccs_search_space_t read_ccs_object_t
alias read_ccs_configuration_space_t read_ccs_object_t
alias read_ccs_binding_t read_ccs_object_t
alias read_ccs_search_configuration_t read_ccs_object_t
alias read_ccs_configuration_t read_ccs_object_t
alias read_ccs_feature_space_t read_ccs_object_t
alias read_ccs_features_t read_ccs_object_t
alias read_ccs_objective_space_t read_ccs_object_t
alias read_ccs_evaluation_t read_ccs_object_t
alias read_ccs_tuner_t read_ccs_object_t
alias read_ccs_map_t read_ccs_object_t
alias read_ccs_error_stack_t read_ccs_object_t
alias read_ccs_tree_t read_ccs_object_t
alias read_ccs_tree_space_t read_ccs_object_t
alias read_ccs_tree_configuration_t read_ccs_object_t
end

ObjectType = enum FFI::Type::INT32, :ccs_object_type_t, [
:CCS_OBJECT_TYPE_RNG,
Expand Down Expand Up @@ -189,7 +173,7 @@ class MemoryPointer
:CCS_RESULT_ERROR_INVALID_TREE_SPACE, -28,
:CCS_RESULT_ERROR_INVALID_DISTRIBUTION_SPACE, -29 ]

class MemoryPointer
module MemoryAccessor
def read_ccs_object_type_t
ObjectType.from_native(read_int32, nil)
end
Expand Down Expand Up @@ -221,7 +205,7 @@ def read_ccs_result_t
:CCS_NUMERIC_TYPE_INT, DataType.to_native(:CCS_DATA_TYPE_INT, nil),
:CCS_NUMERIC_TYPE_FLOAT, DataType.to_native(:CCS_DATA_TYPE_FLOAT, nil) ]

class MemoryPointer
module MemoryAccessor
def read_ccs_numeric_type_t
NumericType.from_native(read_int32, nil)
end
Expand Down Expand Up @@ -475,13 +459,13 @@ def self.from_value(v)
end
end
typedef Datum.by_value, :ccs_datum_t
class MemoryPointer
module MemoryAccessor
def read_ccs_datum_t
Datum::new(self).value
end

def read_array_of_ccs_datum_t(length)
length.times.collect { |i| Datum::new(self[i]).value }
length.times.collect { |i| Datum::new(self + i * Datum.size).value }
end
end

Expand Down
4 changes: 2 additions & 2 deletions bindings/ruby/lib/cconfigspace/distribution.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module CCS
:CCS_DISTRIBUTION_TYPE_MIXTURE,
:CCS_DISTRIBUTION_TYPE_MULTIVARIATE
]
class MemoryPointer
module MemoryAccessor
def read_ccs_distribution_type_t
DistributionType.from_native(read_int32, nil)
end
Expand All @@ -17,7 +17,7 @@ def read_ccs_distribution_type_t
:CCS_SCALE_TYPE_LINEAR,
:CCS_SCALE_TYPE_LOGARITHMIC
]
class MemoryPointer
module MemoryAccessor
def read_ccs_scale_type_t
ScaleType.from_native(read_int32, nil)
end
Expand Down
2 changes: 1 addition & 1 deletion bindings/ruby/lib/cconfigspace/evaluation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module CCS
:CCS_COMPARISON_WORSE, 1,
:CCS_COMPARISON_NOT_COMPARABLE, 2
]
class MemoryPointer
module MemoryAccessor
def read_ccs_comparison_t
Comparison.from_native(read_int32, nil)
end
Expand Down
Loading

0 comments on commit 3de3b8b

Please sign in to comment.