Skip to content

Commit

Permalink
fix python 3.7 compatibility; address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
seeM committed Jun 14, 2022
1 parent 5d1a117 commit fce1adb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 107 deletions.
16 changes: 10 additions & 6 deletions fastcore/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import inspect
from copy import copy
from plum import add_conversion_method, dispatch, Function
from typing import get_args, get_origin

# Cell
# Convert tuple annotations to unions to work with plum
Expand All @@ -27,9 +26,13 @@ def _annot_tuple_to_union(f):

def _dispatch(f): return dispatch(_annot_tuple_to_union(f))

def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))

def _dispatch_method(f, cls):
f = copy(f)
n = f.__name__
# Use __dict__ to avoid searching base classes
if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)
f = copy(f)
# plum uses __qualname__ to infer f's owner
f.__qualname__ = f'{cls.__name__}.{n}'
pf = _dispatch(f)
Expand All @@ -40,8 +43,6 @@ def _dispatch_method(f, cls):
pf.__set_name__(cls, n)
return pf

def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))

# Cell
_tfm_methods = 'encodes','decodes','setups'

Expand Down Expand Up @@ -98,6 +99,10 @@ def _pt_repr(o):
def _pf_repr(pf): return '\n'.join(f"{f.__name__}: ({','.join(_pt_repr(t) for t in s.types)}) -> {_pt_repr(r)}"
for s, (f, r) in pf.methods.items())

# Cell
def _union_to_tuple(t):
return t.__args__ if getattr(t,'__origin__',None) is Union else t

# Cell
class Transform(metaclass=_TfmMeta):
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
Expand All @@ -114,9 +119,8 @@ def identity(x): return x
_pf_dispatch(self.encodes, enc)
self.order = getattr(enc,'order',self.order)
if len(type_hints(enc)) > 0:
self.input_types = first(type_hints(enc).values())
# Convert Union to tuple, remove once the rest of fastai supports Union
if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)
self.input_types = _union_to_tuple(first(type_hints(enc).values()))
self._name = _get_name(enc)
if dec: _pf_dispatch(self.decodes, dec)

Expand Down
130 changes: 30 additions & 100 deletions nbs/05_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
"from fastcore.dispatch import *\n",
"import inspect\n",
"from copy import copy\n",
"from plum import add_conversion_method, dispatch, Function\n",
"from typing import get_args, get_origin"
"from plum import add_conversion_method, dispatch, Function"
]
},
{
Expand Down Expand Up @@ -78,9 +77,13 @@
"\n",
"def _dispatch(f): return dispatch(_annot_tuple_to_union(f))\n",
"\n",
"def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))\n",
"\n",
"def _dispatch_method(f, cls):\n",
" f = copy(f)\n",
" n = f.__name__\n",
" # Use __dict__ to avoid searching base classes\n",
" if n in cls.__dict__: return _pf_dispatch(getattr(cls, n), f)\n",
" f = copy(f)\n",
" # plum uses __qualname__ to infer f's owner\n",
" f.__qualname__ = f'{cls.__name__}.{n}'\n",
" pf = _dispatch(f)\n",
Expand All @@ -89,9 +92,7 @@
" # since we assign after class creation, __set_name__ must be called directly\n",
" # source: https://docs.python.org/3/reference/datamodel.html#object.__set_name__\n",
" pf.__set_name__(cls, n)\n",
" return pf\n",
"\n",
"def _pf_dispatch(pf, f): return pf.dispatch(_annot_tuple_to_union(f))"
" return pf"
]
},
{
Expand Down Expand Up @@ -225,6 +226,28 @@
"test_eq(_pf_repr(_f), '_f1: (int,dict[str,float]) -> float\\n_f2: (int,tuple[str,float]) -> float')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def _union_to_tuple(t):\n",
" return t.__args__ if getattr(t,'__origin__',None) is Union else t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_eq(_union_to_tuple(Union[int,Union[str,None]]), (int,str,NoneType))\n",
"test_eq(_union_to_tuple(Tuple[int,str]), Tuple[int,str])\n",
"test_eq(_union_to_tuple(int), int)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -247,9 +270,8 @@
" _pf_dispatch(self.encodes, enc)\n",
" self.order = getattr(enc,'order',self.order)\n",
" if len(type_hints(enc)) > 0:\n",
" self.input_types = first(type_hints(enc).values())\n",
" # Convert Union to tuple, remove once the rest of fastai supports Union\n",
" if get_origin(self.input_types) is Union: self.input_types=get_args(self.input_types)\n",
" self.input_types = _union_to_tuple(first(type_hints(enc).values()))\n",
" self._name = _get_name(enc)\n",
" if dec: _pf_dispatch(self.decodes, dec)\n",
"\n",
Expand Down Expand Up @@ -520,15 +542,6 @@
"`Transform` can be used as a decorator to turn a function into a `Transform`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nbdev.showdoc import _format_cls_doc, _format_func_doc"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -680,18 +693,6 @@
"test_eq(f(['a','b','c']), \"['a', 'b', 'c']_1\") # input is of type list"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@Transform\n",
"def f(x:(int,float)): return x+1\n",
"test_eq(f(0), 1)\n",
"test_eq(f('a'), 'a')"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -929,77 +930,6 @@
"test_eq(f.decode(t), [1,2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def encodes(self, x): pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Promise(obj=<function <function AL.encodes at 0x11e3fb670> with 2 method(s)>)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AL(encodes)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<lambda>:\n",
"encodes: <lambda>: (object) -> object\n",
"decodes: identity: (object) -> object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AL(lambda x: x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"__main__.AL"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(AL(lambda x: x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
min_python = cfg['min_python']
lic = licenses[cfg['license']]

requirements = ['pip', 'packaging', 'plum-dispatch>=1.5.16']
requirements = ['pip', 'packaging', 'plum-dispatch>=1.6']
if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
dev_requirements = (cfg.get('dev_requirements') or '').split()
Expand Down

0 comments on commit fce1adb

Please sign in to comment.