Skip to content

Commit

Permalink
Merge branch 'main' into optimize_program
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Nov 4, 2024
2 parents f3dd7ab + 827d40c commit e0db9fd
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 79 deletions.
30 changes: 25 additions & 5 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,31 @@ def visit_FieldOperator(
def visit_ScanOperator(
self, node: foast.ScanOperator, **kwargs: Any
) -> itir.FunctionDefinition:
raise NotImplementedError("TODO")
# note: we don't need the axis here as this is handled by the program
# decorator
assert isinstance(node.type, ts_ffront.ScanOperatorType)

# We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`.
# In iterator IR we didn't properly specify if this is legal,
# however after lift-inlining the expressions are transformed back to literals.
forward = self.visit(node.forward, **kwargs)
init = self.visit(node.init, **kwargs)

# lower definition function
func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs)
new_body = func_definition.expr

stencil_args: list[itir.Expr] = []
assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args
for param in func_definition.params[1:]:
new_body = im.let(param.id, im.deref(param.id))(new_body)
stencil_args.append(im.ref(param.id))

definition = itir.Lambda(params=func_definition.params, expr=new_body)

body = im.as_fieldop(im.call("scan")(definition, forward, init))(*stencil_args)

return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body)

def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never:
raise AssertionError("Statements must always be visited in the context of a function.")
Expand Down Expand Up @@ -324,10 +348,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
*lowered_args, *lowered_kwargs.values()
)

# scan operators return an iterator of tuples, transform into tuples of iterator again
if isinstance(node.func.type, ts_ffront.ScanOperatorType):
raise NotImplementedError("TODO")

return result

