From aad948eb296fc86df21eb841b71e0819247f632e Mon Sep 17 00:00:00 2001 From: seem Date: Tue, 31 May 2022 19:23:59 +0200 Subject: [PATCH] enable more robust multiple dispatch with `plum` --- fastcore/_nbdev.py | 6 +- fastcore/basics.py | 2 + fastcore/dispatch.py | 219 ++++----- fastcore/imports.py | 10 + fastcore/transform.py | 52 +- nbs/01_basics.ipynb | 16 +- nbs/04_dispatch.ipynb | 1066 +++++++++++++++++++++++++--------------- nbs/05_transform.ipynb | 124 +++-- setup.py | 2 +- 9 files changed, 907 insertions(+), 590 deletions(-) diff --git a/fastcore/_nbdev.py b/fastcore/_nbdev.py index 13acbd97b..20e0a973d 100644 --- a/fastcore/_nbdev.py +++ b/fastcore/_nbdev.py @@ -227,10 +227,8 @@ "do_request": "03b_net.ipynb", "start_server": "03b_net.ipynb", "start_client": "03b_net.ipynb", - "lenient_issubclass": "04_dispatch.ipynb", - "sorted_topologically": "04_dispatch.ipynb", - "TypeDispatch": "04_dispatch.ipynb", - "DispatchReg": "04_dispatch.ipynb", + "FastFunction": "04_dispatch.ipynb", + "FastDispatcher": "04_dispatch.ipynb", "typedispatch": "04_dispatch.ipynb", "retain_meta": "04_dispatch.ipynb", "default_set_meta": "04_dispatch.ipynb", diff --git a/fastcore/basics.py b/fastcore/basics.py index c23e9d179..a5b46edf5 100644 --- a/fastcore/basics.py +++ b/fastcore/basics.py @@ -888,6 +888,8 @@ def copy_func(f): fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) fn.__kwdefaults__ = f.__kwdefaults__ fn.__dict__.update(f.__dict__) + fn.__annotations__.update(f.__annotations__) + fn.__qualname__ = f.__qualname__ return fn # Cell diff --git a/fastcore/dispatch.py b/fastcore/dispatch.py index a498f6676..4c45757a1 100644 --- a/fastcore/dispatch.py +++ b/fastcore/dispatch.py @@ -4,154 +4,113 @@ from __future__ import annotations -__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast', - 'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types'] +__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type', + 'retain_types', 'explode_types'] # Cell #nbdev_comment from __future__ import annotations from .imports import * from .foundation import * from .utils import * +from .meta import delegates from collections import defaultdict +from plum import add_conversion_method, dispatch, Function, Dispatcher # Cell -def lenient_issubclass(cls, types): - "If possible return whether `cls` is a subclass of `types`, otherwise return False." - if cls is object and types is not object: return False # treat `object` as highest level - try: return isinstance(cls, types) or issubclass(cls, types) - except: return False +def _eval_annotations(f): + # Evaluate future annotations before passing to plum to support backported union operator `|` + # TODO: Could raise deprecation warning here on tuples... + f = copy_func(f) + for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v + return f # Cell -def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): - "Return a new list containing all items from the iterable sorted topologically" - l,res = L(list(iterable)),[] - for _ in range(len(l)): - t = l.reduce(lambda x,y: y if cmp(y,x) else x) - res.append(t), l.remove(t) - return res[::-1] if reverse else res +def _pt_repr(o): + n = type(o).__name__ + if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]" + if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]' + if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]' + if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]' + if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]' + if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types()))) + assert len(o.get_types()) == 1 + return o.get_types()[0].__name__ # Cell -def _chk_defaults(f, ann): - pass -# Implementation removed until we can figure out how to do this without `inspect` module -# try: # Some callables don't have signatures, so ignore those errors -# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)] -# if any(p.default!=inspect.Parameter.empty for p in params): -# warn(f"{f.__name__} has default params. These will be ignored.") -# except ValueError: pass +class FastFunction(Function): + def __init__(self, fs, owner=None): + if not isinstance(fs, (tuple,list)): fs=(fs,) + super().__init__(fs[0], owner) + for f in fs: self.dispatch(f) + @delegates(Function.dispatch) + def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs) +# def register(self, signature, f, precedence=0, return_type=object, delayed=None): +# return super().register(signature, _eval_annotations(f), precedence, return_type, delayed) + def __repr__(self): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}" + for s, (f, r) in self.methods.items()) + def __getitem__(self, ts): + ts = L(ts) + nargs = min(len(o) for o in self.methods.keys()) + while len(ts) < nargs: ts.append(object) + return self.invoke(*ts) # Cell -def _p2_anno(f): - "Get the 1st 2 annotations of `f`, defaulting to `object`" - hints = type_hints(f) - ann = [o for n,o in hints.items() if n!='return'] - if callable(f): _chk_defaults(f, ann) - while len(ann)<2: ann.append(object) - return ann[:2] +class FastDispatcher(Dispatcher): + def _get_function(self, method, owner): + "Adapted from Dispatcher._get_function to use FastFunction" + name = method.__name__ + if owner: + if owner not in self._classes: self._classes[owner] = {} + namespace = self._classes[owner] + else: namespace = self._functions + if name not in namespace: namespace[name] = FastFunction(method, owner=owner) + return namespace[name] + + def _call(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs) + + @delegates(Dispatcher.__call__, but='method') + def __call__(self, f, **kwargs): + ann,glb,loc = get_annotations_ex(f) + if ann: + k,cls = next(iter(ann.items())) + if k == 'self': + cls = union2tuple(eval_type(cls, glb, loc)) + return self.to(cls)(f, **kwargs) + return self._call(f, **kwargs) + + def _to(self, cls, nm, f, **kwargs): + nf = copy_func(f) + nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner + pf = self._call(nf, **kwargs) + # plum uses __set_name__ to resolve a plum.Function's owner + # since we assign after class creation, __set_name__ must be called directly + # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__ + pf.__set_name__(cls, nm) + pf = pf.resolve() + setattr(cls, nm, pf) + return pf + + def to(self, cls): + "Decorator: dispatch `f` to `cls.f`" + if not isinstance(cls, (tuple,list)): cls = (cls,) + @delegates(self.__call__) + def inner(f, **kwargs): + for c_ in cls: + nm = f.__name__ + # check __dict__ to avoid inherited methods + # but use getattr so that pf.__get__ is called + pf = getattr(c_,nm) if nm in c_.__dict__ else None + if pf is not None: + if not hasattr(pf, 'dispatch'): + pf = self._to(c_, nm, pf, **kwargs) + pf.dispatch(f) + else: pf = self._to(c_, nm, f, **kwargs) + return pf + return inner # Cell -class _TypeDict: - def __init__(self): self.d,self.cache = {},{} - - def _reset(self): - self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} - self.cache = {} - - def add(self, t, f): - "Add type `t` and function `f`" - if not isinstance(t, tuple): t = tuple(L(union2tuple(t))) - for t_ in t: self.d[t_] = f - self._reset() - - def all_matches(self, k): - "Find first matching type that is a super-class of `k`" - if k not in self.cache: - types = [f for f in self.d if lenient_issubclass(k,f)] - self.cache[k] = [self.d[o] for o in types] - return self.cache[k] - - def __getitem__(self, k): - "Find first matching type that is a super-class of `k`" - res = self.all_matches(k) - return res[0] if len(res) else None - - def __repr__(self): return self.d.__repr__() - def first(self): return first(self.d.values()) - -# Cell -class TypeDispatch: - "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`" - def __init__(self, funcs=(), bases=()): - self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None)) - for o in L(funcs): self.add(o) - self.inst = None - self.owner = None - - def add(self, f): - "Add type `t` and function `f`" - if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) - else: a0,a1 = _p2_anno(f) - t = self.funcs.d.get(a0) - if t is None: - t = _TypeDict() - self.funcs.add(a0, t) - t.add(a1, f) - - def first(self): - "Get first function in ordered dict of type:func." - return self.funcs.first().first() - - def returns(self, x): - "Get the return type of annotation of `x`." - return anno_ret(self[type(x)]) - - def _attname(self,k): return getattr(k,'__name__',str(k)) - def __repr__(self): - r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}' - for k in self.funcs.d for l,v in self.funcs[k].d.items()] - r = r + [o.__repr__() for o in self.bases] - return '\n'.join(r) - - def __call__(self, *args, **kwargs): - ts = L(args).map(type)[:2] - f = self[tuple(ts)] - if not f: return args[0] - if isinstance(f, staticmethod): f = f.__func__ - elif self.inst is not None: f = MethodType(f, self.inst) - elif self.owner is not None: f = MethodType(f, self.owner) - return f(*args, **kwargs) - - def __get__(self, inst, owner): - self.inst = inst - self.owner = owner - return self - - def __getitem__(self, k): - "Find first matching type that is a super-class of `k`" - k = L(k) - while len(k)<2: k.append(object) - r = self.funcs.all_matches(k[0]) - for t in r: - o = t[k[1]] - if o is not None: return o - for base in self.bases: - res = base[k] - if res is not None: return res - return None - -# Cell -class DispatchReg: - "A global registry for `TypeDispatch` objects keyed by function name" - def __init__(self): self.d = defaultdict(TypeDispatch) - def __call__(self, f): - if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}' - else: nm = f'{f.__qualname__}' - if isinstance(f, classmethod): f=f.__func__ - self.d[nm].add(f) - return self.d[nm] - -typedispatch = DispatchReg() +typedispatch = FastDispatcher() # Cell #nbdev_comment _all_=['cast'] diff --git a/fastcore/imports.py b/fastcore/imports.py index 086c899fd..12665e594 100644 --- a/fastcore/imports.py +++ b/fastcore/imports.py @@ -1,5 +1,6 @@ import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum +from copy import copy from operator import itemgetter,attrgetter from warnings import warn from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple @@ -14,6 +15,15 @@ MethodDescriptorType = type(str.join) from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace +#Patch autoreload (if its loaded) to work with plum +try: from IPython import get_ipython +except ImportError: pass +else: + ip = get_ipython() + if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded: + from plum.autoreload import activate + activate() + NoneType = type(None) string_classes = (str,bytes) diff --git a/fastcore/transform.py b/fastcore/transform.py index 7f3e2efc8..07d891e28 100644 --- a/fastcore/transform.py +++ b/fastcore/transform.py @@ -9,34 +9,34 @@ from .utils import * from .dispatch import * import inspect +from copy import copy +from plum import add_conversion_method # Cell _tfm_methods = 'encodes','decodes','setups' +def _is_tfm_method(n, f): return n in _tfm_methods and callable(f) + class _TfmDict(dict): - def __setitem__(self,k,v): - if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v) - if k not in self: super().__setitem__(k,TypeDispatch()) - self[k].add(v) + def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v) # Cell class _TfmMeta(type): def __new__(cls, name, bases, dict): + # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back res = super().__new__(cls, name, bases, dict) - for nm in _tfm_methods: - base_td = [getattr(b,nm,None) for b in bases] - if nm in res.__dict__: getattr(res,nm).bases = base_td - else: setattr(res, nm, TypeDispatch(bases=base_td)) res.__signature__ = inspect.signature(res.__init__) return res def __call__(cls, *args, **kwargs): - f = args[0] if args else None - n = getattr(f,'__name__',None) - if callable(f) and n in _tfm_methods: - getattr(cls,n).add(f) - return f - return super().__call__(*args, **kwargs) + f = first(args) + n = getattr(f, '__name__', None) + if _is_tfm_method(n, f): return typedispatch.to(cls)(f) + obj = super().__call__(*args, **kwargs) + # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable + # instances of cls, fix it + if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__) + return obj @classmethod def __prepare__(cls, name, bases): return _TfmDict() @@ -60,19 +60,20 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None): self.init_enc = enc or dec if not self.init_enc: return - self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch() + def identity(x): return x + for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity)) if enc: - self.encodes.add(enc) + self.encodes.dispatch(enc) self.order = getattr(enc,'order',self.order) if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values())) self._name = _get_name(enc) - if dec: self.decodes.add(dec) + if dec: self.decodes.dispatch(dec) @property def name(self): return getattr(self, '_name', _get_name(self)) def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs) def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs) - def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}decodes: {self.decodes}' + def __repr__(self): return f'{self.name}:\nencodes: {self.encodes}\ndecodes: {self.decodes}' def setup(self, items=None, train_setup=False): train_setup = train_setup if self.train_setup is None else self.train_setup @@ -85,13 +86,24 @@ def _call(self, fn, x, split_idx=None, **kwargs): def _do_call(self, f, x, **kwargs): if not _is_tuple(x): if f is None: return x - ret = f.returns(x) if hasattr(f,'returns') else None - return retain_type(f(x, **kwargs), x, ret) + ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)] + _, ret = f.resolve_method(*ts) + ret = ret._type + # plum reads empty return annotation as object, retain_type expects it as None + if ret is object: ret = None + return retain_type(f(x,**kwargs), x, ret) res = tuple(self._do_call(f, x_, **kwargs) for x_ in x) return retain_type(res, x) + def encodes(self, x): return x + def decodes(self, x): return x + def setups(self, dl): return dl add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform") +# Cell +#Implement the Transform convention that a None return annotation disables conversion +add_conversion_method(object, NoneType, lambda x: x) + # Cell class InplaceTransform(Transform): "A `Transform` that modifies in-place and just returns whatever it's passed" diff --git a/nbs/01_basics.ipynb b/nbs/01_basics.ipynb index cb3c66a0f..756ca23de 100644 --- a/nbs/01_basics.ipynb +++ b/nbs/01_basics.ipynb @@ -804,7 +804,7 @@ { "data": { "text/markdown": [ - "

