Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug[next]: Fix astype for local fields #1761

Merged
merged 20 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id

def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall:
if isinstance(t[0], ts.FieldType):
return im.cast_as_fieldop(str(new_type))(expr)
else:
assert isinstance(t[0], ts.ScalarType)
return im.call("cast_")(expr, str(new_type))
return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t)

if not isinstance(node.type, ts.TupleType): # to keep the IR simpler
return create_cast(obj, (node.args[0].type,))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,41 @@ class PythonCodegen(codegen.TemplatedGenerator):
as in the case of field domain definitions, for sybolic array shape and map range.
"""

SymRef = as_fmt("{id}")
Literal = as_fmt("{value}")

def _visit_deref(self, node: gtir.FunCall) -> str:
assert len(node.args) == 1
if isinstance(node.args[0], gtir.SymRef):
return self.visit(node.args[0])
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")

def visit_FunCall(self, node: gtir.FunCall) -> str:
if cpm.is_call_to(node, "deref"):
return self._visit_deref(node)
def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str:
if isinstance(node.fun, gtir.Lambda):
# update the mapping from lambda parameters to corresponding argument expressions
lambda_args_map = args_map | {
p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True)
}
return self.visit(node.fun.expr, args_map=lambda_args_map)
elif cpm.is_call_to(node, "deref"):
assert len(node.args) == 1
if not isinstance(node.args[0], gtir.SymRef):
# shift expressions are not expected in this visitor context
raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.")
return self.visit(node.args[0], args_map=args_map)
elif isinstance(node.fun, gtir.SymRef):
args = self.visit(node.args)
args = self.visit(node.args, args_map=args_map)
builtin_name = str(node.fun.id)
return format_builtin(builtin_name, *args)
raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).")

def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str:
symbol = str(node.id)
if symbol in args_map:
return self.visit(args_map[symbol], args_map=args_map)
return symbol


get_source = PythonCodegen.apply
"""
Specialized visit method for symbolic expressions.
def get_source(node: gtir.Node) -> str:
"""
Specialized visit method for symbolic expressions.

Returns:
A string containing the Python code corresponding to a symbolic expression
"""
The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions.

Returns:
A string containing the Python code corresponding to a symbolic expression
"""
return PythonCodegen.apply(node, args_map={})
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:
)


def test_astype_int_local_field(unstructured_case):
@gtx.field_operator
def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]:
tmp = astype(a(E2V), int64)
return neighbor_sum(tmp, axis=E2VDim)

e2v_table = unstructured_case.offset_provider["E2V"].ndarray

cases.verify_with_default_data(
unstructured_case,
testee,
ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0),
comparison=lambda a, b: np.all(a == b),
)


@pytest.mark.uses_tuple_returns
def test_astype_on_tuples(cartesian_case):
@gtx.field_operator
Expand Down
16 changes: 15 additions & 1 deletion tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,22 @@ def foo(a: gtx.Field[[TDim], float64]):

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.cast_as_fieldop("int32")("a")

assert lowered_inlined.expr == reference


def test_astype_local_field():
def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]):
return astype(a, int32)

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)

reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a")

assert lowered.expr == reference


Expand All @@ -295,10 +308,11 @@ def foo(a: float64):

parsed = FieldOperatorParser.apply_to_function(foo)
lowered = FieldOperatorLowering.apply(parsed)
lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered)

reference = im.call("cast_")("a", "int32")

assert lowered.expr == reference
assert lowered_inlined.expr == reference


def test_astype_tuple():
Expand Down
Loading