Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Index builtin #1699

Merged
merged 12 commits into from
Nov 4, 2024
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def as_fieldop(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def index(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def deref(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"unstructured_domain",
"named_range",
"as_fieldop",
"index",
*MATH_BUILTINS,
}

Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,12 +1204,14 @@ def premap(

def restrict(self, item: common.AnyIndexSpec) -> Self:
if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item):
assert len(item) == 1
assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below
d, r = item[0]
assert d == self._dimension
assert isinstance(r, core_defs.INTEGRAL_TYPES)
# TODO(tehrengruber): Use a regular zero dimensional field instead.
return self.__class__(self._dimension, r)
# TODO set a domain...
# TODO: set a domain...
raise NotImplementedError()

__call__ = premap
Expand Down Expand Up @@ -1793,6 +1795,11 @@ def impl(*args):
return impl


@builtins.index.register(EMBEDDED)
def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ class StencilClosure(Node):
domain: FunCall
stencil: Expr
output: Union[SymRef, FunCall]
inputs: List[SymRef]
inputs: List[Union[SymRef, FunCall]]

@datamodels.validator("output")
def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"):
raise ValueError("Only FunCall to 'make_tuple' allowed.")

@datamodels.validator("inputs")
def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value):
raise ValueError("Only FunCall to 'index' allowed.")


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
Expand Down Expand Up @@ -183,6 +188,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
"can_deref",
"scan",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
INT_LITERAL: SIGNED_INT
FLOAT_LITERAL: SIGNED_FLOAT
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ"
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL
AXIS_LITERAL: CNAME ("ᵥ" | "ₕ")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed for consistency.

_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise index(OFFSET_LITERAL) isn't parsed properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean index(AXIS_LITERAL)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

ID_NAME: CNAME
AXIS_NAME: CNAME ("ᵥ" | "ₕ")

?prec0: prec1
| "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam
Expand Down Expand Up @@ -84,7 +84,7 @@
else_branch_seperator: "else"
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}"

named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
Expand Down Expand Up @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
def ID_NAME(self, value: lark_lexer.Token) -> str:
return value.value

def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral:
def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral:
name = value.value[:-1]
kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL
return ir.AxisLiteral(value=name, kind=kind)
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType:
return ts.TupleType(types=list(args))


@_register_builtin_type_synthesizer
def index(arg: ts.DimensionType) -> ts.FieldType:
return ts.FieldType(
dims=[arg.dim],
dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())),
)


@_register_builtin_type_synthesizer
def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType:
assert (
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll
#include <gridtools/fn/${grid_type_str}.hpp>
#include <gridtools/fn/sid_neighbor_table.hpp>
#include <gridtools/stencil/global_parameter.hpp>

namespace generated{

namespace gtfn = ::gridtools::fn;
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class StencilExecution(Stmt):
backend: Backend
stencil: SymRef
output: Union[SymRef, SidComposite]
inputs: list[Union[SymRef, SidComposite, SidFromScalar]]
inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]]


class Scan(Node):
Expand Down Expand Up @@ -192,6 +192,7 @@ class TemporaryAllocation(Node):
"unstructured_domain",
"named_range",
"reduce",
"index",
]
ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS
TYPEBUILTINS = itir.TYPEBUILTINS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E
tuple_constructor=lambda *elements: SidComposite(values=list(elements)),
)

assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was superfluous from the beginning since the format is checked anyway when the SetAt is constructed. We just removed it instead of allowing an index FunCall.

lowered_inputs.append(lowered_input_as_sid)

backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gt4py.next import common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.iterator.ir import SymRef
from gt4py.next.iterator.transforms import program_to_fencil
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
Expand Down Expand Up @@ -197,11 +198,16 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
# Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field
# Add them as dynamic properties to the SDFG

assert all(
isinstance(in_field, SymRef)
for closure in self.itir.closures
for in_field in closure.inputs
) # backend only supports SymRef inputs, not `index` calls
input_fields = [
str(in_field.id)
str(in_field.id) # type: ignore[union-attr] # ensured by assert
for closure in self.itir.closures
for in_field in closure.inputs
if str(in_field.id) in fields
if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert
]
sdfg.gt4py_program_input_fields = {
in_field: dim
Expand Down Expand Up @@ -237,6 +243,9 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
closure.stencil, num_args=len(closure.inputs)
)
for param, shifts in zip(closure.inputs, params_shifts):
assert isinstance(
param, SymRef
) # backend only supports SymRef inputs, not `index` calls
if not isinstance(param.id, str):
continue
if param.id not in sdfg.gt4py_program_input_fields:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def visit_StencilClosure(
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True)

input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
neighbor_tables = get_used_connectivities(node, self.offset_provider)
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
Expand Down Expand Up @@ -565,7 +568,10 @@ def _visit_scan_stencil_closure(

assert isinstance(node.output, SymRef)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down Expand Up @@ -732,7 +738,10 @@ def _visit_parallel_stencil_closure(
],
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"
USES_INDEX_BUILTIN = "uses_index_builtin"

# Skip messages (available format keys: 'marker', 'backend')
UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend"
Expand All @@ -127,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
# Common list of feature markers to skip
COMMON_SKIP_TEST_LIST = [
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
Expand All @@ -145,6 +145,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
]
GTIR_DACE_SKIP_TEST_LIST = [
(ALL, SKIP, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@
import pytest

import gt4py.next as gtx
from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.builtins import (
as_fieldop,
cartesian_domain,
deref,
index,
named_range,
shift,
)
from gt4py.next.iterator.runtime import fendef, fundef, set_at

from next_tests.unit_tests.conftest import program_processor, run_processor


I = gtx.Dimension("I")
Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))


@fundef
Expand Down Expand Up @@ -44,3 +53,48 @@ def test_prog(program_processor):
run_processor(copy_program, program_processor, inp, out, isize, offset_provider={})
if validate:
assert np.allclose(inp.asnumpy(), out.asnumpy())


@fendef
def index_program_simple(out, size):
set_at(
as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.starts_from_gtir_program
@pytest.mark.uses_index_builtin
def test_index_builtin(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_simple, program_processor, out, isize, offset_provider={})
if validate:
assert np.allclose(np.arange(10), out.asnumpy())


@fendef
def index_program_shift(out, size):
set_at(
as_fieldop(
lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size))
)(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.uses_index_builtin
def test_index_builtin_shift(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I})
if validate:
assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy())
Loading