Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit refcount bindings #19

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/python/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ LIBEXT = .so
endif

check: FORCE
PYTHONPATH=$(srcdir)/ LIBCCONFIGSPACE_SO_=$(top_builddir)/src/.libs/libcconfigspace$(LIBEXT) python3 -m unittest discover -s $(srcdir)/test/
PYTHONPATH=$(srcdir)/ LIBCCONFIGSPACE_SO_=$(abs_top_builddir)/src/.libs/libcconfigspace$(LIBEXT) python3 -m unittest discover -s $(srcdir)/test/

FORCE:
143 changes: 64 additions & 79 deletions bindings/python/cconfigspace/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ctypes as ct
import json
import pickle
import sys
import traceback
from . import libcconfigspace
Expand Down Expand Up @@ -106,8 +106,8 @@ def __repr__(self):

class CEnumeration(ct.c_int, metaclass=CEnumerationType):
_members_ = {}
def __init__(self, value):
ct.c_int.__init__(self, value)
def __init__(*args):
ct.c_int.__init__(*args)

@property
def name(self):
Expand Down Expand Up @@ -395,10 +395,9 @@ class DeserializeOption(CEnumeration):
_members_ = [
('END', 0),
'HANDLE_MAP',
'VECTOR',
'DATA',
'VECTOR_CALLBACK',
'NON_BLOCKING',
'CALLBACK'
'DATA_CALLBACK'
]

def _ccs_get_function(method, argtypes = [], restype = Result):
Expand All @@ -422,7 +421,8 @@ def _ccs_get_function(method, argtypes = [], restype = Result):
ccs_object_get_user_data = _ccs_get_function("ccs_object_get_user_data", [ccs_object, ct.POINTER(ct.c_void_p)])
ccs_object_serialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.c_size_t), ct.c_void_p)
ccs_object_set_serialize_callback = _ccs_get_function("ccs_object_set_serialize_callback", [ccs_object, ccs_object_serialize_callback_type, ct.c_void_p])
ccs_object_deserialize_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.c_void_p)
ccs_object_deserialize_data_callback_type = ct.CFUNCTYPE(Result, ccs_object, ct.c_size_t, ct.c_void_p, ct.POINTER(ct.py_object))
ccs_object_deserialize_vector_callback_type = ct.CFUNCTYPE(Result, ObjectType, ct.c_char_p, ct.c_void_p, ct.POINTER(ct.c_void_p), ct.POINTER(ct.py_object))
# Variadic methods
ccs_object_serialize = getattr(libcconfigspace, "ccs_object_serialize")
ccs_object_serialize.argtypes = ccs_object, SerializeFormat, SerializeOperation,
Expand Down Expand Up @@ -520,10 +520,10 @@ def from_handle(cls, h, retain = True):
auto_release = True
return cls._from_handle(h, retain, auto_release)

def set_destroy_callback(self, callback, user_data = None):
_set_destroy_callback(self.handle, callback, user_data = user_data)
def set_destroy_callback(self, callback):
_set_destroy_callback(self.handle, callback)

