diff --git a/.vscode/settings.json b/.vscode/settings.json
index 542b4d9..7b73167 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -4,6 +4,18 @@
],
"files.associations": {
"tuple": "cpp",
- "array": "cpp"
+ "array": "cpp",
+ "ostream": "cpp",
+ "type_traits": "cpp",
+ "optional": "cpp",
+ "*.tcc": "cpp",
+ "random": "cpp",
+ "fstream": "cpp",
+ "functional": "cpp",
+ "istream": "cpp",
+ "limits": "cpp",
+ "sstream": "cpp",
+ "streambuf": "cpp",
+ "complex": "cpp"
}
-}
\ No newline at end of file
+}
diff --git a/assets/images/coverage.svg b/assets/images/coverage.svg
index cd01045..a5804ea 100644
--- a/assets/images/coverage.svg
+++ b/assets/images/coverage.svg
@@ -9,13 +9,13 @@
-
+
coverage
coverage
- 36%
- 36%
+ 46%
+ 46%
diff --git a/poetry.lock b/poetry.lock
index e14aeb9..0d07d24 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -601,6 +601,17 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
typing = ["typing-extensions (>=4.8)"]
+[[package]]
+name = "find-libpython"
+version = "0.4.0"
+description = "Finds the libpython associated with your environment, wherever it may be hiding"
+optional = false
+python-versions = "*"
+files = [
+ {file = "find_libpython-0.4.0-py3-none-any.whl", hash = "sha256:034a4253bd57da3408aefc59aeac1650150f6c1f42e10fdd31615cf1df0842e3"},
+ {file = "find_libpython-0.4.0.tar.gz", hash = "sha256:46f9cdcd397ddb563b2d7592ded3796a41c1df5222443bd9d981721c906c03e6"},
+]
+
[[package]]
name = "identify"
version = "2.5.36"
@@ -1718,4 +1729,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
-content-hash = "32d677ef3e28ddb1500c57d1c74333018f904d6feaccd0114eadef94712b1e87"
+content-hash = "8629062450c9584ff2073b754933b8e71114770d9b9247e711ed9e4180dee02a"
diff --git a/pyproject.toml b/pyproject.toml
index bd2a49b..cc02fc3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "PyCXpress"
-version = "0.0.4"
+version = "0.0.5"
description = "PyCXpress is a high-performance hybrid framework that seamlessly integrates Python and C++ to harness the flexibility of Python and the speed of C++ for efficient and expressive computation, particularly in the realm of deep learning and numerical computing."
readme = "README.md"
authors = ["chaoqing "]
@@ -63,6 +63,7 @@ pytest-cov = "^5.0.0"
[tool.poetry.group.dev.dependencies]
debugpy = "^1.8.1"
+find-libpython = "^0.4.0"
[tool.black]
# https://github.com/psf/black
diff --git a/src/PyCXpress/__init__.py b/src/PyCXpress/__init__.py
index c48fb19..cfe7537 100644
--- a/src/PyCXpress/__init__.py
+++ b/src/PyCXpress/__init__.py
@@ -11,7 +11,6 @@
"version",
]
-import sys
from importlib import metadata as importlib_metadata
from pathlib import Path
@@ -32,8 +31,8 @@ def get_version() -> str:
ModelRuntimeType,
TensorMeta,
convert_to_spec_tuple,
- pycxpress_debugger,
)
+from .debugger import pycxpress_debugger
def get_include() -> str:
diff --git a/src/PyCXpress/core.py b/src/PyCXpress/core.py
index ef2f1ef..7ae79c2 100644
--- a/src/PyCXpress/core.py
+++ b/src/PyCXpress/core.py
@@ -1,78 +1,21 @@
# mypy: disable_error_code="type-arg,arg-type,union-attr,operator,assignment,misc"
-import logging
-
-logger = logging.getLogger(__name__)
-
-
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
-import os
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum, auto
+from itertools import chain
import numpy as np
from numpy.typing import DTypeLike
-# import tensorflow as tf
-
-
-def pycxpress_debugger(
- host: Optional[str] = None,
- port: Optional[int] = None,
- debugger: Optional[str] = None,
-):
- if debugger is None:
- return
-
- if host is None:
- host = os.environ.get("PYCXPRESS_DEBUGGER_HOST", "localhost")
-
- if port is None:
- port = os.environ.get("PYCXPRESS_DEBUGGER_PORT", 5678)
-
- if debugger.lower() == "pycharm":
- try:
- import pydevd_pycharm
-
- pydevd_pycharm.settrace(
- host, port=port, stdoutToServer=True, stderrToServer=True, suspend=True
- )
- except ConnectionRefusedError:
- logger.warning(
- "Can not connect to Python debug server (maybe not started?)"
- )
- logger.warning(
- "Use PYCXPRESS_DEBUGGER_TYPE=debugpy instead as Pycharm professional edition is needed for Python debug server feature."
- )
- elif debugger.lower() == "debugpy":
- import debugpy
-
- debugpy.listen((host, port))
- logger.info(f"debugpy listen on {host}:{port}, please use VSCode to attach")
- debugpy.wait_for_client()
- else:
- logger.warning(
- f"Only PYCXPRESS_DEBUGGER_TYPE=debugpy|pycharm supported but {debugger} provided"
- )
-
-
-def get_c_type(t: DTypeLike) -> Tuple[str, int]:
- dtype = np.dtype(t)
- relation = {
- np.dtype("bool"): "bool",
- np.dtype("int8"): "int8_t",
- np.dtype("int16"): "int16_t",
- np.dtype("int32"): "int32_t",
- np.dtype("int64"): "int64_t",
- np.dtype("uint8"): "uint8_t",
- np.dtype("uint16"): "uint16_t",
- np.dtype("uint32"): "uint32_t",
- np.dtype("uint64"): "uint64_t",
- np.dtype("float32"): "float",
- np.dtype("float64"): "double",
- }
- return relation.get(dtype, "char"), dtype.itemsize or 1
+from .interface import (
+ InputTensorProtocol,
+ ModelProtocol,
+ OutputTensorProtocol,
+ TensorBufferProtocol,
+)
+from .utils import get_c_type, logger
@dataclass
@@ -162,11 +105,11 @@ def __new__(
@staticmethod
def general_funcs(name: str, field_names: List[str]):
- def get_buffer_shape(self, name: str):
- buffer = getattr(self.__buffer_data__, name)
- return buffer.shape
+ def get_buffer_shape(self, name: str) -> Tuple[int]:
+ shape: Tuple[int] = getattr(self.__buffer_data__, name).shape
+ return shape
- def set_buffer_value(self, name: str, value):
+ def set_buffer_value(self, name: str, value: np.ndarray) -> None:
buffer = getattr(self.__buffer_data__, name)
buffer.data = value
@@ -209,9 +152,20 @@ def del_func(_):
return property(fget=get_func, fset=set_func, fdel=del_func, doc=field.doc)
-def convert_to_spec_tuple(fields: Iterable[TensorMeta]):
- return tuple(
- (v["name"], v["dtype"], v["buffer_size"]) for v in [v.to_dict() for v in fields]
+def convert_to_spec_tuple(
+ inputFields: Iterable[TensorMeta], outputFields: Iterable
+) -> Iterable[TensorBufferProtocol]:
+ return chain.from_iterable(
+ [
+ (
+ (v["name"], v["dtype"], v["buffer_size"], False)
+ for v in [v.to_dict() for v in inputFields]
+ ),
+ (
+ (v["name"], v["dtype"], v["buffer_size"], True)
+ for v in [v.to_dict() for v in outputFields]
+ ),
+ ]
)
diff --git a/src/PyCXpress/debugger.py b/src/PyCXpress/debugger.py
new file mode 100644
index 0000000..4216486
--- /dev/null
+++ b/src/PyCXpress/debugger.py
@@ -0,0 +1,54 @@
+from typing import Optional
+
+import os
+
+from .utils import logger
+
+_debugger_status_ = [False]
+
+
+def pycxpress_debugger(
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ debugger: Optional[str] = None,
+):
+ if _debugger_status_[0] == True:
+ return
+
+ if debugger is None:
+ return
+
+ if host is None:
+ host = os.environ.get("PYCXPRESS_DEBUGGER_HOST", "localhost")
+
+ if port is None:
+ port = int(os.environ.get("PYCXPRESS_DEBUGGER_PORT", "5678"))
+
+ if debugger.lower() == "pycharm":
+ try:
+ import pydevd_pycharm
+
+ pydevd_pycharm.settrace(
+ host, port=port, stdoutToServer=True, stderrToServer=True, suspend=True
+ )
+ _debugger_status_[0] = True
+
+ except ConnectionRefusedError:
+ logger.warning(
+ "Can not connect to Python debug server (maybe not started?)"
+ )
+ logger.warning(
+ "Use PYCXPRESS_DEBUGGER_TYPE=debugpy instead as Pycharm professional edition is needed for Python debug server feature."
+ )
+ elif debugger.lower() == "debugpy":
+ import debugpy
+
+ _debugger_status_[0] = True
+
+ debugpy.listen((host, port))
+ logger.info(f"debugpy listen on {host}:{port}, please use VSCode to attach")
+ debugpy.wait_for_client()
+ else:
+ logger.warning(
+ f"Only PYCXPRESS_DEBUGGER_TYPE=debugpy|pycharm supported but {debugger} provided"
+ )
diff --git a/src/PyCXpress/example/Makefile b/src/PyCXpress/example/Makefile
index 0fde39c..06aa78b 100644
--- a/src/PyCXpress/example/Makefile
+++ b/src/PyCXpress/example/Makefile
@@ -1,10 +1,12 @@
# Compiler
CC = c++
+PYTHONPATH=../../
+LD_PRELOAD:=$(shell find_libpython)
# Compiler flags
CFLAGS = -g -Wall -std=c++17 -fPIC
CFLAGS += $(shell python3-config --cflags --ldflags --embed)
-CFLAGS += $(shell PYTHONPATH=../../ python3 -m PyCXpress --includes)
+CFLAGS += $(shell PYTHONPATH=$(PYTHONPATH) python3 -m PyCXpress --includes)
# The build target executable
TARGET = example.out
@@ -36,7 +38,8 @@ clean:
rm -f $(OBJECTS) $(DEPENDS) $(TARGET)
run: $(TARGET)
- PYTHONPATH=../src/ ./$(TARGET)
+ $(info [reminding]: use LD_PRELOAD to load libpython before numpy import)
+ PYTHONPATH=$(PYTHONPATH) LD_PRELOAD=$(LD_PRELOAD) ./$(TARGET)
memcheck: $(TARGET)
valgrind --leak-check=full --show-leak-kinds=all --track-origins=yes -s ./$(TARGET)
diff --git a/src/PyCXpress/example/main.cpp b/src/PyCXpress/example/main.cpp
index a336c57..82a9fd3 100644
--- a/src/PyCXpress/example/main.cpp
+++ b/src/PyCXpress/example/main.cpp
@@ -8,7 +8,7 @@
namespace pcx = PyCXpress;
-void show_test(pcx::PythonInterpreter &python) {
+void show_test(pcx::Model &python) {
std::vector data(12);
for (size_t i = 0; i < 12; i++) {
data[i] = i;
@@ -41,12 +41,19 @@ void show_test(pcx::PythonInterpreter &python) {
int main(int argc, char *argv[]) {
auto &python = utils::Singleton::Instance();
+ auto &model0 = python.create_model("model.Model");
+ auto &model1 = python.create_model("model.Model", "odd");
int loop_times = 3;
+
while (loop_times--) {
std::cout << "looping " << loop_times << std::endl;
- show_test(python);
+ if (loop_times % 2 == 0) {
+ show_test(model0);
+ } else {
+ show_test(model1);
+ }
}
return 0;
-}
\ No newline at end of file
+}
diff --git a/src/PyCXpress/example/model.py b/src/PyCXpress/example/model.py
index e6f48b5..79b6c4e 100644
--- a/src/PyCXpress/example/model.py
+++ b/src/PyCXpress/example/model.py
@@ -1,4 +1,4 @@
-# mypy: disable_error_code="type-arg,attr-defined"
+# mypy: disable_error_code="arg-type,type-arg,attr-defined"
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -69,30 +69,36 @@ class OutputDataSet(
pass
-def init():
- return (
- InputDataSet(),
- OutputDataSet(),
- tuple(
- (
- *convert_to_spec_tuple(InputFields.values()),
- *convert_to_spec_tuple(OutputFields.values()),
- )
- ),
- tuple(OutputFields.keys()),
- )
+class Model:
+ def __init__(self):
+ self.input = None
+ self.output = None
+ def initialize(self):
+ self.input, self.output = InputDataSet(), OutputDataSet()
-def model(input: InputDataSet, output: OutputDataSet):
- with nullcontext():
- # print(input.data_to_be_reshaped)
- # print(input.new_2d_shape)
- output.output_a = input.data_to_be_reshaped.reshape(input.new_2d_shape)
- # print(output.output_a)
+ return (
+ self.input,
+ self.output,
+ tuple(convert_to_spec_tuple(InputFields.values(), OutputFields.values())),
+ )
+
+ def run(self):
+ self.model(self.input, self.output)
+
+ @staticmethod
+ def model(input: InputDataSet, output: OutputDataSet):
+ with nullcontext():
+ # print(input.data_to_be_reshaped)
+ # print(input.new_2d_shape)
+ output.output_a = input.data_to_be_reshaped.reshape(input.new_2d_shape)
+ # print(output.output_a)
def main():
- input_data, output_data, spec, _ = init()
+
+ model = Model()
+ input_data, output_data, spec = model.initialize()
print(spec)
input_data.set_buffer_value("data_to_be_reshaped", np.arange(12, dtype=np.float_))
@@ -101,7 +107,7 @@ def main():
print(input_data.new_2d_shape)
output_data.set_buffer_value("output_a", np.arange(12) * 0)
- model(input_data, output_data)
+ model.run()
print(output_data.output_a)
print(output_data.get_buffer_shape("output_a"))
diff --git a/src/PyCXpress/include/PyCXpress/core.hpp b/src/PyCXpress/include/PyCXpress/core.hpp
index 52ca3a5..531d5b3 100644
--- a/src/PyCXpress/include/PyCXpress/core.hpp
+++ b/src/PyCXpress/include/PyCXpress/core.hpp
@@ -28,13 +28,17 @@ class PYCXPRESS_EXPORT Buffer {
typedef unsigned char Bytes;
template
- static py::array __to_array(const std::vector &shape, void *data) {
+ static py::array __to_array(const std::vector &shape, void *data,
+ size_t max_size) {
std::vector stride(shape.size());
*stride.rbegin() = sizeof(T);
auto ps = shape.rbegin();
for (auto pt = stride.rbegin() + 1; pt != stride.rend(); pt++, ps++) {
*pt = *(pt - 1) * (*ps);
}
+ if (max_size < stride.front() * shape.front()) {
+ throw std::runtime_error("Buffer size is too small");
+ }
return py::array_t{shape, std::move(stride), (T *)(data),
py::none()};
}
@@ -99,49 +103,68 @@ class PYCXPRESS_EXPORT Buffer {
}
void *set(const std::vector &shape) {
- m_array = m_converter(shape, m_data);
+ m_array = m_converter(shape, m_data, m_size);
return m_data;
}
- py::array &get() { return m_array; }
+ inline size_t itemsize() const { return m_size / m_length; }
+
+ py::array &array() { return m_array; }
- void reset() { m_array = m_converter({m_length}, m_data); }
+ void reset() { m_array = m_converter({m_length}, m_data, m_size); }
private:
size_t m_size;
size_t m_length;
Bytes *m_data;
py::array m_array;
- py::array (*m_converter)(const std::vector &, void *);
+ py::array (*m_converter)(const std::vector &, void *, size_t);
};
-class PYCXPRESS_EXPORT PythonInterpreter {
+class PYCXPRESS_EXPORT Model {
public:
- explicit PythonInterpreter(bool init_signal_handlers = true, int argc = 0,
- const char *const *argv = nullptr,
- bool add_program_dir_to_path = true) {
- initialize(init_signal_handlers, argc, argv, add_program_dir_to_path);
+ explicit Model(const std::string &path) {
+ std::vector module_name(path.data(), path.data() + path.length());
+ if (module_name.empty() || module_name.back() == '.') {
+ throw std::runtime_error("No model class provided");
+ }
+ auto iter = module_name.rbegin();
+ while (iter + 1 != module_name.rend() && '.' != *iter) {
+ ++iter;
+ }
+ if (iter + 1 == module_name.rend()) {
+ throw std::runtime_error("not module provided");
+ }
+ auto ith = std::distance(iter, module_name.rend());
+ module_name[ith - 1] = 0;
+ module_name.push_back('\0');
+ initialize(module_name.data(), module_name.data() + ith);
}
- PythonInterpreter(const PythonInterpreter &) = delete;
- PythonInterpreter(PythonInterpreter &&other) noexcept {
- other.is_valid = false;
+ Model(const Model &) = delete;
+ Model(Model &&) = delete;
+ Model &operator=(const Model &) = delete;
+ Model &operator=(Model &&) = delete;
+
+ ~Model() {
+ m_buffers.clear();
+ m_output_buffer_sizes.clear();
+ m_model = py::none();
+ m_input = py::none();
+ m_output = py::none();
}
- PythonInterpreter &operator=(const PythonInterpreter &) = delete;
- PythonInterpreter &operator=(PythonInterpreter &&) = delete;
- ~PythonInterpreter() { finalize(); }
void *set_buffer(const std::string &name,
const std::vector &shape) {
auto &buf = m_buffers[name];
void *p = buf.set(shape);
- m_py_input.attr("set_buffer_value")(name, buf.get());
+ m_input.attr("set_buffer_value")(name, buf.array());
return p;
}
std::pair> get_buffer(const std::string &name) {
- auto &array = m_buffers[name].get();
+ auto &array = m_buffers[name].array();
auto pShape = m_output_buffer_sizes.find(name);
if (pShape == m_output_buffer_sizes.end()) {
return std::make_pair(
@@ -154,11 +177,14 @@ class PYCXPRESS_EXPORT PythonInterpreter {
}
void run() {
- p_pkg->attr("model")(m_py_input, m_py_output);
+ m_model.attr("run")();
+
+ auto get_buffer_shape = m_output.attr("get_buffer_shape");
+
for (auto &kv : m_output_buffer_sizes) {
kv.second.clear();
- py::tuple shape = m_py_output.attr("get_buffer_shape")(kv.first);
+ py::tuple shape = get_buffer_shape(kv.first);
for (auto &d : shape) {
kv.second.push_back(d.cast());
@@ -166,45 +192,64 @@ class PYCXPRESS_EXPORT PythonInterpreter {
}
}
- void show_buffer(const std::string &name) {
- auto &buf = m_buffers[name];
- p_pkg->attr("show")(buf.get());
- }
-
private:
- void initialize(bool init_signal_handlers, int argc,
- const char *const *argv, bool add_program_dir_to_path) {
- py::initialize_interpreter(true, 0, nullptr, true);
-
- p_pkg = std::make_unique(py::module_::import("model"));
- py::print(p_pkg->attr("__file__"));
+ void initialize(const char *module, const char *name) {
+ m_model = py::module_::import(module).attr(name)();
- py::tuple spec, output_fields;
- std::tie(m_py_input, m_py_output, spec, output_fields) =
- p_pkg->attr("init")()
- .cast<
- std::tuple>();
+ py::tuple spec;
+ std::tie(m_input, m_output, spec) =
+ m_model.attr("initialize")()
+ .cast>();
+ auto set_buffer_value = m_output.attr("set_buffer_value");
for (auto d = spec.begin(); d != spec.end(); d++) {
- auto meta = d->cast();
- m_buffers.insert(std::make_pair(
- meta[0].cast(),
- Buffer{meta[2].cast(), meta[1].cast()}));
+ auto meta = d->cast();
+ const auto name = meta[0].cast();
+ auto buf =
+ m_buffers.insert({name, Buffer{meta[2].cast(),
+ meta[1].cast()}});
+ if (meta[3].cast()) {
+ m_output_buffer_sizes[name] = {};
+ auto &buffer = buf.first->second;
+ buffer.reset();
+ set_buffer_value(name, buffer.array());
+ }
}
+ }
- for (auto d = output_fields.begin(); d != output_fields.end(); d++) {
- const auto name = d->cast();
- m_output_buffer_sizes[name] = {};
- auto &buf = m_buffers[name];
- buf.reset();
- m_py_output.attr("set_buffer_value")(name, buf.get());
- }
+
+ std::map m_buffers;
+ std::map> m_output_buffer_sizes;
+
+ py::object m_model;
+ py::object m_input;
+ py::object m_output;
+};
+
+class PYCXPRESS_EXPORT PythonInterpreter {
+public:
+ explicit PythonInterpreter() {}
+
+ PythonInterpreter(const PythonInterpreter &) = delete;
+ PythonInterpreter(PythonInterpreter &&other) noexcept {
+ other.is_valid = false;
}
+ PythonInterpreter &operator=(const PythonInterpreter &) = delete;
+ PythonInterpreter &operator=(PythonInterpreter &&) = delete;
+ ~PythonInterpreter() { finalize(); }
+
+ void initialize(bool init_signal_handlers = true, int argc = 0,
+ const char *const *argv = nullptr,
+ bool add_program_dir_to_path = true) {
+ // TODO: maybe explicitly `dlopen("/path/to/libpython3.x.so", RTLD_NOW |
+ // RTLD_GLOBAL)` to avoid numpy import error
+ py::initialize_interpreter(init_signal_handlers, argc, argv,
+ add_program_dir_to_path);
+ is_valid = true;
+ }
void finalize() {
- p_pkg = nullptr;
- m_py_input = py::none();
- m_py_output = py::none();
+ m_models.clear();
if (is_valid) {
py::finalize_interpreter();
@@ -212,14 +257,23 @@ class PYCXPRESS_EXPORT PythonInterpreter {
}
}
- bool is_valid = true;
- std::unique_ptr p_pkg;
-
- std::map m_buffers;
- std::map> m_output_buffer_sizes;
+ Model &create_model(const std::string &path,
+ const std::string &name = "default") {
+ if (!is_valid) {
+ initialize();
+ }
+ if (m_models.find(name) == m_models.end()) {
+ m_models[name] = std::make_unique(path);
+ } else {
+ std::cerr << "Warning: Model with name " << name
+ << " already exists" << std::endl;
+ }
+ return *m_models[name].get();
+ }
- py::object m_py_input;
- py::object m_py_output;
+private:
+ bool is_valid = false;
+ std::map> m_models;
};
}; // namespace PyCXpress
diff --git a/src/PyCXpress/interface.py b/src/PyCXpress/interface.py
new file mode 100644
index 0000000..634a098
--- /dev/null
+++ b/src/PyCXpress/interface.py
@@ -0,0 +1,48 @@
+# mypy: disable_error_code="type-arg"
+from typing import List, Protocol, Tuple
+
+from abc import abstractmethod
+
+import numpy as np
+
+
+class InputTensorProtocol(Protocol):
+ def __init__(self, *args, **kwargs): ...
+
+ @abstractmethod
+ def set_buffer_value(self, name: str, value: np.ndarray) -> None: ...
+
+
+class OutputTensorProtocol(Protocol):
+ def __init__(self, *args, **kwargs): ...
+
+ @abstractmethod
+ def get_buffer_shape(self, name: str) -> Tuple[int]: ...
+
+ @abstractmethod
+ def set_buffer_value(self, name: str, value: np.ndarray) -> None: ...
+
+
+TensorBufferProtocol = Tuple[
+ str, # name
+ str, # dtype
+ int, # buffer size
+ bool, # is_output
+]
+
+
+class ModelProtocol(Protocol):
+ def __init__(self, *args, **kwargs): ...
+
+ @abstractmethod
+ def initialize(
+ self,
+ ) -> Tuple[
+ InputTensorProtocol, OutputTensorProtocol, Tuple[TensorBufferProtocol]
+ ]: ...
+
+ @abstractmethod
+ def run(self) -> None: ...
+
+ @abstractmethod
+ def reset(self) -> None: ...
diff --git a/src/PyCXpress/utils.py b/src/PyCXpress/utils.py
new file mode 100644
index 0000000..590a880
--- /dev/null
+++ b/src/PyCXpress/utils.py
@@ -0,0 +1,26 @@
+from typing import Tuple
+
+import logging
+
+import numpy as np
+from numpy.typing import DTypeLike
+
+logger = logging.getLogger("PyCXpress")
+
+
+def get_c_type(t: DTypeLike) -> Tuple[str, int]:
+ dtype = np.dtype(t)
+ relation = {
+ np.dtype("bool"): "bool",
+ np.dtype("int8"): "int8_t",
+ np.dtype("int16"): "int16_t",
+ np.dtype("int32"): "int32_t",
+ np.dtype("int64"): "int64_t",
+ np.dtype("uint8"): "uint8_t",
+ np.dtype("uint16"): "uint16_t",
+ np.dtype("uint32"): "uint32_t",
+ np.dtype("uint64"): "uint64_t",
+ np.dtype("float32"): "float",
+ np.dtype("float64"): "double",
+ }
+ return relation.get(dtype, "char"), dtype.itemsize or 1