Skip to content

Commit

Permalink
bug[next]: Fix astype for local fields (#1761)
Browse files Browse the repository at this point in the history
Fix astype by calling  `_map` additionally and add corresponding tests

Co-authored-by: Edoardo Paone <[email protected]>
  • Loading branch information
SF-N and edopao authored Dec 4, 2024
1 parent a2551ac commit a936761
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 23 deletions.
6 changes: 1 addition & 5 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
if isinstance(t[0], ts.FieldType):
return im.cast_as_fieldop(str(new_type))(expr)
else:
assert isinstance(t[0], ts.ScalarType)
return im.call("cast_")(expr, str(new_type))
return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t)

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj, (node.args[0].type,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,41 @@ class PythonCodegen(codegen.TemplatedGenerator):
as in the case of field domain definitions, for sybolic array shape and map range.
"""

SymRef = as_fmt("{id}")
Literal = as_fmt("{value}")

def _visit_deref(self, node: gtir.FunCall) -> str:
assert len(node.args) == 1
if isinstance(node.args[0], gtir.SymRef):
return self.visit(node.args[0])
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")

def visit_FunCall(self, node: gtir.FunCall) -> str:
if cpm.is_call_to(node, "deref"):
return self._visit_deref(node)
def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str:
if isinstance(node.fun, gtir.Lambda):
# update the mapping from lambda parameters to corresponding argument expressions
lambda_args_map = args_map | {
p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)
}
return self.visit(node.fun.expr, args_map=lambda_args_map)
elif cpm.is_call_to(node, "deref"):
assert len(node.args) == 1
if not isinstance(node.args[0], gtir.SymRef):
# shift expressions are not expected in this visitor context
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")
return self.visit(node.args[0], args_map=args_map)
elif isinstance(node.fun, gtir.SymRef):
args = self.visit(node.args)
args = self.visit(node.args, args_map=args_map)
builtin_name = str(node.fun.id)
return format_builtin(builtin_name, *args)
raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).")

def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str:
symbol = str(node.id)
if symbol in args_map:
return self.visit(args_map[symbol], args_map=args_map)
return symbol


get_source = PythonCodegen.apply
"""
Specialized visit method for symbolic expressions.
def get_source(node: gtir.Node) -> str:
"""
Specialized visit method for symbolic expressions.
Returns:
A string containing the Python code corresponding to a symbolic expression
"""
The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions.
Returns:
A string containing the Python code corresponding to a symbolic expression
"""
return PythonCodegen.apply(node, args_map={})
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


def test_astype_int_local_field(unstructured_case):
@gtx.field_operator
def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]:
tmp = astype(a(E2V), int64)
return neighbor_sum(tmp, axis=E2VDim)

e2v_table = unstructured_case.offset_provider["E2V"].ndarray

cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0),
comparison=lambda a, b: np.all(a == b),
)


@pytest.mark.uses_tuple_returns
def test_astype_on_tuples(cartesian_case):
@gtx.field_operator
Expand Down
16 changes: 15 additions & 1 deletion tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,22 @@ def foo(a: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.cast_as_fieldop("int32")("a")

assert lowered_inlined.expr == reference


def test_astype_local_field():
def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]):
return astype(a, int32)

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a")

assert lowered.expr == reference


Expand All @@ -295,10 +308,11 @@ def foo(a: float64):

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.call("cast_")("a", "int32")

assert lowered.expr == reference
assert lowered_inlined.expr == reference


def test_astype_tuple():
Expand Down

0 comments on commit a936761

Please sign in to comment.