Skip to content

Commit

Permalink
refactor transform
Browse files Browse the repository at this point in the history
  • Loading branch information
seeM committed Jul 1, 2022
1 parent 06922b7 commit a52b382
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 13 deletions.
19 changes: 13 additions & 6 deletions fastcore/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# Cell
_tfm_methods = 'encodes','decodes','setups'

def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)

class _TfmDict(dict):
def __setitem__(self,k,v):
if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)
def __setitem__(self, k, v):
if not _is_tfm_method(k, v): return super().__setitem__(k,v)
if k not in self: super().__setitem__(k,TypeDispatch())
self[k].add(v)

Expand All @@ -27,16 +29,21 @@ def __new__(cls, name, bases, dict):
base_td = [getattr(b,nm,None) for b in bases]
if nm in res.__dict__: getattr(res,nm).bases = base_td
else: setattr(res, nm, TypeDispatch(bases=base_td))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
res.__signature__ = inspect.signature(res.__init__)
return res

def __call__(cls, *args, **kwargs):
f = args[0] if args else None
n = getattr(f,'__name__',None)
if callable(f) and n in _tfm_methods:
f = first(args)
n = getattr(f, '__name__', None)
if _is_tfm_method(n, f):
getattr(cls,n).add(f)
return f
return super().__call__(*args, **kwargs)
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
return obj

@classmethod
def __prepare__(cls, name, bases): return _TfmDict()
Expand Down
59 changes: 52 additions & 7 deletions nbs/05_transform.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@
"#export\n",
"_tfm_methods = 'encodes','decodes','setups'\n",
"\n",
"def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n",
"\n",
"class _TfmDict(dict):\n",
" def __setitem__(self,k,v):\n",
" if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n",
" def __setitem__(self, k, v):\n",
" if not _is_tfm_method(k, v): return super().__setitem__(k,v)\n",
" if k not in self: super().__setitem__(k,TypeDispatch())\n",
" self[k].add(v)"
]
Expand All @@ -90,16 +92,21 @@
" base_td = [getattr(b,nm,None) for b in bases]\n",
" if nm in res.__dict__: getattr(res,nm).bases = base_td\n",
" else: setattr(res, nm, TypeDispatch(bases=base_td))\n",
" # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n",
" res.__signature__ = inspect.signature(res.__init__)\n",
" return res\n",
"\n",
" def __call__(cls, *args, **kwargs):\n",
" f = args[0] if args else None\n",
" n = getattr(f,'__name__',None)\n",
" if callable(f) and n in _tfm_methods:\n",
" f = first(args)\n",
" n = getattr(f, '__name__', None)\n",
" if _is_tfm_method(n, f):\n",
" getattr(cls,n).add(f)\n",
" return f\n",
" return super().__call__(*args, **kwargs)\n",
" obj = super().__call__(*args, **kwargs)\n",
" # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n",
" # instances of cls, fix it\n",
" if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n",
" return obj\n",
"\n",
" @classmethod\n",
" def __prepare__(cls, name, bases): return _TfmDict()"
Expand Down Expand Up @@ -368,6 +375,44 @@
"test_eq_type(f3(2), 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Transforms can be created from class methods too:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class A:\n",
" @classmethod\n",
" def create(cls, x:int): return x+1\n",
"test_eq(Transform(A.create)(1), 2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"# Test extension of a tfm method defined in the class\n",
"class A(Transform):\n",
" def encodes(self, x): return 'obj'\n",
"\n",
"@A\n",
"def encodes(self, x:int): return 'int'\n",
"\n",
"a = A()\n",
"test_eq(a.encodes(0), 'int')\n",
"test_eq(a.encodes(0.0), 'obj')"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -845,7 +890,7 @@
"def encodes(self, x:str): return x+'hello'\n",
"\n",
"@B\n",
"def encodes(self, x)->None: return str(x)+'!'"
"def encodes(self, x): return str(x)+'!'"
]
},
{
Expand Down

0 comments on commit a52b382

Please sign in to comment.