def serialize(self, format = 'binary', path = None, file_descriptor = None, callback = None, callback_data = None):
def serialize(self, format = 'binary', path = None, file_descriptor = None, callback = None):
if format != 'binary':
raise Error(Result(Result.ERROR_INVALID_VALUE))
if path and file_descriptor:
Expand All @@ -532,7 +532,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call
if callback:
cb_wrapper = _get_serialize_callback_wrapper(callback)
cb_wrapper_func = ccs_object_serialize_callback_type(cb_wrapper)
options = [SerializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options
options = [SerializeOption.CALLBACK, cb_wrapper_func, ct.py_object()] + options
elif _default_user_data_serializer:
options = [SerializeOption.CALLBACK, _default_user_data_serializer, ct.py_object()] + options
if path:
Expand All @@ -556,7 +556,7 @@ def serialize(self, format = 'binary', path = None, file_descriptor = None, call
return v

@classmethod
def deserialize(cls, format = 'binary', handle_map = None, vector = None, data = None, path = None, buffer = None, file_descriptor = None, callback = None, callback_data = None):
def deserialize(cls, format = 'binary', handle_map = None, vector_callback = None, path = None, buffer = None, file_descriptor = None, callback = None):
if format != 'binary':
raise Error(Result(Result.ERROR_INVALID_VALUE))
mode_count = 0;
Expand All @@ -572,16 +572,16 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data =
options = [DeserializeOption.END]
if handle_map:
options = [DeserializeOption.HANDLE_MAP, handle_map.handle] + options
if vector:
options = [DeserializeOption.VECTOR, ct.byref(vector)] + options
if data:
options = [DeserializeOption.DATA, ct.py_object(data)] + options
if vector_callback:
vector_cb_wrapper = _get_deserialize_vector_callback_wrapper(vector_callback)
vector_cb_wrapper_func = ccs_object_deserialize_vector_callback_type(vector_cb_wrapper)
options = [DeserializeOption.VECTOR_CALLBACK, vector_cb_wrapper_func, ct.py_object()] + options
if callback:
cb_wrapper = _get_deserialize_callback_wrapper(callback)
cb_wrapper_func = ccs_object_deserialize_callback_type(cb_wrapper)
options = [DeserializeOption.CALLBACK, cb_wrapper_func, ct.py_object(callback_data)] + options
cb_wrapper = _get_deserialize_data_callback_wrapper(callback)
cb_wrapper_func = ccs_object_deserialize_data_callback_type(cb_wrapper)
options = [DeserializeOption.DATA_CALLBACK, cb_wrapper_func, ct.py_object()] + options
elif _default_user_data_deserializer:
options = [DeserializeOption.CALLBACK, _default_user_data_deserializer, ct.py_object()] + options
options = [DeserializeOption.DATA_CALLBACK, _default_user_data_deserializer, ct.py_object()] + options
if buffer:
s = ct.c_size_t(ct.sizeof(buffer))
res = ccs_object_deserialize(ct.byref(o), SerializeFormat.BINARY, SerializeOperation.MEMORY, s, buffer, *options)
Expand All @@ -602,73 +602,58 @@ def deserialize(cls, format = 'binary', handle_map = None, vector = None, data =
def _get_serialize_callback_wrapper(callback):
def serialize_callback_wrapper(obj, serialize_data_size, serialize_data, serialize_data_size_ret, cb_data):
try:
p_sd = ct.cast(serialize_data, ct.c_void_p)
p_sdsz = ct.cast(serialize_data_size_ret, ct.POINTER(ct.c_size_t))
cb_data = ct.cast(cb_data, ct.c_void_p)
if cb_data:
cb_data = ct.cast(cb_data, ct.py_object).value
else:
cb_data = None
serialized = callback(Object.from_handle(ccs_object(obj)), cb_data, True if serialize_data_size == 0 else False)
if p_sd and serialize_data_size < ct.sizeof(serialized):
serialized = callback(Object.from_handle(obj).user_data)
state = ct.create_string_buffer(serialized, len(serialized))
if serialize_data and serialize_data_size < ct.sizeof(state):
raise Error(Result(Result.ERROR_INVALID_VALUE))
if p_sd:
ct.memmove(p_sd, ct.byref(serialized), ct.sizeof(serialized))
if p_sdsz:
p_sdsz[0] = ct.sizeof(serialized)
if serialize_data:
ct.memmove(serialize_data, ct.byref(state), ct.sizeof(state))
if serialize_data_size_ret:
serialize_data_size_ret[0] = ct.sizeof(state)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
return serialize_callback_wrapper

def _get_deserialize_callback_wrapper(callback):
def deserialize_callback_wrapper(obj, serialize_data_size, serialize_data, cb_data):
def _get_deserialize_data_callback_wrapper(callback):
def deserialize_data_callback_wrapper(obj, serialize_data_size, p_serialize_data, cb_data):
try:
p_sd = ct.cast(serialize_data, ct.c_void_p)
cb_data = ct.cast(cb_data, ct.c_void_p)
if cb_data:
cb_data = ct.cast(cb_data, ct.py_object).value
else:
cb_data = None
if p_sd:
serialized = ct.cast(p_sd, ct.POINTER(ct.c_byte * serialize_data_size))
else:
serialized = None
callback(Object.from_handle(ccs_object(obj)), serialized, cb_data)
user_data = callback(ct.string_at(p_serialize_data, serialize_data_size))
Object.from_handle(ccs_object(obj)).user_data = user_data
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
return deserialize_callback_wrapper
return deserialize_data_callback_wrapper

def _json_user_data_serializer(obj, data, size):
string = json.dumps(obj.user_data).encode("ascii")
return ct.create_string_buffer(string)

def _json_user_data_deserializer(obj, serialized, data):
serialized = ct.cast(serialized, ct.c_char_p)
obj.user_data = json.loads(serialized.value)
def _get_deserialize_vector_callback_wrapper(callback):
def deserialize_vector_callback_wrapper(obj_type, name, callback_user_data, vector_ret, data_ret):
try:
(vector, data) = callback(obj_type.value, name.decode())
c_vector = ct.py_object(vector)
c_data = ct.py_object(data)
vector_ret[0] = ct.cast(ct.byref(vector), ct.c_void_p)
data_ret[0] = c_data
ct.pythonapi.Py_IncRef(c_vector)
ct.pythonapi.Py_IncRef(c_data)
return Result.SUCCESS
except Exception as e:
return Error.set_error(e)
return deserialize_vector_callback_wrapper

_json_user_data_serializer_wrap = _get_serialize_callback_wrapper(_json_user_data_serializer)
_json_user_data_serializer_func = ccs_object_serialize_callback_type(_json_user_data_serializer_wrap)
def _pickle_user_data_serializer(user_data):
return pickle.dumps(user_data)

_json_user_data_deserializer_wrap = _get_deserialize_callback_wrapper(_json_user_data_deserializer)
_json_user_data_deserializer_func = ccs_object_deserialize_callback_type(_json_user_data_deserializer_wrap)
def _pickle_user_data_deserializer(serialized):
return pickle.loads(serialized)

_default_user_data_serializer = _json_user_data_serializer_func
_default_user_data_deserializer = _json_user_data_deserializer_func
_pickle_user_data_serializer_wrap = _get_serialize_callback_wrapper(_pickle_user_data_serializer)
_pickle_user_data_serializer_func = ccs_object_serialize_callback_type(_pickle_user_data_serializer_wrap)

# Delete wrappers are responsible for deregistering the object data_store
def _register_vector(handle, vector_data):
value = handle.value
if value in _data_store:
raise Error(Result(Result.ERROR_INVALID_VALUE))
_data_store[value] = dict.fromkeys(['callbacks', 'user_data', 'serialize_calback', 'strings'])
_data_store[value]['callbacks'] = vector_data
_data_store[value]['strings'] = []
_pickle_user_data_deserializer_wrap = _get_deserialize_data_callback_wrapper(_pickle_user_data_deserializer)
_pickle_user_data_deserializer_func = ccs_object_deserialize_data_callback_type(_pickle_user_data_deserializer_wrap)

def _unregister_vector(handle):
value = handle.value
del _data_store[value]
_default_user_data_serializer = _pickle_user_data_serializer_func
_default_user_data_deserializer = _pickle_user_data_deserializer_func

# If objects don't have a user-defined del operation, then the first time a
# data needs to be registered a destruction callback is attached.
Expand Down Expand Up @@ -707,30 +692,30 @@ def _register_serialize_callback(handle, callback_data):
_register_destroy_callback(handle)
_data_store[value]['serialize_calback'] = callback_data

def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, callback = None, callback_data = None):
return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, callback = callback, callback_data = callback_data)
def deserialize(format = 'binary', handle_map = None, path = None, buffer = None, file_descriptor = None, vector_callback = None, vector_callback_data = None, callback = None, callback_data = None):
return Object.deserialize(format = format, handle_map = handle_map, path = path, buffer = buffer, file_descriptor = file_descriptor, vector_callback = vector_callback, callback = callback)

def _set_destroy_callback(handle, callback, user_data = None):
def _set_destroy_callback(handle, callback):
if callback is None:
raise Error(Result(Result.ERROR_INVALID_VALUE))
def cb_wrapper(obj, data):
callback(Object.from_handle(obj), data)
cb_wrapper_func = ccs_object_destroy_callback_type(cb_wrapper)
res = ccs_object_set_destroy_callback(handle, cb_wrapper_func, user_data)
res = ccs_object_set_destroy_callback(handle, cb_wrapper_func, None)
Error.check(res)
_register_callback(handle, [callback, cb_wrapper, cb_wrapper_func, user_data])
_register_callback(handle, [callback, cb_wrapper, cb_wrapper_func])

def _set_serialize_callback(handle, callback, user_data = None):
def _set_serialize_callback(handle, callback):
if callback is None:
res = ccs_object_set_serialize_callback(handle, None, None)
Error.check(res)
_register_serialize_callback(handle, None)
else:
cb_wrapper = _get_serialize_callback_wrapper(callback)
cb_wrapper_func = ccs_object_serialize_callback_type(cb_wrapper)
res = ccs_object_set_serialize_callback(handle, cb_wrapper_func, user_data)
res = ccs_object_set_serialize_callback(handle, cb_wrapper_func, None)
Error.check(res)
_register_serialize_callback(handle, [callback, cb_wrapper, cb_wrapper_func, user_data])
_register_serialize_callback(handle, [callback, cb_wrapper, cb_wrapper_func])

_ccs_id = 0
def _ccs_get_id():
Expand Down
7 changes: 4 additions & 3 deletions bindings/python/cconfigspace/configuration_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,20 @@

class ConfigurationSpace(Context):
def __init__(self, handle = None, retain = False, auto_release = True,
name = "", parameters = None, conditions = None, forbidden_clauses = None, feature_space = None, rng = None):
name = "", parameters = None, conditions = None, forbidden_clauses = None, feature_space = None, rng = None, binding = None):
if handle is None:
count = len(parameters)

