From fce1adb2c6c42af1d9c0a005bfc83a5b820e9a51 Mon Sep 17 00:00:00 2001 From: seem Date: Tue, 14 Jun 2022 10:50:34 +1000 Subject: [PATCH] fix python 3.7 compatibility; address review comments --- fastcore/transform.py | 16 +++-- nbs/05_transform.ipynb | 130 ++++++++++------------------------------- setup.py | 2 +- 3 files changed, 41 insertions(+), 107 deletions(-) diff --git a/fastcore/transform.py b/fastcore/transform.py index 06f43004..59755996 100644 --- a/fastcore/transform.py +++ b/fastcore/transform.py @@ -16,7 +16,6 @@ import inspect from copy import copy from plum import add_conversion_method, dispatch, Function -from typing import get_args, get_origin # Cell # Convert tuple annotations to unions to work with plum @@ -27,9 +26,13 @@ def _annot_tuple_to_union(f): def _dispatch(f): return dispatch(_annot_tuple_to_union(f)) +def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f)) + def _dispatch_method(f, cls): - f = copy(f) n = f.__name__ + # Use __dict__ to avoid searching base classes + if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f) + f = copy(f) # plum uses __qualname__ to infer f's owner f.__qualname__ = f'{cls.__name__}.{n}' pf = _dispatch(f) @@ -40,8 +43,6 @@ def _dispatch_method(f, cls): pf.__set_name__(cls, n) return pf -def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f)) - # Cell _tfm_methods = 'encodes','decodes','setups' @@ -98,6 +99,10 @@ def _pt_repr(o): def _pf_repr(pf): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}" for s, (f, r) in pf.methods.items()) +# Cell +def _union_to_tuple(t): + return t.__args__ if getattr(t,'__origin__',None) is Union else t + # Cell class Transform(metaclass=_TfmMeta): "Delegates (`__call__`,`decode`,`setup`) to (encodes,decodes,setups) if `split_idx` matches" @@ -114,9 +119,8 @@ def identity(x): return x _pf_dispatch(self.encodes, enc) self.order = getattr(enc,'order',self.order) if len(type_hints(enc)) > 0: - self.input_types = first(type_hints(enc).values()) # Convert Union to tuple, remove once the rest of fastai supports Union - if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types) + self.input_types = _union_to_tuple(first(type_hints(enc).values())) self._name = _get_name(enc) if dec: _pf_dispatch(self.decodes, dec) diff --git a/nbs/05_transform.ipynb b/nbs/05_transform.ipynb index 23155b2e..bcf1d847 100644 --- a/nbs/05_transform.ipynb +++ b/nbs/05_transform.ipynb @@ -23,8 +23,7 @@ "from fastcore.dispatch import *\n", "import inspect\n", "from copy import copy\n", - "from plum import add_conversion_method, dispatch, Function\n", - "from typing import get_args, get_origin" + "from plum import add_conversion_method, dispatch, Function" ] }, { @@ -78,9 +77,13 @@ "\n", "def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n", "\n", + "def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))\n", + "\n", "def _dispatch_method(f, cls):\n", - " f = copy(f)\n", " n = f.__name__\n", + " # Use __dict__ to avoid searching base classes\n", + " if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)\n", + " f = copy(f)\n", " # plum uses __qualname__ to infer f's owner\n", " f.__qualname__ = f'{cls.__name__}.{n}'\n", " pf = _dispatch(f)\n", @@ -89,9 +92,7 @@ " # since we assign after class creation, __set_name__ must be called directly\n", " # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n", " pf.__set_name__(cls, n)\n", - " return pf\n", - "\n", - "def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))" + " return pf" ] }, { @@ -225,6 +226,28 @@ "test_eq(_pf_repr(_f), '_f1: (int,dict[str,float]) -> float\\n_f2: (int,tuple[str,float]) -> float')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "def _union_to_tuple(t):\n", + " return t.__args__ if getattr(t,'__origin__',None) is Union else t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_eq(_union_to_tuple(Union[int,Union[str,None]]), (int,str,NoneType))\n", + "test_eq(_union_to_tuple(Tuple[int,str]), Tuple[int,str])\n", + "test_eq(_union_to_tuple(int), int)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -247,9 +270,8 @@ " _pf_dispatch(self.encodes, enc)\n", " self.order = getattr(enc,'order',self.order)\n", " if len(type_hints(enc)) > 0:\n", - " self.input_types = first(type_hints(enc).values())\n", " # Convert Union to tuple, remove once the rest of fastai supports Union\n", - " if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n", + " self.input_types = _union_to_tuple(first(type_hints(enc).values()))\n", " self._name = _get_name(enc)\n", " if dec: _pf_dispatch(self.decodes, dec)\n", "\n", @@ -520,15 +542,6 @@ "`Transform` can be used as a decorator to turn a function into a `Transform`." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from nbdev.showdoc import _format_cls_doc, _format_func_doc" - ] - }, { "cell_type": "code", "execution_count": null, @@ -680,18 +693,6 @@ "test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@Transform\n", - "def f(x:(int,float)): return x+1\n", - "test_eq(f(0), 1)\n", - "test_eq(f('a'), 'a')" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -929,77 +930,6 @@ "test_eq(f.decode(t), [1,2])" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def encodes(self, x): pass" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Promise(obj= with 2 method(s)>)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "AL(encodes)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ":\n", - "encodes: : (object) -> object\n", - "decodes: identity: (object) -> object" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "AL(lambda x: x)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "__main__.AL" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(AL(lambda x: x))" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/setup.py b/setup.py index 6eb65456..5b93cb4c 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ min_python = cfg['min_python'] lic = licenses[cfg['license']] -requirements = ['pip', 'packaging', 'plum-dispatch>=1.5.16'] +requirements = ['pip', 'packaging', 'plum-dispatch>=1.6'] if cfg.get('requirements'): requirements += cfg.get('requirements','').split() if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() dev_requirements = (cfg.get('dev_requirements') or '').split()