diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 7f8742f..63ae3dd 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -7,36 +7,36 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.8", "3.9"]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- - uses: actions/checkout@v4
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5.1.0
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Install poetry
- run: make poetry-download
-
- - name: Set up cache
- uses: actions/cache@v4.0.2
- with:
- path: .venv
- key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }}
- - name: Install dependencies
- run: |
- poetry config virtualenvs.in-project true
- poetry install
-
- - name: Run style checks
- run: |
- make check-codestyle
-
- - name: Run tests
- run: |
- make test
-
- - name: Run safety checks
- run: |
- make check-safety
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5.1.0
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install poetry
+ run: make poetry-download
+
+ - name: Set up cache
+ uses: actions/cache@v4.0.2
+ with:
+ path: .venv
+ key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }}
+ - name: Install dependencies
+ run: |
+ poetry config virtualenvs.in-project true
+ poetry install
+
+ - name: Run style checks
+ run: |
+ make check-codestyle
+
+ - name: Run tests
+ run: |
+ make test
+
+ - name: Run safety checks
+ run: |
+ make check-safety
diff --git a/.gitignore b/.gitignore
index f089344..4f2ee93 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,3 +34,4 @@
# Python
dist/
__pycache__/
+.coverage
diff --git a/Makefile b/Makefile
index a7fb1eb..f378543 100644
--- a/Makefile
+++ b/Makefile
@@ -53,7 +53,7 @@ pre-commit-install:
#* Formatters
.PHONY: codestyle
codestyle:
- poetry run pyupgrade --exit-zero-even-if-changed --py38-plus **/*.py
+ poetry run pyupgrade --exit-zero-even-if-changed --py38-plus src/**/*.py
poetry run isort --settings-path pyproject.toml ./
poetry run black --config pyproject.toml ./
@@ -83,8 +83,8 @@ mypy:
.PHONY: check-safety
check-safety:
poetry check
- poetry run safety check --full-report
- poetry run bandit -ll --recursive PyCXpress tests
+ -poetry run safety check --full-report
+ poetry run bandit -ll --recursive src/PyCXpress tests
.PHONY: lint
lint: test check-codestyle mypy check-safety
diff --git a/assets/images/coverage.svg b/assets/images/coverage.svg
index 0644a48..53e7fcb 100644
--- a/assets/images/coverage.svg
+++ b/assets/images/coverage.svg
@@ -9,13 +9,13 @@
-
+
coverage
coverage
- 26%
- 26%
+ 40%
+ 40%
diff --git a/poetry.lock b/poetry.lock
index 30d7fa6..981f7f3 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -112,6 +112,31 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
+[[package]]
+name = "build"
+version = "1.2.1"
+description = "A simple, correct Python build frontend"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "build-1.2.1-py3-none-any.whl", hash = "sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4"},
+ {file = "build-1.2.1.tar.gz", hash = "sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "os_name == \"nt\""}
+importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""}
+packaging = ">=19.1"
+pyproject_hooks = "*"
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"]
+test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0)", "setuptools (>=56.0.0)", "setuptools (>=56.0.0)", "setuptools (>=67.8.0)", "wheel (>=0.36.0)"]
+typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", "tomli", "typing-extensions (>=3.7.4.3)"]
+uv = ["uv (>=0.1.18)"]
+virtualenv = ["virtualenv (>=20.0.35)"]
+
[[package]]
name = "certifi"
version = "2024.2.2"
@@ -570,6 +595,25 @@ files = [
{file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"},
]
+[[package]]
+name = "importlib-metadata"
+version = "7.1.0"
+description = "Read metadata from Python packages"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"},
+ {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"},
+]
+
+[package.dependencies]
+zipp = ">=0.5"
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+perf = ["ipython"]
+testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"]
+
[[package]]
name = "iniconfig"
version = "2.0.0"
@@ -1145,6 +1189,20 @@ typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""
spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
+[[package]]
+name = "pyproject-hooks"
+version = "1.0.0"
+description = "Wrappers to call pyproject.toml-based build backend hooks."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pyproject_hooks-1.0.0-py3-none-any.whl", hash = "sha256:283c11acd6b928d2f6a7c73fa0d01cb2bdc5f07c57a2eeb6e83d5e56b97976f8"},
+ {file = "pyproject_hooks-1.0.0.tar.gz", hash = "sha256:f271b298b97f5955d53fb12b72c1fb1948c22c1a6b70b315c54cedaca0264ef5"},
+]
+
+[package.dependencies]
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+
[[package]]
name = "pytest"
version = "8.1.1"
@@ -1592,13 +1650,13 @@ zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "virtualenv"
-version = "20.25.1"
+version = "20.25.2"
description = "Virtual Python Environment builder"
optional = false
python-versions = ">=3.7"
files = [
- {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"},
- {file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"},
+ {file = "virtualenv-20.25.2-py3-none-any.whl", hash = "sha256:6e1281a57849c8a54da89ba82e5eb7c8937b9d057ff01aaf5bc9afaa3552e90f"},
+ {file = "virtualenv-20.25.2.tar.gz", hash = "sha256:fa7edb8428620518010928242ec17aa7132ae435319c29c1651d1cf4c4173aad"},
]
[package.dependencies]
@@ -1607,10 +1665,25 @@ filelock = ">=3.12.2,<4"
platformdirs = ">=3.9.1,<5"
[package.extras]
-docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
+docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"]
+[[package]]
+name = "zipp"
+version = "3.18.1"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"},
+ {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"},
+]
+
+[package.extras]
+docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
+
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
-content-hash = "8193a02a1aa5dee0c7e18950f25b6cfcc9ed8b84f70b511a236c220d242feb40"
+content-hash = "17346d0057aa8cd3c4ddd4f4ae5c53402353e2038de2e110ec61eca54db7ffee"
diff --git a/pyproject.toml b/pyproject.toml
index d24d6ef..1d1dcd3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,7 @@ numpy = "^1.22"
bandit = "^1.7.1"
black = {version = "^24.4", allow-prereleases = true}
darglint = "^1.8.1"
+build = "^1.2.1"
isort = {extras = ["colors"], version = "^5.10.1"}
mypy = "^1.0"
mypy-extensions = "^0.4.3"
@@ -121,6 +122,8 @@ warn_unreachable = true
warn_unused_configs = true
warn_unused_ignores = true
+disable_error_code = ["no-untyped-def"]
+
[tool.pytest.ini_options]
# https://docs.pytest.org/en/6.2.x/customize.html#pyproject-toml
@@ -147,4 +150,4 @@ branch = true
[coverage.report]
fail_under = 50
-show_missing = true
\ No newline at end of file
+show_missing = true
diff --git a/src/PyCXpress/__init__.py b/src/PyCXpress/__init__.py
index 131685a..140adf7 100644
--- a/src/PyCXpress/__init__.py
+++ b/src/PyCXpress/__init__.py
@@ -1,4 +1,3 @@
-# type: ignore[attr-defined]
"""PyCXpress is a high-performance hybrid framework that seamlessly integrates Python and C++ to harness the flexibility of Python and the speed of C++ for efficient and expressive computation, particularly in the realm of deep learning and numerical computing."""
__all__ = [
@@ -12,13 +11,9 @@
]
import sys
+from importlib import metadata as importlib_metadata
from pathlib import Path
-if sys.version_info >= (3, 8):
- from importlib import metadata as importlib_metadata
-else:
- import importlib_metadata
-
def get_version() -> str:
try:
@@ -30,9 +25,14 @@ def get_version() -> str:
version: str = get_version()
-from .core import TensorMeta, ModelAnnotationCreator, ModelAnnotationType, ModelRuntimeType
-from .core import convert_to_spec_tuple
+from .core import (
+ ModelAnnotationCreator,
+ ModelAnnotationType,
+ ModelRuntimeType,
+ TensorMeta,
+ convert_to_spec_tuple,
+)
def get_include() -> str:
- return str(Path(__file__).parent.absolute()/"include")
\ No newline at end of file
+ return str(Path(__file__).parent.absolute() / "include")
diff --git a/src/PyCXpress/__main__.py b/src/PyCXpress/__main__.py
index ec7a1bc..5327f33 100644
--- a/src/PyCXpress/__main__.py
+++ b/src/PyCXpress/__main__.py
@@ -1,21 +1,21 @@
-# type: ignore[attr-defined]
# pylint: disable=missing-function-docstring
import argparse
import sys
import sysconfig
-from PyCXpress import version, get_include
from pybind11 import get_include as pybind11_include
+from PyCXpress import get_include, version
+
def print_includes() -> None:
- dirs = set([
+ dirs = {
sysconfig.get_path("include"),
sysconfig.get_path("platinclude"),
pybind11_include(),
get_include(),
- ])
+ }
print(" ".join(f"-I {d}" for d in dirs))
@@ -39,5 +39,6 @@ def main() -> None:
if args.includes:
print_includes()
+
if __name__ == "__main__":
main()
diff --git a/src/PyCXpress/core.py b/src/PyCXpress/core.py
index 03d6f50..8b07812 100644
--- a/src/PyCXpress/core.py
+++ b/src/PyCXpress/core.py
@@ -1,44 +1,49 @@
+# mypy: disable_error_code="type-arg,arg-type,union-attr,operator,assignment,misc"
import logging
logger = logging.getLogger(__name__)
-import numpy as np
-# import tensorflow as tf
-
-from typing import List, Tuple, Iterable, Callable, Dict, Optional, Union
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from collections import namedtuple
from dataclasses import dataclass
+from enum import Enum, auto
+
import numpy as np
from numpy.typing import DTypeLike
-from enum import Enum, auto
+
+# import tensorflow as tf
-def get_c_type(t: DTypeLike) -> (str, int):
+def get_c_type(t: DTypeLike) -> Tuple[str, int]:
dtype = np.dtype(t)
- relation = {np.dtype('bool'): 'bool',
- np.dtype('int8'): 'int8_t',
- np.dtype('int16'): 'int16_t',
- np.dtype('int32'): 'int32_t',
- np.dtype('int64'): 'int64_t',
- np.dtype('uint8'): 'uint8_t',
- np.dtype('uint16'): 'uint16_t',
- np.dtype('uint32'): 'uint32_t',
- np.dtype('uint64'): 'uint64_t',
- np.dtype('float32'): 'float',
- np.dtype('float64'): 'double'}
+ relation = {
+ np.dtype("bool"): "bool",
+ np.dtype("int8"): "int8_t",
+ np.dtype("int16"): "int16_t",
+ np.dtype("int32"): "int32_t",
+ np.dtype("int64"): "int64_t",
+ np.dtype("uint8"): "uint8_t",
+ np.dtype("uint16"): "uint16_t",
+ np.dtype("uint32"): "uint32_t",
+ np.dtype("uint64"): "uint64_t",
+ np.dtype("float32"): "float",
+ np.dtype("float64"): "double",
+ }
return relation.get(dtype, "char"), dtype.itemsize or 1
@dataclass
class TensorMeta:
dtype: DTypeLike # the data type similar to np.int_
- shape: Union[int, Iterable[int], Callable[..., Union[int, Iterable[int]]]] # the maximal size of each dimension
+ shape: Union[
+ int, Iterable[int], Callable[..., Union[int, Iterable[int]]]
+ ] # the maximal size of each dimension
name: Optional[str] = None
doc: Optional[str] = None
- def to_dict(self, *args, **kwargs) -> dict:
+ def to_dict(self, *args, **kwargs) -> Dict:
assert self.name is not None
max_size = self.shape
@@ -51,12 +56,12 @@ def to_dict(self, *args, **kwargs) -> dict:
return {
"name": self.name,
"dtype": dtype,
- "shape": tuple(round(-i) if i<0 else None for i in max_size),
+ "shape": tuple(round(-i) if i < 0 else None for i in max_size),
"buffer_size": np.prod([round(abs(i)) for i in max_size]) * itemsize,
- "doc": f"" if self.doc is None else self.doc
+ "doc": f"" if self.doc is None else self.doc,
}
- def setdefault(self, name: str):
+ def setdefault(self, name: str) -> str:
if self.name is None:
self.name = name
return self.name
@@ -68,18 +73,29 @@ class ModelAnnotationType(Enum):
Operator = auto()
HyperParams = auto()
+
class ModelRuntimeType(Enum):
GraphExecution = auto()
EagerExecution = auto()
OfflineExecution = auto()
+
@dataclass
class TensorWithShape:
- data: Optional[np.array] = None
+ data: Optional[np.ndarray] = None
shape: Optional[Tuple] = None
+
class ModelAnnotationCreator(type):
- def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelAnnotationType, mode: ModelRuntimeType):
+ def __new__(
+ mcs,
+ name,
+ bases,
+ attrs,
+ fields: Dict[str, TensorMeta],
+ type: ModelAnnotationType,
+ mode: ModelRuntimeType,
+ ):
if type == ModelAnnotationType.Input:
generate_property = mcs.generate_input_property
elif type == ModelAnnotationType.Output:
@@ -91,8 +107,9 @@ def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelA
field_meta.setdefault(field_name)
attrs[field_name] = generate_property(field_meta)
- get_buffer_shape, set_buffer_value, init_func = mcs.general_funcs(name, [field_meta.name for field_meta in fields.values()])
-
+ get_buffer_shape, set_buffer_value, init_func = mcs.general_funcs(
+ name, [field_meta.name for field_meta in fields.values()]
+ )
attrs["__init__"] = init_func
attrs["set_buffer_value"] = set_buffer_value
@@ -100,11 +117,10 @@ def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelA
attrs["get_buffer_shape"] = get_buffer_shape
attrs.setdefault("__slots__", []).append("__buffer_data__")
-
return super().__new__(mcs, name, bases, attrs)
@staticmethod
- def general_funcs(name: str, field_names: list):
+ def general_funcs(name: str, field_names: List[str]):
def get_buffer_shape(self, name: str):
buffer = getattr(self.__buffer_data__, name)
return buffer.shape
@@ -115,7 +131,9 @@ def set_buffer_value(self, name: str, value):
def init_func(self):
_BufferData_ = namedtuple("_BufferData_", field_names)
- self.__buffer_data__ = _BufferData_(*tuple(TensorWithShape() for _ in field_names))
+ self.__buffer_data__ = _BufferData_(
+ *tuple(TensorWithShape() for _ in field_names)
+ )
return get_buffer_shape, set_buffer_value, init_func
@@ -137,24 +155,28 @@ def generate_output_property(field: TensorMeta):
def get_func(self):
logger.warning(f"Only read the data field {field.name} in debugging mode")
buffer = getattr(self.__buffer_data__, field.name)
- return buffer.data[:np.prod(buffer.shape)].reshape(buffer.shape)
+ return buffer.data[: np.prod(buffer.shape)].reshape(buffer.shape)
def set_func(self, data):
buffer = getattr(self.__buffer_data__, field.name)
buffer.shape = data.shape
- buffer.data[:np.prod(data.shape)] = data.flatten()
+ buffer.data[: np.prod(data.shape)] = data.flatten()
def del_func(_):
raise AssertionError("Not supported for output tensor")
return property(fget=get_func, fset=set_func, fdel=del_func, doc=field.doc)
+
def convert_to_spec_tuple(fields: Iterable[TensorMeta]):
- return tuple((v["name"], v["dtype"], v["buffer_size"]) for v in [v.to_dict() for v in fields])
+ return tuple(
+ (v["name"], v["dtype"], v["buffer_size"]) for v in [v.to_dict() for v in fields]
+ )
def main():
pass
+
if __name__ == "__main__":
main()
diff --git a/src/PyCXpress/example/model.py b/src/PyCXpress/example/model.py
index 8945554..6db9563 100644
--- a/src/PyCXpress/example/model.py
+++ b/src/PyCXpress/example/model.py
@@ -1,48 +1,87 @@
+# mypy: disable_error_code="type-arg,attr-defined"
import os
-os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import logging
+
logging.basicConfig(level=logging.DEBUG)
-from pathlib import Path
import sys
+from pathlib import Path
+
import numpy as np
-sys.path.append(str(Path(__file__).parent/".."/"src"/"python"))
-from PyCXpress import TensorMeta, ModelAnnotationCreator, ModelAnnotationType, ModelRuntimeType
-from PyCXpress import convert_to_spec_tuple
+sys.path.append(str(Path(__file__).parent / ".." / "src" / "python"))
+
from contextlib import nullcontext
-def show(a: np.array):
+from PyCXpress import (
+ ModelAnnotationCreator,
+ ModelAnnotationType,
+ ModelRuntimeType,
+ TensorMeta,
+ convert_to_spec_tuple,
+)
+
+
+def show(a: np.ndarray):
logging.info(f"array data type: {a.dtype}")
logging.info(f"array data shape: {a.shape}")
logging.info(f"array data: ")
logging.info(a)
+
InputFields = dict(
- data_to_be_reshaped=TensorMeta(dtype=np.float_,
- shape=(100,),
- ),
- new_2d_shape=TensorMeta(dtype=np.uint8,
- shape=-2,)
+ data_to_be_reshaped=TensorMeta(
+ dtype=np.float_,
+ shape=(100,),
+ ),
+ new_2d_shape=TensorMeta(
+ dtype=np.uint8,
+ shape=-2,
+ ),
)
-class InputDataSet(metaclass=ModelAnnotationCreator, fields=InputFields, type=ModelAnnotationType.Input, mode=ModelRuntimeType.EagerExecution):
+class InputDataSet(
+ metaclass=ModelAnnotationCreator,
+ fields=InputFields,
+ type=ModelAnnotationType.Input,
+ mode=ModelRuntimeType.EagerExecution,
+):
pass
OutputFields = dict(
- output_a=TensorMeta(dtype=np.float_,
- shape=(10, 10),),
+ output_a=TensorMeta(
+ dtype=np.float_,
+ shape=(10, 10),
+ ),
)
-class OutputDataSet(metaclass=ModelAnnotationCreator, fields=OutputFields, type=ModelAnnotationType.Output, mode=ModelRuntimeType.EagerExecution):
+class OutputDataSet(
+ metaclass=ModelAnnotationCreator,
+ fields=OutputFields,
+ type=ModelAnnotationType.Output,
+ mode=ModelRuntimeType.EagerExecution,
+):
pass
def init():
- return InputDataSet(), OutputDataSet(), tuple((*convert_to_spec_tuple(InputFields.values()), *convert_to_spec_tuple(OutputFields.values()))), tuple(OutputFields.keys())
+ return (
+ InputDataSet(),
+ OutputDataSet(),
+ tuple(
+ (
+ *convert_to_spec_tuple(InputFields.values()),
+ *convert_to_spec_tuple(OutputFields.values()),
+ )
+ ),
+ tuple(OutputFields.keys()),
+ )
+
def model(input: InputDataSet, output: OutputDataSet):
with nullcontext():
@@ -51,6 +90,7 @@ def model(input: InputDataSet, output: OutputDataSet):
output.output_a = input.data_to_be_reshaped.reshape(input.new_2d_shape)
# print(output.output_a)
+
def main():
input_data, output_data, spec = init()
print(spec)
@@ -59,11 +99,12 @@ def main():
print(input_data.data_to_be_reshaped)
input_data.set_buffer_value("new_2d_shape", np.array([3, 4]).astype(np.uint8))
print(input_data.new_2d_shape)
- output_data.set_buffer_value("output_a", np.arange(12)*0)
+ output_data.set_buffer_value("output_a", np.arange(12) * 0)
model(input_data, output_data)
print(output_data.output_a)
print(output_data.get_buffer_shape("output_a"))
+
if __name__ == "__main__":
main()
diff --git a/src/PyCXpress/include/PyCXpress/core.hpp b/src/PyCXpress/include/PyCXpress/core.hpp
index 3707853..52ca3a5 100644
--- a/src/PyCXpress/include/PyCXpress/core.hpp
+++ b/src/PyCXpress/include/PyCXpress/core.hpp
@@ -13,19 +13,18 @@
#include "utils.hpp"
#if !defined(PYCXPRESS_EXPORT)
-# if defined(WIN32) || defined(_WIN32)
-# define PYCXPRESS_EXPORT __declspec(dllexport)
-# else
-# define PYCXPRESS_EXPORT __attribute__((visibility("default")))
-# endif
-#endif
+#if defined(WIN32) || defined(_WIN32)
+#define PYCXPRESS_EXPORT __declspec(dllexport)
+#else
+#define PYCXPRESS_EXPORT __attribute__((visibility("default")))
+#endif
+#endif
namespace PyCXpress {
namespace py = pybind11;
using namespace utils;
-class PYCXPRESS_EXPORT
-Buffer {
+class PYCXPRESS_EXPORT Buffer {
typedef unsigned char Bytes;
template
@@ -116,8 +115,7 @@ Buffer {
py::array (*m_converter)(const std::vector &, void *);
};
-class PYCXPRESS_EXPORT
-PythonInterpreter {
+class PYCXPRESS_EXPORT PythonInterpreter {
public:
explicit PythonInterpreter(bool init_signal_handlers = true, int argc = 0,
const char *const *argv = nullptr,
@@ -226,4 +224,4 @@ PythonInterpreter {
}; // namespace PyCXpress
-#endif // __PYCXPRESS_HPP__
\ No newline at end of file
+#endif // __PYCXPRESS_HPP__
diff --git a/tests/basic/test_pkg.py b/tests/basic/test_pkg.py
new file mode 100644
index 0000000..21df992
--- /dev/null
+++ b/tests/basic/test_pkg.py
@@ -0,0 +1,16 @@
+from importlib import import_module
+from pathlib import Path
+
+import pytest
+
+from PyCXpress import get_include
+
+
+@pytest.mark.parametrize(
+ ("include_suffix",),
+ [
+ ("PyCXpress/include",),
+ ],
+)
+def test_get_include(include_suffix):
+ assert get_include().endswith(include_suffix)