diff --git a/fastcore/transform.py b/fastcore/transform.py index 7f3e2efc..b8bba612 100644 --- a/fastcore/transform.py +++ b/fastcore/transform.py @@ -13,9 +13,11 @@ # Cell _tfm_methods = 'encodes','decodes','setups' +def _is_tfm_method(n, f): return n in _tfm_methods and callable(f) + class _TfmDict(dict): - def __setitem__(self,k,v): - if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v) + def __setitem__(self, k, v): + if not _is_tfm_method(k, v): return super().__setitem__(k,v) if k not in self: super().__setitem__(k,TypeDispatch()) self[k].add(v) @@ -27,16 +29,21 @@ def __new__(cls, name, bases, dict): base_td = [getattr(b,nm,None) for b in bases] if nm in res.__dict__: getattr(res,nm).bases = base_td else: setattr(res, nm, TypeDispatch(bases=base_td)) + # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back res.__signature__ = inspect.signature(res.__init__) return res def __call__(cls, *args, **kwargs): - f = args[0] if args else None - n = getattr(f,'__name__',None) - if callable(f) and n in _tfm_methods: + f = first(args) + n = getattr(f, '__name__', None) + if _is_tfm_method(n, f): getattr(cls,n).add(f) return f - return super().__call__(*args, **kwargs) + obj = super().__call__(*args, **kwargs) + # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable + # instances of cls, fix it + if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__) + return obj @classmethod def __prepare__(cls, name, bases): return _TfmDict() diff --git a/nbs/05_transform.ipynb b/nbs/05_transform.ipynb index e0bfbbaa..9bc78109 100644 --- a/nbs/05_transform.ipynb +++ b/nbs/05_transform.ipynb @@ -69,9 +69,11 @@ "#export\n", "_tfm_methods = 'encodes','decodes','setups'\n", "\n", + "def _is_tfm_method(n, f): return n in _tfm_methods and callable(f)\n", + "\n", "class _TfmDict(dict):\n", - " def __setitem__(self,k,v):\n", - " if k not in _tfm_methods or not callable(v): return super().__setitem__(k,v)\n", + " def __setitem__(self, k, v):\n", + " if not _is_tfm_method(k, v): return super().__setitem__(k,v)\n", " if k not in self: super().__setitem__(k,TypeDispatch())\n", " self[k].add(v)" ] @@ -90,16 +92,21 @@ " base_td = [getattr(b,nm,None) for b in bases]\n", " if nm in res.__dict__: getattr(res,nm).bases = base_td\n", " else: setattr(res, nm, TypeDispatch(bases=base_td))\n", + " # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back\n", " res.__signature__ = inspect.signature(res.__init__)\n", " return res\n", "\n", " def __call__(cls, *args, **kwargs):\n", - " f = args[0] if args else None\n", - " n = getattr(f,'__name__',None)\n", - " if callable(f) and n in _tfm_methods:\n", + " f = first(args)\n", + " n = getattr(f, '__name__', None)\n", + " if _is_tfm_method(n, f):\n", " getattr(cls,n).add(f)\n", " return f\n", - " return super().__call__(*args, **kwargs)\n", + " obj = super().__call__(*args, **kwargs)\n", + " # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable\n", + " # instances of cls, fix it\n", + " if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)\n", + " return obj\n", "\n", " @classmethod\n", " def __prepare__(cls, name, bases): return _TfmDict()" @@ -368,6 +375,44 @@ "test_eq_type(f3(2), 2)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Transforms can be created from class methods too:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A:\n", + " @classmethod\n", + " def create(cls, x:int): return x+1\n", + "test_eq(Transform(A.create)(1), 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#hide\n", + "# Test extension of a tfm method defined in the class\n", + "class A(Transform):\n", + " def encodes(self, x): return 'obj'\n", + "\n", + "@A\n", + "def encodes(self, x:int): return 'int'\n", + "\n", + "a = A()\n", + "test_eq(a.encodes(0), 'int')\n", + "test_eq(a.encodes(0.0), 'obj')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -845,7 +890,7 @@ "def encodes(self, x:str): return x+'hello'\n", "\n", "@B\n", - "def encodes(self, x)->None: return str(x)+'!'" + "def encodes(self, x): return str(x)+'!'" ] }, {