diff --git a/src/fixit/rule.py b/src/fixit/rule.py index 4138c220..a68a5c02 100644 --- a/src/fixit/rule.py +++ b/src/fixit/rule.py @@ -148,11 +148,16 @@ def node_comments(self, node: CSTNode) -> Generator[str, None, None]: # comments at the start of the file are part of the module header rather than # part of the first statement's leading_lines, so we need to look there in case # the reported node is part of the first statement. - parent = self.get_metadata(ParentNodeProvider, node) - if isinstance(parent, Module) and parent.body and parent.body[0] == node: - for line in parent.header: + if isinstance(node, Module): + for line in node.header: if line.comment: yield line.comment.value + else: + parent = self.get_metadata(ParentNodeProvider, node) + if isinstance(parent, Module) and parent.body and parent.body[0] == node: + for line in parent.header: + if line.comment: + yield line.comment.value def ignore_lint(self, node: CSTNode) -> bool: """ diff --git a/src/fixit/tests/rule.py b/src/fixit/tests/rule.py index 2c8ae6d3..1e7c887e 100644 --- a/src/fixit/tests/rule.py +++ b/src/fixit/tests/rule.py @@ -66,6 +66,10 @@ def test_timing_hook(self) -> None: class ExerciseReportRule(LintRule): MESSAGE = "message on the class" + def visit_Module(self, node: cst.Module) -> bool: + self.report(node, "Module") + return False + def visit_ClassDef(self, node: cst.ClassDef) -> bool: self.report(node, "class def") return False @@ -89,39 +93,99 @@ def setUp(self) -> None: def test_pass_happy(self) -> None: runner = LintRunner(Path("fake.py"), b"pass") - (violation,) = list(runner.collect_violations(self.rules, Config())) - self.assertIsInstance(violation.node, cst.Pass) + + # Since the "pass" code is part of a Module and ExerciseReportRule() visit's the Module + # 2 violations are collected. + module_violation, pass_violation = list( + runner.collect_violations(self.rules, Config()) + ) + self.assertIsInstance(module_violation.node, cst.Module) + self.assertIsInstance(pass_violation.node, cst.Pass) + + self.assertEqual( + module_violation, + LintViolation( + "ExerciseReport", + CodeRange(start=CodePosition(1, 0), end=CodePosition(2, 0)), + "Module", + module_violation.node, + None, + ), + ) self.assertEqual( - violation, + pass_violation, LintViolation( "ExerciseReport", CodeRange(start=CodePosition(1, 0), end=CodePosition(1, 4)), "I pass", - violation.node, + pass_violation.node, None, ), ) def test_ellipsis_position_override(self) -> None: runner = LintRunner(Path("fake.py"), b"...") - (violation,) = list(runner.collect_violations(self.rules, Config())) - self.assertIsInstance(violation.node, cst.Ellipsis) + + # Since the "..." code is part of a Module and ExerciseReportRule() visit's the Module + # 2 violations are collected. + module_violation, ellipses_violation = list( + runner.collect_violations(self.rules, Config()) + ) + self.assertIsInstance(module_violation.node, cst.Module) + self.assertIsInstance(ellipses_violation.node, cst.Ellipsis) + self.assertEqual( - violation, + module_violation, + LintViolation( + "ExerciseReport", + CodeRange(start=CodePosition(1, 0), end=CodePosition(2, 0)), + "Module", + module_violation.node, + None, + ), + ) + self.assertEqual( + ellipses_violation, LintViolation( "ExerciseReport", CodeRange(start=CodePosition(1, 1), end=CodePosition(2, 0)), "I ellipse", - violation.node, + ellipses_violation.node, None, ), ) def test_del_no_message(self) -> None: runner = LintRunner(Path("fake.py"), b"del foo") + + # Since the "del foo" code is part of a Module and ExerciseReportRule() visit's the Module + # 2 violations are collected. violations = list(runner.collect_violations(self.rules, Config())) - self.assertEqual(len(violations), 1) - self.assertEqual(violations[0].message, "message on the class") + module_violation, del_violation = violations + self.assertEqual(len(violations), 2) + self.assertIsInstance(module_violation.node, cst.Module) + self.assertIsInstance(del_violation.node, cst.Del) + + self.assertEqual( + module_violation, + LintViolation( + "ExerciseReport", + CodeRange(start=CodePosition(1, 0), end=CodePosition(2, 0)), + "Module", + module_violation.node, + None, + ), + ) + self.assertEqual( + del_violation, + LintViolation( + "ExerciseReport", + CodeRange(start=CodePosition(1, 0), end=CodePosition(1, 7)), + "message on the class", + del_violation.node, + None, + ), + ) def test_ignore_lint(self) -> None: idx = 0 @@ -225,12 +289,24 @@ class Foo(object): ) if message and position: - self.assertEqual(len(violations), 1) + self.assertIn(len(violations), (1, 2)) + + # There's always going to be at least 1 violation (A module node violation). + # So, let's just assert it is a module then remove it to make the test simpler. + if len(violations) == 2: + self.assertIsInstance(violations[0].node, cst.Module) + violations.pop(0) + (violation,) = violations + self.assertEqual(violation.message, message) self.assertEqual(violation.range.start, CodePosition(*position)) else: + if len(violations) == 1: + self.assertIsInstance(violations[0].node, cst.Module) + violations.pop(0) + self.assertEqual( len(violations), 0, "Unexpected lint errors reported" )