From 05722475a3808df77fa222975463e003ee0c2246 Mon Sep 17 00:00:00 2001 From: Nick Wang Date: Wed, 17 Apr 2024 11:56:54 +0000 Subject: [PATCH] format; by pass mypy error for now; --- .github/workflows/build.yml | 62 ++++++++--------- .gitignore | 1 + Makefile | 6 +- assets/images/coverage.svg | 6 +- poetry.lock | 83 +++++++++++++++++++++-- pyproject.toml | 5 +- src/PyCXpress/__init__.py | 18 ++--- src/PyCXpress/__main__.py | 9 +-- src/PyCXpress/core.py | 86 +++++++++++++++--------- src/PyCXpress/example/model.py | 75 ++++++++++++++++----- src/PyCXpress/include/PyCXpress/core.hpp | 20 +++--- tests/basic/test_pkg.py | 16 +++++ 12 files changed, 271 insertions(+), 116 deletions(-) create mode 100644 tests/basic/test_pkg.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7f8742f..63ae3dd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,36 +7,36 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5.1.0 - with: - python-version: ${{ matrix.python-version }} - - - name: Install poetry - run: make poetry-download - - - name: Set up cache - uses: actions/cache@v4.0.2 - with: - path: .venv - key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }} - - name: Install dependencies - run: | - poetry config virtualenvs.in-project true - poetry install - - - name: Run style checks - run: | - make check-codestyle - - - name: Run tests - run: | - make test - - - name: Run safety checks - run: | - make check-safety + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5.1.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install poetry + run: make poetry-download + + - name: Set up cache + uses: actions/cache@v4.0.2 + with: + path: .venv + key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }} + - name: Install dependencies + run: | + poetry config virtualenvs.in-project true + poetry install + + - name: Run style checks + run: | + make check-codestyle + + - name: Run tests + run: | + make test + + - name: Run safety checks + run: | + make check-safety diff --git a/.gitignore b/.gitignore index f089344..4f2ee93 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ # Python dist/ __pycache__/ +.coverage diff --git a/Makefile b/Makefile index a7fb1eb..f378543 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,7 @@ pre-commit-install: #* Formatters .PHONY: codestyle codestyle: - poetry run pyupgrade --exit-zero-even-if-changed --py38-plus **/*.py + poetry run pyupgrade --exit-zero-even-if-changed --py38-plus src/**/*.py poetry run isort --settings-path pyproject.toml ./ poetry run black --config pyproject.toml ./ @@ -83,8 +83,8 @@ mypy: .PHONY: check-safety check-safety: poetry check - poetry run safety check --full-report - poetry run bandit -ll --recursive PyCXpress tests + -poetry run safety check --full-report + poetry run bandit -ll --recursive src/PyCXpress tests .PHONY: lint lint: test check-codestyle mypy check-safety diff --git a/assets/images/coverage.svg b/assets/images/coverage.svg index 0644a48..53e7fcb 100644 --- a/assets/images/coverage.svg +++ b/assets/images/coverage.svg @@ -9,13 +9,13 @@ - + coverage coverage - 26% - 26% + 40% + 40% diff --git a/poetry.lock b/poetry.lock index 30d7fa6..981f7f3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -112,6 +112,31 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "build" +version = "1.2.1" +description = "A simple, correct Python build frontend" +optional = false +python-versions = ">=3.8" +files = [ + {file = "build-1.2.1-py3-none-any.whl", hash = "sha256:75e10f767a433d9a86e50d83f418e83efc18ede923ee5ff7df93b6cb0306c5d4"}, + {file = "build-1.2.1.tar.gz", hash = "sha256:526263f4870c26f26c433545579475377b2b7588b6f1eac76a001e873ae3e19d"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "os_name == \"nt\""} +importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} +packaging = ">=19.1" +pyproject_hooks = "*" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + +[package.extras] +docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] +test = ["build[uv,virtualenv]", "filelock (>=3)", "pytest (>=6.2.4)", "pytest-cov (>=2.12)", "pytest-mock (>=2)", "pytest-rerunfailures (>=9.1)", "pytest-xdist (>=1.34)", "setuptools (>=42.0.0)", "setuptools (>=56.0.0)", "setuptools (>=56.0.0)", "setuptools (>=67.8.0)", "wheel (>=0.36.0)"] +typing = ["build[uv]", "importlib-metadata (>=5.1)", "mypy (>=1.9.0,<1.10.0)", "tomli", "typing-extensions (>=3.7.4.3)"] +uv = ["uv (>=0.1.18)"] +virtualenv = ["virtualenv (>=20.0.35)"] + [[package]] name = "certifi" version = "2024.2.2" @@ -570,6 +595,25 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "importlib-metadata" +version = "7.1.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-7.1.0-py3-none-any.whl", hash = "sha256:30962b96c0c223483ed6cc7280e7f0199feb01a0e40cfae4d4450fc6fab1f570"}, + {file = "importlib_metadata-7.1.0.tar.gz", hash = "sha256:b78938b926ee8d5f020fc4772d487045805a55ddbad2ecf21c6d60938dc7fcd2"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1145,6 +1189,20 @@ typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\"" spelling = ["pyenchant (>=3.2,<4.0)"] testutils = ["gitpython (>3)"] +[[package]] +name = "pyproject-hooks" +version = "1.0.0" +description = "Wrappers to call pyproject.toml-based build backend hooks." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyproject_hooks-1.0.0-py3-none-any.whl", hash = "sha256:283c11acd6b928d2f6a7c73fa0d01cb2bdc5f07c57a2eeb6e83d5e56b97976f8"}, + {file = "pyproject_hooks-1.0.0.tar.gz", hash = "sha256:f271b298b97f5955d53fb12b72c1fb1948c22c1a6b70b315c54cedaca0264ef5"}, +] + +[package.dependencies] +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} + [[package]] name = "pytest" version = "8.1.1" @@ -1592,13 +1650,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "virtualenv" -version = "20.25.1" +version = "20.25.2" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"}, - {file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"}, + {file = "virtualenv-20.25.2-py3-none-any.whl", hash = "sha256:6e1281a57849c8a54da89ba82e5eb7c8937b9d057ff01aaf5bc9afaa3552e90f"}, + {file = "virtualenv-20.25.2.tar.gz", hash = "sha256:fa7edb8428620518010928242ec17aa7132ae435319c29c1651d1cf4c4173aad"}, ] [package.dependencies] @@ -1607,10 +1665,25 @@ filelock = ">=3.12.2,<4" platformdirs = ">=3.9.1,<5" [package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +[[package]] +name = "zipp" +version = "3.18.1" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, + {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] + [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "8193a02a1aa5dee0c7e18950f25b6cfcc9ed8b84f70b511a236c220d242feb40" +content-hash = "17346d0057aa8cd3c4ddd4f4ae5c53402353e2038de2e110ec61eca54db7ffee" diff --git a/pyproject.toml b/pyproject.toml index d24d6ef..1d1dcd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ numpy = "^1.22" bandit = "^1.7.1" black = {version = "^24.4", allow-prereleases = true} darglint = "^1.8.1" +build = "^1.2.1" isort = {extras = ["colors"], version = "^5.10.1"} mypy = "^1.0" mypy-extensions = "^0.4.3" @@ -121,6 +122,8 @@ warn_unreachable = true warn_unused_configs = true warn_unused_ignores = true +disable_error_code = ["no-untyped-def"] + [tool.pytest.ini_options] # https://docs.pytest.org/en/6.2.x/customize.html#pyproject-toml @@ -147,4 +150,4 @@ branch = true [coverage.report] fail_under = 50 -show_missing = true \ No newline at end of file +show_missing = true diff --git a/src/PyCXpress/__init__.py b/src/PyCXpress/__init__.py index 131685a..140adf7 100644 --- a/src/PyCXpress/__init__.py +++ b/src/PyCXpress/__init__.py @@ -1,4 +1,3 @@ -# type: ignore[attr-defined] """PyCXpress is a high-performance hybrid framework that seamlessly integrates Python and C++ to harness the flexibility of Python and the speed of C++ for efficient and expressive computation, particularly in the realm of deep learning and numerical computing.""" __all__ = [ @@ -12,13 +11,9 @@ ] import sys +from importlib import metadata as importlib_metadata from pathlib import Path -if sys.version_info >= (3, 8): - from importlib import metadata as importlib_metadata -else: - import importlib_metadata - def get_version() -> str: try: @@ -30,9 +25,14 @@ def get_version() -> str: version: str = get_version() -from .core import TensorMeta, ModelAnnotationCreator, ModelAnnotationType, ModelRuntimeType -from .core import convert_to_spec_tuple +from .core import ( + ModelAnnotationCreator, + ModelAnnotationType, + ModelRuntimeType, + TensorMeta, + convert_to_spec_tuple, +) def get_include() -> str: - return str(Path(__file__).parent.absolute()/"include") \ No newline at end of file + return str(Path(__file__).parent.absolute() / "include") diff --git a/src/PyCXpress/__main__.py b/src/PyCXpress/__main__.py index ec7a1bc..5327f33 100644 --- a/src/PyCXpress/__main__.py +++ b/src/PyCXpress/__main__.py @@ -1,21 +1,21 @@ -# type: ignore[attr-defined] # pylint: disable=missing-function-docstring import argparse import sys import sysconfig -from PyCXpress import version, get_include from pybind11 import get_include as pybind11_include +from PyCXpress import get_include, version + def print_includes() -> None: - dirs = set([ + dirs = { sysconfig.get_path("include"), sysconfig.get_path("platinclude"), pybind11_include(), get_include(), - ]) + } print(" ".join(f"-I {d}" for d in dirs)) @@ -39,5 +39,6 @@ def main() -> None: if args.includes: print_includes() + if __name__ == "__main__": main() diff --git a/src/PyCXpress/core.py b/src/PyCXpress/core.py index 03d6f50..8b07812 100644 --- a/src/PyCXpress/core.py +++ b/src/PyCXpress/core.py @@ -1,44 +1,49 @@ +# mypy: disable_error_code="type-arg,arg-type,union-attr,operator,assignment,misc" import logging logger = logging.getLogger(__name__) -import numpy as np -# import tensorflow as tf - -from typing import List, Tuple, Iterable, Callable, Dict, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from collections import namedtuple from dataclasses import dataclass +from enum import Enum, auto + import numpy as np from numpy.typing import DTypeLike -from enum import Enum, auto + +# import tensorflow as tf -def get_c_type(t: DTypeLike) -> (str, int): +def get_c_type(t: DTypeLike) -> Tuple[str, int]: dtype = np.dtype(t) - relation = {np.dtype('bool'): 'bool', - np.dtype('int8'): 'int8_t', - np.dtype('int16'): 'int16_t', - np.dtype('int32'): 'int32_t', - np.dtype('int64'): 'int64_t', - np.dtype('uint8'): 'uint8_t', - np.dtype('uint16'): 'uint16_t', - np.dtype('uint32'): 'uint32_t', - np.dtype('uint64'): 'uint64_t', - np.dtype('float32'): 'float', - np.dtype('float64'): 'double'} + relation = { + np.dtype("bool"): "bool", + np.dtype("int8"): "int8_t", + np.dtype("int16"): "int16_t", + np.dtype("int32"): "int32_t", + np.dtype("int64"): "int64_t", + np.dtype("uint8"): "uint8_t", + np.dtype("uint16"): "uint16_t", + np.dtype("uint32"): "uint32_t", + np.dtype("uint64"): "uint64_t", + np.dtype("float32"): "float", + np.dtype("float64"): "double", + } return relation.get(dtype, "char"), dtype.itemsize or 1 @dataclass class TensorMeta: dtype: DTypeLike # the data type similar to np.int_ - shape: Union[int, Iterable[int], Callable[..., Union[int, Iterable[int]]]] # the maximal size of each dimension + shape: Union[ + int, Iterable[int], Callable[..., Union[int, Iterable[int]]] + ] # the maximal size of each dimension name: Optional[str] = None doc: Optional[str] = None - def to_dict(self, *args, **kwargs) -> dict: + def to_dict(self, *args, **kwargs) -> Dict: assert self.name is not None max_size = self.shape @@ -51,12 +56,12 @@ def to_dict(self, *args, **kwargs) -> dict: return { "name": self.name, "dtype": dtype, - "shape": tuple(round(-i) if i<0 else None for i in max_size), + "shape": tuple(round(-i) if i < 0 else None for i in max_size), "buffer_size": np.prod([round(abs(i)) for i in max_size]) * itemsize, - "doc": f"" if self.doc is None else self.doc + "doc": f"" if self.doc is None else self.doc, } - def setdefault(self, name: str): + def setdefault(self, name: str) -> str: if self.name is None: self.name = name return self.name @@ -68,18 +73,29 @@ class ModelAnnotationType(Enum): Operator = auto() HyperParams = auto() + class ModelRuntimeType(Enum): GraphExecution = auto() EagerExecution = auto() OfflineExecution = auto() + @dataclass class TensorWithShape: - data: Optional[np.array] = None + data: Optional[np.ndarray] = None shape: Optional[Tuple] = None + class ModelAnnotationCreator(type): - def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelAnnotationType, mode: ModelRuntimeType): + def __new__( + mcs, + name, + bases, + attrs, + fields: Dict[str, TensorMeta], + type: ModelAnnotationType, + mode: ModelRuntimeType, + ): if type == ModelAnnotationType.Input: generate_property = mcs.generate_input_property elif type == ModelAnnotationType.Output: @@ -91,8 +107,9 @@ def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelA field_meta.setdefault(field_name) attrs[field_name] = generate_property(field_meta) - get_buffer_shape, set_buffer_value, init_func = mcs.general_funcs(name, [field_meta.name for field_meta in fields.values()]) - + get_buffer_shape, set_buffer_value, init_func = mcs.general_funcs( + name, [field_meta.name for field_meta in fields.values()] + ) attrs["__init__"] = init_func attrs["set_buffer_value"] = set_buffer_value @@ -100,11 +117,10 @@ def __new__(mcs, name, bases, attrs, fields: Dict[str, TensorMeta], type: ModelA attrs["get_buffer_shape"] = get_buffer_shape attrs.setdefault("__slots__", []).append("__buffer_data__") - return super().__new__(mcs, name, bases, attrs) @staticmethod - def general_funcs(name: str, field_names: list): + def general_funcs(name: str, field_names: List[str]): def get_buffer_shape(self, name: str): buffer = getattr(self.__buffer_data__, name) return buffer.shape @@ -115,7 +131,9 @@ def set_buffer_value(self, name: str, value): def init_func(self): _BufferData_ = namedtuple("_BufferData_", field_names) - self.__buffer_data__ = _BufferData_(*tuple(TensorWithShape() for _ in field_names)) + self.__buffer_data__ = _BufferData_( + *tuple(TensorWithShape() for _ in field_names) + ) return get_buffer_shape, set_buffer_value, init_func @@ -137,24 +155,28 @@ def generate_output_property(field: TensorMeta): def get_func(self): logger.warning(f"Only read the data field {field.name} in debugging mode") buffer = getattr(self.__buffer_data__, field.name) - return buffer.data[:np.prod(buffer.shape)].reshape(buffer.shape) + return buffer.data[: np.prod(buffer.shape)].reshape(buffer.shape) def set_func(self, data): buffer = getattr(self.__buffer_data__, field.name) buffer.shape = data.shape - buffer.data[:np.prod(data.shape)] = data.flatten() + buffer.data[: np.prod(data.shape)] = data.flatten() def del_func(_): raise AssertionError("Not supported for output tensor") return property(fget=get_func, fset=set_func, fdel=del_func, doc=field.doc) + def convert_to_spec_tuple(fields: Iterable[TensorMeta]): - return tuple((v["name"], v["dtype"], v["buffer_size"]) for v in [v.to_dict() for v in fields]) + return tuple( + (v["name"], v["dtype"], v["buffer_size"]) for v in [v.to_dict() for v in fields] + ) def main(): pass + if __name__ == "__main__": main() diff --git a/src/PyCXpress/example/model.py b/src/PyCXpress/example/model.py index 8945554..6db9563 100644 --- a/src/PyCXpress/example/model.py +++ b/src/PyCXpress/example/model.py @@ -1,48 +1,87 @@ +# mypy: disable_error_code="type-arg,attr-defined" import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import logging + logging.basicConfig(level=logging.DEBUG) -from pathlib import Path import sys +from pathlib import Path + import numpy as np -sys.path.append(str(Path(__file__).parent/".."/"src"/"python")) -from PyCXpress import TensorMeta, ModelAnnotationCreator, ModelAnnotationType, ModelRuntimeType -from PyCXpress import convert_to_spec_tuple +sys.path.append(str(Path(__file__).parent / ".." / "src" / "python")) + from contextlib import nullcontext -def show(a: np.array): +from PyCXpress import ( + ModelAnnotationCreator, + ModelAnnotationType, + ModelRuntimeType, + TensorMeta, + convert_to_spec_tuple, +) + + +def show(a: np.ndarray): logging.info(f"array data type: {a.dtype}") logging.info(f"array data shape: {a.shape}") logging.info(f"array data: ") logging.info(a) + InputFields = dict( - data_to_be_reshaped=TensorMeta(dtype=np.float_, - shape=(100,), - ), - new_2d_shape=TensorMeta(dtype=np.uint8, - shape=-2,) + data_to_be_reshaped=TensorMeta( + dtype=np.float_, + shape=(100,), + ), + new_2d_shape=TensorMeta( + dtype=np.uint8, + shape=-2, + ), ) -class InputDataSet(metaclass=ModelAnnotationCreator, fields=InputFields, type=ModelAnnotationType.Input, mode=ModelRuntimeType.EagerExecution): +class InputDataSet( + metaclass=ModelAnnotationCreator, + fields=InputFields, + type=ModelAnnotationType.Input, + mode=ModelRuntimeType.EagerExecution, +): pass OutputFields = dict( - output_a=TensorMeta(dtype=np.float_, - shape=(10, 10),), + output_a=TensorMeta( + dtype=np.float_, + shape=(10, 10), + ), ) -class OutputDataSet(metaclass=ModelAnnotationCreator, fields=OutputFields, type=ModelAnnotationType.Output, mode=ModelRuntimeType.EagerExecution): +class OutputDataSet( + metaclass=ModelAnnotationCreator, + fields=OutputFields, + type=ModelAnnotationType.Output, + mode=ModelRuntimeType.EagerExecution, +): pass def init(): - return InputDataSet(), OutputDataSet(), tuple((*convert_to_spec_tuple(InputFields.values()), *convert_to_spec_tuple(OutputFields.values()))), tuple(OutputFields.keys()) + return ( + InputDataSet(), + OutputDataSet(), + tuple( + ( + *convert_to_spec_tuple(InputFields.values()), + *convert_to_spec_tuple(OutputFields.values()), + ) + ), + tuple(OutputFields.keys()), + ) + def model(input: InputDataSet, output: OutputDataSet): with nullcontext(): @@ -51,6 +90,7 @@ def model(input: InputDataSet, output: OutputDataSet): output.output_a = input.data_to_be_reshaped.reshape(input.new_2d_shape) # print(output.output_a) + def main(): input_data, output_data, spec = init() print(spec) @@ -59,11 +99,12 @@ def main(): print(input_data.data_to_be_reshaped) input_data.set_buffer_value("new_2d_shape", np.array([3, 4]).astype(np.uint8)) print(input_data.new_2d_shape) - output_data.set_buffer_value("output_a", np.arange(12)*0) + output_data.set_buffer_value("output_a", np.arange(12) * 0) model(input_data, output_data) print(output_data.output_a) print(output_data.get_buffer_shape("output_a")) + if __name__ == "__main__": main() diff --git a/src/PyCXpress/include/PyCXpress/core.hpp b/src/PyCXpress/include/PyCXpress/core.hpp index 3707853..52ca3a5 100644 --- a/src/PyCXpress/include/PyCXpress/core.hpp +++ b/src/PyCXpress/include/PyCXpress/core.hpp @@ -13,19 +13,18 @@ #include "utils.hpp" #if !defined(PYCXPRESS_EXPORT) -# if defined(WIN32) || defined(_WIN32) -# define PYCXPRESS_EXPORT __declspec(dllexport) -# else -# define PYCXPRESS_EXPORT __attribute__((visibility("default"))) -# endif -#endif +#if defined(WIN32) || defined(_WIN32) +#define PYCXPRESS_EXPORT __declspec(dllexport) +#else +#define PYCXPRESS_EXPORT __attribute__((visibility("default"))) +#endif +#endif namespace PyCXpress { namespace py = pybind11; using namespace utils; -class PYCXPRESS_EXPORT -Buffer { +class PYCXPRESS_EXPORT Buffer { typedef unsigned char Bytes; template @@ -116,8 +115,7 @@ Buffer { py::array (*m_converter)(const std::vector &, void *); }; -class PYCXPRESS_EXPORT -PythonInterpreter { +class PYCXPRESS_EXPORT PythonInterpreter { public: explicit PythonInterpreter(bool init_signal_handlers = true, int argc = 0, const char *const *argv = nullptr, @@ -226,4 +224,4 @@ PythonInterpreter { }; // namespace PyCXpress -#endif // __PYCXPRESS_HPP__ \ No newline at end of file +#endif // __PYCXPRESS_HPP__ diff --git a/tests/basic/test_pkg.py b/tests/basic/test_pkg.py new file mode 100644 index 0000000..21df992 --- /dev/null +++ b/tests/basic/test_pkg.py @@ -0,0 +1,16 @@ +from importlib import import_module +from pathlib import Path + +import pytest + +from PyCXpress import get_include + + +@pytest.mark.parametrize( + ("include_suffix",), + [ + ("PyCXpress/include",), + ], +) +def test_get_include(include_suffix): + assert get_include().endswith(include_suffix)