noop[source]

\n", + "

noop[source]

\n", "\n", "> noop(**`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", @@ -840,7 +840,7 @@ { "data": { "text/markdown": [ - "

noops[source]

\n", + "

noops[source]

\n", "\n", "> noops(**`self`**, **`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", @@ -4676,6 +4676,8 @@ " fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)\n", " fn.__kwdefaults__ = f.__kwdefaults__\n", " fn.__dict__.update(f.__dict__)\n", + " fn.__annotations__.update(f.__annotations__)\n", + " fn.__qualname__ = f.__qualname__\n", " return fn" ] }, @@ -5635,7 +5637,7 @@ { "data": { "text/markdown": [ - "

ipython_shell[source]

\n", + "

ipython_shell[source]

\n", "\n", "> ipython_shell()\n", "\n", @@ -5661,7 +5663,7 @@ { "data": { "text/markdown": [ - "

in_ipython[source]

\n", + "

in_ipython[source]

\n", "\n", "> in_ipython()\n", "\n", @@ -5687,7 +5689,7 @@ { "data": { "text/markdown": [ - "

in_colab[source]

\n", + "

in_colab[source]

\n", "\n", "> in_colab()\n", "\n", @@ -5713,7 +5715,7 @@ { "data": { "text/markdown": [ - "

