Skip to content

Commit

Permalink
Lintrule collections abc (#460)
Browse files Browse the repository at this point in the history
* Added lint rule to catch the deprecated ABC imports in collections

* Added an extra test case

* Updated comments

* Formatted file

* Fixed errors produced by make

* Added auto fix for cases of mixed ABC imports

* Fixed typos

* Updated generated docs

* Rename rule module to match class name, simplify rule message

* Update generated docs

* Added typing to variables

* Updated import search to use matcher

* Changed conditional in visit_ImportAlias to use matcher

* Added additional test case

* Switched libcst._nodes.statement imports to cst. imports

* Added new testcase

* Removed unnecessary matching

* Removed isinstance in favor of libcst.ensure_type

---------

Co-authored-by: Amethyst Reese <[email protected]>
  • Loading branch information
surge119 and amyreese authored Jun 5, 2024
1 parent cfc8829 commit 4cd8a6e
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 0 deletions.
42 changes: 42 additions & 0 deletions docs/guide/builtins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Built-in Rules
- :class:`CollapseIsinstanceChecks`
- :class:`ComparePrimitivesByEqual`
- :class:`CompareSingletonPrimitivesByIs`
- :class:`DeprecatedABCImport`
- :class:`DeprecatedUnittestAsserts`
- :class:`NoAssertTrueForComparisons`
- :class:`NoInheritFromObject`
Expand Down Expand Up @@ -241,6 +242,47 @@ Built-in Rules
# suggested fix
x is not False
.. class:: DeprecatedABCImport

Checks for the use of the deprecated collections ABC import. Since python 3.3,
the Collections Abstract Base Classes (ABC) have been moved to `collections.abc`.
These ABCs are import errors starting in Python 3.10.

.. attribute:: MESSAGE

ABCs must be imported from collections.abc

.. attribute:: AUTOFIX
:type: Yes

.. attribute:: PYTHON_VERSION
:type: '>= 3.3'

.. attribute:: VALID

.. code:: python
from collections.abc import Container
.. code:: python
from collections.abc import Container, Hashable
.. attribute:: INVALID

.. code:: python
from collections import Container
# suggested fix
from collections.abc import Container
.. code:: python
from collections import Container, Hashable
# suggested fix
from collections.abc import Container, Hashable
.. class:: DeprecatedUnittestAsserts

Discourages the use of various deprecated unittest.TestCase functions
Expand Down
308 changes: 308 additions & 0 deletions src/fixit/rules/deprecated_abc_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Union

import libcst as cst

import libcst.matchers as m

from fixit import Invalid, LintRule, Valid


# The ABCs that have been moved to `collections.abc`
ABCS = frozenset(
{
"AsyncGenerator",
"AsyncIterable",
"AsyncIterator",
"Awaitable",
"Buffer",
"ByteString",
"Callable",
"Collection",
"Container",
"Coroutine",
"Generator",
"Hashable",
"ItemsView",
"Iterable",
"Iterator",
"KeysView",
"Mapping",
"MappingView",
"MutableMapping",
"MutableSequence",
"MutableSet",
"Reversible",
"Sequence",
"Set",
"Sized",
"ValuesView",
}
)


class DeprecatedABCImport(LintRule):
"""
Checks for the use of the deprecated collections ABC import. Since python 3.3,
the Collections Abstract Base Classes (ABC) have been moved to `collections.abc`.
These ABCs are import errors starting in Python 3.10.
"""

MESSAGE = "ABCs must be imported from collections.abc"
PYTHON_VERSION = ">= 3.3"

VALID = [
Valid("from collections.abc import Container"),
Valid("from collections.abc import Container, Hashable"),
Valid("from collections.abc import (Container, Hashable)"),
Valid("from collections import defaultdict"),
Valid("from collections import abc"),
Valid("import collections"),
Valid("import collections.abc"),
Valid("import collections.abc.Container"),
Valid(
"""
class MyTest(collections.Something):
def test(self):
pass
"""
),
]
INVALID = [
Invalid(
"from collections import Container",
expected_replacement="from collections.abc import Container",
),
Invalid(
"from collections import Container, Hashable",
expected_replacement="from collections.abc import Container, Hashable",
),
Invalid(
"from collections import (Container, Hashable)",
expected_replacement="from collections.abc import (Container, Hashable)",
),
Invalid(
"import collections.Container",
expected_replacement="import collections.abc.Container",
),
Invalid(
"import collections.Container as cont",
expected_replacement="import collections.abc.Container as cont",
),
Invalid(
"from collections import defaultdict, Container",
expected_replacement="from collections import defaultdict\nfrom collections.abc import Container",
),
Invalid(
"from collections import defaultdict\nfrom collections import Container",
expected_replacement="from collections import defaultdict\nfrom collections.abc import Container",
),
Invalid(
"""
class MyTest(collections.Container):
def test(self):
pass
""",
expected_replacement="""
class MyTest(collections.abc.Container):
def test(self):
pass
""",
),
]

def __init__(self) -> None:
super().__init__()
# If the module needs to updated
self.update_module: bool = False
# The original imports
self.imports_names: List[str] = []

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
"""
This catches the `from collections import <ABC>` cases
"""
# Get imports in this statement
import_names = (
[name.name.value for name in node.names]
if type(node.names) is tuple
else []
)
# Filter the imports for ABC imports
import_names_in_abc = [name in ABCS for name in import_names]
if (
node.module
and node.module.value == "collections"
and any(import_names_in_abc)
):
# Replacing the case where there are ABCs mixed with non-ABCs requires
# splitting a single import statement into two separate imports. This
# cannot be achieved in this method and is offloaded to leaving the module.
if not all(import_names_in_abc):
# We set this variable which triggers the `self.report` to be called
# in `leave_Module`. We report in the `leave_Module`
# so that we can add an additional `SimpleStatementLine` for the new
# import
self.update_module = True
self.imports_names = import_names
else:
self.report(
node,
replacement=node.with_changes(
module=cst.Attribute(
value=cst.Name(value="collections"),
attr=cst.Name(value="abc"),
)
),
)

def get_import_from(
self, node: Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]
) -> Optional[cst.ImportFrom]:
"""
Iterate over a Statement Sequence and return a Statement if it is a
`cst.ImportFrom` statement.
"""
imp = m.findall(
node,
m.ImportFrom(
module=m.Name("collections"),
names=m.OneOf(
[m.ImportAlias(name=m.Name(n)) for n in self.imports_names]
),
),
)
return imp[0] if len(imp) > 0 and isinstance(imp[0], cst.ImportFrom) else None

