From 6df02814228655d5c32fbba9a5fccc71aa43a314 Mon Sep 17 00:00:00 2001 From: seem Date: Tue, 31 May 2022 19:23:59 +0200 Subject: [PATCH] enable more robust multiple dispatch with `plum` --- fastcore/_nbdev.py | 6 +- fastcore/basics.py | 2 + fastcore/dispatch.py | 202 +++----- fastcore/imports.py | 10 + fastcore/transform.py | 49 +- nbs/01_basics.ipynb | 16 +- nbs/04_dispatch.ipynb | 1023 +++++++--------------------------------- nbs/05_transform.ipynb | 111 +++-- settings.ini | 2 +- setup.py | 2 +- 10 files changed, 385 insertions(+), 1038 deletions(-) diff --git a/fastcore/_nbdev.py b/fastcore/_nbdev.py index 13acbd97b..20e0a973d 100644 --- a/fastcore/_nbdev.py +++ b/fastcore/_nbdev.py @@ -227,10 +227,8 @@ "do_request": "03b_net.ipynb", "start_server": "03b_net.ipynb", "start_client": "03b_net.ipynb", - "lenient_issubclass": "04_dispatch.ipynb", - "sorted_topologically": "04_dispatch.ipynb", - "TypeDispatch": "04_dispatch.ipynb", - "DispatchReg": "04_dispatch.ipynb", + "FastFunction": "04_dispatch.ipynb", + "FastDispatcher": "04_dispatch.ipynb", "typedispatch": "04_dispatch.ipynb", "retain_meta": "04_dispatch.ipynb", "default_set_meta": "04_dispatch.ipynb", diff --git a/fastcore/basics.py b/fastcore/basics.py index c23e9d179..a5b46edf5 100644 --- a/fastcore/basics.py +++ b/fastcore/basics.py @@ -888,6 +888,8 @@ def copy_func(f): fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__) fn.__kwdefaults__ = f.__kwdefaults__ fn.__dict__.update(f.__dict__) + fn.__annotations__.update(f.__annotations__) + fn.__qualname__ = f.__qualname__ return fn # Cell diff --git a/fastcore/dispatch.py b/fastcore/dispatch.py index a498f6676..de6f8b0dc 100644 --- a/fastcore/dispatch.py +++ b/fastcore/dispatch.py @@ -4,154 +4,96 @@ from __future__ import annotations -__all__ = ['lenient_issubclass', 'sorted_topologically', 'TypeDispatch', 'DispatchReg', 'typedispatch', 'cast', - 'retain_meta', 'default_set_meta', 'retain_type', 'retain_types', 'explode_types'] +__all__ = ['FastFunction', 'FastDispatcher', 'typedispatch', 'cast', 'retain_meta', 'default_set_meta', 'retain_type', + 'retain_types', 'explode_types'] # Cell #nbdev_comment from __future__ import annotations from .imports import * from .foundation import * from .utils import * +from .meta import delegates from collections import defaultdict +from plum import Function, Dispatcher # Cell -def lenient_issubclass(cls, types): - "If possible return whether `cls` is a subclass of `types`, otherwise return False." - if cls is object and types is not object: return False # treat `object` as highest level - try: return isinstance(cls, types) or issubclass(cls, types) - except: return False +def _eval_annotations(f): + "Evaluate future annotations before passing to plum to support backported union operator `|`" + f = copy_func(f) + for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v + return f # Cell -def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False): - "Return a new list containing all items from the iterable sorted topologically" - l,res = L(list(iterable)),[] - for _ in range(len(l)): - t = l.reduce(lambda x,y: y if cmp(y,x) else x) - res.append(t), l.remove(t) - return res[::-1] if reverse else res +def _pt_repr(o): + "Concise repr of plum types" + n = type(o).__name__ + if n == 'Tuple': return f"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]" + if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]' + if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]' + if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]' + if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]' + if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types()))) + assert len(o.get_types()) == 1 + return o.get_types()[0].__name__ # Cell -def _chk_defaults(f, ann): - pass -# Implementation removed until we can figure out how to do this without `inspect` module -# try: # Some callables don't have signatures, so ignore those errors -# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)] -# if any(p.default!=inspect.Parameter.empty for p in params): -# warn(f"{f.__name__} has default params. These will be ignored.") -# except ValueError: pass - -# Cell -def _p2_anno(f): - "Get the 1st 2 annotations of `f`, defaulting to `object`" - hints = type_hints(f) - ann = [o for n,o in hints.items() if n!='return'] - if callable(f): _chk_defaults(f, ann) - while len(ann)<2: ann.append(object) - return ann[:2] +class FastFunction(Function): + def __repr__(self): + return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}" + for s, (f, r) in self.methods.items()) -# Cell -class _TypeDict: - def __init__(self): self.d,self.cache = {},{} - - def _reset(self): - self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)} - self.cache = {} - - def add(self, t, f): - "Add type `t` and function `f`" - if not isinstance(t, tuple): t = tuple(L(union2tuple(t))) - for t_ in t: self.d[t_] = f - self._reset() - - def all_matches(self, k): - "Find first matching type that is a super-class of `k`" - if k not in self.cache: - types = [f for f in self.d if lenient_issubclass(k,f)] - self.cache[k] = [self.d[o] for o in types] - return self.cache[k] - - def __getitem__(self, k): - "Find first matching type that is a super-class of `k`" - res = self.all_matches(k) - return res[0] if len(res) else None - - def __repr__(self): return self.d.__repr__() - def first(self): return first(self.d.values()) + @delegates(Function.dispatch) + def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs) -# Cell -class TypeDispatch: - "Dictionary-like object; `__getitem__` matches keys of types using `issubclass`" - def __init__(self, funcs=(), bases=()): - self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None)) - for o in L(funcs): self.add(o) - self.inst = None - self.owner = None - - def add(self, f): - "Add type `t` and function `f`" - if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__) - else: a0,a1 = _p2_anno(f) - t = self.funcs.d.get(a0) - if t is None: - t = _TypeDict() - self.funcs.add(a0, t) - t.add(a1, f) - - def first(self): - "Get first function in ordered dict of type:func." - return self.funcs.first().first() - - def returns(self, x): - "Get the return type of annotation of `x`." - return anno_ret(self[type(x)]) - - def _attname(self,k): return getattr(k,'__name__',str(k)) - def __repr__(self): - r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, "__name__", type(v).__name__)}' - for k in self.funcs.d for l,v in self.funcs[k].d.items()] - r = r + [o.__repr__() for o in self.bases] - return '\n'.join(r) - - def __call__(self, *args, **kwargs): - ts = L(args).map(type)[:2] - f = self[tuple(ts)] - if not f: return args[0] - if isinstance(f, staticmethod): f = f.__func__ - elif self.inst is not None: f = MethodType(f, self.inst) - elif self.owner is not None: f = MethodType(f, self.owner) - return f(*args, **kwargs) - - def __get__(self, inst, owner): - self.inst = inst - self.owner = owner - return self - - def __getitem__(self, k): - "Find first matching type that is a super-class of `k`" - k = L(k) - while len(k)<2: k.append(object) - r = self.funcs.all_matches(k[0]) - for t in r: - o = t[k[1]] - if o is not None: return o - for base in self.bases: - res = base[k] - if res is not None: return res - return None + def __getitem__(self, ts): + "Return the most-specific matching method with fewest parameters" + ts = L(ts) + nargs = min(len(o) for o in self.methods.keys()) + while len(ts) < nargs: ts.append(object) + return self.invoke(*ts) # Cell -class DispatchReg: - "A global registry for `TypeDispatch` objects keyed by function name" - def __init__(self): self.d = defaultdict(TypeDispatch) - def __call__(self, f): - if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}' - else: nm = f'{f.__qualname__}' - if isinstance(f, classmethod): f=f.__func__ - self.d[nm].add(f) - return self.d[nm] - -typedispatch = DispatchReg() +class FastDispatcher(Dispatcher): + def _get_function(self, method, owner): + "Adapted from `Dispatcher._get_function` to use `FastFunction`" + name = method.__name__ + if owner: + if owner not in self._classes: self._classes[owner] = {} + namespace = self._classes[owner] + else: namespace = self._functions + if name not in namespace: namespace[name] = FastFunction(method, owner=owner) + return namespace[name] + + @delegates(Dispatcher.__call__, but='method') + def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs) + + def _to(self, cls, nm, f, **kwargs): + nf = copy_func(f) + nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner + pf = self(nf, **kwargs) + # plum uses __set_name__ to resolve a plum.Function's owner + # since we assign after class creation, __set_name__ must be called directly + # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__ + pf.__set_name__(cls, nm) + pf = pf.resolve() + setattr(cls, nm, pf) + return pf + + def to(self, cls): + "Decorator: dispatch `f` to `cls.f`" + def _inner(f, **kwargs): + nm = f.__name__ + # check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on + if nm in cls.__dict__: + pf = getattr(cls, nm) + if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs) + pf.dispatch(f) + else: pf = self._to(cls, nm, f, **kwargs) + return pf + return _inner + +typedispatch = FastDispatcher() # Cell #nbdev_comment _all_=['cast'] diff --git a/fastcore/imports.py b/fastcore/imports.py index 086c899fd..12665e594 100644 --- a/fastcore/imports.py +++ b/fastcore/imports.py @@ -1,5 +1,6 @@ import sys,os,re,typing,itertools,operator,functools,math,warnings,functools,io,enum +from copy import copy from operator import itemgetter,attrgetter from warnings import warn from typing import Iterable,Generator,Sequence,Iterator,List,Set,Dict,Union,Optional,Tuple @@ -14,6 +15,15 @@ MethodDescriptorType = type(str.join) from types import BuiltinFunctionType,BuiltinMethodType,MethodType,FunctionType,SimpleNamespace +#Patch autoreload (if its loaded) to work with plum +try: from IPython import get_ipython +except ImportError: pass +else: + ip = get_ipython() + if ip is not None and 'IPython.extensions.storemagic' in ip.extension_manager.loaded: + from plum.autoreload import activate + activate() + NoneType = type(None) string_classes = (str,bytes) diff --git a/fastcore/transform.py b/fastcore/transform.py index 7f3e2efc8..d5c861605 100644 --- a/fastcore/transform.py +++ b/fastcore/transform.py @@ -9,34 +9,33 @@ from .utils import * from .dispatch import * import inspect +from plum import add_conversion_method # Cell _tfm_methods = 'encodes','decodes','setups' +def _is_tfm_method(n, f): return n in _tfm_methods and callable(f) + class _TfmDict(dict): - def __setitem__(self,k,v): - if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v) - if k not in self: super().__setitem__(k,TypeDispatch()) - self[k].add(v) + def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v) # Cell class _TfmMeta(type): def __new__(cls, name, bases, dict): + # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back res = super().__new__(cls, name, bases, dict) - for nm in _tfm_methods: - base_td = [getattr(b,nm,None) for b in bases] - if nm in res.__dict__: getattr(res,nm).bases = base_td - else: setattr(res, nm, TypeDispatch(bases=base_td)) res.__signature__ = inspect.signature(res.__init__) return res def __call__(cls, *args, **kwargs): - f = args[0] if args else None - n = getattr(f,'__name__',None) - if callable(f) and n in _tfm_methods: - getattr(cls,n).add(f) - return f - return super().__call__(*args, **kwargs) + f = first(args) + n = getattr(f, '__name__', None) + if _is_tfm_method(n, f): return typedispatch.to(cls)(f) + obj = super().__call__(*args, **kwargs) + # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable + # instances of cls, fix it + if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__) + return obj @classmethod def __prepare__(cls, name, bases): return _TfmDict() @@ -60,13 +59,14 @@ def __init__(self, enc=None, dec=None, split_idx=None, order=None): self.init_enc = enc or dec if not self.init_enc: return - self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch() + def identity(x): return x + for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity)) if enc: - self.encodes.add(enc) + self.encodes.dispatch(enc) self.order = getattr(enc,'order',self.order) if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values())) self._name = _get_name(enc) - if dec: self.decodes.add(dec) + if dec: self.decodes.dispatch(dec) @property def name(self): return getattr(self, '_name', _get_name(self)) @@ -85,13 +85,24 @@ def _call(self, fn, x, split_idx=None, **kwargs): def _do_call(self, f, x, **kwargs): if not _is_tuple(x): if f is None: return x - ret = f.returns(x) if hasattr(f,'returns') else None - return retain_type(f(x, **kwargs), x, ret) + ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)] + _, ret = f.resolve_method(*ts) + ret = ret._type + # plum reads empty return annotation as object, retain_type expects it as None + if ret is object: ret = None + return retain_type(f(x,**kwargs), x, ret) res = tuple(self._do_call(f, x_, **kwargs) for x_ in x) return retain_type(res, x) + def encodes(self, x): return x + def decodes(self, x): return x + def setups(self, dl): return dl add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform") +# Cell +#Implement the Transform convention that a None return annotation disables conversion +add_conversion_method(object, NoneType, lambda x: x) + # Cell class InplaceTransform(Transform): "A `Transform` that modifies in-place and just returns whatever it's passed" diff --git a/nbs/01_basics.ipynb b/nbs/01_basics.ipynb index cb3c66a0f..756ca23de 100644 --- a/nbs/01_basics.ipynb +++ b/nbs/01_basics.ipynb @@ -804,7 +804,7 @@ { "data": { "text/markdown": [ - "

