diff --git a/docs/guide/builtins.rst b/docs/guide/builtins.rst index ea6f396d..6f2b1484 100644 --- a/docs/guide/builtins.rst +++ b/docs/guide/builtins.rst @@ -22,6 +22,7 @@ Built-in Rules - :class:`CollapseIsinstanceChecks` - :class:`ComparePrimitivesByEqual` - :class:`CompareSingletonPrimitivesByIs` +- :class:`DeprecatedABCImport` - :class:`DeprecatedUnittestAsserts` - :class:`NoAssertTrueForComparisons` - :class:`NoInheritFromObject` @@ -241,6 +242,47 @@ Built-in Rules # suggested fix x is not False +.. class:: DeprecatedABCImport + + Checks for the use of the deprecated collections ABC import. Since python 3.3, + the Collections Abstract Base Classes (ABC) have been moved to `collections.abc`. + These ABCs are import errors starting in Python 3.10. + + .. attribute:: MESSAGE + + ABCs must be imported from collections.abc + + .. attribute:: AUTOFIX + :type: Yes + + .. attribute:: PYTHON_VERSION + :type: '>= 3.3' + + .. attribute:: VALID + + .. code:: python + + from collections.abc import Container + .. code:: python + + from collections.abc import Container, Hashable + + .. attribute:: INVALID + + .. code:: python + + from collections import Container + + # suggested fix + from collections.abc import Container + + .. code:: python + + from collections import Container, Hashable + + # suggested fix + from collections.abc import Container, Hashable + .. class:: DeprecatedUnittestAsserts Discourages the use of various deprecated unittest.TestCase functions diff --git a/src/fixit/rules/deprecated_abc_import.py b/src/fixit/rules/deprecated_abc_import.py new file mode 100644 index 00000000..14257d5f --- /dev/null +++ b/src/fixit/rules/deprecated_abc_import.py @@ -0,0 +1,308 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Union + +import libcst as cst + +import libcst.matchers as m + +from fixit import Invalid, LintRule, Valid + + +# The ABCs that have been moved to `collections.abc` +ABCS = frozenset( + { + "AsyncGenerator", + "AsyncIterable", + "AsyncIterator", + "Awaitable", + "Buffer", + "ByteString", + "Callable", + "Collection", + "Container", + "Coroutine", + "Generator", + "Hashable", + "ItemsView", + "Iterable", + "Iterator", + "KeysView", + "Mapping", + "MappingView", + "MutableMapping", + "MutableSequence", + "MutableSet", + "Reversible", + "Sequence", + "Set", + "Sized", + "ValuesView", + } +) + + +class DeprecatedABCImport(LintRule): + """ + Checks for the use of the deprecated collections ABC import. Since python 3.3, + the Collections Abstract Base Classes (ABC) have been moved to `collections.abc`. + These ABCs are import errors starting in Python 3.10. + """ + + MESSAGE = "ABCs must be imported from collections.abc" + PYTHON_VERSION = ">= 3.3" + + VALID = [ + Valid("from collections.abc import Container"), + Valid("from collections.abc import Container, Hashable"), + Valid("from collections.abc import (Container, Hashable)"), + Valid("from collections import defaultdict"), + Valid("from collections import abc"), + Valid("import collections"), + Valid("import collections.abc"), + Valid("import collections.abc.Container"), + Valid( + """ + class MyTest(collections.Something): + def test(self): + pass + """ + ), + ] + INVALID = [ + Invalid( + "from collections import Container", + expected_replacement="from collections.abc import Container", + ), + Invalid( + "from collections import Container, Hashable", + expected_replacement="from collections.abc import Container, Hashable", + ), + Invalid( + "from collections import (Container, Hashable)", + expected_replacement="from collections.abc import (Container, Hashable)", + ), + Invalid( + "import collections.Container", + expected_replacement="import collections.abc.Container", + ), + Invalid( + "import collections.Container as cont", + expected_replacement="import collections.abc.Container as cont", + ), + Invalid( + "from collections import defaultdict, Container", + expected_replacement="from collections import defaultdict\nfrom collections.abc import Container", + ), + Invalid( + "from collections import defaultdict\nfrom collections import Container", + expected_replacement="from collections import defaultdict\nfrom collections.abc import Container", + ), + Invalid( + """ + class MyTest(collections.Container): + def test(self): + pass + """, + expected_replacement=""" + class MyTest(collections.abc.Container): + def test(self): + pass + """, + ), + ] + + def __init__(self) -> None: + super().__init__() + # If the module needs to updated + self.update_module: bool = False + # The original imports + self.imports_names: List[str] = [] + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """ + This catches the `from collections import ` cases + """ + # Get imports in this statement + import_names = ( + [name.name.value for name in node.names] + if type(node.names) is tuple + else [] + ) + # Filter the imports for ABC imports + import_names_in_abc = [name in ABCS for name in import_names] + if ( + node.module + and node.module.value == "collections" + and any(import_names_in_abc) + ): + # Replacing the case where there are ABCs mixed with non-ABCs requires + # splitting a single import statement into two separate imports. This + # cannot be achieved in this method and is offloaded to leaving the module. + if not all(import_names_in_abc): + # We set this variable which triggers the `self.report` to be called + # in `leave_Module`. We report in the `leave_Module` + # so that we can add an additional `SimpleStatementLine` for the new + # import + self.update_module = True + self.imports_names = import_names + else: + self.report( + node, + replacement=node.with_changes( + module=cst.Attribute( + value=cst.Name(value="collections"), + attr=cst.Name(value="abc"), + ) + ), + ) + + def get_import_from( + self, node: Union[cst.SimpleStatementLine, cst.BaseCompoundStatement] + ) -> Optional[cst.ImportFrom]: + """ + Iterate over a Statement Sequence and return a Statement if it is a + `cst.ImportFrom` statement. + """ + imp = m.findall( + node, + m.ImportFrom( + module=m.Name("collections"), + names=m.OneOf( + [m.ImportAlias(name=m.Name(n)) for n in self.imports_names] + ), + ), + ) + return imp[0] if len(imp) > 0 and isinstance(imp[0], cst.ImportFrom) else None + + def leave_Module(self, node: cst.Module) -> None: + """ + While leaving the module, check if we need to split up imports. + """ + if self.update_module: + # Filter the ABCs and non-ABCs + abcs: List[str] = [] + non_abcs: List[str] = [] + for name in self.imports_names: + (non_abcs, abcs)[name in ABCS].append(name) + + node_body = list(node.body) + + # Iterate over the module to find bad imports + for idx, statement in enumerate(node_body): + # Find if the statement is the one we are searching for + import_statement = self.get_import_from(statement) + if import_statement: + # Remove the original import statement + node_body.remove(statement) + # Add the non ABC imports + node_body.insert( + idx, + cst.SimpleStatementLine( + body=( + cst.ImportFrom( + module=cst.Name(value="collections"), + names=[ + cst.ImportAlias(name=cst.Name(value=imp)) + for imp in non_abcs + ], + ), + ) + ), + ) + # Add the ABC imports + node_body.insert( + idx + 1, + cst.SimpleStatementLine( + body=( + cst.ImportFrom( + module=cst.Attribute( + value=cst.Name(value="collections"), + attr=cst.Name(value="abc"), + ), + names=[ + cst.ImportAlias(name=cst.Name(value=imp)) + for imp in abcs + ], + ), + ) + ), + ) + + self.report(node, replacement=node.with_changes(body=node_body)) + + def visit_ImportAlias(self, node: cst.ImportAlias) -> None: + """ + This catches the `import collections.` cases. + """ + if m.matches( + node, + m.ImportAlias( + name=m.Attribute( + value=m.Name("collections"), + attr=m.OneOf(*[m.Name(abc) for abc in ABCS]), + ) + ), + ): + self.report( + node, + replacement=node.with_changes( + name=cst.Attribute( + value=cst.Attribute( + value=cst.Name(value="collections"), + attr=cst.Name(value="abc"), + ), + attr=cst.ensure_type(node.name, cst.Attribute).attr, + ) + ), + ) + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Iterate over inherited Classes and search for `collections.` + for base in node.bases: + if m.matches( + base, + m.Arg( + value=m.Attribute( + value=m.Name("collections"), + attr=m.OneOf(*[m.Name(abc) for abc in ABCS]), + ) + ), + ): + # Report + replace `collections.` with `collections.abc.` + # while keeping the remaining classes. + self.report( + node, + replacement=node.with_changes( + bases=[ + ( + cst.Arg( + value=cst.Attribute( + value=cst.Attribute( + value=cst.Name("collections"), + attr=cst.Name("abc"), + ), + attr=base.value.attr, + ), + ) + if m.matches( + base, + m.Arg( + value=m.Attribute( + value=m.Name("collections"), + attr=m.OneOf( + *[m.Name(abc) for abc in ABCS] + ), + ) + ), + ) + and isinstance(base.value, cst.Attribute) + else base + ) + for base in node.bases + ] + ), + )