Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Index builtin #1699

Merged
merged 12 commits into from
Nov 4, 2024
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def as_fieldop(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def index(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def deref(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -430,6 +435,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
"unstructured_domain",
"named_range",
"as_fieldop",
"index",
*MATH_BUILTINS,
}

Expand Down
9 changes: 8 additions & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,7 @@ def ndarray(self) -> core_defs.NDArrayObject:
def asnumpy(self) -> np.ndarray:
raise NotImplementedError()

# TODO(tehrengruber): Use a regular zero dimensional field instead.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We moved the comment to a better place. With respect to the comment itself: The IndexField class is rather strange. It has two modes of operation: Either it is a field with (conceptually) field(domain) == domain or it is a zero-dimensional field. Both modes don't share any implementation similarities, but are mushed into the same class. The way the class behaves is then controlled using _cur_index. It would be much simpler to just make field[index] return a zero-dimensional field which is exactly what we want instead of re-implementing it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you are saying...

def as_scalar(self) -> core_defs.IntegralScalar:
if self.domain.ndim != 0:
raise ValueError(
Expand All @@ -1174,12 +1175,13 @@ def premap(

def restrict(self, item: common.AnyIndexSpec) -> Self:
if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item):
assert len(item) == 1
assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below
d, r = item[0]
assert d == self._dimension
assert isinstance(r, core_defs.INTEGRAL_TYPES)
return self.__class__(self._dimension, r)
# TODO set a domain...
# TODO: set a domain...
raise NotImplementedError()

__call__ = premap
Expand Down Expand Up @@ -1701,6 +1703,11 @@ def impl(*args):
return impl


@builtins.index.register(EMBEDDED)
def index(axis: common.Dimension) -> common.Field:
return IndexField(axis)


@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ class StencilClosure(Node):
domain: FunCall
stencil: Expr
output: Union[SymRef, FunCall]
inputs: List[SymRef]
inputs: List[Union[SymRef, FunCall]]

@datamodels.validator("output")
def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"):
raise ValueError("Only FunCall to 'make_tuple' allowed.")

@datamodels.validator("inputs")
def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value):
if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value):
raise ValueError("Only FunCall to 'index' allowed.")


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_LOGICAL_BUILTINS = {"not_"}
Expand Down Expand Up @@ -183,6 +188,7 @@ def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attrib
"can_deref",
"scan",
"if_",
"index", # `index(dim)` creates a dim-field that has the current index at each point
*ARITHMETIC_BUILTINS,
*TYPEBUILTINS,
}
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/pretty_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
INT_LITERAL: SIGNED_INT
FLOAT_LITERAL: SIGNED_FLOAT
OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ"
_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL
AXIS_LITERAL: CNAME ("ᵥ" | "ₕ")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed for consistency.

_literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise index(OFFSET_LITERAL) isn't parsed properly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean index(AXIS_LITERAL)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

ID_NAME: CNAME
AXIS_NAME: CNAME ("ᵥ" | "ₕ")

?prec0: prec1
| "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam
Expand Down Expand Up @@ -84,7 +84,7 @@
else_branch_seperator: "else"
if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}"

named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")"
named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")"
function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";"
declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";"
stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";"
Expand Down Expand Up @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral:
def ID_NAME(self, value: lark_lexer.Token) -> str:
return value.value

def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral:
def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral:
name = value.value[:-1]
kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL
return ir.AxisLiteral(value=name, kind=kind)
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,14 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType:
return ts.TupleType(types=list(args))


@_register_builtin_type_synthesizer
def index(arg: ts.DimensionType) -> ts.FieldType:
return ts.FieldType(
dims=[arg.dim],
dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())),
)