noop[source]

\n", + "

noop[source]

\n", "\n", "> noop(**`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", @@ -840,7 +840,7 @@ { "data": { "text/markdown": [ - "

noops[source]

\n", + "

noops[source]

\n", "\n", "> noops(**`self`**, **`x`**=*`None`*, **\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", @@ -4676,6 +4676,8 @@ " fn = FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__)\n", " fn.__kwdefaults__ = f.__kwdefaults__\n", " fn.__dict__.update(f.__dict__)\n", + " fn.__annotations__.update(f.__annotations__)\n", + " fn.__qualname__ = f.__qualname__\n", " return fn" ] }, @@ -5635,7 +5637,7 @@ { "data": { "text/markdown": [ - "

ipython_shell[source]

\n", + "

ipython_shell[source]

\n", "\n", "> ipython_shell()\n", "\n", @@ -5661,7 +5663,7 @@ { "data": { "text/markdown": [ - "

in_ipython[source]

\n", + "

in_ipython[source]

\n", "\n", "> in_ipython()\n", "\n", @@ -5687,7 +5689,7 @@ { "data": { "text/markdown": [ - "

in_colab[source]

\n", + "

in_colab[source]

\n", "\n", "> in_colab()\n", "\n", @@ -5713,7 +5715,7 @@ { "data": { "text/markdown": [ - "

in_jupyter[source]

\n", + "

in_jupyter[source]

\n", "\n", "> in_jupyter()\n", "\n", @@ -5739,7 +5741,7 @@ { "data": { "text/markdown": [ - "

in_notebook[source]

\n", + "

in_notebook[source]

\n", "\n", "> in_notebook()\n", "\n", diff --git a/nbs/04_dispatch.ipynb b/nbs/04_dispatch.ipynb index 08b6bcbea..6bf65d2d4 100644 --- a/nbs/04_dispatch.ipynb +++ b/nbs/04_dispatch.ipynb @@ -20,8 +20,10 @@ "from fastcore.imports import *\n", "from fastcore.foundation import *\n", "from fastcore.utils import *\n", + "from fastcore.meta import delegates\n", "\n", - "from collections import defaultdict" + "from collections import defaultdict\n", + "from plum import Function, Dispatcher" ] }, { @@ -41,42 +43,18 @@ "source": [ "# Type dispatch\n", "\n", - "> Basic single and dual parameter dispatch" + "> Multiple dispatch, extending [plum](https://github.com/wesselb/plum)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Helpers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def lenient_issubclass(cls, types):\n", - " \"If possible return whether `cls` is a subclass of `types`, otherwise return False.\"\n", - " if cls is object and types is not object: return False # treat `object` as highest level\n", - " try: return isinstance(cls, types) or issubclass(cls, types)\n", - " except: return False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert not lenient_issubclass(typing.Collection, list)\n", - "assert lenient_issubclass(list, typing.Collection)\n", - "assert lenient_issubclass(typing.Collection, object)\n", - "assert lenient_issubclass(typing.List, typing.Collection)\n", - "assert not lenient_issubclass(typing.Collection, typing.List)\n", - "assert not lenient_issubclass(object, typing.Callable)" + "Type dispatch, or [multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based on the input types it receives. This is a prominent feature in some programming languages like [Julia](https://docs.julialang.org/en/v1/manual/methods/).\n", + "\n", + "Type dispatch allows you to have a common API for functions that do similar tasks. This is especially useful in data science, where the same operation (e.g. normalize, categorize) requires an implementation that depends on its input type (e.g. numpy array, pandas dataframe, pytorch tensor).\n", + "\n", + "Fastcore uses and extends the wonderful [plum](https://github.com/wesselb/plum) library's implementation of multiple dispatch for Python. Be sure to view their [informative documentation](https://github.com/wesselb/plum#basic-usage) as well." ] }, { @@ -86,34 +64,11 @@ "outputs": [], "source": [ "#export\n", - "def sorted_topologically(iterable, *, cmp=operator.lt, reverse=False):\n", - " \"Return a new list containing all items from the iterable sorted topologically\"\n", - " l,res = L(list(iterable)),[]\n", - " for _ in range(len(l)):\n", - " t = l.reduce(lambda x,y: y if cmp(y,x) else x)\n", - " res.append(t), l.remove(t)\n", - " return res[::-1] if reverse else res" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "td = [3, 1, 2, 5]\n", - "test_eq(sorted_topologically(td), [1, 2, 3, 5])\n", - "test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "td = {int:1, numbers.Number:2, numbers.Integral:3}\n", - "test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])" + "def _eval_annotations(f):\n", + " \"Evaluate future annotations before passing to plum to support backported union operator `|`\"\n", + " f = copy_func(f)\n", + " for k, v in type_hints(f).items(): f.__annotations__[k] = Union[v] if isinstance(v, tuple) else v\n", + " return f" ] }, { @@ -122,26 +77,13 @@ "metadata": {}, "outputs": [], "source": [ - "td = [numbers.Integral, tuple, list, int, dict]\n", - "td = sorted_topologically(td, cmp=lenient_issubclass)\n", - "assert td.index(int) < td.index(numbers.Integral)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "def _chk_defaults(f, ann):\n", - " pass\n", - "# Implementation removed until we can figure out how to do this without `inspect` module\n", - "# try: # Some callables don't have signatures, so ignore those errors\n", - "# params = list(inspect.signature(f).parameters.values())[:min(len(ann),2)]\n", - "# if any(p.default!=inspect.Parameter.empty for p in params):\n", - "# warn(f\"{f.__name__} has default params. These will be ignored.\")\n", - "# except ValueError: pass" + "#hide\n", + "def f(x:int|str) -> float: pass\n", + "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n", + "def f(x:(int,str)) -> float: pass\n", + "test_eq(_eval_annotations(f).__annotations__, {'x': typing.Union[int, str], 'return': float})\n", + "def f(x): pass\n", + "test_eq(_eval_annotations(f).__annotations__, {})" ] }, { @@ -151,13 +93,17 @@ "outputs": [], "source": [ "#export\n", - "def _p2_anno(f):\n", - " \"Get the 1st 2 annotations of `f`, defaulting to `object`\"\n", - " hints = type_hints(f)\n", - " ann = [o for n,o in hints.items() if n!='return']\n", - " if callable(f): _chk_defaults(f, ann)\n", - " while len(ann)<2: ann.append(object)\n", - " return ann[:2]" + "def _pt_repr(o):\n", + " \"Concise repr of plum types\"\n", + " n = type(o).__name__\n", + " if n == 'Tuple': return f\"{n.lower()}[{','.join(_pt_repr(t) for t in o._el_types)}]\"\n", + " if n == 'List': return f'{n.lower()}[{_pt_repr(o._el_type)}]'\n", + " if n == 'Dict': return f'{n.lower()}[{_pt_repr(o._key_type)},{_pt_repr(o._value_type)}]'\n", + " if n in ('Sequence','Iterable'): return f'{n}[{_pt_repr(o._el_type)}]'\n", + " if n == 'VarArgs': return f'{n}[{_pt_repr(o.type)}]'\n", + " if n == 'Union': return '|'.join(sorted(t.__name__ for t in (o.get_types())))\n", + " assert len(o.get_types()) == 1\n", + " return o.get_types()[0].__name__" ] }, { @@ -167,116 +113,26 @@ "outputs": [], "source": [ "#hide\n", - "def _f(a): pass\n", - "test_eq(_p2_anno(_f), (object,object))\n", - "def _f(a, b): pass\n", - "test_eq(_p2_anno(_f), (object,object))\n", - "def _f(a:None, b)->str: pass\n", - "test_eq(_p2_anno(_f), (NoneType,object))\n", - "def _f(a:str, b)->float: pass\n", - "test_eq(_p2_anno(_f), (str,object))\n", - "def _f(a:None, b:str)->float: pass\n", - "test_eq(_p2_anno(_f), (NoneType,str))\n", - "def _f(a:int, b:int)->float: pass\n", - "test_eq(_p2_anno(_f), (int,int))\n", - "def _f(self, a:int, b:int): pass\n", - "test_eq(_p2_anno(_f), (int,int))\n", - "def _f(a:int, b:str)->float: pass\n", - "test_eq(_p2_anno(_f), (int,str))\n", - "test_eq(_p2_anno(attrgetter('foo')), (object,object))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "([object, object], [int, object])" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#hide\n", - "# Disabled until _chk_defaults fixed\n", - "# def _f(x:int,y:int=10): pass\n", - "# test_warns(lambda: _p2_anno(_f))\n", - "def _f(x:int,y=10): pass\n", - "_p2_anno(None),_p2_anno(_f)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## TypeDispatch" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Type dispatch, or [Multiple dispatch](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia), allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a [conceptual example](https://en.wikipedia.org/wiki/Multiple_dispatch#Julia) of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:\n", - "\n", - "```julia\n", - "collide_with(x::Asteroid, y::Asteroid) = ... \n", - "# deal with asteroid hitting asteroid\n", - "\n", - "collide_with(x::Asteroid, y::Spaceship) = ... \n", - "# deal with asteroid hitting spaceship\n", - "\n", - "collide_with(x::Spaceship, y::Asteroid) = ... \n", - "# deal with spaceship hitting asteroid\n", + "from typing import Dict, List, Iterable, Sequence, Tuple\n", + "from plum.type import VarArgs, ptype\n", "\n", - "collide_with(x::Spaceship, y::Spaceship) = ... \n", - "# deal with spaceship hitting spaceship\n", - "```\n", - "\n", - "Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.\n", - "\n", - "The `TypeDispatch` class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs." + "test_eq(_pt_repr(ptype(int)), 'int')\n", + "test_eq(_pt_repr(ptype(Union[int, str])), 'int|str')\n", + "test_eq(_pt_repr(ptype(Tuple[int, str])), 'tuple[int,str]')\n", + "test_eq(_pt_repr(ptype(List[int])), 'list[int]')\n", + "test_eq(_pt_repr(ptype(Sequence[int])), 'Sequence[int]')\n", + "test_eq(_pt_repr(ptype(Iterable[int])), 'Iterable[int]')\n", + "test_eq(_pt_repr(ptype(Dict[str, int])), 'dict[str,int]')\n", + "test_eq(_pt_repr(ptype(VarArgs[str])), 'VarArgs[str]')\n", + "test_eq(_pt_repr(ptype(Dict[Tuple[Union[int,str],float], List[Tuple[object]]])),\n", + " 'dict[tuple[int|str,float],list[tuple[object]]]')" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "#export\n", - "class _TypeDict:\n", - " def __init__(self): self.d,self.cache = {},{}\n", - "\n", - " def _reset(self):\n", - " self.d = {k:self.d[k] for k in sorted_topologically(self.d, cmp=lenient_issubclass)}\n", - " self.cache = {}\n", - "\n", - " def add(self, t, f):\n", - " \"Add type `t` and function `f`\"\n", - " if not isinstance(t, tuple): t = tuple(L(union2tuple(t)))\n", - " for t_ in t: self.d[t_] = f\n", - " self._reset()\n", - "\n", - " def all_matches(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " if k not in self.cache:\n", - " types = [f for f in self.d if lenient_issubclass(k,f)]\n", - " self.cache[k] = [self.d[o] for o in types]\n", - " return self.cache[k]\n", - "\n", - " def __getitem__(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " res = self.all_matches(k)\n", - " return res[0] if len(res) else None\n", - "\n", - " def __repr__(self): return self.d.__repr__()\n", - " def first(self): return first(self.d.values())" + "## FastFunction -" ] }, { @@ -286,92 +142,34 @@ "outputs": [], "source": [ "#export\n", - "class TypeDispatch:\n", - " \"Dictionary-like object; `__getitem__` matches keys of types using `issubclass`\"\n", - " def __init__(self, funcs=(), bases=()):\n", - " self.funcs,self.bases = _TypeDict(),L(bases).filter(is_not(None))\n", - " for o in L(funcs): self.add(o)\n", - " self.inst = None\n", - " self.owner = None\n", - "\n", - " def add(self, f):\n", - " \"Add type `t` and function `f`\"\n", - " if isinstance(f, staticmethod): a0,a1 = _p2_anno(f.__func__)\n", - " else: a0,a1 = _p2_anno(f)\n", - " t = self.funcs.d.get(a0)\n", - " if t is None:\n", - " t = _TypeDict()\n", - " self.funcs.add(a0, t)\n", - " t.add(a1, f)\n", - "\n", - " def first(self):\n", - " \"Get first function in ordered dict of type:func.\"\n", - " return self.funcs.first().first()\n", - "\n", - " def returns(self, x):\n", - " \"Get the return type of annotation of `x`.\"\n", - " return anno_ret(self[type(x)])\n", - "\n", - " def _attname(self,k): return getattr(k,'__name__',str(k))\n", + "class FastFunction(Function):\n", " def __repr__(self):\n", - " r = [f'({self._attname(k)},{self._attname(l)}) -> {getattr(v, \"__name__\", type(v).__name__)}'\n", - " for k in self.funcs.d for l,v in self.funcs[k].d.items()]\n", - " r = r + [o.__repr__() for o in self.bases]\n", - " return '\\n'.join(r)\n", - "\n", - " def __call__(self, *args, **kwargs):\n", - " ts = L(args).map(type)[:2]\n", - " f = self[tuple(ts)]\n", - " if not f: return args[0]\n", - " if isinstance(f, staticmethod): f = f.__func__\n", - " elif self.inst is not None: f = MethodType(f, self.inst)\n", - " elif self.owner is not None: f = MethodType(f, self.owner)\n", - " return f(*args, **kwargs)\n", + " return '\\n'.join(f\"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}\"\n", + " for s, (f, r) in self.methods.items())\n", "\n", - " def __get__(self, inst, owner):\n", - " self.inst = inst\n", - " self.owner = owner\n", - " return self\n", + " @delegates(Function.dispatch)\n", + " def dispatch(self, f=None, **kwargs): return super().dispatch(_eval_annotations(f), **kwargs)\n", "\n", - " def __getitem__(self, k):\n", - " \"Find first matching type that is a super-class of `k`\"\n", - " k = L(k)\n", - " while len(k)<2: k.append(object)\n", - " r = self.funcs.all_matches(k[0])\n", - " for t in r:\n", - " o = t[k[1]]\n", - " if o is not None: return o\n", - " for base in self.bases:\n", - " res = base[k]\n", - " if res is not None: return res\n", - " return None" + " def __getitem__(self, ts):\n", + " \"Return the most-specific matching method with fewest parameters\"\n", + " ts = L(ts)\n", + " nargs = min(len(o) for o in self.methods.keys())\n", + " while len(ts) < nargs: ts.append(object)\n", + " return self.invoke(*ts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "To demonstrate how `TypeDispatch` works, we define a set of functions that accept a variety of input types, specified with different type annotations:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def f2(x:int, y:float): return x+y #int and float for 2nd arg\n", - "def f_nin(x:numbers.Integral)->int: return x+1 #integral numeric\n", - "def f_ni2(x:int): return x #integer\n", - "def f_bll(x:bool|list): return x #bool or list\n", - "def f_num(x:numbers.Number): return x #Number (root of numerics) " + "`FastFunction` extends `plum.Function` with the following functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can optionally initialize `TypeDispatch` with a list of functions we want to search. Printing an instance of `TypeDispatch` will display convenient mapping of types -> functions:" + "`FastFunction` has a concise `repr`:" ] }, { @@ -382,12 +180,7 @@ { "data": { "text/plain": [ - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(object,object) -> NoneType" + "f: (int) -> float" ] }, "execution_count": null, @@ -396,39 +189,16 @@ } ], "source": [ - "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", - "t" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Note that only the first two arguments are used for `TypeDispatch`. If your function only contains one argument, the second parameter will be shown as `object`. If you pass `None` into `TypeDispatch`, then this will be displayed as `(object, object) -> NoneType`.\n", - "\n", - "`TypeDispatch` is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:\n", - "\n", - "```py\n", - "t[float]\n", - "```\n", - "Will return `f_num` because that is the matching function that has a type annotation that is a super-class of of `float` - `numbers.Number`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert issubclass(float, numbers.Number)\n", - "test_eq(t[float], f_num)" + "def f(x: int) -> float: pass\n", + "f = FastFunction(f).dispatch(f)\n", + "f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The same is true for other types as well:" + "`FastFunction` supports fastcore's backport of the `|` operator on types:" ] }, { @@ -437,169 +207,20 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[np.int32], f_nin)\n", - "test_eq(t[bool], f_bll)\n", - "test_eq(t[list], f_bll)\n", - "test_eq(t[np.int32], f_nin)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you try to get a type that doesn't match, `TypeDispatch` will return `None`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_eq(t[str], None)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "

TypeDispatch.add[source]

\n", - "\n", - "> TypeDispatch.add(**`f`**)\n", - "\n", - "Add type `t` and function `f`" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "show_doc(TypeDispatch.add)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This method allows you to add an additional function to an existing `TypeDispatch` instance :" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(typing.Collection,object) -> f_col\n", - "(object,object) -> NoneType" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f_col(x:typing.Collection): return x\n", - "t.add(f_col)\n", - "test_eq(t[str], f_col)\n", - "t" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you accidentally add the same function more than once things will still work as expected:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.add(f_ni2) \n", - "test_eq(t[int], f_ni2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def f_ni3(z:int): return z # collides with f_ni2 with same type annotations\n", - "t.add(f_ni3) \n", - "test_eq(t[int], f_ni3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Using `bases`:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The argument `bases` can optionally accept a single instance of `TypeDispatch` or a collection (i.e. a tuple or list) of `TypeDispatch` objects. This can provide functionality similar to multiple inheritance. \n", + "def f1(x): return 'obj'\n", + "def f2(x: int|str): return 'int|str'\n", + "f = FastFunction(f1).dispatch(f1).dispatch(f2)\n", "\n", - "These are searched for matching functions if no match in your list of functions:" + "test_eq(f(0), 'int|str')\n", + "test_eq(f(''), 'int|str')\n", + "test_eq(f(0.0), 'obj')" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(str,object) -> f_str\n", - "(bool,object) -> f_bll\n", - "(int,object) -> f_ni2\n", - "(Integral,object) -> f_nin\n", - "(Number,object) -> f_num\n", - "(list,object) -> f_bll\n", - "(object,object) -> NoneType" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "def f_str(x:str): return x+'1'\n", - "\n", - "t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])\n", - "t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.\n", - "t2" + "Indexing a `FastFunction` works like [`plum.Function.invoke`](https://github.com/wesselb/plum#directly-invoke-a-method) but returns the most-specific matching method with the fewest parameters:" ] }, { @@ -608,60 +229,23 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t2[int], f_ni2) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[np.int32], f_nin) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[float], f_num) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[bool], f_bll) # searches `t` b/c not found in `t2`\n", - "test_eq(t2[str], f_str) # found in `t`!\n", - "test_eq(t2('a'), 'a1') # found in `t`!, and uses __call__\n", + "def f1(a: int, b, c): return 'int, 3 args'\n", + "def f2(a: int, b, c, d): return 'int, 4 args'\n", + "def f3(a: float, b, c): return 'float, 3 args'\n", + "def f4(a: float, b: str, c): return 'float, str, 3 args'\n", + "f = FastFunction(f1).dispatch(f1).dispatch(f2).dispatch(f3).dispatch(f4)\n", "\n", - "o = np.int32(1)\n", - "test_eq(t2(o), 2) # found in `t2` and uses __call__" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Up To Two Arguments" + "test_eq(f[int](0,0,0), 'int, 3 args')\n", + "test_eq(f[float](0,0,0), 'float, 3 args')\n", + "test_eq(f[float](0,0,0), 'float, 3 args')\n", + "test_eq(f[float, str](0,0,0), 'float, str, 3 args')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`TypeDispatch` supports up to two arguments when searching for the appropriate function. The following functions `f1` and `f2` both have two parameters:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(int,float) -> f2\n", - "(Integral,object) -> f1" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f1(x:numbers.Integral, y): return x+1 #Integral is a numeric type\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", - "t" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " You can lookup functions from a `TypeDispatch` instance with two parameters like this:" + "## FastDispatcher -" ] }, { @@ -670,154 +254,61 @@ "metadata": {}, "outputs": [], "source": [ - "test_eq(t[np.int32], f1)\n", - "test_eq(t[int,float], f2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added. In the below example, `f1` is ignored in favor of `f2` because the first two parameters have identical type hints:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(str,int) -> f2" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def f1(a:str, b:int, c:list): return a\n", - "def f2(a: str, b:int): return b\n", - "t = TypeDispatch([f1,f2])\n", - "test_eq(t[str, int], f2)\n", - "t" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Matching" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`Type Dispatch` matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions. \n", + "#export\n", + "class FastDispatcher(Dispatcher):\n", + " def _get_function(self, method, owner):\n", + " \"Adapted from `Dispatcher._get_function` to use `FastFunction`\"\n", + " name = method.__name__\n", + " if owner:\n", + " if owner not in self._classes: self._classes[owner] = {}\n", + " namespace = self._classes[owner]\n", + " else: namespace = self._functions\n", + " if name not in namespace: namespace[name] = FastFunction(method, owner=owner)\n", + " return namespace[name]\n", "\n", - "Let's consider an example where we try to retrieve the function corresponding to types of `[np.int32, float]`.\n", + " @delegates(Dispatcher.__call__, but='method')\n", + " def __call__(self, f, **kwargs): return super().__call__(_eval_annotations(f), **kwargs)\n", "\n", - "In this scenario, `f2` will not be matched. This is because the first type annotation of `f2`, `int`, is not a superclass (or the same class) of `np.int32`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def f1(x:numbers.Integral, y): return x+1\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", + " def _to(self, cls, nm, f, **kwargs):\n", + " nf = copy_func(f)\n", + " nf.__qualname__ = f'{cls.__name__}.{nm}' # plum uses __qualname__ to infer f's owner\n", + " pf = self(nf, **kwargs)\n", + " # plum uses __set_name__ to resolve a plum.Function's owner\n", + " # since we assign after class creation, __set_name__ must be called directly\n", + " # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n", + " pf.__set_name__(cls, nm)\n", + " pf = pf.resolve()\n", + " setattr(cls, nm, pf)\n", + " return pf\n", "\n", - "assert not issubclass(np.int32, int)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Instead, `f1` is a valid match, as its first argument is annoted with the type `numbers.Integeral`, which `np.int32` is a subclass of: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert issubclass(np.int32, numbers.Integral)\n", - "test_eq(t[np.int32,float], f1) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In `f1` , the 2nd parameter `y` is not annotated, which means `TypeDispatch` will match anything where the first argument matches `int` that is not matched with anything else:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral\n", - "test_eq(t[int], f1)\n", - "test_eq(t[int,int], f1)" + " def to(self, cls):\n", + " \"Decorator: dispatch `f` to `cls.f`\"\n", + " def _inner(f, **kwargs):\n", + " nm = f.__name__\n", + " # check __dict__ to avoid inherited methods but use getattr so pf.__get__ is called, which plum relies on\n", + " if nm in cls.__dict__:\n", + " pf = getattr(cls, nm)\n", + " if not hasattr(pf, 'dispatch'): pf = self._to(cls, nm, pf, **kwargs)\n", + " pf.dispatch(f)\n", + " else: pf = self._to(cls, nm, f, **kwargs)\n", + " return pf\n", + " return _inner\n", + "\n", + "typedispatch = FastDispatcher()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If no match is possible, `None` is returned:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_eq(t[float,float], None)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "

TypeDispatch.__call__[source]

\n", - "\n", - "> TypeDispatch.__call__(**\\*`args`**, **\\*\\*`kwargs`**)\n", - "\n", - "Call self as a function." - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "show_doc(TypeDispatch.__call__)" + "`FastDispatcher` extends `plum.Dispatcher` with the following functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "`TypeDispatch` is also callable. When you call an instance of `TypeDispatch`, it will execute the relevant function:" + "Dispatching with `FastDispatcher` returns a `FastFunction`:" ] }, { @@ -826,23 +317,17 @@ "metadata": {}, "outputs": [], "source": [ - "def f_arr(x:np.ndarray): return x.sum()\n", - "def f_int(x:np.int32): return x+1\n", - "t = TypeDispatch([f_arr, f_int])\n", - "\n", - "arr = np.array([5,4,3,2,1])\n", - "test_eq(t(arr), 15) # dispatches to f_arr\n", + "@typedispatch\n", + "def f(x): return 'obj'\n", "\n", - "o = np.int32(1)\n", - "test_eq(t(o), 2) # dispatches to f_int\n", - "assert t.first() is not None " + "assert isinstance(f, FastFunction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can also call an instance of of `TypeDispatch` when there are two parameters:" + "`FastDispatcher` supports fastcore's backport of the `|` operator on types:" ] }, { @@ -851,61 +336,19 @@ "metadata": {}, "outputs": [], "source": [ - "def f1(x:numbers.Integral, y): return x+1\n", - "def f2(x:int, y:float): return x+y\n", - "t = TypeDispatch([f1,f2])\n", + "@typedispatch\n", + "def f(x:int|str): return 'int|str'\n", "\n", - "test_eq(t(3,2.0), 5)\n", - "test_eq(t(3,2), 4)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "When no match is found, a `TypeDispatch` instance becomes an identity function. This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_eq(t('a'), 'a')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "

TypeDispatch.returns[source]

\n", - "\n", - "> TypeDispatch.returns(**`x`**)\n", - "\n", - "Get the return type of annotation of `x`." - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "show_doc(TypeDispatch.returns)" + "test_eq(f(0), 'int|str')\n", + "test_eq(f(''), 'int|str')\n", + "test_eq(f(0.0), 'obj')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can optionally pass an object to `TypeDispatch.returns` and get the return type annotation back:" + "... `FastDispatcher.multi` works too:" ] }, { @@ -914,33 +357,21 @@ "metadata": {}, "outputs": [], "source": [ - "def f1(x:int) -> np.ndarray: return np.array(x)\n", - "def f2(x:str) -> float: return List\n", - "def f3(x:float): return List # f3 has no return type annotation\n", - "\n", - "t = TypeDispatch([f1, f2, f3])\n", - "\n", - "test_eq(t.returns(1), np.ndarray) # dispatched to f1\n", - "test_eq(t.returns('Hello'), float) # dispatched to f2\n", - "test_eq(t.returns(1.0), None) # dispatched to f3\n", + "@typedispatch.multi([bool],[list])\n", + "def f(x: bool|list): return 'bool|list'\n", + "@typedispatch\n", + "def f(x: int): return 'int'\n", "\n", - "class _Test: pass\n", - "_test = _Test()\n", - "test_eq(t.returns(_test), None) # type `_Test` not found, so None returned" + "test_eq(f(True), 'bool|list')\n", + "test_eq(f([]), 'bool|list')\n", + "test_eq(f(0), 'int')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Using TypeDispatch With Methods" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can use `TypeDispatch` when defining methods as well:" + "`FastDispatcher.to` lets you dynamically dispatch to class instance methods:" ] }, { @@ -949,76 +380,23 @@ "metadata": {}, "outputs": [], "source": [ - "def m_nin(self, x:str|numbers.Integral): return str(x)+'1'\n", - "def m_bll(self, x:bool): self.foo='a'\n", - "def m_num(self, x:numbers.Number): return x*2\n", + "class A:\n", + " @typedispatch\n", + " def f(self, x): return 'obj'\n", "\n", - "t = TypeDispatch([m_nin,m_num,m_bll])\n", - "class A: f = t # set class attribute `f` equal to a TypeDispatch instance\n", - " \n", - "a = A()\n", - "test_eq(a.f(1), '11') #dispatch to m_nin\n", - "test_eq(a.f(1.), 2.) #dispatch to m_num\n", - "test_is(a.f.inst, a)\n", + "@typedispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", "\n", - "a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'\n", - "test_eq(a.foo, 'a')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "As discussed in `TypeDispatch.__call__`, when there is not a match, `TypeDispatch.__call__` becomes an identity function. In the below example, a tuple does not match any type annotations so a tuple is returned:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "test_eq(a.f(()), ()) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We extend the previous example by using `bases` to add an additional method that supports tuples:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def m_tup(self, x:tuple): return x+(1,)\n", - "t2 = TypeDispatch(m_tup, bases=t)\n", - "\n", - "class A2: f = t2\n", - "a2 = A2()\n", - "test_eq(a2.f(1), '11')\n", - "test_eq(a2.f(1.), 2.)\n", - "test_is(a2.f.inst, a2)\n", - "a2.f(False)\n", - "test_eq(a2.foo, 'a')\n", - "test_eq(a2.f(()), (1,))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Using TypeDispatch With Class Methods" + "a = A()\n", + "test_eq(a.f(0), 'int')\n", + "test_eq(a.f(''), 'obj')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You can use `TypeDispatch` when defining class methods too:" + "### Tests -" ] }, { @@ -1027,46 +405,20 @@ "metadata": {}, "outputs": [], "source": [ - "def m_nin(cls, x:str|numbers.Integral): return str(x)+'1'\n", - "def m_bll(cls, x:bool): cls.foo='a'\n", - "def m_num(cls, x:numbers.Number): return x*2\n", + "#hide\n", + "#Call `to` twice consecutively\n", + "class A: pass\n", "\n", - "t = TypeDispatch([m_nin,m_num,m_bll])\n", - "class A: f = t # set class attribute `f` equal to a TypeDispatch\n", + "@typedispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", "\n", - "test_eq(A.f(1), '11') #dispatch to m_nin\n", - "test_eq(A.f(1.), 2.) #dispatch to m_num\n", - "test_is(A.f.owner, A)\n", + "a = A()\n", + "test_eq(a.f(0), 'int')\n", "\n", - "A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'\n", - "test_eq(A.foo, 'a')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## typedispatch Decorator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#export\n", - "class DispatchReg:\n", - " \"A global registry for `TypeDispatch` objects keyed by function name\"\n", - " def __init__(self): self.d = defaultdict(TypeDispatch)\n", - " def __call__(self, f):\n", - " if isinstance(f, (classmethod, staticmethod)): nm = f'{f.__func__.__qualname__}'\n", - " else: nm = f'{f.__qualname__}'\n", - " if isinstance(f, classmethod): f=f.__func__\n", - " self.d[nm].add(f)\n", - " return self.d[nm]\n", + "@typedispatch.to(A)\n", + "def f(self, x:str): return 'str'\n", "\n", - "typedispatch = DispatchReg()" + "test_eq(a.f(''), 'str')" ] }, { @@ -1075,34 +427,17 @@ "metadata": {}, "outputs": [], "source": [ - "@typedispatch\n", - "def f_td_test(x, y): return f'{x}{y}'\n", - "@typedispatch\n", - "def f_td_test(x:numbers.Integral|int, y): return x+1\n", - "@typedispatch\n", - "def f_td_test(x:int, y:float): return x+y\n", - "@typedispatch\n", - "def f_td_test(x:int, y:int): return x*y\n", + "#hide\n", + "#Call `to` on an ordinary method (not a `FastFunction`)\n", + "class A:\n", + " def f(self, x): return 'obj'\n", "\n", - "test_eq(f_td_test(3,2.0), 5)\n", - "assert issubclass(int, numbers.Integral)\n", - "test_eq(f_td_test(3,2), 6)\n", + "@typedispatch.to(A)\n", + "def f(self, x:int): return 'int'\n", "\n", - "test_eq(f_td_test('a','b'), 'ab')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Using typedispatch With other decorators" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can use `typedispatch` with `classmethod` and `staticmethod` decorator" + "a = A()\n", + "test_eq(a.f(0), 'int')\n", + "test_eq(a.f(''), 'obj')" ] }, { @@ -1111,19 +446,19 @@ "metadata": {}, "outputs": [], "source": [ + "#hide\n", + "#Calling `to` when there is a matching inherited method doesn't alter the base class\n", + "#but still dispatches to it\n", "class A:\n", + " def f(self, x): return 'A'\n", + "Af = A.f\n", + "class B(A):\n", " @typedispatch\n", - " def f_td_test(self, x:numbers.Integral, y): return x+1\n", - " @typedispatch\n", - " @classmethod\n", - " def f_td_test(cls, x:int, y:float): return x+y\n", - " @typedispatch\n", - " @staticmethod\n", - " def f_td_test(x:int, y:int): return x*y\n", - " \n", - "test_eq(A.f_td_test(3,2), 6)\n", - "test_eq(A.f_td_test(3,2.0), 5)\n", - "test_eq(A().f_td_test(3,'2.0'), 4)" + " def f(self, x:int): return 'B'\n", + "test_is(Af, A.f)\n", + "b = B()\n", + "test_eq(b.f(0), 'B')\n", + "test_eq(b.f(''), 'A')" ] }, { diff --git a/nbs/05_transform.ipynb b/nbs/05_transform.ipynb index e0bfbbaa0..18b6758f6 100644 --- a/nbs/05_transform.ipynb +++ b/nbs/05_transform.ipynb @@ -20,7 +20,8 @@ "from fastcore.foundation import *\n", "from fastcore.utils import *\n", "from fastcore.dispatch import *\n", - "import inspect" + "import inspect\n", + "from plum import add_conversion_method" ] }, { @@ -69,11 +70,10 @@ "#export\n", "_tfm_methods = 'encodes','decodes','setups'\n", "\n", + "def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n", + "\n", "class _TfmDict(dict):\n", - " def __setitem__(self,k,v):\n", - " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", - " if k not in self: super().__setitem__(k,TypeDispatch())\n", - " self[k].add(v)" + " def __setitem__(self, k, v): super().__setitem__(k, typedispatch(v) if _is_tfm_method(k, v) else v)" ] }, { @@ -85,21 +85,20 @@ "#export\n", "class _TfmMeta(type):\n", " def __new__(cls, name, bases, dict):\n", + " # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n", " res = super().__new__(cls, name, bases, dict)\n", - " for nm in _tfm_methods:\n", - " base_td = [getattr(b,nm,None) for b in bases]\n", - " if nm in res.__dict__: getattr(res,nm).bases = base_td\n", - " else: setattr(res, nm, TypeDispatch(bases=base_td))\n", " res.__signature__ = inspect.signature(res.__init__)\n", " return res\n", "\n", " def __call__(cls, *args, **kwargs):\n", - " f = args[0] if args else None\n", - " n = getattr(f,'__name__',None)\n", - " if callable(f) and n in _tfm_methods:\n", - " getattr(cls,n).add(f)\n", - " return f\n", - " return super().__call__(*args, **kwargs)\n", + " f = first(args)\n", + " n = getattr(f, '__name__', None)\n", + " if _is_tfm_method(n, f): return typedispatch.to(cls)(f)\n", + " obj = super().__call__(*args, **kwargs)\n", + " # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n", + " # instances of cls, fix it\n", + " if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n", + " return obj\n", "\n", " @classmethod\n", " def __prepare__(cls, name, bases): return _TfmDict()" @@ -144,13 +143,14 @@ " self.init_enc = enc or dec\n", " if not self.init_enc: return\n", "\n", - " self.encodes,self.decodes,self.setups = TypeDispatch(),TypeDispatch(),TypeDispatch()\n", + " def identity(x): return x\n", + " for n in _tfm_methods: setattr(self,n,FastFunction(identity).dispatch(identity))\n", " if enc:\n", - " self.encodes.add(enc)\n", + " self.encodes.dispatch(enc)\n", " self.order = getattr(enc,'order',self.order)\n", " if len(type_hints(enc)) > 0: self.input_types = union2tuple(first(type_hints(enc).values()))\n", " self._name = _get_name(enc)\n", - " if dec: self.decodes.add(dec)\n", + " if dec: self.decodes.dispatch(dec)\n", "\n", " @property\n", " def name(self): return getattr(self, '_name', _get_name(self))\n", @@ -169,14 +169,32 @@ " def _do_call(self, f, x, **kwargs):\n", " if not _is_tuple(x):\n", " if f is None: return x\n", - " ret = f.returns(x) if hasattr(f,'returns') else None\n", - " return retain_type(f(x, **kwargs), x, ret)\n", + " ts = [type(self),type(x)] if hasattr(f,'instance') else [type(x)]\n", + " _, ret = f.resolve_method(*ts)\n", + " ret = ret._type\n", + " # plum reads empty return annotation as object, retain_type expects it as None\n", + " if ret is object: ret = None\n", + " return retain_type(f(x,**kwargs), x, ret)\n", " res = tuple(self._do_call(f, x_, **kwargs) for x_ in x)\n", " return retain_type(res, x)\n", + " def encodes(self, x): return x\n", + " def decodes(self, x): return x\n", + " def setups(self, dl): return dl\n", "\n", "add_docs(Transform, decode=\"Delegate to decodes to undo transform\", setup=\"Delegate to setups to set up transform\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "#Implement the Transform convention that a None return annotation disables conversion\n", + "add_conversion_method(object, NoneType, lambda x: x)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -368,6 +386,44 @@ "test_eq_type(f3(2), 2)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Transforms can be created from class methods too:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A:\n", + " @classmethod\n", + " def create(cls, x:int): return x+1\n", + "test_eq(Transform(A.create)(1), 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#hide\n", + "# Test extension of a tfm method defined in the class\n", + "class A(Transform):\n", + " def encodes(self, x): return 'obj'\n", + "\n", + "@A\n", + "def encodes(self, x:int): return 'int'\n", + "\n", + "a = A()\n", + "test_eq(a.encodes(0), 'int')\n", + "test_eq(a.encodes(0.0), 'obj')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -845,7 +901,7 @@ "def encodes(self, x:str): return x+'hello'\n", "\n", "@B\n", - "def encodes(self, x)->None: return str(x)+'!'" + "def encodes(self, x): return str(x)+'!'" ] }, { @@ -1015,8 +1071,7 @@ "data": { "text/plain": [ "A:\n", - "encodes: (object,object) -> noop\n", - "decodes: (object,object) -> noop" + "encodes: noop: (object,VarArgs[object]) -> objectdecodes: noop: (object,VarArgs[object]) -> object" ] }, "execution_count": null, @@ -1046,8 +1101,7 @@ "data": { "text/plain": [ "A -- {'a': 1, 'b': 2}:\n", - "encodes: (object,object) -> noop\n", - "decodes: " + "encodes: noop: (object,VarArgs[object]) -> objectdecodes: decodes: (object,object) -> object" ] }, "execution_count": null, @@ -1933,13 +1987,6 @@ "from nbdev.export import notebook2script\n", "notebook2script()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/settings.ini b/settings.ini index 625c459c4..b5e60ec8e 100644 --- a/settings.ini +++ b/settings.ini @@ -7,7 +7,7 @@ author = Jeremy Howard and Sylvain Gugger author_email = infos@fast.ai copyright = fast.ai branch = master -version = 1.4.6 +version = 1.5.0 min_python = 3.7 audience = Developers language = English diff --git a/setup.py b/setup.py index 1fa9526b5..aea178003 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ min_python = cfg['min_python'] lic = licenses[cfg['license']] -requirements = ['pip', 'packaging'] +requirements = ['pip', 'packaging', 'plum-dispatch>=1.6'] if cfg.get('requirements'): requirements += cfg.get('requirements','').split() if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() dev_requirements = (cfg.get('dev_requirements') or '').split()