From 9f3b0a7508ce5a2e66bae5c528f4a1b4d4194728 Mon Sep 17 00:00:00 2001 From: SF-N Date: Mon, 4 Nov 2024 13:00:59 +0100 Subject: [PATCH] feat[next]: Index builtin (#1699) Adds index builtin for embedded and gtfn backends. --- src/gt4py/next/iterator/builtins.py | 6 ++ src/gt4py/next/iterator/embedded.py | 9 ++- src/gt4py/next/iterator/ir.py | 8 ++- src/gt4py/next/iterator/pretty_parser.py | 8 +-- .../iterator/type_system/type_synthesizer.py | 8 +++ .../codegens/gtfn/codegen.py | 2 +- .../codegens/gtfn/gtfn_ir.py | 3 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 1 - .../runners/dace_iterator/__init__.py | 13 ++++- .../runners/dace_iterator/itir_to_sdfg.py | 15 ++++- tests/next_tests/definitions.py | 4 +- .../iterator_tests/test_program.py | 56 ++++++++++++++++++- 12 files changed, 117 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 264ac2685c..c8edc12331 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -22,6 +22,11 @@ def as_fieldop(*args): raise BackendNotSelectedError() +@builtin_dispatch +def index(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def deref(*args): raise BackendNotSelectedError() @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "unstructured_domain", "named_range", "as_fieldop", + "index", *MATH_BUILTINS, } diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 84dd9e3f72..6221c95522 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1204,12 +1204,14 @@ def premap( def restrict(self, item: common.AnyIndexSpec) -> Self: if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item): + assert len(item) == 1 assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) + # TODO(tehrengruber): Use a regular zero dimensional field instead. return self.__class__(self._dimension, r) - # TODO set a domain... + # TODO: set a domain... raise NotImplementedError() __call__ = premap @@ -1793,6 +1795,11 @@ def impl(*args): return impl +@builtins.index.register(EMBEDDED) +def index(axis: common.Dimension) -> common.Field: + return IndexField(axis) + + @runtime.closure.register(EMBEDDED) def closure( domain_: Domain, diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 42da4c83a6..b6f543e9d1 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -101,13 +101,18 @@ class StencilClosure(Node): domain: FunCall stencil: Expr output: Union[SymRef, FunCall] - inputs: List[SymRef] + inputs: List[Union[SymRef, FunCall]] @datamodels.validator("output") def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): raise ValueError("Only FunCall to 'make_tuple' allowed.") + @datamodels.validator("inputs") + def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): + if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): + raise ValueError("Only FunCall to 'index' allowed.") + UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} @@ -183,6 +188,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib "can_deref", "scan", "if_", + "index", # `index(dim)` creates a dim-field that has the current index at each point *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 08459a9423..b4a673772f 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -31,9 +31,9 @@ INT_LITERAL: SIGNED_INT FLOAT_LITERAL: SIGNED_FLOAT OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" - _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL + AXIS_LITERAL: CNAME ("ᵥ" | "ₕ") + _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL ID_NAME: CNAME - AXIS_NAME: CNAME ("ᵥ" | "ₕ") ?prec0: prec1 | "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: def ID_NAME(self, value: lark_lexer.Token) -> str: return value.value - def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral: + def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral: name = value.value[:-1] kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL return ir.AxisLiteral(value=name, kind=kind) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index c55cfd8d51..6579107197 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -189,6 +189,14 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType: return ts.TupleType(types=list(args)) +@_register_builtin_type_synthesizer +def index(arg: ts.DimensionType) -> ts.FieldType: + return ts.FieldType( + dims=[arg.dim], + dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())), + ) + + @_register_builtin_type_synthesizer def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: assert ( diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 92dbcedeaa..bfc45d7944 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -260,7 +260,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include #include - + namespace generated{ namespace gtfn = ::gridtools::fn; diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1995e4de0b..20a1a0cf76 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -153,7 +153,7 @@ class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] - inputs: list[Union[SymRef, SidComposite, SidFromScalar]] + inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] class Scan(Node): @@ -192,6 +192,7 @@ class TemporaryAllocation(Node): "unstructured_domain", "named_range", "reduce", + "index", ] ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS TYPEBUILTINS = itir.TYPEBUILTINS diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..fb2645208c 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E tuple_constructor=lambda *elements: SidComposite(values=list(elements)), ) - assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef)) lowered_inputs.append(lowered_input_as_sid) backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index dab8d29fd1..6383d4bb44 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,6 +24,7 @@ from gt4py.next import common from gt4py.next.ffront import decorator from gt4py.next.iterator import transforms as itir_transforms +from gt4py.next.iterator.ir import SymRef from gt4py.next.iterator.transforms import program_to_fencil from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.program_processors.runners.dace_common import utility as dace_utils @@ -197,11 +198,16 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field # Add them as dynamic properties to the SDFG + assert all( + isinstance(in_field, SymRef) + for closure in self.itir.closures + for in_field in closure.inputs + ) # backend only supports SymRef inputs, not `index` calls input_fields = [ - str(in_field.id) + str(in_field.id) # type: ignore[union-attr] # ensured by assert for closure in self.itir.closures for in_field in closure.inputs - if str(in_field.id) in fields + if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert ] sdfg.gt4py_program_input_fields = { in_field: dim @@ -237,6 +243,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: closure.stencil, num_args=len(closure.inputs) ) for param, shifts in zip(closure.inputs, params_shifts): + assert isinstance( + param, SymRef + ) # backend only supports SymRef inputs, not `index` calls if not isinstance(param.id, str): continue if param.id not in sdfg.gt4py_program_input_fields: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index d52fbc5857..a824760ce4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -357,7 +357,10 @@ def visit_StencilClosure( closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert neighbor_tables = get_used_connectivities(node, self.offset_provider) connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() @@ -565,7 +568,10 @@ def _visit_scan_stencil_closure( assert isinstance(node.output, SymRef) neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] @@ -732,7 +738,10 @@ def _visit_parallel_stencil_closure( ], ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] + assert all( + isinstance(inp, SymRef) for inp in node.inputs + ) # backend only supports SymRef inputs, not `index` calls + input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert connectivity_names = [ dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() ] diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 2c4102d5af..3fef43865b 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" +USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -127,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), @@ -145,6 +145,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE), + (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), ] GTIR_DACE_SKIP_TEST_LIST = [ (ALL, SKIP, UNSUPPORTED_MESSAGE), diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 4eab7502e7..db1c2a42aa 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -10,13 +10,22 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.builtins import ( + as_fieldop, + cartesian_domain, + deref, + index, + named_range, + shift, +) from gt4py.next.iterator.runtime import fendef, fundef, set_at from next_tests.unit_tests.conftest import program_processor, run_processor I = gtx.Dimension("I") +Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) @fundef @@ -44,3 +53,48 @@ def test_prog(program_processor): run_processor(copy_program, program_processor, inp, out, isize, offset_provider={}) if validate: assert np.allclose(inp.asnumpy(), out.asnumpy()) + + +@fendef +def index_program_simple(out, size): + set_at( + as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.starts_from_gtir_program +@pytest.mark.uses_index_builtin +def test_index_builtin(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_simple, program_processor, out, isize, offset_provider={}) + if validate: + assert np.allclose(np.arange(10), out.asnumpy()) + + +@fendef +def index_program_shift(out, size): + set_at( + as_fieldop( + lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size)) + )(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.uses_index_builtin +def test_index_builtin_shift(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I}) + if validate: + assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy())