Skip to content

Commit

Permalink
feat[next]: Index builtin (#1699)
Browse files Browse the repository at this point in the history
Adds index builtin for embedded and gtfn backends.
  • Loading branch information
SF-N authored Nov 4, 2024
1 parent 44d6224 commit 9f3b0a7
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 16 deletions.
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def as_fieldop(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def index(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def deref(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"unstructured_domain",
"named_range",
"as_fieldop",
"index",
*MATH_BUILTINS,
}

Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,12 +1204,14 @@ def premap(

def restrict(self, item: common.AnyIndexSpec) -> Self:
if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item):
assert len(item) == 1
assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below
d, r = item[0]
assert d == self._dimension
assert isinstance(r, core_defs.INTEGRAL_TYPES)
# TODO(tehrengruber): Use a regular zero dimensional field instead.
return self.__class__(self._dimension, r)
# TODO set a domain...
# TODO: set a domain...
raise NotImplementedError()

__call__ = premap
Expand Down Expand Up @@ -1793,6 +1795,11 @@ def impl(*args):
return impl


@builtins.index.register(EMBEDDED)
def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ class StencilClosure(Node):
domain: FunCall
stencil: Expr
output: Union[SymRef, FunCall]
inputs: List[SymRef]
inputs: List[Union[SymRef, FunCall]]

@datamodels.validator("output")
def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"):
raise ValueError("Only FunCall to 'make_tuple' allowed.")

@datamodels.validator("inputs")
def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value):
raise ValueError("Only FunCall to 'index' allowed.")


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
Expand Down Expand Up @@ -183,6 +188,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
"can_deref",
"scan",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
INT_LITERAL: SIGNED_INT
FLOAT_LITERAL: SIGNED_FLOAT
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ"
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL
AXIS_LITERAL: CNAME ("ᵥ" | "ₕ")
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL
ID_NAME: CNAME
AXIS_NAME: CNAME ("ᵥ" | "ₕ")
?prec0: prec1
| "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam
Expand Down Expand Up @@ -84,7 +84,7 @@
else_branch_seperator: "else"
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}"
named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
Expand Down Expand Up @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
def ID_NAME(self, value: lark_lexer.Token) -> str:
return value.value

def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral:
def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral:
name = value.value[:-1]
kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL
return ir.AxisLiteral(value=name, kind=kind)
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType:
return ts.TupleType(types=list(args))


@_register_builtin_type_synthesizer
def index(arg: ts.DimensionType) -> ts.FieldType:
return ts.FieldType(
dims=[arg.dim],
dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())),
)


@_register_builtin_type_synthesizer
def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType:
assert (
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll
#include <gridtools/fn/${grid_type_str}.hpp>
#include <gridtools/fn/sid_neighbor_table.hpp>
#include <gridtools/stencil/global_parameter.hpp>
namespace generated{
namespace gtfn = ::gridtools::fn;
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class StencilExecution(Stmt):
backend: Backend
stencil: SymRef
output: Union[SymRef, SidComposite]
inputs: list[Union[SymRef, SidComposite, SidFromScalar]]
inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]]


class Scan(Node):
Expand Down Expand Up @@ -192,6 +192,7 @@ class TemporaryAllocation(Node):
"unstructured_domain",
"named_range",
"reduce",
"index",
]
ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS
TYPEBUILTINS = itir.TYPEBUILTINS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E
tuple_constructor=lambda *elements: SidComposite(values=list(elements)),
)

assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef))
lowered_inputs.append(lowered_input_as_sid)

backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gt4py.next import common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.iterator.ir import SymRef
from gt4py.next.iterator.transforms import program_to_fencil
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
Expand Down Expand Up @@ -197,11 +198,16 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
# Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field
# Add them as dynamic properties to the SDFG

assert all(
isinstance(in_field, SymRef)
for closure in self.itir.closures
for in_field in closure.inputs
) # backend only supports SymRef inputs, not `index` calls
input_fields = [
str(in_field.id)
str(in_field.id) # type: ignore[union-attr] # ensured by assert
for closure in self.itir.closures
for in_field in closure.inputs
if str(in_field.id) in fields
if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert
]
sdfg.gt4py_program_input_fields = {
in_field: dim
Expand Down Expand Up @@ -237,6 +243,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
closure.stencil, num_args=len(closure.inputs)
)
for param, shifts in zip(closure.inputs, params_shifts):
assert isinstance(
param, SymRef
) # backend only supports SymRef inputs, not `index` calls
if not isinstance(param.id, str):
continue
if param.id not in sdfg.gt4py_program_input_fields:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def visit_StencilClosure(
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True)

input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
neighbor_tables = get_used_connectivities(node, self.offset_provider)
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
Expand Down Expand Up @@ -565,7 +568,10 @@ def _visit_scan_stencil_closure(

assert isinstance(node.output, SymRef)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down Expand Up @@ -732,7 +738,10 @@ def _visit_parallel_stencil_closure(
],
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"
USES_INDEX_BUILTIN = "uses_index_builtin"

# Skip messages (available format keys: 'marker', 'backend')
UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend"
Expand All @@ -127,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
# Common list of feature markers to skip
COMMON_SKIP_TEST_LIST = [
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
Expand All @@ -145,6 +145,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
]
GTIR_DACE_SKIP_TEST_LIST = [
(ALL, SKIP, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@
import pytest

import gt4py.next as gtx
from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.builtins import (
as_fieldop,
cartesian_domain,
deref,
index,
named_range,
shift,
)
from gt4py.next.iterator.runtime import fendef, fundef, set_at

from next_tests.unit_tests.conftest import program_processor, run_processor


I = gtx.Dimension("I")
Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))


@fundef
Expand Down Expand Up @@ -44,3 +53,48 @@ def test_prog(program_processor):
run_processor(copy_program, program_processor, inp, out, isize, offset_provider={})
if validate:
assert np.allclose(inp.asnumpy(), out.asnumpy())


@fendef
def index_program_simple(out, size):
set_at(
as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.starts_from_gtir_program
@pytest.mark.uses_index_builtin
def test_index_builtin(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_simple, program_processor, out, isize, offset_provider={})
if validate:
assert np.allclose(np.arange(10), out.asnumpy())


@fendef
def index_program_shift(out, size):
set_at(
as_fieldop(
lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size))
)(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.uses_index_builtin
def test_index_builtin_shift(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I})
if validate:
assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy())

0 comments on commit 9f3b0a7

Please sign in to comment.