From 47c6950a4c141532cd196491bf9c9a696d7979a8 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Wed, 22 Sep 2021 11:12:52 +0100 Subject: [PATCH] Fix for star imports Fixes #70. --- HISTORY.rst | 4 ++++ src/django_upgrade/ast.py | 5 +++-- src/django_upgrade/data.py | 4 +++- .../fixers/compatibility_imports_1_11.py | 6 +++--- .../fixers/compatibility_imports_1_9.py | 6 +++--- src/django_upgrade/fixers/django_urls.py | 6 +++--- src/django_upgrade/fixers/jsonfield.py | 6 +++--- src/django_upgrade/fixers/null_boolean_field.py | 4 ++-- .../fixers/postgres_float_range_field.py | 6 +++--- src/django_upgrade/fixers/queryset_paginator.py | 4 ++-- .../fixers/timezone_fixedoffset.py | 6 +++--- src/django_upgrade/fixers/utils_encoding.py | 4 ++-- src/django_upgrade/fixers/utils_http.py | 4 ++-- src/django_upgrade/fixers/utils_translation.py | 6 +++--- tests/fixers/test_compatibility_imports_1_9.py | 16 ++++++++++++++++ 15 files changed, 55 insertions(+), 32 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 2572cc112..865c2b0f3 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,10 @@ History ======= +* Fix import fixers to not crash on star imports (``from foo import *``). + + Thanks to Mikhail for the report in `Issue #70 `__. + 1.3.0 (2021-09-22) ------------------ diff --git a/src/django_upgrade/ast.py b/src/django_upgrade/ast.py index 913f78182..afdf763b5 100644 --- a/src/django_upgrade/ast.py +++ b/src/django_upgrade/ast.py @@ -16,5 +16,6 @@ def ast_start_offset(node: Union[ast.expr, ast.keyword, ast.stmt]) -> Offset: return Offset(node.lineno, node.col_offset) -# def ast_end_offset(node: Union[ast.expr, ast.keyword, ast.stmt]) -> Offset: -# return Offset(node.end_lineno, node.end_col_offset) +def is_rewritable_import_from(node: ast.ImportFrom) -> bool: + # Not relative import or import * + return node.level == 0 and not (len(node.names) == 1 and node.names[0].name == "*") diff --git a/src/django_upgrade/data.py b/src/django_upgrade/data.py index 8704d7a36..ff1f64493 100644 --- a/src/django_upgrade/data.py +++ b/src/django_upgrade/data.py @@ -84,7 +84,9 @@ def visit( ) ): state.from_imports[node.module].update( - name.name for name in node.names if not name.asname + name.name + for name in node.names + if name.asname is None and name.name != "*" ) for name in reversed(node._fields): diff --git a/src/django_upgrade/fixers/compatibility_imports_1_11.py b/src/django_upgrade/fixers/compatibility_imports_1_11.py index b22f6a92b..e8391c5ee 100644 --- a/src/django_upgrade/fixers/compatibility_imports_1_11.py +++ b/src/django_upgrade/fixers/compatibility_imports_1_11.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import update_import_modules @@ -40,8 +40,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module in REWRITES + node.module in REWRITES + and is_rewritable_import_from(node) and any(alias.name in REWRITES[node.module] for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/src/django_upgrade/fixers/compatibility_imports_1_9.py b/src/django_upgrade/fixers/compatibility_imports_1_9.py index b4fa746ee..1a4c8e2b7 100644 --- a/src/django_upgrade/fixers/compatibility_imports_1_9.py +++ b/src/django_upgrade/fixers/compatibility_imports_1_9.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import update_import_modules @@ -32,8 +32,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module in REWRITES + node.module in REWRITES + and is_rewritable_import_from(node) and any(alias.name in REWRITES[node.module] for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/src/django_upgrade/fixers/django_urls.py b/src/django_upgrade/fixers/django_urls.py index 4e6b18e1f..ebe68ed7a 100644 --- a/src/django_upgrade/fixers/django_urls.py +++ b/src/django_upgrade/fixers/django_urls.py @@ -10,7 +10,7 @@ from tokenize_rt import Offset, Token -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.compat import str_removeprefix, str_removesuffix from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import ( @@ -35,8 +35,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module == "django.conf.urls" + node.module == "django.conf.urls" + and is_rewritable_import_from(node) and any(alias.name in ("include", "url") for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/src/django_upgrade/fixers/jsonfield.py b/src/django_upgrade/fixers/jsonfield.py index 090b5f06b..a378d3af6 100644 --- a/src/django_upgrade/fixers/jsonfield.py +++ b/src/django_upgrade/fixers/jsonfield.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import update_import_modules @@ -44,8 +44,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module in REWRITES + node.module in REWRITES + and is_rewritable_import_from(node) and any(alias.name in REWRITES[node.module] for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/src/django_upgrade/fixers/null_boolean_field.py b/src/django_upgrade/fixers/null_boolean_field.py index 62d335744..54a0f5c47 100644 --- a/src/django_upgrade/fixers/null_boolean_field.py +++ b/src/django_upgrade/fixers/null_boolean_field.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset, Token -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import ( CODE, @@ -31,7 +31,7 @@ def visit_ImportFrom( node: ast.ImportFrom, parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: - if node.level == 0 and node.module == "django.db.models": + if is_rewritable_import_from(node) and node.module == "django.db.models": yield ast_start_offset(node), partial( update_import_names, node=node, diff --git a/src/django_upgrade/fixers/postgres_float_range_field.py b/src/django_upgrade/fixers/postgres_float_range_field.py index 0f376701e..4f14654d6 100644 --- a/src/django_upgrade/fixers/postgres_float_range_field.py +++ b/src/django_upgrade/fixers/postgres_float_range_field.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import find_and_replace_name, update_import_names @@ -35,8 +35,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module in MODULES + node.module in MODULES + and is_rewritable_import_from(node) and any(alias.name in NAME_MAP for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/src/django_upgrade/fixers/queryset_paginator.py b/src/django_upgrade/fixers/queryset_paginator.py index 8b1542886..0dd519a06 100644 --- a/src/django_upgrade/fixers/queryset_paginator.py +++ b/src/django_upgrade/fixers/queryset_paginator.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import find_and_replace_name, update_import_names @@ -29,7 +29,7 @@ def visit_ImportFrom( node: ast.ImportFrom, parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: - if node.level == 0 and node.module == MODULE: + if node.module == MODULE and is_rewritable_import_from(node): yield ast_start_offset(node), partial( update_import_names, node=node, name_map=NAMES ) diff --git a/src/django_upgrade/fixers/timezone_fixedoffset.py b/src/django_upgrade/fixers/timezone_fixedoffset.py index 26895768f..48ea5f348 100644 --- a/src/django_upgrade/fixers/timezone_fixedoffset.py +++ b/src/django_upgrade/fixers/timezone_fixedoffset.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset, Token -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import ( OP, @@ -36,8 +36,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module == MODULE + node.module == MODULE + and is_rewritable_import_from(node) and any(alias.name == OLD_NAME for alias in node.names) ): yield ast_start_offset(node), partial(fix_import_from, node=node) diff --git a/src/django_upgrade/fixers/utils_encoding.py b/src/django_upgrade/fixers/utils_encoding.py index d3dc53a2a..910390ad3 100644 --- a/src/django_upgrade/fixers/utils_encoding.py +++ b/src/django_upgrade/fixers/utils_encoding.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import find_and_replace_name, update_import_names @@ -30,7 +30,7 @@ def visit_ImportFrom( node: ast.ImportFrom, parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: - if node.level == 0 and node.module == MODULE: + if node.module == MODULE and is_rewritable_import_from(node): yield ast_start_offset(node), partial( update_import_names, node=node, name_map=NAMES ) diff --git a/src/django_upgrade/fixers/utils_http.py b/src/django_upgrade/fixers/utils_http.py index e64c0b8fd..67569f3c8 100644 --- a/src/django_upgrade/fixers/utils_http.py +++ b/src/django_upgrade/fixers/utils_http.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset, Token -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import ( extract_indent, @@ -40,7 +40,7 @@ def visit_ImportFrom( node: ast.ImportFrom, parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: - if node.level == 0 and node.module == MODULE: + if node.module == MODULE and is_rewritable_import_from(node): name_map = {} urllib_names = {} for alias in node.names: diff --git a/src/django_upgrade/fixers/utils_translation.py b/src/django_upgrade/fixers/utils_translation.py index 71c62d952..ae4a34d6c 100644 --- a/src/django_upgrade/fixers/utils_translation.py +++ b/src/django_upgrade/fixers/utils_translation.py @@ -8,7 +8,7 @@ from tokenize_rt import Offset -from django_upgrade.ast import ast_start_offset +from django_upgrade.ast import ast_start_offset, is_rewritable_import_from from django_upgrade.data import Fixer, State, TokenFunc from django_upgrade.tokens import find_and_replace_name, update_import_names @@ -34,8 +34,8 @@ def visit_ImportFrom( parent: ast.AST, ) -> Iterable[Tuple[Offset, TokenFunc]]: if ( - node.level == 0 - and node.module == MODULE + node.module == MODULE + and is_rewritable_import_from(node) and any(alias.name in NAME_MAP for alias in node.names) ): yield ast_start_offset(node), partial( diff --git a/tests/fixers/test_compatibility_imports_1_9.py b/tests/fixers/test_compatibility_imports_1_9.py index 214d9b836..d66cec623 100644 --- a/tests/fixers/test_compatibility_imports_1_9.py +++ b/tests/fixers/test_compatibility_imports_1_9.py @@ -25,6 +25,22 @@ def test_unrecognized_import_format(): ) +def test_import_star(): + check_transformed( + """\ + from django.forms.forms import * + + pretty_name() + """, + """\ + from django.forms.forms import * + + pretty_name() + """, + settings, + ) + + def test_name_imported(): check_transformed( """\