diff --git a/extism/__init__.py b/extism/__init__.py index ce4a2c4..de51448 100644 --- a/extism/__init__.py +++ b/extism/__init__.py @@ -15,6 +15,7 @@ ValType, Val, CurrentPlugin, + TypedPlugin, Codec, Json, Pickle, @@ -30,6 +31,7 @@ "Memory", "host_fn", "CurrentPlugin", + "TypedPlugin", "Function", "ValType", "Val", diff --git a/extism/extism.py b/extism/extism.py index dedd51e..5c67ca3 100644 --- a/extism/extism.py +++ b/extism/extism.py @@ -21,6 +21,7 @@ from annotated_types import Gt import functools import pickle +import inspect HOST_FN_REGISTRY: List[Any] = [] @@ -840,3 +841,105 @@ def handle_args(current, inputs, n_inputs, outputs, n_outputs, user_data): for i in range(n_outputs): _convert_output(outputs[i], outp[i]) + + +def _get_typed_plugin_parser(xs): + string = lambda x: x[:].decode() + if xs == str: + return string + + if xs == bytes: + return lambda x: x[:] + + metadata = getattr(xs, "__metadata__", ()) + for item in metadata: + if item == Json: + return lambda x: json.loads(x[:]) + + if item == Pickle: + return lambda x: pickle.loads(x[:]) + + if isinstance(item, Codec): + return lambda x: item.codec(x[:]) + + raise TypeError("Could not infer return type") + + +def _get_typed_plugin_encoder(xs): + if xs == str: + return lambda x: x + + if xs == bytes: + return lambda x: x + + metadata = getattr(xs, "__metadata__", ()) + for item in metadata: + if item == Json: + return lambda x: json.dumps(x) + + if item == Pickle: + return lambda x: pickle.dumps(x) + + if isinstance(item, Codec): + return lambda x: item.codec(x[:]) + + raise TypeError("Could not infer input type") + + +class TypedPlugin: + """ + Allows plugins to be defined as classes that get transformed into Plugin function calls + """ + + plugin: Plugin + + def __init__(self, *args, **kw): + """ + Initialize a typed plugin, this will check all the class methods function names to make sure + they're registered. Since this wraps `Plugin.call` the behavior is the same when the method is + untyped, however type annotations can be included to specify a particular encoding. + + :param plugin: An extisting plugin object or parameters to be forwarded to `Plugin.__init__` + """ + if len(args) > 0 and isinstance(args[0], Plugin): + self.plugin = args[0] + else: + self.plugin = Plugin(*args, **kw) + + # Wrap methods + methods = inspect.getmembers(self, predicate=inspect.ismethod) + for name, m in methods: + if name == "__init__": + continue + + if not self.plugin.function_exists(name): + raise Error(f"Function not found in {self.__class__.__name__}: {name}") + + hints = get_type_hints(m, include_extras=True) + if len(hints) > 2: + raise Error( + f"TypedPlugin methods should take a single input parameter, there are {len(hints)} in {name}" + ) + + parse_return = _get_typed_plugin_parser(hints["return"]) + encode_input = lambda x: x + + n = 0 + for k, v in hints.items(): + if k == "return": + continue + n += 1 + encode = _get_typed_plugin_encoder(v) + if n == 0: + + def func(input): + return self.plugin.call(name, b"", parse=parse_return) + + else: + + def func(input): + return self.plugin.call( + name, encode_input(input), parse=parse_return + ) + + self.__setattr__(name, func) diff --git a/tests/test_extism.py b/tests/test_extism.py index 3d42539..a483126 100644 --- a/tests/test_extism.py +++ b/tests/test_extism.py @@ -19,6 +19,19 @@ def frobbitz(self): return "gromble %s" % self.v +class Typed(extism.TypedPlugin): + def count_vowels(self, input: str) -> typing.Annotated[str, extism.Json]: + raise NotImplementedError + + +class TypedIntCodec(extism.TypedPlugin): + def count_vowels( + self, + input: str, + ) -> typing.Annotated[int, extism.Codec(lambda x: json.loads(x[:])["count"])]: + raise NotImplementedError + + class TestExtism(unittest.TestCase): def test_call_plugin(self): plugin = extism.Plugin(self._manifest()) @@ -159,6 +172,23 @@ def cancel(handle): Thread(target=cancel, args=[cancel_handle]).run() self.assertRaises(extism.Error, lambda: plugin.call("infinite_loop", b"")) + def test_typed_plugin(self): + t = Typed(extism.Plugin(self._count_vowels_wasm(), wasi=True)) + res = t.count_vowels("foobar") + self.assertEqual(type(res), dict) + self.assertEqual(res, {"count": 3, "total": 3, "vowels": "aeiouAEIOU"}) + + def test_typed_plugin_codec(self): + t = TypedIntCodec(self._count_vowels_wasm(), wasi=True) + res = t.count_vowels("foobar") + self.assertEqual(type(res), int) + self.assertEqual(res, 3) + + def test_failed_typed_plugin(self): + self.assertRaises( + extism.Error, lambda: TypedIntCodec(self._loop_manifest(), wasi=True) + ) + def _manifest(self, functions=False): wasm = self._count_vowels_wasm(functions) hash = hashlib.sha256(wasm).hexdigest()