raise AssertionError(
Expand Down
168 changes: 130 additions & 38 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.otf import arguments
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -186,6 +187,12 @@ def mapped_index(
NamedFieldIndices: TypeAlias = Mapping[Tag, FieldIndex | SparsePositionEntry]


# Magic local dimension for the result of a `make_const_list`.
# A clean implementation will probably involve to tag the `make_const_list`
# with the neighborhood it is meant to be used with.
_CONST_DIM = common.Dimension(value="_CONST_DIM", kind=common.DimensionKind.LOCAL)


@runtime_checkable
class ItIterator(Protocol):
"""
Expand Down Expand Up @@ -227,6 +234,12 @@ class MutableLocatedField(LocatedField, Protocol):
def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ...


def _numpy_structured_value_to_tuples(value: Any) -> Any:
if _elem_dtype(value).names is not None:
return tuple(_numpy_structured_value_to_tuples(v) for v in value)
return value


class Column(np.lib.mixins.NDArrayOperatorsMixin):
"""Represents a column when executed in column mode (`column_axis != None`).
Expand All @@ -247,6 +260,10 @@ def dtype(self) -> np.dtype:
# not directly dtype of `self.data` as that might be a structured type containing `None`
return _elem_dtype(self.data[self.kstart])

def __gt_type__(self) -> ts.TypeSpec:
elem = self.data[self.kstart]
return type_translation.from_value(_numpy_structured_value_to_tuples(elem))

def __getitem__(self, i: int) -> Any:
result = self.data[i - self.kstart]
# numpy type
Expand Down Expand Up @@ -576,17 +593,20 @@ def execute_shift(
for i, p in reversed(list(enumerate(new_entry))):
# first shift applies to the last sparse dimensions of that axis type
if p is None:
offset_implementation = offset_provider[tag]
assert isinstance(offset_implementation, common.Connectivity)
cur_index = pos[offset_implementation.origin_axis.value]
assert common.is_int_index(cur_index)
if offset_implementation.mapped_index(cur_index, index) in [
None,
common._DEFAULT_SKIP_VALUE,
]:
return None

new_entry[i] = index
if tag == _CONST_DIM.value:
new_entry[i] = 0
else:
offset_implementation = offset_provider[tag]
assert isinstance(offset_implementation, common.Connectivity)
cur_index = pos[offset_implementation.origin_axis.value]
assert common.is_int_index(cur_index)
if offset_implementation.mapped_index(cur_index, index) in [
None,
common._DEFAULT_SKIP_VALUE,
]:
return None

new_entry[i] = index
break
# the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard
return cast(IncompletePosition, pos) | {tag: new_entry}
Expand Down Expand Up @@ -920,9 +940,9 @@ def deref(self) -> Any:
return _make_tuple(self.field, position, column_axis=self.column_axis)


def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]:
def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[common.Dimension]:
return [
axis.value
axis
for axis in axes
if isinstance(axis, common.Dimension) and axis.kind == common.DimensionKind.LOCAL
]
Expand All @@ -945,7 +965,7 @@ def make_in_iterator(
new_pos: Position = pos.copy()
for sparse_dim in set(sparse_dimensions):
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
new_pos[sparse_dim.value] = init # type: ignore[assignment] # looks like mypy is confused
if column_dimension is not None:
column_range = embedded_context.closure_column_range.get().unit_range
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
Expand All @@ -956,7 +976,7 @@ def make_in_iterator(
)
if len(sparse_dimensions) >= 1:
if len(sparse_dimensions) == 1:
return SparseListIterator(it, sparse_dimensions[0])
return SparseListIterator(it, sparse_dimensions[0].value)
else:
raise NotImplementedError(
f"More than one local dimension is currently not supported, got {sparse_dimensions}."
Expand Down Expand Up @@ -1004,7 +1024,17 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any:

def field_setitem(self, named_indices: NamedFieldIndices, value: Any):
if isinstance(self._ndarrayfield, common.MutableField):
self._ndarrayfield[self._translate_named_indices(named_indices)] = value
if isinstance(value, _List):
for i, v in enumerate(value): # type:ignore[var-annotated, arg-type]
self._ndarrayfield[
self._translate_named_indices({**named_indices, value.offset.value: i}) # type: ignore[dict-item]
] = v
elif isinstance(value, _ConstList):
self._ndarrayfield[
self._translate_named_indices({**named_indices, _CONST_DIM.value: 0})
] = value.value
else:
self._ndarrayfield[self._translate_named_indices(named_indices)] = value
else:
raise RuntimeError("Assigment into a non-mutable Field is not allowed.")

Expand Down Expand Up @@ -1383,7 +1413,23 @@ def impl(it: ItIterator) -> ItIterator:
DT = TypeVar("DT")


class _List(tuple, Generic[DT]): ...
@dataclasses.dataclass(frozen=True)
class _List(Generic[DT]):
values: tuple[DT, ...]
offset: runtime.Offset

def __getitem__(self, i: int):
return self.values[i]

def __gt_type__(self) -> itir_ts.ListType:
offset_tag = self.offset.value
assert isinstance(offset_tag, str)
element_type = type_translation.from_value(self.values[0])
assert isinstance(element_type, ts.DataType)
return itir_ts.ListType(
element_type=element_type,
offset_type=common.Dimension(value=offset_tag, kind=common.DimensionKind.LOCAL),
)


@dataclasses.dataclass(frozen=True)
Expand All @@ -1393,6 +1439,14 @@ class _ConstList(Generic[DT]):
def __getitem__(self, _):
return self.value

def __gt_type__(self) -> itir_ts.ListType:
element_type = type_translation.from_value(self.value)
assert isinstance(element_type, ts.DataType)
return itir_ts.ListType(
element_type=element_type,
offset_type=_CONST_DIM,
)


@builtins.neighbors.register(EMBEDDED)
def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
Expand All @@ -1403,9 +1457,12 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
connectivity = offset_provider[offset_str]
assert isinstance(connectivity, common.Connectivity)
return _List(
shifted.deref()
for i in range(connectivity.max_neighbors)
if (shifted := it.shift(offset_str, i)).can_deref()
values=tuple(
shifted.deref()
for i in range(connectivity.max_neighbors)
if (shifted := it.shift(offset_str, i)).can_deref()
),
offset=offset,
)


Expand All @@ -1414,10 +1471,23 @@ def list_get(i, lst: _List[Optional[DT]]) -> Optional[DT]:
return lst[i]


def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]:
offsets = set((lst.offset for lst in lists if hasattr(lst, "offset")))
if len(offsets) == 0:
return None
if len(offsets) == 1:
return offsets.pop()
raise AssertionError("All lists must have the same offset.")


@builtins.map_.register(EMBEDDED)
def map_(op):
def impl_(*lists):
return _List(map(lambda x: op(*x), zip(*lists)))
offset = _get_offset(*lists)
if offset is None:
return _ConstList(value=op(*[lst.value for lst in lists]))
else:
return _List(values=tuple(map(lambda x: op(*x), zip(*lists))), offset=offset)

return impl_

Expand All @@ -1438,7 +1508,7 @@ def sten(*lists):
break
# we can check a single argument for length,
# because all arguments share the same pattern
n = len(lst)
n = len(lst.values)
res = init
for i in range(n):
res = fun(res, *(lst[i] for lst in lists))
Expand All @@ -1454,14 +1524,23 @@ class SparseListIterator:
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
if self.list_offset == _CONST_DIM.value:
return _ConstList(
value=self.it.shift(*self.offsets, SparseTag(self.list_offset), 0).deref()
)
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
connectivity = offset_provider[self.list_offset]
assert isinstance(connectivity, common.Connectivity)
return _List(
shifted.deref()
for i in range(connectivity.max_neighbors)
if (shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i)).can_deref()
values=tuple(
shifted.deref()
for i in range(connectivity.max_neighbors)
if (
shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i)
).can_deref()
),
offset=runtime.Offset(value=self.list_offset),
)

def can_deref(self) -> bool:
Expand Down Expand Up @@ -1654,16 +1733,6 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType:
return eve.NOTHING


def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType:
if structured_dtype.names is None:
return type_translation.from_dtype(core_defs.dtype(structured_dtype))
return ts.TupleType(
types=[
_structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names
]
)


def _get_output_type(
fun: Callable,
domain_: runtime.CartesianDomain | runtime.UnstructuredDomain,
Expand All @@ -1682,16 +1751,39 @@ def _get_output_type(
with embedded_context.new_context(closure_column_range=col_range) as ctx:
single_pos_result = ctx.run(_compute_at_position, fun, args, pos_in_domain, col_dim)
assert single_pos_result is not _UNDEFINED, "Stencil contains an Out-Of-Bound access."
dtype = _elem_dtype(single_pos_result)
return _structured_dtype_to_typespec(dtype)
return type_translation.from_value(single_pos_result)


def _fieldspec_list_to_value(
domain: common.Domain, type_: ts.TypeSpec
) -> tuple[common.Domain, ts.TypeSpec]:
"""Translate the list element type into the domain."""
if isinstance(type_, itir_ts.ListType):
if type_.offset_type == _CONST_DIM:
return domain.insert(
len(domain), common.named_range((_CONST_DIM, 1))
), type_.element_type
else:
offset_provider = embedded_context.offset_provider.get()
offset_type = type_.offset_type
assert isinstance(offset_type, common.Dimension)
connectivity = offset_provider[offset_type.value]
assert isinstance(connectivity, common.Connectivity)
return domain.insert(
len(domain),
common.named_range((offset_type, connectivity.max_neighbors)),
), type_.element_type
return domain, type_


@builtins.as_fieldop.register(EMBEDDED)
def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.UnstructuredDomain):
def impl(*args):
xp = field_utils.get_array_ns(*args)
type_ = _get_output_type(fun, domain, [promote_scalars(arg) for arg in args])
out = field_utils.field_from_typespec(type_, common.domain(domain), xp)

new_domain, type_ = _fieldspec_list_to_value(common.domain(domain), type_)
out = field_utils.field_from_typespec(type_, new_domain, xp)

# TODO(havogt): after updating all tests to use the new program,
# we should get rid of closure and move the implementation to this function
Expand Down
7 changes: 2 additions & 5 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,7 @@ def apply_fieldview_transforms(
) -> itir.Program:
ir = inline_fundefs.InlineFundefs().visit(ir)
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
ir = InlineLambdas.apply(ir, opcount_preserving=True)
ir = infer_domain.infer_program(
ir,
offset_provider=offset_provider,
)
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program`
ir = infer_domain.infer_program(ir, offset_provider=offset_provider)
return ir
5 changes: 4 additions & 1 deletion src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Literal
from typing import Literal, Optional

from gt4py.next import common
from gt4py.next.type_system import type_specifications as ts
Expand All @@ -31,6 +31,9 @@ class OffsetLiteralType(ts.TypeSpec):
@dataclasses.dataclass(frozen=True)
class ListType(ts.DataType):
element_type: ts.DataType
# TODO(havogt): the `offset_type` is not yet used in type_inference,
# it is meant to describe the neighborhood (via the local dimension)
offset_type: Optional[common.Dimension] = None


@dataclasses.dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,6 @@ def stencil(

@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_variable_offsets(backend):
if backend == "dace:cpu":
pytest.skip("Internal compiler error in GitHub action container")

@gtscript.stencil(backend=backend)
def stencil_ij(
in_field: gtscript.Field[np.float_],
Expand All @@ -391,9 +388,6 @@ def stencil_ijk(

@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_variable_offsets_and_while_loop(backend):
if backend == "dace:cpu":
pytest.skip("Internal compiler error in GitHub action container")

@gtscript.stencil(backend=backend)
def stencil(
pe1: gtscript.Field[np.float_],
Expand Down
Loading

0 comments on commit e0db9fd

Please sign in to comment.