diff --git a/example.py b/example.py index bdc1d92..4f7d985 100644 --- a/example.py +++ b/example.py @@ -5,18 +5,18 @@ import pathlib sys.path.append(".") -from extism import Function, host_fn, ValType, Plugin, set_log_file +from extism import Function, host_fn, ValType, Plugin, set_log_file, Json +from typing import Annotated set_log_file("stderr", "trace") -@host_fn -def hello_world(plugin, input_, output, a_string): +@host_fn(user_data=b"Hello again!") +def hello_world(inp: Annotated[dict, Json], *a_string) -> Annotated[dict, Json]: print("Hello from Python!") print(a_string) - print(input_) - print(plugin.input_string(input_[0])) - output[0] = input_[0] + inp["roundtrip"] = 1 + return inp # Compare against Python implementation. @@ -36,16 +36,7 @@ def main(args): hash = hashlib.sha256(wasm).hexdigest() manifest = {"wasm": [{"data": wasm, "hash": hash}]} - functions = [ - Function( - "hello_world", - [ValType.I64], - [ValType.I64], - hello_world, - "Hello again!", - ) - ] - plugin = Plugin(manifest, wasi=True, functions=functions) + plugin = Plugin(manifest, wasi=True) print(plugin.id) # Call `count_vowels` wasm_vowel_count = plugin.call("count_vowels", data) @@ -55,6 +46,7 @@ def main(args): print("Number of vowels:", j["count"]) assert j["count"] == count_vowels(data) + assert j["roundtrip"] == 1 if __name__ == "__main__": diff --git a/extism/__init__.py b/extism/__init__.py index 35077b1..9da3dba 100644 --- a/extism/__init__.py +++ b/extism/__init__.py @@ -14,6 +14,9 @@ ValType, Val, CurrentPlugin, + Codec, + Json, + Pickle, ) __all__ = [ @@ -28,4 +31,7 @@ "Function", "ValType", "Val", + "Codec", + "Json", + "Pickle", ] diff --git a/extism/extism.py b/extism/extism.py index de38d2b..e34fdc8 100644 --- a/extism/extism.py +++ b/extism/extism.py @@ -2,12 +2,100 @@ import os from base64 import b64encode from cffi import FFI -from typing import Any, Callable, Dict, List, Union, Literal, Optional +from typing import ( + Annotated, + get_args, + get_origin, + get_type_hints, + Any, + Callable, + List, + Union, + Literal, + Optional, + Tuple, +) from enum import Enum from uuid import UUID from extism_sys import lib as _lib, ffi as _ffi # type: ignore -from typing import Annotated from annotated_types import Gt +import functools +import pickle + + +HOST_FN_REGISTRY: List[Any] = [] + + +class Json: + """ + Typing metadata: indicates that an extism host function parameter (or return value) + should be encoded (or decoded) using ``json``. + + .. sourcecode:: python + + @extism.host_fn() + def load(input: typing.Annotated[dict, extism.Json]): + # input will be a dictionary decoded from json input. + input.get("hello", None) + + @extism.host_fn() + def load(input: int) -> typing.Annotated[dict, extism.Json]: + return { + 'hello': 3 + } + + """ + + ... + + +class Pickle: + """ + Typing metadata: indicates that an extism host function parameter (or return value) + should be encoded (or decoded) using ``pickle``. + + .. sourcecode:: python + + class Grimace: + ... + + @extism.host_fn() + def load(input: typing.Annotated[Grimace, extism.Pickle]): + # input will be an instance of Grimace! + ... + + @extism.host_fn() + def load(input: int) -> typing.Annotated[Grimace, extism.Pickle]: + return Grimace() + + """ + + ... + + +class Codec: + """ + Typing metadata: indicates that an extism host function parameter (or return value) + should be transformed with the provided function. + + .. sourcecode:: python + + import json + + @extism.host_fn() + def load(input: typing.Annotated[str, extism.Codec(lambda inp: inp.decode(encoding = 'shift_jis'))]): + # you can accept shift-jis bytes as input! + ... + + mojibake_factory = lambda out: out.encode(encoding='utf8').decode(encoding='latin1').encode() + + @extism.host_fn() + def load(input: int) -> typing.Annotated[str, extism.Codec(mojibake_factory)]: + return "get ready for some mojibake 🎉" + """ + + def __init__(self, codec): + self.codec = codec class Error(Exception): @@ -19,6 +107,36 @@ class Error(Exception): ... +class ValType(Enum): + """ + An enumeration of all available `Wasm value types `_. + """ + + I32 = 0 + I64 = 1 + F32 = 2 + F64 = 3 + V128 = 4 + FUNC_REF = 5 + EXTERN_REF = 6 + + +class Val: + """ + Low-level WebAssembly value with associated :py:class:`ValType`. + """ + + def __init__(self, t: ValType, v): + self.t = t + self.value = v + + def __repr__(self): + return f"Val({self.t}, {self.value})" + + def _assign(self, v): + self.value = v + + class _Base64Encoder(json.JSONEncoder): # pylint: disable=method-hidden def default(self, o): @@ -90,6 +208,7 @@ def __init__(self, name: str, args, returns, f, *user_data): self.user_data = _ffi.new_handle(user_data) else: self.user_data = _ffi.NULL + self.pointer = _lib.extism_function_new( name.encode(), args, @@ -115,6 +234,153 @@ def __del__(self): _lib.extism_function_free(self.pointer) +def _map_arg(arg_name, xs) -> Tuple[ValType, Callable[[Any, Any], Any]]: + if xs == str: + return (ValType.I64, lambda plugin, slot: plugin.input_string(slot)) + + if xs == bytes: + return (ValType.I64, lambda plugin, slot: plugin.input_bytes(slot)) + + if xs == int: + return (ValType.I64, lambda _, slot: slot.value) + + if xs == float: + return (ValType.F64, lambda _, slot: slot.value) + + if xs == bool: + return (ValType.I32, lambda _, slot: slot.value) + + metadata = getattr(xs, "__metadata__", ()) + for item in metadata: + if item == Json: + return ( + ValType.I64, + lambda plugin, slot: json.loads(plugin.input_string(slot)), + ) + + if item == Pickle: + return ( + ValType.I64, + lambda plugin, slot: pickle.loads(plugin.input_bytes(slot)), + ) + + if isinstance(item, Codec): + return ( + ValType.I64, + lambda plugin, slot: item.codec(plugin.input_bytes(slot)), + ) + + raise TypeError("Could not infer input type for argument %s" % arg_name) + + +def _map_ret(xs) -> List[Tuple[ValType, Callable[[Any, Any, Any], Any]]]: + if xs == str: + return [ + (ValType.I64, lambda plugin, slot, value: plugin.return_string(slot, value)) + ] + + if xs == bytes: + return [ + (ValType.I64, lambda plugin, slot, value: plugin.return_bytes(slot, value)) + ] + + if xs == int: + return [(ValType.I64, lambda _, slot, value: slot.assign(value))] + + if xs == float: + return [(ValType.F64, lambda _, slot, value: slot.assign(value))] + + if xs == bool: + return [(ValType.I32, lambda _, slot, value: slot.assign(value))] + + if get_origin(xs) == tuple: + return functools.reduce(lambda lhs, rhs: lhs + _map_ret(rhs), get_args(xs), []) + + metadata = getattr(xs, "__metadata__", ()) + for item in metadata: + if item == Json: + return [ + ( + ValType.I64, + lambda plugin, slot, value: plugin.return_string( + slot, json.dumps(value) + ), + ) + ] + + if item == Pickle: + return [ + ( + ValType.I64, + lambda plugin, slot, value: plugin.return_bytes( + slot, pickle.dumps(value) + ), + ) + ] + + if isinstance(item, Codec): + return [ + ( + ValType.I64, + lambda plugin, slot, value: plugin.return_bytes( + slot, item.codec(value) + ), + ) + ] + + raise TypeError("Could not infer return type") + + +class ExplicitFunction(Function): + def __init__(self, name, namespace, args, returns, func, user_data): + self.func = func + + super().__init__(name, args, returns, handle_args, *user_data) + if namespace is not None: + self.set_namespace(namespace) + + functools.update_wrapper(self, func) + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class TypeInferredFunction(ExplicitFunction): + def __init__(self, name, namespace, func, user_data): + hints = get_type_hints(func, include_extras=True) + if len(hints) == 0: + raise TypeError( + "Host function must include Python type annotations or explicitly list arguments." + ) + + arg_names = [arg for arg in hints.keys() if arg != "return"] + returns = hints.pop("return", None) + + args = [_map_arg(arg, hints[arg]) for arg in arg_names] + returns = [] if returns is None else _map_ret(returns) + + def inner_func(plugin, inputs, outputs, *user_data): + inner_args = [ + extract(plugin, slot) for ((_, extract), slot) in zip(args, inputs) + ] + + if user_data is not None: + inner_args += list(user_data) + + result = func(*inner_args) + for (_, emplace), slot in zip(returns, outputs): + emplace(plugin, slot, result) + + super().__init__( + name, + namespace, + [typ for (typ, _) in args], + [typ for (typ, _) in returns], + inner_func, + user_data, + ) + + class CancelHandle: def __init__(self, ptr): self.pointer = ptr @@ -133,7 +399,8 @@ class Plugin: :param config: An optional JSON-serializable object holding a map of configuration keys and values. :param functions: An optional list of host :py:class:`functions <.extism.Function>` to - expose to the guest program. + expose to the guest program. Defaults to all registered ``@host_fn()``'s + if not given. """ def __init__( @@ -141,7 +408,7 @@ def __init__( plugin: Union[str, bytes, dict], wasi: bool = False, config: Optional[Any] = None, - functions: Optional[List[Function]] = None, + functions: Optional[List[Function]] = HOST_FN_REGISTRY, ): wasm = _wasm(plugin) self.functions = functions @@ -196,7 +463,7 @@ def call( function_name: str, data: Union[str, bytes], parse: Callable[[Any], Any] = lambda xs: bytes(xs), - ): + ) -> Any: """ Call a function by name with the provided input data @@ -261,33 +528,6 @@ def _convert_output(x, v): raise Error("Unsupported return type: " + str(x.t)) -class ValType(Enum): - """ - An enumeration of all available `Wasm value types `_. - """ - - I32 = 0 - I64 = 1 - F32 = 2 - F64 = 3 - V128 = 4 - FUNC_REF = 5 - EXTERN_REF = 6 - - -class Val: - """ - Low-level WebAssembly value with associated :py:class:`ValType`. - """ - - def __init__(self, t: ValType, v): - self.t = t - self.value = v - - def __repr__(self): - return f"Val({self.t}, {self.value})" - - class CurrentPlugin: """ This object is accessible when calling from the guest :py:class:`Plugin` into the host via @@ -366,7 +606,7 @@ def input_string(self, input: Val) -> str: .. sourcecode:: python - @extism.host_fn + @extism.host_fn(signature=([extism.ValType.I64], [])) def hello_world(plugin, params, results): my_str = plugin.input_string(params[0]) print(my_str) @@ -407,9 +647,7 @@ def hello_world(plugin, params, results): with open("example.wasm", "rb") as wasm_file: data = wasm_file.read() - with extism.Plugin(data, functions=[ - extism.Function("hello_world", [extism.ValType.I64], [], hello_world).with_namespace("example") - ]) as plugin: + with extism.Plugin(data, functions=[hello_world]) as plugin: plugin.call("my_func", "") :param input: The input value that references a string. @@ -418,46 +656,124 @@ def hello_world(plugin, params, results): def host_fn( - func: Union[ - Any, - Callable[[CurrentPlugin, List[Val], List[Val]], List[Val]], - Callable[[CurrentPlugin, List[Val], List[Val], Optional[Any]], List[Val]], - ] + name: Optional[str] = None, + namespace: Optional[str] = None, + signature: Optional[Tuple[List[ValType], List[ValType]]] = None, + user_data: Optional[bytes | List[bytes]] = None, ): """ - A decorator for creating host functions, this decorator wraps a function - that takes the following parameters: + A decorator for creating host functions. Host functions are installed into a thread-local + registry. + + :param name: The function name to expose to the guest plugin. If not given, inferred from the + wrapped function name. + :param namespace: The namespace to install the function into; defaults to "env" if not given. + :param signature: A tuple of two arrays representing the function parameter types and return value types. + If not given, types will be inferred from ``typing`` annotations. + :param userdata: Any custom userdata to associate with the function. + + Supported Inferred Types + ------------------------ + + - ``typing.Annotated[Any, extism.Json]``: In both parameter and return + positions. Written to extism memory; offset encoded in return value as + ``I64``. + - ``typing.Annotated[Any, extism.Pickle]``: In both parameter and return + positions. Written to extism memory; offset encoded in return value as + ``I64``. + - ``str``, ``bytes``: In both parameter and return + positions. Written to extism memory; offset encoded in return value as + ``I64``. + - ``int``: In both parameter and return positions. Encoded as ``I64``. + - ``float``: In both parameter and return positions. Encoded as ``F64``. + - ``bool``: In both parameter and return positions. Encoded as ``I32``. + - ``typing.Tuple[]``: In return position; expands + return list to include all member type encodings. + + .. sourcecode:: python + + import typing + import extism + + @extism.host_fn() + def greet(who: str) -> str: + return "hello %s" % who + + @extism.host_fn() + def load(input: typing.Annotated[dict, extism.Json]) -> typing.Tuple[int, int]: + # input will be a dictionary decoded from json input. The tuple will be returned + # two I64 values. + return (3, 4) + + @extism.host_fn() + def return_many_encoded() -> typing.Tuple(int, typing.Annotated[dict, extism.Json]): + # we auto-encoded any Json-annotated return values, even in a tuple + return (32, {"hello": "world"}) + + class Gromble: + ... + + @extism.host_fn() + def everyone_loves_a_pickle(grumble: typing.Annotated[Gromble, extism.Pickle]) -> typing.Annotated[Gromble, extism.Pickle]: + # you can pass pickled objects in and out of host funcs + return Gromble() + + @extism.host_fn(signature=([extism.ValType.I64], [])) + def more_control( + current_plugin: extism.CurrentPlugin, + params: typing.List[extism.Val], + results: typing.List[extism.Val], + *user_data + ): + # if you need more control, you can specify the wasm-level input + # and output types explicitly. + ... - - ``current_plugin``: :py:class:`CurrentPlugin <.CurrentPlugin>` - - ``inputs``: :py:class:`List[Val] <.Val>` - - ``outputs``: :py:class:`List[Val] <.Val>` - - ``user_data``: any number of values passed as user data - - The function should return a list of `Val`. """ + if user_data is None: + user_data = [] + elif isinstance(user_data, bytes): + user_data = [user_data] + + def outer(func): + n = name or func.__name__ + + idx = len(HOST_FN_REGISTRY).to_bytes(length=4, byteorder="big") + user_data.append(idx) + fn = ( + TypeInferredFunction(n, namespace, func, user_data) + if signature is None + else ExplicitFunction( + n, namespace, signature[0], signature[1], func, user_data + ) + ) + HOST_FN_REGISTRY.append(fn) + return fn - @_ffi.callback( - "void(ExtismCurrentPlugin*, const ExtismVal*, ExtismSize, ExtismVal*, ExtismSize, void*)" - ) - def handle_args(current, inputs, n_inputs, outputs, n_outputs, user_data): - inp = [] - outp = [] + return outer - for i in range(n_inputs): - inp.append(_convert_value(inputs[i])) - for i in range(n_outputs): - outp.append(_convert_value(outputs[i])) +@_ffi.callback( + "void(ExtismCurrentPlugin*, const ExtismVal*, ExtismSize, ExtismVal*, ExtismSize, void*)" +) +def handle_args(current, inputs, n_inputs, outputs, n_outputs, user_data): + inp = [] + outp = [] - cast_func: Any = func + for i in range(n_inputs): + inp.append(_convert_value(inputs[i])) - if user_data == _ffi.NULL: - cast_func(CurrentPlugin(current), inp, outp) - else: - udata = _ffi.from_handle(user_data) - cast_func(CurrentPlugin(current), inp, outp, *udata) + for i in range(n_outputs): + outp.append(_convert_value(outputs[i])) + + if user_data == _ffi.NULL: + udata = [] + else: + udata = list(_ffi.from_handle(user_data)) + + idx = int.from_bytes(udata.pop(), byteorder="big") - for i in range(n_outputs): - _convert_output(outputs[i], outp[i]) + HOST_FN_REGISTRY[idx](CurrentPlugin(current), inp, outp, *udata) - return handle_args + for i in range(n_outputs): + _convert_output(outputs[i], outp[i]) diff --git a/justfile b/justfile index 717286d..86ae995 100644 --- a/justfile +++ b/justfile @@ -19,8 +19,19 @@ prepare: fi test: prepare + #!/bin/bash + set -eou pipefail poetry run python -m unittest discover + set +e + msg=$(2>&1 poetry run python example.py) + if [ $? != 0 ]; then + >&2 echo "$msg" + exit 1 + else + echo -e 'poetry run python example.py... \x1b[32mok\x1b[0m' + fi + poetry *args: prepare #!/bin/bash poetry $args diff --git a/poetry.lock b/poetry.lock index 31f05e7..3457942 100644 --- a/poetry.lock +++ b/poetry.lock @@ -22,6 +22,34 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "appnope" +version = "0.1.3" +description = "Disable App Nap on macOS >= 10.9" +optional = false +python-versions = "*" +files = [ + {file = "appnope-0.1.3-py2.py3-none-any.whl", hash = "sha256:265a455292d0bd8a72453494fa24df5a11eb18373a60c7c0430889f22548605e"}, + {file = "appnope-0.1.3.tar.gz", hash = "sha256:02bd91c4de869fbb1e1c50aafc4098827a7a54ab2f39d9dcba6c9547ed920e24"}, +] + +[[package]] +name = "asttokens" +version = "2.4.0" +description = "Annotate AST trees with source code positions" +optional = false +python-versions = "*" +files = [ + {file = "asttokens-2.4.0-py2.py3-none-any.whl", hash = "sha256:cf8fc9e61a86461aa9fb161a14a0841a03c405fa829ac6b202670b3495d2ce69"}, + {file = "asttokens-2.4.0.tar.gz", hash = "sha256:2e0171b991b2c959acc6c49318049236844a5da1d65ba2672c4880c1c894834e"}, +] + +[package.dependencies] +six = ">=1.12.0" + +[package.extras] +test = ["astroid", "pytest"] + [[package]] name = "babel" version = "2.13.0" @@ -36,6 +64,17 @@ files = [ [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] +[[package]] +name = "backcall" +version = "0.2.0" +description = "Specifications for callback functions passed in to an API" +optional = false +python-versions = "*" +files = [ + {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"}, + {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"}, +] + [[package]] name = "black" version = "23.9.1" @@ -281,6 +320,17 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "decorator" +version = "5.1.1" +description = "Decorators for Humans" +optional = false +python-versions = ">=3.5" +files = [ + {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, + {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, +] + [[package]] name = "docutils" version = "0.18.1" @@ -292,6 +342,34 @@ files = [ {file = "docutils-0.18.1.tar.gz", hash = "sha256:679987caf361a7539d76e584cbeddc311e3aee937877c87346f31debc63e9d06"}, ] +[[package]] +name = "exceptiongroup" +version = "1.1.3" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "executing" +version = "2.0.0" +description = "Get the currently executing AST node of a frame, and other information" +optional = false +python-versions = "*" +files = [ + {file = "executing-2.0.0-py2.py3-none-any.whl", hash = "sha256:06df6183df67389625f4e763921c6cf978944721abf3e714000200aab95b0657"}, + {file = "executing-2.0.0.tar.gz", hash = "sha256:0ff053696fdeef426cda5bd18eacd94f82c91f49823a2e9090124212ceea9b08"}, +] + +[package.extras] +tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] + [[package]] name = "extism-sys" version = "0.5.3" @@ -351,6 +429,81 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +[[package]] +name = "ipdb" +version = "0.13.13" +description = "IPython-enabled pdb" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4"}, + {file = "ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726"}, +] + +[package.dependencies] +decorator = {version = "*", markers = "python_version > \"3.6\""} +ipython = {version = ">=7.31.1", markers = "python_version > \"3.6\""} +tomli = {version = "*", markers = "python_version > \"3.6\" and python_version < \"3.11\""} + +[[package]] +name = "ipython" +version = "8.16.1" +description = "IPython: Productive Interactive Computing" +optional = false +python-versions = ">=3.9" +files = [ + {file = "ipython-8.16.1-py3-none-any.whl", hash = "sha256:0852469d4d579d9cd613c220af7bf0c9cc251813e12be647cb9d463939db9b1e"}, + {file = "ipython-8.16.1.tar.gz", hash = "sha256:ad52f58fca8f9f848e256c629eff888efc0528c12fe0f8ec14f33205f23ef938"}, +] + +[package.dependencies] +appnope = {version = "*", markers = "sys_platform == \"darwin\""} +backcall = "*" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +decorator = "*" +exceptiongroup = {version = "*", markers = "python_version < \"3.11\""} +jedi = ">=0.16" +matplotlib-inline = "*" +pexpect = {version = ">4.3", markers = "sys_platform != \"win32\""} +pickleshare = "*" +prompt-toolkit = ">=3.0.30,<3.0.37 || >3.0.37,<3.1.0" +pygments = ">=2.4.0" +stack-data = "*" +traitlets = ">=5" +typing-extensions = {version = "*", markers = "python_version < \"3.10\""} + +[package.extras] +all = ["black", "curio", "docrepr", "exceptiongroup", "ipykernel", "ipyparallel", "ipywidgets", "matplotlib", "matplotlib (!=3.2.0)", "nbconvert", "nbformat", "notebook", "numpy (>=1.21)", "pandas", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "qtconsole", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "trio", "typing-extensions"] +black = ["black"] +doc = ["docrepr", "exceptiongroup", "ipykernel", "matplotlib", "pytest (<7)", "pytest (<7.1)", "pytest-asyncio", "setuptools (>=18.5)", "sphinx (>=1.3)", "sphinx-rtd-theme", "stack-data", "testpath", "typing-extensions"] +kernel = ["ipykernel"] +nbconvert = ["nbconvert"] +nbformat = ["nbformat"] +notebook = ["ipywidgets", "notebook"] +parallel = ["ipyparallel"] +qtconsole = ["qtconsole"] +test = ["pytest (<7.1)", "pytest-asyncio", "testpath"] +test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.21)", "pandas", "pytest (<7.1)", "pytest-asyncio", "testpath", "trio"] + +[[package]] +name = "jedi" +version = "0.19.1" +description = "An autocompletion tool for Python that can be used for text editors." +optional = false +python-versions = ">=3.6" +files = [ + {file = "jedi-0.19.1-py2.py3-none-any.whl", hash = "sha256:e983c654fe5c02867aef4cdfce5a2fbb4a50adc0af145f70504238f18ef5e7e0"}, + {file = "jedi-0.19.1.tar.gz", hash = "sha256:cf0496f3651bc65d7174ac1b7d043eff454892c708a87d1b683e57b569927ffd"}, +] + +[package.dependencies] +parso = ">=0.8.3,<0.9.0" + +[package.extras] +docs = ["Jinja2 (==2.11.3)", "MarkupSafe (==1.1.1)", "Pygments (==2.8.1)", "alabaster (==0.7.12)", "babel (==2.9.1)", "chardet (==4.0.0)", "commonmark (==0.8.1)", "docutils (==0.17.1)", "future (==0.18.2)", "idna (==2.10)", "imagesize (==1.2.0)", "mock (==1.0.1)", "packaging (==20.9)", "pyparsing (==2.4.7)", "pytz (==2021.1)", "readthedocs-sphinx-ext (==2.1.4)", "recommonmark (==0.5.0)", "requests (==2.25.1)", "six (==1.15.0)", "snowballstemmer (==2.1.0)", "sphinx (==1.8.5)", "sphinx-rtd-theme (==0.4.3)", "sphinxcontrib-serializinghtml (==1.1.4)", "sphinxcontrib-websupport (==1.2.4)", "urllib3 (==1.26.4)"] +qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] +testing = ["Django", "attrs", "colorama", "docopt", "pytest (<7.0.0)"] + [[package]] name = "jinja2" version = "3.1.2" @@ -437,6 +590,20 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.6" +description = "Inline Matplotlib backend for Jupyter" +optional = false +python-versions = ">=3.5" +files = [ + {file = "matplotlib-inline-0.1.6.tar.gz", hash = "sha256:f887e5f10ba98e8d2b150ddcf4702c1e5f8b3a20005eb0f74bfdbd360ee6f304"}, + {file = "matplotlib_inline-0.1.6-py3-none-any.whl", hash = "sha256:f1f41aab5328aa5aaea9b16d083b128102f8712542f819fe7e6a420ff581b311"}, +] + +[package.dependencies] +traitlets = "*" + [[package]] name = "mypy" version = "1.5.1" @@ -505,6 +672,21 @@ files = [ {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, ] +[[package]] +name = "parso" +version = "0.8.3" +description = "A Python Parser" +optional = false +python-versions = ">=3.6" +files = [ + {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, + {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, +] + +[package.extras] +qa = ["flake8 (==3.8.3)", "mypy (==0.782)"] +testing = ["docopt", "pytest (<6.0.0)"] + [[package]] name = "pathspec" version = "0.11.2" @@ -516,6 +698,31 @@ files = [ {file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"}, ] +[[package]] +name = "pexpect" +version = "4.8.0" +description = "Pexpect allows easy control of interactive console applications." +optional = false +python-versions = "*" +files = [ + {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, + {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, +] + +[package.dependencies] +ptyprocess = ">=0.5" + +[[package]] +name = "pickleshare" +version = "0.7.5" +description = "Tiny 'shelve'-like database with concurrency support" +optional = false +python-versions = "*" +files = [ + {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, + {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, +] + [[package]] name = "platformdirs" version = "3.11.0" @@ -531,6 +738,45 @@ files = [ docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.1)", "sphinx-autodoc-typehints (>=1.24)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)"] +[[package]] +name = "prompt-toolkit" +version = "3.0.39" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "prompt_toolkit-3.0.39-py3-none-any.whl", hash = "sha256:9dffbe1d8acf91e3de75f3b544e4842382fc06c6babe903ac9acb74dc6e08d88"}, + {file = "prompt_toolkit-3.0.39.tar.gz", hash = "sha256:04505ade687dc26dc4284b1ad19a83be2f2afe83e7a828ace0c72f3a1df72aac"}, +] + +[package.dependencies] +wcwidth = "*" + +[[package]] +name = "ptyprocess" +version = "0.7.0" +description = "Run a subprocess in a pseudo terminal" +optional = false +python-versions = "*" +files = [ + {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, + {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, +] + +[[package]] +name = "pure-eval" +version = "0.2.2" +description = "Safely evaluate AST nodes without side effects" +optional = false +python-versions = "*" +files = [ + {file = "pure_eval-0.2.2-py3-none-any.whl", hash = "sha256:01eaab343580944bc56080ebe0a674b39ec44a945e6d09ba7db3cb8cec289350"}, + {file = "pure_eval-0.2.2.tar.gz", hash = "sha256:2b45320af6dfaa1750f543d714b6d1c520a1688dec6fd24d339063ce0aaa9ac3"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "pycparser" version = "2.21" @@ -577,6 +823,17 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] + [[package]] name = "snowballstemmer" version = "2.2.0" @@ -779,6 +1036,25 @@ Sphinx = ">=5" lint = ["docutils-stubs", "flake8", "mypy"] test = ["pytest"] +[[package]] +name = "stack-data" +version = "0.6.3" +description = "Extract data from python stack frames and tracebacks for informative displays" +optional = false +python-versions = "*" +files = [ + {file = "stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695"}, + {file = "stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9"}, +] + +[package.dependencies] +asttokens = ">=2.1.0" +executing = ">=1.2.0" +pure-eval = "*" + +[package.extras] +tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] + [[package]] name = "tomli" version = "2.0.1" @@ -790,6 +1066,21 @@ files = [ {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +[[package]] +name = "traitlets" +version = "5.11.2" +description = "Traitlets Python configuration system" +optional = false +python-versions = ">=3.8" +files = [ + {file = "traitlets-5.11.2-py3-none-any.whl", hash = "sha256:98277f247f18b2c5cabaf4af369187754f4fb0e85911d473f72329db8a7f4fae"}, + {file = "traitlets-5.11.2.tar.gz", hash = "sha256:7564b5bf8d38c40fa45498072bf4dc5e8346eb087bbf1e2ae2d8774f6a0f078e"}, +] + +[package.extras] +docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] +test = ["argcomplete (>=3.0.3)", "mypy (>=1.5.1)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -818,6 +1109,17 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "wcwidth" +version = "0.2.8" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.8-py2.py3-none-any.whl", hash = "sha256:77f719e01648ed600dfa5402c347481c0992263b81a027344f3e1ba25493a704"}, + {file = "wcwidth-0.2.8.tar.gz", hash = "sha256:8705c569999ffbb4f6a87c6d1b80f324bd6db952f5eb0b95bc07517f4c1813d4"}, +] + [[package]] name = "zipp" version = "3.17.0" @@ -836,4 +1138,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bd6b6d2345af9c5ccd7f9f3159ad81e2e98869b51d5e35dda5ad267f3f9ef9ea" +content-hash = "74e92a5dca3335c6a4cc73bf48030d131be41f376fe68bf77b174847338b6ca3" diff --git a/pyproject.toml b/pyproject.toml index 52f1378..40b0545 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ sphinx = "^7.2.6" sphinx-rtd-theme = "^1.3.0" sphinx-autodoc-typehints = "^1.24.0" mypy = "^1.5.1" +ipdb = "^0.13.13" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/test_extism.py b/tests/test_extism.py index 7e620e5..7bc4840 100644 --- a/tests/test_extism.py +++ b/tests/test_extism.py @@ -6,6 +6,17 @@ from threading import Thread from datetime import datetime, timedelta from os.path import join, dirname +import typing +import pickle + + +# A pickle-able object. +class Gribble: + def __init__(self, v): + self.v = v + + def frobbitz(self): + return "gromble %s" % self.v class TestExtism(unittest.TestCase): @@ -55,26 +66,79 @@ def test_extism_plugin_timeout(self): ) def test_extism_host_function(self): - @extism.host_fn + @extism.host_fn( + signature=([extism.ValType.I64], [extism.ValType.I64]), user_data=b"test" + ) def hello_world(plugin, params, results, user_data): offs = plugin.alloc(len(user_data)) mem = plugin.memory(offs) mem[:] = user_data results[0].value = offs.offset - f = [ - extism.Function( - "hello_world", - [extism.ValType.I64], - [extism.ValType.I64], - hello_world, - b"test", - ) - ] - plugin = extism.Plugin(self._manifest(functions=True), functions=f, wasi=True) + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) res = plugin.call("count_vowels", "aaa") self.assertEqual(res, b"test") + def test_inferred_extism_host_function(self): + @extism.host_fn(user_data=b"test") + def hello_world(inp: str, *user_data) -> str: + return "hello world: %s %s" % (inp, user_data[0].decode("utf-8")) + + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) + res = plugin.call("count_vowels", "aaa") + self.assertEqual(res, b'hello world: {"count": 3} test') + + def test_inferred_json_param_extism_host_function(self): + @extism.host_fn(user_data=b"test") + def hello_world(inp: typing.Annotated[dict, extism.Json], *user_data) -> str: + return "hello world: %s %s" % (inp["count"], user_data[0].decode("utf-8")) + + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) + res = plugin.call("count_vowels", "aaa") + self.assertEqual(res, b"hello world: 3 test") + + def test_codecs(self): + @extism.host_fn(user_data=b"test") + def hello_world( + inp: typing.Annotated[ + str, extism.Codec(lambda xs: xs.decode().replace("o", "u")) + ], + *user_data + ) -> typing.Annotated[ + str, extism.Codec(lambda xs: xs.replace("u", "a").encode()) + ]: + return inp + + foo = b"bar" + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) + res = plugin.call("count_vowels", "aaa") + # Iiiiiii + self.assertEqual(res, b'{"caant": 3}') # stand it, I know you planned it + + def test_inferred_pickle_return_param_extism_host_function(self): + @extism.host_fn(user_data=b"test") + def hello_world( + inp: typing.Annotated[dict, extism.Json], *user_data + ) -> typing.Annotated[Gribble, extism.Pickle]: + return Gribble("robble") + + plugin = extism.Plugin( + self._manifest(functions=True), functions=[hello_world], wasi=True + ) + res = plugin.call("count_vowels", "aaa") + + result = pickle.loads(res) + self.assertIsInstance(result, Gribble) + self.assertEqual(result.frobbitz(), "gromble robble") + def test_extism_plugin_cancel(self): plugin = extism.Plugin(self._loop_manifest()) cancel_handle = plugin.cancel_handle()