Skip to content

Commit

Permalink
🐛 Fix yield_from_generator being too broad
Browse files Browse the repository at this point in the history
  • Loading branch information
foosel committed Feb 24, 2021
1 parent f0f297e commit 48422eb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 8 deletions.
31 changes: 23 additions & 8 deletions octoprint_codemods/yield_from_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, cast

import libcst as cst
import libcst.matchers as m
Expand All @@ -20,14 +20,29 @@ def leave_For(
self, original_node: cst.For, updated_node: cst.For
) -> Union[cst.For, cst.SimpleStatementLine]:
if m.matches(
updated_node.body,
m.IndentedBlock(body=[m.SimpleStatementLine(body=[m.Expr(value=m.Yield())])]),
updated_node,
m.For(
target=m.Name(),
body=m.IndentedBlock(
body=[m.SimpleStatementLine(body=[m.Expr(value=m.Yield(m.Name()))])]
),
),
):
self._report_node(original_node)
self.count += 1
updated_node = cst.SimpleStatementLine(
body=[cst.Expr(value=cst.Yield(value=cst.From(item=updated_node.iter)))]
)
target = updated_node.target.value
block = cast(cst.IndentedBlock, updated_node.body)
simple_stmt = cast(cst.SimpleStatementLine, block.body[0])
expr_stmt = cast(cst.Expr, simple_stmt.body[0])
yield_stmt = cast(cst.Yield, expr_stmt.value)
yielded = cast(cst.Name, yield_stmt.value).value

if target == yielded:
self._report_node(original_node)
self.count += 1
updated_node = cst.SimpleStatementLine(
body=[
cst.Expr(value=cst.Yield(value=cst.From(item=updated_node.iter)))
]
)
return updated_node


Expand Down
20 changes: 20 additions & 0 deletions tests/expected/yield_from_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
data = (x for x in range(10))


def fnord(x):
return x


def a():
yield from data

Expand All @@ -11,3 +15,19 @@ def b():

def c():
yield from data


def d():
for x in data:
yield fnord(x)


def e():
for x in data:
yield x, True


def f():
l = "aaaaaaaaaaaaaaaaaaa"
for x in range(3):
yield l[x:]
20 changes: 20 additions & 0 deletions tests/fixtures/yield_from_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
data = (x for x in range(10))


def fnord(x):
return x


def a():
for entry in data:
yield entry
Expand All @@ -13,3 +17,19 @@ def b():

def c():
yield from data


def d():
for x in data:
yield fnord(x)


def e():
for x in data:
yield x, True


def f():
l = "aaaaaaaaaaaaaaaaaaa"
for x in range(3):
yield l[x:]

0 comments on commit 48422eb

Please sign in to comment.