Skip to content

Commit

Permalink
fix[next]: Fix annex & type preservation in inline_lambdas (#1760)
Browse files Browse the repository at this point in the history
Co-authored-by: SF-N <[email protected]>
  • Loading branch information
tehrengruber and SF-N authored Dec 1, 2024
1 parent d581060 commit a26d91f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
11 changes: 5 additions & 6 deletions src/gt4py/next/iterator/transforms/inline_lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def new_name(name):

if all(eligible_params):
new_expr.location = node.location
return new_expr
else:
new_expr = ir.FunCall(
fun=ir.Lambda(
Expand All @@ -111,11 +110,11 @@ def new_name(name):
args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible],
location=node.location,
)
for attr in ("type", "recorded_shifts", "domain"):
if hasattr(node.annex, attr):
setattr(new_expr.annex, attr, getattr(node.annex, attr))
itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True)
return new_expr
for attr in ("type", "recorded_shifts", "domain"):
if hasattr(node.annex, attr):
setattr(new_expr.annex, attr, getattr(node.annex, attr))
itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True)
return new_expr


@dataclasses.dataclass
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait
from gt4py.next.iterator import ir
from gt4py.next.iterator.type_system import inference as type_inference


class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator):
Expand Down Expand Up @@ -46,7 +47,9 @@ def visit_SymRef(
self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None
):
if active and node.id in active:
return ir.SymRef(id=name_map.get(node.id, node.id))
new_ref = ir.SymRef(id=name_map.get(node.id, node.id))
type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True)
return new_ref
return node

def generic_visit( # type: ignore[override]
Expand Down
7 changes: 5 additions & 2 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None:
node.type = type_


def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None:
def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None:
"""
Copy type from one node to another.
This function mainly exists for readability reasons.
"""
assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec)
_set_node_type(to, from_.type) # type: ignore[arg-type]
if from_.type is None:
assert allow_untyped
return
_set_node_type(to, from_.type)


def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,10 @@ def test_inline_lambda_args():
)
inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True)
assert inlined == expected


def test_type_preservation():
testee = im.let("a", "b")("a")
testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32)
inlined = InlineLambdas.apply(testee)
assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32)

0 comments on commit a26d91f

Please sign in to comment.