def leave_Module(self, node: cst.Module) -> None:
"""
While leaving the module, check if we need to split up imports.
"""
if self.update_module:
# Filter the ABCs and non-ABCs
abcs: List[str] = []
non_abcs: List[str] = []
for name in self.imports_names:
(non_abcs, abcs)[name in ABCS].append(name)

node_body = list(node.body)

# Iterate over the module to find bad imports
for idx, statement in enumerate(node_body):
# Find if the statement is the one we are searching for
import_statement = self.get_import_from(statement)
if import_statement:
# Remove the original import statement
node_body.remove(statement)
# Add the non ABC imports
node_body.insert(
idx,
cst.SimpleStatementLine(
body=(
cst.ImportFrom(
module=cst.Name(value="collections"),
names=[
cst.ImportAlias(name=cst.Name(value=imp))
for imp in non_abcs
],
),
)
),
)
# Add the ABC imports
node_body.insert(
idx + 1,
cst.SimpleStatementLine(
body=(
cst.ImportFrom(
module=cst.Attribute(
value=cst.Name(value="collections"),
attr=cst.Name(value="abc"),
),
names=[
cst.ImportAlias(name=cst.Name(value=imp))
for imp in abcs
],
),
)
),
)

self.report(node, replacement=node.with_changes(body=node_body))

def visit_ImportAlias(self, node: cst.ImportAlias) -> None:
"""
This catches the `import collections.<ABC>` cases.
"""
if m.matches(
node,
m.ImportAlias(
name=m.Attribute(
value=m.Name("collections"),
attr=m.OneOf(*[m.Name(abc) for abc in ABCS]),
)
),
):
self.report(
node,
replacement=node.with_changes(
name=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="collections"),
attr=cst.Name(value="abc"),
),
attr=cst.ensure_type(node.name, cst.Attribute).attr,
)
),
)

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Iterate over inherited Classes and search for `collections.<ABC>`
for base in node.bases:
if m.matches(
base,
m.Arg(
value=m.Attribute(
value=m.Name("collections"),
attr=m.OneOf(*[m.Name(abc) for abc in ABCS]),
)
),
):
# Report + replace `collections.<ABC>` with `collections.abc.<ABC>`
# while keeping the remaining classes.
self.report(
node,
replacement=node.with_changes(
bases=[
(
cst.Arg(
value=cst.Attribute(
value=cst.Attribute(
value=cst.Name("collections"),
attr=cst.Name("abc"),
),
attr=base.value.attr,
),
)
if m.matches(
base,
m.Arg(
value=m.Attribute(
value=m.Name("collections"),
attr=m.OneOf(
*[m.Name(abc) for abc in ABCS]
),
)
),
)
and isinstance(base.value, cst.Attribute)
else base
)
for base in node.bases
]
),
)

0 comments on commit 4cd8a6e

Please sign in to comment.