ctx_params = parameters
if feature_space is not None:
ctx_params = ctx_params + list(feature_space.parameters)
ctx = dict(zip([x.name for x in ctx_params], ctx_params))
extra = (ctx, binding)

if forbidden_clauses is not None:
numfc = len(forbidden_clauses)
if numfc > 0:
forbidden_clauses = [ parser.parse(fc, extra = ctx) if isinstance(fc, str) else fc for fc in forbidden_clauses ]
forbidden_clauses = [ parser.parse(fc, extra = extra) if isinstance(fc, str) else fc for fc in forbidden_clauses ]
fcv = (ccs_expression * numfc)(*[x.handle.value for x in forbidden_clauses])
else:
fcv = None
Expand All @@ -45,7 +46,7 @@ def __init__(self, handle = None, retain = False, auto_release = True,
if conditions is not None:
indexdict = dict(reversed(ele) for ele in enumerate(parameters))
cv = (ccs_expression * count)()
conditions = dict( (k, parser.parse(v, extra = ctx) if isinstance(v, str) else v) for (k, v) in conditions.items() )
conditions = dict( (k, parser.parse(v, extra = extra) if isinstance(v, str) else v) for (k, v) in conditions.items() )
for (k, v) in conditions.items():
if isinstance(k, Parameter):
cv[indexdict[k]] = v.handle.value
Expand Down
Loading
Loading