in_jupyter[source]

\n", + "

in_jupyter[source]

\n", "\n", "> in_jupyter()\n", "\n", @@ -5739,7 +5741,7 @@ { "data": { "text/markdown": [ - "

in_notebook[source]

\n", + "

in_notebook[source]

\n", "\n", "> in_notebook()\n", "\n", diff --git a/nbs/04_dispatch.ipynb b/nbs/04_dispatch.ipynb index 08b6bcbea..e70dcd4a8 100644 --- a/nbs/04_dispatch.ipynb +++ b/nbs/04_dispatch.ipynb @@ -20,8 +20,10 @@ "from fastcore.imports import *\n", "from fastcore.foundation import *\n", "from fastcore.utils import *\n", + "from fastcore.meta import delegates\n", "\n", - "from collections import defaultdict" + "from collections import defaultdict\n", + "from plum import add_conversion_method, dispatch, Function, Dispatcher" ] }, { @@ -45,10 +47,33 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## Helpers" + "#export\n", + "def _eval_annotations(f):\n", + " # Evaluate future annotations before passing to plum to support backported union operator `|`\n", + " # TODO: Could raise deprecation warning here on tuples...\n", + " f = copy_func(f)\n", + " for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v\n", + " return f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#hide\n", + "def f(x:int|str) -> float: pass\n", + "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n", + "def f(x:(int,str)) -> float: pass\n", + "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n", + "def f(x): pass\n", + "test_eq(_eval_annotations(f).__annotations__, {})" ] }, { @@ -58,11 +83,16 @@ "outputs": [], "source": [ "#export\n", - "def lenient_issubclass(cls, types):\n", - " \"If possible return whether `cls` is a subclass of `types`, otherwise return False.\"\n", - " if cls is object and types is not object: return False # treat `object` as highest level\n", - " try: return isinstance(cls, types) or issubclass(cls, types)\n", - " except: return False" + "def _pt_repr(o):\n", + " n = type(o).__name__\n", + " if n == 'Tuple': return f\"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]\"\n", + " if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'\n", + " if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'\n", + " if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'\n", + " if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'\n", + " if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))\n", + " assert len(o.get_types()) == 1\n", + " return o.get_types()[0].__name__" ] }, { @@ -71,12 +101,20 @@ "metadata": {}, "outputs": [], "source": [ - "assert not lenient_issubclass(typing.Collection, list)\n", - "assert lenient_issubclass(list, typing.Collection)\n", - "assert lenient_issubclass(typing.Collection, object)\n", - "assert lenient_issubclass(typing.List, typing.Collection)\n", - "assert not lenient_issubclass(typing.Collection, typing.List)\n", - "assert not lenient_issubclass(object, typing.Callable)" + "#hide\n", + "from typing import Dict, List, Iterable, Sequence, Tuple\n", + "from plum.type import VarArgs, ptype\n", + "\n", + "test_eq(_pt_repr(ptype(int)), 'int')\n", + "test_eq(_pt_repr(ptype(Union[int, str])), 'int|str')\n", + "test_eq(_pt_repr(ptype(Tuple[int, str])), 'tuple[int,str]')\n", + "test_eq(_pt_repr(ptype(List[int])), 'list[int]')\n", + "test_eq(_pt_repr(ptype(Sequence[int])), 'Sequence[int]')\n", + "test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')\n", + "test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')\n", + "test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')\n", + "test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),\n", + " 'dict[tuple[int|str,float],list[tuple[object]]]')" ] }, { @@ -86,13 +124,22 @@ "outputs": [], "source": [ "#export\n", - "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n", - " \"Return a new list containing all items from the iterable sorted topologically\"\n", - " l,res = L(list(iterable)),[]\n", - " for _ in range(len(l)):\n", - " t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n", - " res.append(t), l.remove(t)\n", - " return res[::-1] if reverse else res" + "class FastFunction(Function):\n", + " def __init__(self, fs, owner=None):\n", + " if not isinstance(fs, (tuple,list)): fs=(fs,)\n", + " super().__init__(fs[0], owner)\n", + " for f in fs: self.dispatch(f)\n", + " @delegates(Function.dispatch)\n", + " def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)\n", + "# def register(self, signature, f, precedence=0, return_type=object, delayed=None):\n", + "# return super().register(signature, _eval_annotations(f), precedence, return_type, delayed)\n", + " def __repr__(self): return '\\n'.join(f\"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}\"\n", + " for s, (f, r) in self.methods.items())\n", + " def __getitem__(self, ts):\n", + " ts = L(ts)\n", + " nargs = min(len(o) for o in self.methods.keys())\n", + " while len(ts) < nargs: ts.append(object)\n", + " return self.invoke(*ts)" ] }, { @@ -101,9 +148,13 @@ "metadata": {}, "outputs": [], "source": [ - "td = [3, 1, 2, 5]\n", - "test_eq(sorted_topologically(td), [1, 2, 3, 5])\n", - "test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])" + "def f(a,b,c=1,d=2,e=3): return 'f'\n", + "f = FastFunction(f)\n", + "\n", + "test_eq(f[object](1,2,3),'f')\n", + "test_eq(f[object,object](1,2,3),'f')\n", + "test_eq(f[object,object,object](1,2,3),'f')\n", + "test_eq(f[object,object,object,object](1,2,3),'f')" ] }, { @@ -112,8 +163,11 @@ "metadata": {}, "outputs": [], "source": [ - "td = {int:1, numbers.Number:2, numbers.Integral:3}\n", - "test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])" + "#hide\n", + "def _f1(x: int, y: Dict[str, float]) -> float: pass\n", + "def _f2(x: int, y: Tuple[str, float]) -> float: pass\n", + "_f = FastFunction(_f1).dispatch(_f1).dispatch(_f2)\n", + "test_eq(repr(_f), '_f1: (int,dict[str,float]) -> float\\n_f2: (int,tuple[str,float]) -> float')" ] }, { @@ -122,9 +176,7 @@ "metadata": {}, "outputs": [], "source": [ - "td = [numbers.Integral, tuple, list, int, dict]\n", - "td = sorted_topologically(td, cmp=lenient_issubclass)\n", - "assert td.index(int) < td.index(numbers.Integral)" + "# TODO: Add the identity function default???" ] }, { @@ -134,14 +186,65 @@ "outputs": [], "source": [ "#export\n", - "def _chk_defaults(f, ann):\n", - " pass\n", - "# Implementation removed until we can figure out how to do this without `inspect` module\n", - "# try: # Some callables don't have signatures, so ignore those errors\n", - "# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]\n", - "# if any(p.default!=inspect.Parameter.empty for p in params):\n", - "# warn(f\"{f.__name__} has default params. These will be ignored.\")\n", - "# except ValueError: pass" + "class FastDispatcher(Dispatcher):\n", + " def _get_function(self, method, owner):\n", + " \"Adapted from Dispatcher._get_function to use FastFunction\"\n", + " name = method.__name__\n", + " if owner:\n", + " if owner not in self._classes: self._classes[owner] = {}\n", + " namespace = self._classes[owner]\n", + " else: namespace = self._functions\n", + " if name not in namespace: namespace[name] = FastFunction(method, owner=owner)\n", + " return namespace[name]\n", + "\n", + " def _call(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)\n", + "\n", + " @delegates(Dispatcher.__call__, but='method')\n", + " def __call__(self, f, **kwargs):\n", + " ann,glb,loc = get_annotations_ex(f)\n", + " if ann:\n", + " k,cls = next(iter(ann.items()))\n", + " if k == 'self':\n", + " cls = union2tuple(eval_type(cls, glb, loc))\n", + " return self.to(cls)(f, **kwargs)\n", + " return self._call(f, **kwargs)\n", + "\n", + " def _to(self, cls, nm, f, **kwargs):\n", + " nf = copy_func(f)\n", + " nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner\n", + " pf = self._call(nf, **kwargs)\n", + " # plum uses __set_name__ to resolve a plum.Function's owner\n", + " # since we assign after class creation, __set_name__ must be called directly\n", + " # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n", + " pf.__set_name__(cls, nm)\n", + " pf = pf.resolve()\n", + " setattr(cls, nm, pf)\n", + " return pf\n", + "\n", + " def to(self, cls):\n", + " \"Decorator: dispatch `f` to `cls.f`\"\n", + " if not isinstance(cls, (tuple,list)): cls = (cls,)\n", + " @delegates(self.__call__)\n", + " def inner(f, **kwargs):\n", + " for c_ in cls:\n", + " nm = f.__name__\n", + " # check __dict__ to avoid inherited methods\n", + " # but use getattr so that pf.__get__ is called\n", + " pf = getattr(c_,nm) if nm in c_.__dict__ else None\n", + " if pf is not None:\n", + " if not hasattr(pf, 'dispatch'):\n", + " pf = self._to(c_, nm, pf, **kwargs)\n", + " pf.dispatch(f)\n", + " else: pf = self._to(c_, nm, f, **kwargs)\n", + " return pf\n", + " return inner" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Maybe it makes more sense to use patch with dispatch somehow? Cause I'm kinda doing both... What could that look like?" ] }, { @@ -150,14 +253,17 @@ "metadata": {}, "outputs": [], "source": [ - "#export\n", - "def _p2_anno(f):\n", - " \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n", - " hints = type_hints(f)\n", - " ann = [o for n,o in hints.items() if n!='return']\n", - " if callable(f): _chk_defaults(f, ann)\n", - " while len(ann)<2: ann.append(object)\n", - " return ann[:2]" + "#TODO: Move somewhere better\n", + "_dispatch = FastDispatcher()\n", + "class A:\n", + " @_dispatch\n", + " def f(self, x): return 'obj'\n", + "@_dispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", + "\n", + "a = A()\n", + "test_eq(a.f(0), 'int')\n", + "test_eq(a.f(''), 'obj')" ] }, { @@ -166,81 +272,124 @@ "metadata": {}, "outputs": [], "source": [ - "#hide\n", - "def _f(a): pass\n", - "test_eq(_p2_anno(_f), (object,object))\n", - "def _f(a, b): pass\n", - "test_eq(_p2_anno(_f), (object,object))\n", - "def _f(a:None, b)->str: pass\n", - "test_eq(_p2_anno(_f), (NoneType,object))\n", - "def _f(a:str, b)->float: pass\n", - "test_eq(_p2_anno(_f), (str,object))\n", - "def _f(a:None, b:str)->float: pass\n", - "test_eq(_p2_anno(_f), (NoneType,str))\n", - "def _f(a:int, b:int)->float: pass\n", - "test_eq(_p2_anno(_f), (int,int))\n", - "def _f(self, a:int, b:int): pass\n", - "test_eq(_p2_anno(_f), (int,int))\n", - "def _f(a:int, b:str)->float: pass\n", - "test_eq(_p2_anno(_f), (int,str))\n", - "test_eq(_p2_anno(attrgetter('foo')), (object,object))" + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x,y): return 2\n", + "@_dispatch\n", + "def f(x): return 1\n", + "test_eq(f(0,0), 2)\n", + "test_eq(f(0), 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "([object, object], [int, object])" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "#hide\n", - "# Disabled until _chk_defaults fixed\n", - "# def _f(x:int,y:int=10): pass\n", - "# test_warns(lambda: _p2_anno(_f))\n", - "def _f(x:int,y=10): pass\n", - "_p2_anno(None),_p2_anno(_f)" + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:int|str) -> str: return 'int|str'\n", + "test_eq(f(0), 'int|str')\n", + "test_eq(f(''), 'int|str')" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "## TypeDispatch" + "@f.dispatch\n", + "def f(x:float|tuple) -> str: return 'float|tuple'" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n", + "test_eq(f(0), 'int|str')\n", + "test_eq(f(''), 'int|str')\n", + "test_eq(f(0.0), 'float|tuple')\n", + "test_eq(f(()), 'float|tuple')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A:\n", + " @_dispatch\n", + " def f(self, x): return 'A'\n", + "class B(A):\n", + " @_dispatch\n", + " def f(self, x:int): return 'B'\n", + "b = B()\n", + "test_eq(b.f(0), 'B')\n", + "test_eq(b.f(''), 'A')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A:\n", + " @_dispatch\n", + " def f(self, x:int|str): return 'int|str'\n", + " @_dispatch\n", + " def f(self, x:float|tuple): return 'float|tuple'\n", + "a = A()\n", + "test_eq(a.f(0), 'int|str')\n", + "test_eq(a.f(''), 'int|str')\n", + "test_eq(a.f(0.0), 'float|tuple')\n", + "test_eq(a.f(()), 'float|tuple')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A: pass\n", "\n", - "```julia\n", - "collide_with(x::Asteroid, y::Asteroid) = ... \n", - "# deal with asteroid hitting asteroid\n", + "@_dispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", "\n", - "collide_with(x::Asteroid, y::Spaceship) = ... \n", - "# deal with asteroid hitting spaceship\n", + "a = A()\n", + "test_eq(a.f(0), 'int')\n", "\n", - "collide_with(x::Spaceship, y::Asteroid) = ... \n", - "# deal with spaceship hitting asteroid\n", + "@_dispatch.to(A)\n", + "def f(self, x:str): return 'str'\n", "\n", - "collide_with(x::Spaceship, y::Spaceship) = ... \n", - "# deal with spaceship hitting spaceship\n", - "```\n", + "test_eq(a.f(''), 'str')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A:\n", + " def f(self, x): return 'obj'\n", "\n", - "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n", + "@_dispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", "\n", - "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs." + "a = A()\n", + "test_eq(a.f(0), 'int')\n", + "test_eq(a.f(''), 'obj')" ] }, { @@ -249,34 +398,40 @@ "metadata": {}, "outputs": [], "source": [ - "#export\n", - "class _TypeDict:\n", - " def __init__(self): self.d,self.cache = {},{}\n", - "\n", - " def _reset(self):\n", - " self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n", - " self.cache = {}\n", + "_dispatch = FastDispatcher()\n", + "class A:\n", + " def f(self, x): return 'obj'\n", "\n", - " def add(self, t, f):\n", - " \"Add type `t` and function `f`\"\n", - " if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))\n", - " for t_ in t: self.d[t_] = f\n", - " self._reset()\n", + "@_dispatch\n", + "def f(self:A, x:int): return 'int'\n", "\n", - " def all_matches(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " if k not in self.cache:\n", - " types = [f for f in self.d if lenient_issubclass(k,f)]\n", - " self.cache[k] = [self.d[o] for o in types]\n", - " return self.cache[k]\n", + "a = A()\n", + "test_eq(a.f(0), 'int')\n", + "test_eq(a.f(''), 'obj')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A: pass\n", + "class B: pass\n", "\n", - " def __getitem__(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " res = self.all_matches(k)\n", - " return res[0] if len(res) else None\n", + "@_dispatch\n", + "def f(self:A|B, x:int): return 'int'\n", "\n", - " def __repr__(self): return self.d.__repr__()\n", - " def first(self): return first(self.d.values())" + "test_eq(A().f(0), 'int')\n", + "test_eq(B().f(0), 'int')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Returns the last type" ] }, { @@ -285,66 +440,104 @@ "metadata": {}, "outputs": [], "source": [ - "#export\n", - "class TypeDispatch:\n", - " \"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`\"\n", - " def __init__(self, funcs=(), bases=()):\n", - " self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n", - " for o in L(funcs): self.add(o)\n", - " self.inst = None\n", - " self.owner = None\n", - "\n", - " def add(self, f):\n", - " \"Add type `t` and function `f`\"\n", - " if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n", - " else: a0,a1 = _p2_anno(f)\n", - " t = self.funcs.d.get(a0)\n", - " if t is None:\n", - " t = _TypeDict()\n", - " self.funcs.add(a0, t)\n", - " t.add(a1, f)\n", + "test_eq(f, B.f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:bool): return 'bool'\n", + "@_dispatch\n", + "def f(x:int): return 'int'\n", + "test_eq(f(True), 'bool')\n", + "test_eq(f(0), 'int')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Make this syntax nicer?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "@_dispatch.multi((bool,),(list,))\n", + "def f(x:bool|list): return 'bool'\n", + "@_dispatch\n", + "def f(x:int): return 'int'\n", + "test_eq(f(True),'bool')\n", + "test_eq(f(0), 'int')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Leaves base class methods unaffected, but still searched." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_dispatch = FastDispatcher()\n", + "class A:\n", + " def f(self, x): return 'A'\n", + "Af = A.f\n", + "class B(A):\n", + " @_dispatch\n", + " def f(self, x:int): return 'B'\n", + "test_is(Af, A.f)\n", + "b = B()\n", + "test_eq(b.f(0), 'B')\n", + "test_eq(b.f(''), 'A')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TypeDispatch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n", "\n", - " def first(self):\n", - " \"Get first function in ordered dict of type:func.\"\n", - " return self.funcs.first().first()\n", + "```julia\n", + "collide_with(x::Asteroid, y::Asteroid) = ... \n", + "# deal with asteroid hitting asteroid\n", "\n", - " def returns(self, x):\n", - " \"Get the return type of annotation of `x`.\"\n", - " return anno_ret(self[type(x)])\n", + "collide_with(x::Asteroid, y::Spaceship) = ... \n", + "# deal with asteroid hitting spaceship\n", "\n", - " def _attname(self,k): return getattr(k,'__name__',str(k))\n", - " def __repr__(self):\n", - " r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, \"__name__\", type(v).__name__)}'\n", - " for k in self.funcs.d for l,v in self.funcs[k].d.items()]\n", - " r = r + [o.__repr__() for o in self.bases]\n", - " return '\\n'.join(r)\n", + "collide_with(x::Spaceship, y::Asteroid) = ... \n", + "# deal with spaceship hitting asteroid\n", "\n", - " def __call__(self, *args, **kwargs):\n", - " ts = L(args).map(type)[:2]\n", - " f = self[tuple(ts)]\n", - " if not f: return args[0]\n", - " if isinstance(f, staticmethod): f = f.__func__\n", - " elif self.inst is not None: f = MethodType(f, self.inst)\n", - " elif self.owner is not None: f = MethodType(f, self.owner)\n", - " return f(*args, **kwargs)\n", + "collide_with(x::Spaceship, y::Spaceship) = ... \n", + "# deal with spaceship hitting spaceship\n", + "```\n", "\n", - " def __get__(self, inst, owner):\n", - " self.inst = inst\n", - " self.owner = owner\n", - " return self\n", + "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n", "\n", - " def __getitem__(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " k = L(k)\n", - " while len(k)<2: k.append(object)\n", - " r = self.funcs.all_matches(k[0])\n", - " for t in r:\n", - " o = t[k[1]]\n", - " if o is not None: return o\n", - " for base in self.bases:\n", - " res = base[k]\n", - " if res is not None: return res\n", - " return None" + "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs." ] }, { @@ -360,18 +553,20 @@ "metadata": {}, "outputs": [], "source": [ - "def f2(x:int, y:float): return x+y #int and float for 2nd arg\n", - "def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric\n", - "def f_ni2(x:int): return x #integer\n", - "def f_bll(x:bool|list): return x #bool or list\n", - "def f_num(x:numbers.Number): return x #Number (root of numerics) " + "# TODO: Make this work by init'ing FF with a list?" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "We can optionally initialize `TypeDispatch` with a list of functions we want to search. Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:" + "def f2(x:int, y:float): return x+y #int and float for 2nd arg\n", + "def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric\n", + "def f_ni2(x:int): return x #integer\n", + "def f_bll(x:bool|list): return x #bool or list\n", + "def f_num(x:numbers.Number): return x #Number (root of numerics) " ] }, { @@ -382,12 +577,10 @@ { "data": { "text/plain": [ - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(object,object) -> NoneType" + "f_nin: (Integral) -> int\n", + "f_ni2: (int) -> object\n", + "f_num: (Number) -> object\n", + "f_bll: (bool|list) -> object" ] }, "execution_count": null, @@ -396,10 +589,54 @@ } ], "source": [ - "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", + "t = FastFunction([f_nin,f_ni2,f_num,f_bll])\n", "t" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# _dispatch = FastDispatcher()\n", + "# @_dispatch\n", + "# def f(x:int, y:float): return 'f2' #int and float for 2nd arg\n", + "# @_dispatch\n", + "# def f(x:numbers.Integral)->str: return 'f_nin' #integral numeric\n", + "# @_dispatch\n", + "# def f(x:int): return 'f_ni2' #integer\n", + "# @_dispatch.multi((bool,),(list,))\n", + "# def f(x:bool|list): return 'f_bll' #bool or list\n", + "# @_dispatch\n", + "# def f(x:numbers.Number): return 'f_num' #Number (root of numerics) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can optionally initialize `TypeDispatch` with a list of functions we want to search. Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Support None?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Do we need to suport __getitem__?" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -420,8 +657,9 @@ "metadata": {}, "outputs": [], "source": [ - "assert issubclass(float, numbers.Number)\n", - "test_eq(t[float], f_num)" + "# assert issubclass(float, numbers.Number)\n", + "# test_eq(t[float], f_num)\n", + "# test_eq(f(0.0), 'f_num')" ] }, { @@ -437,10 +675,13 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[np.int32], f_nin)\n", - "test_eq(t[bool], f_bll)\n", - "test_eq(t[list], f_bll)\n", - "test_eq(t[np.int32], f_nin)" + "# test_eq(t[np.int32], f_nin)\n", + "# test_eq(t[bool], f_bll)\n", + "# test_eq(t[list], f_bll)\n", + "# test_eq(t[np.int32], f_nin)\n", + "# test_eq(f(np.int32(0)), 'f_nin')\n", + "# test_eq(f(True), 'f_bll')\n", + "# test_eq(f([]), 'f_bll')" ] }, { @@ -456,7 +697,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[str], None)" + "# test_eq(t[str], None)" ] }, { @@ -467,11 +708,10 @@ { "data": { "text/markdown": [ - "

TypeDispatch.add[source]

\n", + "

FastFunction.dispatch[source]

\n", "\n", - "> TypeDispatch.add(**`f`**)\n", - "\n", - "Add type `t` and function `f`" + "> FastFunction.dispatch(**`f`**=*`None`*)\n", + "\n" ], "text/plain": [ "" @@ -482,7 +722,7 @@ } ], "source": [ - "show_doc(TypeDispatch.add)" + "show_doc(FastFunction.dispatch)" ] }, { @@ -492,6 +732,15 @@ "This method allows you to add an additional function to an existing `TypeDispatch` instance :" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: No support for collection" + ] + }, { "cell_type": "code", "execution_count": null, @@ -500,13 +749,11 @@ { "data": { "text/plain": [ - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(typing.Collection,object) -> f_col\n", - "(object,object) -> NoneType" + "f: (bool|list) -> object\n", + "f: (bool) -> object\n", + "f: (list) -> object\n", + "f: (int) -> object\n", + "f_tup: (tuple) -> object" ] }, "execution_count": null, @@ -515,10 +762,11 @@ } ], "source": [ - "def f_col(x:typing.Collection): return x\n", - "t.add(f_col)\n", - "test_eq(t[str], f_col)\n", - "t" + "def f_tup(x:tuple): return 'f_tup'\n", + "f.dispatch(f_tup)\n", + "#test_eq(t[str], f_tup)\n", + "test_eq(f(()), 'f_tup')\n", + "f" ] }, { @@ -532,10 +780,26 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "f: (bool|list) -> object\n", + "f: (bool) -> object\n", + "f: (list) -> object\n", + "f: (int) -> object\n", + "f_tup: (tuple) -> object" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "t.add(f_ni2) \n", - "test_eq(t[int], f_ni2)" + "f.dispatch(f_tup) \n", + "# test_eq(t[int], f_ni2)\n", + "f" ] }, { @@ -549,11 +813,27 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "f: (bool|list) -> object\n", + "f: (bool) -> object\n", + "f: (list) -> object\n", + "f_ni3: (int) -> object\n", + "f_tup: (tuple) -> object" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def f_ni3(z:int): return z # collides with f_ni2 with same type annotations\n", - "t.add(f_ni3) \n", - "test_eq(t[int], f_ni3)" + "def f_ni3(z:int): return 'f_ni3' # collides with f_ni2 with same type annotations\n", + "f.dispatch(f_ni3) \n", + "#test_eq(t[int], f_ni3)\n", + "f" ] }, { @@ -576,30 +856,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(str,object) -> f_str\n", - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(object,object) -> NoneType" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "def f_str(x:str): return x+'1'\n", + "# def f_str(x:str): return x+'1'\n", "\n", - "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", - "t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.\n", - "t2" + "# t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", + "# t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.\n", + "# t2" ] }, { @@ -608,15 +871,15 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[float], f_num) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[str], f_str) # found in `t`!\n", - "test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__\n", + "# test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2`\n", + "# test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2`\n", + "# test_eq(t2[float], f_num) # searches `t` b/c not found in `t2`\n", + "# test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2`\n", + "# test_eq(t2[str], f_str) # found in `t`!\n", + "# test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__\n", "\n", - "o = np.int32(1)\n", - "test_eq(t2(o), 2) # found in `t2` and uses __call__" + "# o = np.int32(1)\n", + "# test_eq(t2(o), 2) # found in `t2` and uses __call__" ] }, { @@ -637,24 +900,12 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(int,float) -> f2\n", - "(Integral,object) -> f1" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", - "t" + "# def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type\n", + "# def f2(x:int, y:float): return x+y\n", + "# t = TypeDispatch([f1,f2])\n", + "# t" ] }, { @@ -670,8 +921,8 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[np.int32], f1)\n", - "test_eq(t[int,float], f2)" + "# test_eq(t[np.int32], f1)\n", + "# test_eq(t[int,float], f2)" ] }, { @@ -685,24 +936,13 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(str,int) -> f2" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "def f1(a:str, b:int, c:list): return a\n", - "def f2(a: str, b:int): return b\n", - "t = TypeDispatch([f1,f2])\n", - "test_eq(t[str, int], f2)\n", - "t" + "# def f1(a:str, b:int, c:list): return a\n", + "# def f2(a: str, b:int): return b\n", + "# t = TypeDispatch([f1,f2])\n", + "# test_eq(t[str, int], f2)\n", + "# t" ] }, { @@ -729,10 +969,11 @@ "metadata": {}, "outputs": [], "source": [ - "def f1(x:numbers.Integral, y): return x+1\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", - "\n", + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:numbers.Integral, y): return 'int'\n", + "@_dispatch\n", + "def f(x:int, y:float): return 'int,float'\n", "assert not issubclass(np.int32, int)" ] }, @@ -750,7 +991,8 @@ "outputs": [], "source": [ "assert issubclass(np.int32, numbers.Integral)\n", - "test_eq(t[np.int32,float], f1) " + "test_eq(f(0,0.0), 'int,float')\n", + "#test_eq(t[np.int32,float], f1) " ] }, { @@ -767,15 +1009,26 @@ "outputs": [], "source": [ "assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral\n", - "test_eq(t[int], f1)\n", - "test_eq(t[int,int], f1)" + "# test_eq(t[int], f1)\n", + "# test_eq(t[int,int], f1)\n", + "test_eq(f(0,None), 'int')\n", + "test_eq(f(0,0), 'int')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If no match is possible, `None` is returned:" + "If no match is possible, `None` is returned: **TODO**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_fail(f, args=(0.0,0.0), contains='Signature(builtins.float, builtins.float) could not be resolved')" ] }, { @@ -784,7 +1037,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[float,float], None)" + "# test_eq(t[float,float], None)" ] }, { @@ -795,11 +1048,10 @@ { "data": { "text/markdown": [ - "

TypeDispatch.__call__[source]

\n", + "

Function.__call__[source]

\n", "\n", - "> TypeDispatch.__call__(**\\*`args`**, **\\*\\*`kwargs`**)\n", - "\n", - "Call self as a function." + "> Function.__call__(**\\*`args`**, **\\*\\*`kw_args`**)\n", + "\n" ], "text/plain": [ "" @@ -810,7 +1062,7 @@ } ], "source": [ - "show_doc(TypeDispatch.__call__)" + "show_doc(FastFunction.__call__)" ] }, { @@ -826,16 +1078,30 @@ "metadata": {}, "outputs": [], "source": [ - "def f_arr(x:np.ndarray): return x.sum()\n", - "def f_int(x:np.int32): return x+1\n", - "t = TypeDispatch([f_arr, f_int])\n", + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:np.ndarray): return x.sum()\n", + "@_dispatch\n", + "def f(x:np.int32): return x+1\n", + "# t = TypeDispatch([f_arr, f_int])\n", "\n", "arr = np.array([5,4,3,2,1])\n", - "test_eq(t(arr), 15) # dispatches to f_arr\n", + "# test_eq(t(arr), 15) # dispatches to f_arr\n", + "test_eq(f(arr), 15) # dispatches to f_arr\n", "\n", "o = np.int32(1)\n", - "test_eq(t(o), 2) # dispatches to f_int\n", - "assert t.first() is not None " + "# test_eq(t(o), 2) # dispatches to f_int\n", + "test_eq(f(o), 2) # dispatches to f_int\n", + "#assert t.first() is not None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: FF.first?" ] }, { @@ -851,12 +1117,15 @@ "metadata": {}, "outputs": [], "source": [ - "def f1(x:numbers.Integral, y): return x+1\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:numbers.Integral, y): return x+1\n", + "@_dispatch\n", + "def f(x:int, y:float): return x+y\n", + "# t = TypeDispatch([f1,f2])\n", "\n", - "test_eq(t(3,2.0), 5)\n", - "test_eq(t(3,2), 4)" + "test_eq(f(3,2.0), 5)\n", + "test_eq(f(3,2), 4)" ] }, { @@ -872,33 +1141,25 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t('a'), 'a')" + "# test_eq(f('a'), 'a')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "

TypeDispatch.returns[source]

\n", - "\n", - "> TypeDispatch.returns(**`x`**)\n", - "\n", - "Get the return type of annotation of `x`." - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "show_doc(TypeDispatch.returns)" + "# TODO: need returns?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show_doc(TypeDispatch.returns)" ] }, { @@ -914,19 +1175,23 @@ "metadata": {}, "outputs": [], "source": [ - "def f1(x:int) -> np.ndarray: return np.array(x)\n", - "def f2(x:str) -> float: return List\n", - "def f3(x:float): return List # f3 has no return type annotation\n", + "_dispatch = FastDispatcher()\n", + "@_dispatch\n", + "def f(x:int) -> np.ndarray: return np.array(x)\n", + "@_dispatch\n", + "def f(x:str) -> float: return List\n", + "@_dispatch\n", + "def f(x:float): return List # f3 has no return type annotation\n", "\n", - "t = TypeDispatch([f1, f2, f3])\n", + "# t = TypeDispatch([f1, f2, f3])\n", "\n", - "test_eq(t.returns(1), np.ndarray) # dispatched to f1\n", - "test_eq(t.returns('Hello'), float) # dispatched to f2\n", - "test_eq(t.returns(1.0), None) # dispatched to f3\n", + "# test_eq(f.returns(1), np.ndarray) # dispatched to f1\n", + "# test_eq(f.returns('Hello'), float) # dispatched to f2\n", + "# test_eq(f.returns(1.0), None) # dispatched to f3\n", "\n", - "class _Test: pass\n", - "_test = _Test()\n", - "test_eq(t.returns(_test), None) # type `_Test` not found, so None returned" + "# class _Test: pass\n", + "# _test = _Test()\n", + "# test_eq(t.returns(_test), None) # type `_Test` not found, so None returned" ] }, { @@ -949,20 +1214,21 @@ "metadata": {}, "outputs": [], "source": [ - "def m_nin(self, x:str|numbers.Integral): return str(x)+'1'\n", - "def m_bll(self, x:bool): self.foo='a'\n", - "def m_num(self, x:numbers.Number): return x*2\n", + "# def m_nin(self, x:str|numbers.Integral): return str(x)+'1'\n", + "# def m_bll(self, x:bool): self.foo='a'\n", + "# def m_num(self, x:numbers.Number): return x*2\n", + "# t = FastFunction(m_nin).dispatch_multi((str,),(numbers.Integral,))(m_nin).dispatch(m_bll).dispatch(m_num)\n", "\n", - "t = TypeDispatch([m_nin,m_num,m_bll])\n", - "class A: f = t # set class attribute `f` equal to a TypeDispatch instance\n", + "# # t = TypeDispatch([m_nin,m_num,m_bll])\n", + "# class A: f = t # set class attribute `f` equal to a TypeDispatch instance\n", " \n", - "a = A()\n", - "test_eq(a.f(1), '11') #dispatch to m_nin\n", - "test_eq(a.f(1.), 2.) #dispatch to m_num\n", - "test_is(a.f.inst, a)\n", + "# a = A()\n", + "# test_eq(a.f(1), '11') #dispatch to m_nin\n", + "# test_eq(a.f(1.), 2.) #dispatch to m_num\n", + "# test_is(a.f.inst, a)\n", "\n", - "a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'\n", - "test_eq(a.foo, 'a')" + "# a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'\n", + "# test_eq(a.foo, 'a')" ] }, { @@ -978,7 +1244,7 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(a.f(()), ()) " + "# test_eq(a.f(()), ()) " ] }, { @@ -994,17 +1260,17 @@ "metadata": {}, "outputs": [], "source": [ - "def m_tup(self, x:tuple): return x+(1,)\n", - "t2 = TypeDispatch(m_tup, bases=t)\n", + "# def m_tup(self, x:tuple): return x+(1,)\n", + "# t2 = TypeDispatch(m_tup, bases=t)\n", "\n", - "class A2: f = t2\n", - "a2 = A2()\n", - "test_eq(a2.f(1), '11')\n", - "test_eq(a2.f(1.), 2.)\n", - "test_is(a2.f.inst, a2)\n", - "a2.f(False)\n", - "test_eq(a2.foo, 'a')\n", - "test_eq(a2.f(()), (1,))" + "# class A2: f = t2\n", + "# a2 = A2()\n", + "# test_eq(a2.f(1), '11')\n", + "# test_eq(a2.f(1.), 2.)\n", + "# test_is(a2.f.inst, a2)\n", + "# a2.f(False)\n", + "# test_eq(a2.foo, 'a')\n", + "# test_eq(a2.f(()), (1,))" ] }, { @@ -1027,19 +1293,19 @@ "metadata": {}, "outputs": [], "source": [ - "def m_nin(cls, x:str|numbers.Integral): return str(x)+'1'\n", - "def m_bll(cls, x:bool): cls.foo='a'\n", - "def m_num(cls, x:numbers.Number): return x*2\n", + "# def m_nin(cls, x:str|numbers.Integral): return str(x)+'1'\n", + "# def m_bll(cls, x:bool): cls.foo='a'\n", + "# def m_num(cls, x:numbers.Number): return x*2\n", "\n", - "t = TypeDispatch([m_nin,m_num,m_bll])\n", - "class A: f = t # set class attribute `f` equal to a TypeDispatch\n", + "# t = TypeDispatch([m_nin,m_num,m_bll])\n", + "# class A: f = t # set class attribute `f` equal to a TypeDispatch\n", "\n", - "test_eq(A.f(1), '11') #dispatch to m_nin\n", - "test_eq(A.f(1.), 2.) #dispatch to m_num\n", - "test_is(A.f.owner, A)\n", + "# test_eq(A.f(1), '11') #dispatch to m_nin\n", + "# test_eq(A.f(1.), 2.) #dispatch to m_num\n", + "# test_is(A.f.owner, A)\n", "\n", - "A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'\n", - "test_eq(A.foo, 'a')" + "# A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'\n", + "# test_eq(A.foo, 'a')" ] }, { @@ -1055,18 +1321,28 @@ "metadata": {}, "outputs": [], "source": [ - "#export\n", - "class DispatchReg:\n", - " \"A global registry for `TypeDispatch` objects keyed by function name\"\n", - " def __init__(self): self.d = defaultdict(TypeDispatch)\n", - " def __call__(self, f):\n", - " if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'\n", - " else: nm = f'{f.__qualname__}'\n", - " if isinstance(f, classmethod): f=f.__func__\n", - " self.d[nm].add(f)\n", - " return self.d[nm]\n", + "# #export\n", + "# class DispatchReg:\n", + "# \"A global registry for `TypeDispatch` objects keyed by function name\"\n", + "# def __init__(self): self.d = defaultdict(TypeDispatch)\n", + "# def __call__(self, f):\n", + "# if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'\n", + "# else: nm = f'{f.__qualname__}'\n", + "# if isinstance(f, classmethod): f=f.__func__\n", + "# self.d[nm].add(f)\n", + "# return self.d[nm]\n", "\n", - "typedispatch = DispatchReg()" + "# typedispatch = DispatchReg()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "typedispatch = FastDispatcher()" ] }, { @@ -1111,19 +1387,19 @@ "metadata": {}, "outputs": [], "source": [ - "class A:\n", - " @typedispatch\n", - " def f_td_test(self, x:numbers.Integral, y): return x+1\n", - " @typedispatch\n", - " @classmethod\n", - " def f_td_test(cls, x:int, y:float): return x+y\n", - " @typedispatch\n", - " @staticmethod\n", - " def f_td_test(x:int, y:int): return x*y\n", + "# class A:\n", + "# @typedispatch\n", + "# def f_td_test(self, x:numbers.Integral, y): return x+1\n", + "# @typedispatch\n", + "# @classmethod\n", + "# def f_td_test(cls, x:int, y:float): return x+y\n", + "# @typedispatch\n", + "# @staticmethod\n", + "# def f_td_test(x:int, y:int): return x*y\n", " \n", - "test_eq(A.f_td_test(3,2), 6)\n", - "test_eq(A.f_td_test(3,2.0), 5)\n", - "test_eq(A().f_td_test(3,'2.0'), 4)" + "# test_eq(A.f_td_test(3,2), 6)\n", + "# test_eq(A.f_td_test(3,2.0), 5)\n", + "# test_eq(A().f_td_test(3,'2.0'), 4)" ] }, { diff --git a/nbs/05_transform.ipynb b/nbs/05_transform.ipynb index e0bfbbaa0..3fb55e98a 100644 --- a/nbs/05_transform.ipynb +++ b/nbs/05_transform.ipynb @@ -20,7 +20,9 @@ "from fastcore.foundation import *\n", "from fastcore.utils import *\n", "from fastcore.dispatch import *\n", - "import inspect" + "import inspect\n", + "from copy import copy\n", + "from plum import add_conversion_method" ] }, { @@ -69,11 +71,10 @@ "#export\n", "_tfm_methods = 'encodes','decodes','setups'\n", "\n", + "def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n", + "\n", "class _TfmDict(dict):\n", - " def __setitem__(self,k,v):\n", - " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", - " if k not in self: super().__setitem__(k,TypeDispatch())\n", - " self[k].add(v)" + " def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)" ] }, { @@ -85,21 +86,20 @@ "#export\n", "class _TfmMeta(type):\n", " def __new__(cls, name, bases, dict):\n", + " # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n", " res = super().__new__(cls, name, bases, dict)\n", - " for nm in _tfm_methods:\n", - " base_td = [getattr(b,nm,None) for b in bases]\n", - " if nm in res.__dict__: getattr(res,nm).bases = base_td\n", - " else: setattr(res, nm, TypeDispatch(bases=base_td))\n", " res.__signature__ = inspect.signature(res.__init__)\n", " return res\n", "\n", " def __call__(cls, *args, **kwargs):\n", - " f = args[0] if args else None\n", - " n = getattr(f,'__name__',None)\n", - " if callable(f) and n in _tfm_methods:\n", - " getattr(cls,n).add(f)\n", - " return f\n", - " return super().__call__(*args, **kwargs)\n", + " f = first(args)\n", + " n = getattr(f, '__name__', None)\n", + " if _is_tfm_method(n, f): return typedispatch.to(cls)(f)\n", + " obj = super().__call__(*args, **kwargs)\n", + " # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n", + " # instances of cls, fix it\n", + " if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n", + " return obj\n", "\n", " @classmethod\n", " def __prepare__(cls, name, bases): return _TfmDict()" @@ -144,19 +144,20 @@ " self.init_enc = enc or dec\n", " if not self.init_enc: return\n", "\n", - " self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()\n", + " def identity(x): return x\n", + " for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))\n", " if enc:\n", - " self.encodes.add(enc)\n", + " self.encodes.dispatch(enc)\n", " self.order = getattr(enc,'order',self.order)\n", " if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))\n", " self._name = _get_name(enc)\n", - " if dec: self.decodes.add(dec)\n", + " if dec: self.decodes.dispatch(dec)\n", "\n", " @property\n", " def name(self): return getattr(self, '_name', _get_name(self))\n", " def __call__(self, x, **kwargs): return self._call('encodes', x, **kwargs)\n", " def decode (self, x, **kwargs): return self._call('decodes', x, **kwargs)\n", - " def __repr__(self): return f'{self.name}:\\nencodes: {self.encodes}decodes: {self.decodes}'\n", + " def __repr__(self): return f'{self.name}:\\nencodes: {self.encodes}\\ndecodes: {self.decodes}'\n", "\n", " def setup(self, items=None, train_setup=False):\n", " train_setup = train_setup if self.train_setup is None else self.train_setup\n", @@ -169,14 +170,59 @@ " def _do_call(self, f, x, **kwargs):\n", " if not _is_tuple(x):\n", " if f is None: return x\n", - " ret = f.returns(x) if hasattr(f,'returns') else None\n", - " return retain_type(f(x, **kwargs), x, ret)\n", + " ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]\n", + " _, ret = f.resolve_method(*ts)\n", + " ret = ret._type\n", + " # plum reads empty return annotation as object, retain_type expects it as None\n", + " if ret is object: ret = None\n", + " return retain_type(f(x,**kwargs), x, ret)\n", " res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)\n", " return retain_type(res, x)\n", + " def encodes(self, x): return x\n", + " def decodes(self, x): return x\n", + " def setups(self, dl): return dl\n", "\n", "add_docs(Transform, decode=\"Delegate to decodes to undo transform\", setup=\"Delegate to setups to set up transform\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Move somewhere better\n", + "class Categorize(Transform):\n", + " def encodes(self, x): return 'obj'\n", + "\n", + "@Categorize\n", + "def encodes(self, x:int): return 'int'\n", + "\n", + "c = Categorize()\n", + "test_eq(c.encodes(0), 'int')\n", + "test_eq(c.encodes(0.0), 'obj')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Should this belong in dispatch?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "#Implement the Transform convention that a None return annotation disables conversion\n", + "add_conversion_method(object, NoneType, lambda x: x)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -368,6 +414,25 @@ "test_eq_type(f3(2), 2)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Transforms can be created from class methods too:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A:\n", + " @classmethod\n", + " def create(cls, x:int): return x+1\n", + "test_eq(Transform(A.create)(1), 2)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -845,7 +910,7 @@ "def encodes(self, x:str): return x+'hello'\n", "\n", "@B\n", - "def encodes(self, x)->None: return str(x)+'!'" + "def encodes(self, x): return str(x)+'!'" ] }, { @@ -1015,8 +1080,8 @@ "data": { "text/plain": [ "A:\n", - "encodes: (object,object) -> noop\n", - "decodes: (object,object) -> noop" + "encodes: noop: (object,VarArgs[object]) -> object\n", + "decodes: noop: (object,VarArgs[object]) -> object" ] }, "execution_count": null, @@ -1046,8 +1111,8 @@ "data": { "text/plain": [ "A -- {'a': 1, 'b': 2}:\n", - "encodes: (object,object) -> noop\n", - "decodes: " + "encodes: noop: (object,VarArgs[object]) -> object\n", + "decodes: decodes: (object,object) -> object" ] }, "execution_count": null, @@ -1933,13 +1998,6 @@ "from nbdev.export import notebook2script\n", "notebook2script()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/setup.py b/setup.py index 1fa9526b5..aea178003 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ min_python = cfg['min_python'] lic = licenses[cfg['license']] -requirements = ['pip', 'packaging'] +requirements = ['pip', 'packaging', 'plum-dispatch>=1.6'] if cfg.get('requirements'): requirements += cfg.get('requirements','').split() if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() dev_requirements = (cfg.get('dev_requirements') or '').split()