@_register_builtin_type_synthesizer
def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType:
assert (
Expand Down
11 changes: 11 additions & 0 deletions src/gt4py/next/program_processors/codegens/gtfn/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll
#include <gridtools/fn/${grid_type_str}.hpp>
#include <gridtools/fn/sid_neighbor_table.hpp>
#include <gridtools/stencil/global_parameter.hpp>
#include <gridtools/stencil/positional.hpp>

// TODO(havogt): move to gtfn?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd propose to wait until GridTools/gridtools#1806 is merged next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once #1720 is in, we can remove this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I removed that code.

namespace gridtools{
namespace fn{
template <class T>
auto index(T){
return stencil::positional<std::decay_t<T>>();}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <gridtools/stencil/positional.hpp>
// TODO(havogt): move to gtfn?
namespace gridtools{
namespace fn{
template <class T>
auto index(T){
return stencil::positional<std::decay_t<T>>();}
}
}


namespace generated{

Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class StencilExecution(Stmt):
backend: Backend
stencil: SymRef
output: Union[SymRef, SidComposite]
inputs: list[Union[SymRef, SidComposite, SidFromScalar]]
inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]]


class Scan(Node):
Expand Down Expand Up @@ -192,6 +192,7 @@ class TemporaryAllocation(Node):
"unstructured_domain",
"named_range",
"reduce",
"index",
]
ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS
TYPEBUILTINS = itir.TYPEBUILTINS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E
tuple_constructor=lambda *elements: SidComposite(values=list(elements)),
)

assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was superfluous from the beginning since the format is checked anyway when the SetAt is constructed. We just removed it instead of allowing an index FunCall.

lowered_inputs.append(lowered_input_as_sid)

backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
# Add them as dynamic properties to the SDFG

input_fields = [
str(in_field.id)
str(in_field.id) # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
for closure in self.itir.closures
for in_field in closure.inputs
if str(in_field.id) in fields
if str(in_field.id) in fields # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
]
sdfg.gt4py_program_input_fields = {
in_field: dim
Expand Down Expand Up @@ -237,11 +237,11 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
closure.stencil, num_args=len(closure.inputs)
)
for param, shifts in zip(closure.inputs, params_shifts):
if not isinstance(param.id, str):
if not isinstance(param.id, str): # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
continue
if param.id not in sdfg.gt4py_program_input_fields:
if param.id not in sdfg.gt4py_program_input_fields: # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
continue
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts))
sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls

return sdfg

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def visit_StencilClosure(
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True)

input_names = [str(inp.id) for inp in node.inputs]
assert all(
isinstance(inp, SymRef) for inp in node.inputs
) # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # `inp`s are only `SymRef`s

or similar

neighbor_tables = get_used_connectivities(node, self.offset_provider)
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
Expand Down Expand Up @@ -565,7 +568,7 @@ def _visit_scan_stencil_closure(

assert isinstance(node.output, SymRef)
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down Expand Up @@ -732,7 +735,7 @@ def _visit_parallel_stencil_closure(
],
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
neighbor_tables = get_used_connectivities(node, self.offset_provider)
input_names = [str(inp.id) for inp in node.inputs]
input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # backend only supports SymRef inputs, not `index` calls
connectivity_names = [
dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys()
]
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values"
USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"
USES_INDEX_BUILTIN = "uses_index_builtin"

# Skip messages (available format keys: 'marker', 'backend')
UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend"
Expand All @@ -127,7 +128,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
# Common list of feature markers to skip
COMMON_SKIP_TEST_LIST = [
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
Expand All @@ -145,6 +145,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
(USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE),
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
]
GTIR_DACE_SKIP_TEST_LIST = [
(ALL, SKIP, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,22 @@
import pytest

import gt4py.next as gtx
from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.builtins import (
as_fieldop,
cartesian_domain,
deref,
index,
named_range,
shift,
)
from gt4py.next.iterator.runtime import fendef, fundef, set_at

from next_tests.unit_tests.conftest import program_processor, run_processor


I = gtx.Dimension("I")
Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))


@fundef
Expand Down Expand Up @@ -44,3 +53,48 @@ def test_prog(program_processor):
run_processor(copy_program, program_processor, inp, out, isize, offset_provider={})
if validate:
assert np.allclose(inp.asnumpy(), out.asnumpy())


@fendef
def index_program_simple(out, size):
set_at(
as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.starts_from_gtir_program
@pytest.mark.uses_index_builtin
def test_index_builtin(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_simple, program_processor, out, isize, offset_provider={})
if validate:
assert np.allclose(np.arange(10), out.asnumpy())


@fendef
def index_program_shift(out, size):
set_at(
as_fieldop(
lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size))
)(index(I)),
cartesian_domain(named_range(I, 0, size)),
out,
)


@pytest.mark.uses_index_builtin
def test_index_builtin_shift(program_processor):
program_processor, validate = program_processor

isize = 10
out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, itir.INTEGER_INDEX_BUILTIN))

run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I})
if validate:
assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy())
Loading