diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 2864c7f727..a4e111e785 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall: return _impl +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 60dcd8ddc9..94ab3a6f76 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,7 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain: the corresponding lower and upper bounds. The returned lower bound is inclusive, the upper bound is exclusive: [lower_bound, upper_bound[ """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - lower_bound, upper_bound = ( - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ) - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, lower_bound, upper_bound)) + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") return domain @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData: return result_temps +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, lower_bound, upper_bound = domain[0] + dim_index = dace_gtir_utils.get_map_variable(dim) + + field_dims, field_offset, field_shape = _get_field_layout(domain) + field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) + + output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) + output_node = state.add_access(output) + + sdfg_builder.add_mapped_tasklet( + "index", + state, + map_ranges={ + dim_index: f"{lower_bound}:{upper_bound}", + }, + inputs={}, + code=f"__val = {dim_index}", + outputs={ + "__val": dace.Memlet( + data=output_node.data, + subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), + ) + }, + input_nodes={}, + output_nodes={output_node.data: output_node}, + external_edges=True, + ) + + return FieldopData(output_node, field_type, field_offset) + + def _get_data_nodes( sdfg: dace.SDFG, state: dace.SDFGState, @@ -777,6 +833,7 @@ def translate_symbol_ref( translate_as_fieldop, translate_broadcast_scalar, translate_if, + translate_index, translate_literal, translate_make_tuple, translate_tuple_get, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index f15287e64c..6b5e164458 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -568,6 +568,8 @@ def visit_FunCall( # use specialized dataflow builder classes for each builtin function if cpm.is_call_to(node, "if_"): return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) elif cpm.is_call_to(node, "make_tuple"): return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) elif cpm.is_call_to(node, "tuple_get"): diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 01fd18897d..349d3e9f70 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ - (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index f5191fbaaa..c7466b853f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,15 +12,15 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools import numpy as np import pytest -from gt4py.next import common as gtx_common, constructors +from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -1973,3 +1973,49 @@ def test_gtir_if_values(): sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_cast", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.empty(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref)