diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT006.py b/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT006.py index f444a835efbe5..e25c70a7d7894 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT006.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pytest_style/PT006.py @@ -69,3 +69,10 @@ def test_implicit_str_concat_with_multi_parens(param1, param2, param3): @pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) def test_csv_with_parens(param1, param2): ... + + +parametrize = pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) + +@parametrize +def test_csv_with_parens_decorator(param1, param2): + ... diff --git a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs index df623fdc40d6d..3e73e366ff0d7 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/expression.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/expression.rs @@ -856,6 +856,13 @@ pub(crate) fn expression(expr: &Expr, checker: &mut Checker) { checker.diagnostics.push(diagnostic); } } + if checker.any_enabled(&[ + Rule::PytestParametrizeNamesWrongType, + Rule::PytestParametrizeValuesWrongType, + Rule::PytestDuplicateParametrizeTestCases, + ]) { + flake8_pytest_style::rules::parametrize(checker, call); + } if checker.enabled(Rule::PytestUnittestAssertion) { if let Some(diagnostic) = flake8_pytest_style::rules::unittest_assertion( checker, expr, func, args, keywords, diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index 48d8734fbac7c..d5a01eb3ea9f9 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -309,13 +309,6 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { body, ); } - if checker.any_enabled(&[ - Rule::PytestParametrizeNamesWrongType, - Rule::PytestParametrizeValuesWrongType, - Rule::PytestDuplicateParametrizeTestCases, - ]) { - flake8_pytest_style::rules::parametrize(checker, decorator_list); - } if checker.any_enabled(&[ Rule::PytestIncorrectMarkParenthesesStyle, Rule::PytestUseFixturesWithoutParameters, diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/helpers.rs b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/helpers.rs index edbe837e2c1dd..d09474023eb10 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/helpers.rs +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/helpers.rs @@ -2,7 +2,7 @@ use std::fmt; use ruff_python_ast::helpers::map_callable; use ruff_python_ast::name::UnqualifiedName; -use ruff_python_ast::{self as ast, Decorator, Expr, Keyword}; +use ruff_python_ast::{self as ast, Decorator, Expr, ExprCall, Keyword}; use ruff_python_semantic::SemanticModel; use ruff_python_trivia::PythonWhitespace; @@ -38,9 +38,9 @@ pub(super) fn is_pytest_yield_fixture(decorator: &Decorator, semantic: &Semantic }) } -pub(super) fn is_pytest_parametrize(decorator: &Decorator, semantic: &SemanticModel) -> bool { +pub(super) fn is_pytest_parametrize(call: &ExprCall, semantic: &SemanticModel) -> bool { semantic - .resolve_qualified_name(map_callable(&decorator.expression)) + .resolve_qualified_name(&call.func) .is_some_and(|qualified_name| { matches!(qualified_name.segments(), ["pytest", "mark", "parametrize"]) }) diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/parametrize.rs b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/parametrize.rs index 3b2a923b45e3d..38c8542ebc29c 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/rules/parametrize.rs +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/rules/parametrize.rs @@ -5,7 +5,7 @@ use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::comparable::ComparableExpr; use ruff_python_ast::parenthesize::parenthesized_range; use ruff_python_ast::AstNode; -use ruff_python_ast::{self as ast, Arguments, Decorator, Expr, ExprContext}; +use ruff_python_ast::{self as ast, Expr, ExprCall, ExprContext}; use ruff_python_codegen::Generator; use ruff_python_trivia::CommentRanges; use ruff_python_trivia::{SimpleTokenKind, SimpleTokenizer}; @@ -317,23 +317,21 @@ fn elts_to_csv(elts: &[Expr], generator: Generator) -> Option { /// /// This method assumes that the first argument is a string. fn get_parametrize_name_range( - decorator: &Decorator, + call: &ExprCall, expr: &Expr, comment_ranges: &CommentRanges, source: &str, ) -> Option { - decorator.expression.as_call_expr().and_then(|call| { - parenthesized_range( - expr.into(), - call.arguments.as_any_node_ref(), - comment_ranges, - source, - ) - }) + parenthesized_range( + expr.into(), + call.arguments.as_any_node_ref(), + comment_ranges, + source, + ) } /// PT006 -fn check_names(checker: &mut Checker, decorator: &Decorator, expr: &Expr) { +fn check_names(checker: &mut Checker, call: &ExprCall, expr: &Expr) { let names_type = checker.settings.flake8_pytest_style.parametrize_names_type; match expr { @@ -343,7 +341,7 @@ fn check_names(checker: &mut Checker, decorator: &Decorator, expr: &Expr) { match names_type { types::ParametrizeNameType::Tuple => { let name_range = get_parametrize_name_range( - decorator, + call, expr, checker.comment_ranges(), checker.locator().contents(), @@ -378,7 +376,7 @@ fn check_names(checker: &mut Checker, decorator: &Decorator, expr: &Expr) { } types::ParametrizeNameType::List => { let name_range = get_parametrize_name_range( - decorator, + call, expr, checker.comment_ranges(), checker.locator().contents(), @@ -797,30 +795,26 @@ fn handle_value_rows( } } -pub(crate) fn parametrize(checker: &mut Checker, decorators: &[Decorator]) { - for decorator in decorators { - if is_pytest_parametrize(decorator, checker.semantic()) { - if let Expr::Call(ast::ExprCall { - arguments: Arguments { args, .. }, - .. - }) = &decorator.expression - { - if checker.enabled(Rule::PytestParametrizeNamesWrongType) { - if let [names, ..] = &**args { - check_names(checker, decorator, names); - } - } - if checker.enabled(Rule::PytestParametrizeValuesWrongType) { - if let [names, values, ..] = &**args { - check_values(checker, names, values); - } - } - if checker.enabled(Rule::PytestDuplicateParametrizeTestCases) { - if let [_, values, ..] = &**args { - check_duplicates(checker, values); - } - } - } +pub(crate) fn parametrize(checker: &mut Checker, call: &ExprCall) { + if !is_pytest_parametrize(call, checker.semantic()) { + return; + } + + if checker.enabled(Rule::PytestParametrizeNamesWrongType) { + if let Some(names) = call.arguments.find_argument("argnames", 0) { + check_names(checker, call, names); + } + } + if checker.enabled(Rule::PytestParametrizeValuesWrongType) { + let names = call.arguments.find_argument("argnames", 0); + let values = call.arguments.find_argument("argvalues", 1); + if let (Some(names), Some(values)) = (names, values) { + check_values(checker, names, values); + } + } + if checker.enabled(Rule::PytestDuplicateParametrizeTestCases) { + if let Some(values) = call.arguments.find_argument("argvalues", 1) { + check_duplicates(checker, values); } } } diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_default.snap b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_default.snap index ddd4e40f51797..1cbbe53a90cfd 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_default.snap +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_default.snap @@ -1,6 +1,5 @@ --- source: crates/ruff_linter/src/rules/flake8_pytest_style/mod.rs -snapshot_kind: text --- PT006.py:9:26: PT006 [*] Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple` | @@ -228,3 +227,23 @@ PT006.py:69:26: PT006 [*] Wrong type passed to first argument of `@pytest.mark.p 69 |+@pytest.mark.parametrize(("param1", "param2"), [(1, 2), (3, 4)]) 70 70 | def test_csv_with_parens(param1, param2): 71 71 | ... +72 72 | + +PT006.py:74:39: PT006 [*] Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `tuple` + | +74 | parametrize = pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) + | ^^^^^^^^^^^^^^^^^ PT006 +75 | +76 | @parametrize + | + = help: Use a `tuple` for the first argument + +ℹ Unsafe fix +71 71 | ... +72 72 | +73 73 | +74 |-parametrize = pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) + 74 |+parametrize = pytest.mark.parametrize(("param1", "param2"), [(1, 2), (3, 4)]) +75 75 | +76 76 | @parametrize +77 77 | def test_csv_with_parens_decorator(param1, param2): diff --git a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_list.snap b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_list.snap index ed5e5b11cf60a..e05e60cfa5041 100644 --- a/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_list.snap +++ b/crates/ruff_linter/src/rules/flake8_pytest_style/snapshots/ruff_linter__rules__flake8_pytest_style__tests__PT006_list.snap @@ -1,6 +1,5 @@ --- source: crates/ruff_linter/src/rules/flake8_pytest_style/mod.rs -snapshot_kind: text --- PT006.py:9:26: PT006 [*] Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `list` | @@ -190,3 +189,23 @@ PT006.py:69:26: PT006 [*] Wrong type passed to first argument of `@pytest.mark.p 69 |+@pytest.mark.parametrize(["param1", "param2"], [(1, 2), (3, 4)]) 70 70 | def test_csv_with_parens(param1, param2): 71 71 | ... +72 72 | + +PT006.py:74:39: PT006 [*] Wrong type passed to first argument of `@pytest.mark.parametrize`; expected `list` + | +74 | parametrize = pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) + | ^^^^^^^^^^^^^^^^^ PT006 +75 | +76 | @parametrize + | + = help: Use a `list` for the first argument + +ℹ Unsafe fix +71 71 | ... +72 72 | +73 73 | +74 |-parametrize = pytest.mark.parametrize(("param1,param2"), [(1, 2), (3, 4)]) + 74 |+parametrize = pytest.mark.parametrize(["param1", "param2"], [(1, 2), (3, 4)]) +75 75 | +76 76 | @parametrize +77 77 | def test_csv_with_parens_decorator(param1, param2):