Skip to content

Commit

Permalink
refactor[next]: itir embedded: cleaner closure run (#1521)
Browse files Browse the repository at this point in the history
This enables to run the closure also outside of the fendef context, e.g.
from field view embedded.
  • Loading branch information
havogt authored Apr 15, 2024
1 parent 705530c commit d5d59d2
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 74 deletions.
12 changes: 5 additions & 7 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,8 @@
#: closure execution context.
closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range")

_undefined_offset_provider: common.OffsetProvider = {}

#: Offset provider dict in the current embedded execution context.
offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar(
"offset_provider", default=_undefined_offset_provider
)
offset_provider: cvars.ContextVar[common.OffsetProvider] = cvars.ContextVar("offset_provider")


@contextlib.contextmanager
Expand All @@ -41,6 +37,8 @@ def new_context(
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
) -> Generator[cvars.Context, None, None]:
"""Create a new context, updating the provided values."""

import gt4py.next.embedded.context as this_module

updates: list[tuple[cvars.ContextVar[Any], Any]] = []
Expand All @@ -61,5 +59,5 @@ def ctx_updater(*args: tuple[cvars.ContextVar[Any], Any]) -> None:
yield ctx


def within_context() -> bool:
return offset_provider.get() is not _undefined_offset_provider
def within_valid_context() -> bool:
return offset_provider.get(eve.NOTHING) is not eve.NOTHING
2 changes: 1 addition & 1 deletion src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def field_operator_call(op: EmbeddedOperator[_R, _P], args: Any, kwargs: Any) ->
if "out" in kwargs:
# called from program or direct field_operator as program
new_context_kwargs = {}
if embedded_context.within_context():
if embedded_context.within_valid_context():
# called from program
assert "offset_provider" not in kwargs
else:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def as_program(
return self._program_cache[hash_]

def __call__(self, *args, **kwargs) -> None:
if not next_embedded.context.within_context() and self.backend is not None:
if not next_embedded.context.within_valid_context() and self.backend is not None:
# non embedded execution
if "offset_provider" not in kwargs:
raise errors.MissingArgumentError(None, "offset_provider", True)
Expand Down
124 changes: 63 additions & 61 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import abc
import contextvars as cvars
import copy
import dataclasses
import itertools
Expand All @@ -28,6 +27,7 @@
import numpy as np
import numpy.typing as npt

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import (
Expand All @@ -52,8 +52,8 @@
overload,
runtime_checkable,
)
from gt4py.next import common, embedded as next_embedded
from gt4py.next.embedded import exceptions as embedded_exceptions
from gt4py.next import common
from gt4py.next.embedded import context as embedded_context, exceptions as embedded_exceptions
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime

Expand Down Expand Up @@ -191,12 +191,6 @@ class MutableLocatedField(LocatedField, Protocol):
def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ...


#: Column range used in column mode (`column_axis != None`) in the current closure execution context.
column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range
#: Offset provider dict in the current closure execution context.
offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider


class Column(np.lib.mixins.NDArrayOperatorsMixin):
"""Represents a column when executed in column mode (`column_axis != None`).
Expand All @@ -207,7 +201,7 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None:
self.kstart = kstart
assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673
column_range: common.NamedRange = column_range_cvar.get()
column_range: common.NamedRange = embedded_context.closure_column_range.get()
self.data = (
data if isinstance(data, np.ndarray) else np.full(len(column_range.unit_range), data)
)
Expand Down Expand Up @@ -751,7 +745,7 @@ def _make_tuple(
except embedded_exceptions.IndexOutOfBounds:
return _UNDEFINED
else:
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
assert column_range is not None

col: list[
Expand Down Expand Up @@ -796,7 +790,7 @@ class MDIterator:

def shift(self, *offsets: OffsetPart) -> MDIterator:
complete_offsets = group_offsets(*offsets)
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
return MDIterator(
self.field,
Expand All @@ -821,8 +815,8 @@ def deref(self) -> Any:
if not all(axis.value in shifted_pos.keys() for axis in axes if axis is not None):
raise IndexError("Iterator position doesn't point to valid location for its field.")
slice_column = dict[Tag, range]()
column_range = column_range_cvar.get()
if self.column_axis is not None:
column_range = embedded_context.closure_column_range.get()
assert column_range is not None
k_pos = shifted_pos.pop(self.column_axis)
assert isinstance(k_pos, int)
Expand Down Expand Up @@ -862,7 +856,7 @@ def make_in_iterator(
init = [None] * sparse_dimensions.count(sparse_dim)
new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused
if column_axis is not None:
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
# if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted
assert column_range is not None
new_pos[column_axis] = column_range.start
Expand Down Expand Up @@ -1303,7 +1297,7 @@ def __getitem__(self, _):
def neighbors(offset: runtime.Offset, it: ItIterator) -> _List:
offset_str = offset.value if isinstance(offset, runtime.Offset) else offset
assert isinstance(offset_str, str)
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
connectivity = offset_provider[offset_str]
assert isinstance(connectivity, common.Connectivity)
Expand Down Expand Up @@ -1359,7 +1353,7 @@ class SparseListIterator:
offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True)

def deref(self) -> Any:
offset_provider = offset_provider_cvar.get()
offset_provider = embedded_context.offset_provider.get()
assert offset_provider is not None
connectivity = offset_provider[self.list_offset]
assert isinstance(connectivity, common.Connectivity)
Expand All @@ -1376,12 +1370,6 @@ def shift(self, *offsets: OffsetPart) -> SparseListIterator:
return SparseListIterator(self.it, self.list_offset, offsets=[*offsets, *self.offsets])


@dataclasses.dataclass(frozen=True)
class ColumnDescriptor:
axis: str
col_range: range # TODO(havogt) introduce range type that doesn't have step


@dataclasses.dataclass(frozen=True)
class ScanArgIterator:
wrapped_iter: ItIterator
Expand Down Expand Up @@ -1480,7 +1468,7 @@ def _column_dtype(elem: Any) -> np.dtype:
@builtins.scan.register(EMBEDDED)
def scan(scan_pass, is_forward: bool, init):
def impl(*iters: ItIterator):
column_range = column_range_cvar.get().unit_range
column_range = embedded_context.closure_column_range.get().unit_range
if column_range is None:
raise RuntimeError("Column range is not defined, cannot scan.")

Expand Down Expand Up @@ -1508,64 +1496,78 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None:
)


def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("'offset_provider' not provided.")

offset_provider = kwargs["offset_provider"]

@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
sten: Callable[..., Any],
out, #: MutableLocatedField,
ins: list[common.Field],
) -> None:
_validate_domain(domain_, kwargs["offset_provider"])
domain: dict[Tag, range] = _dimension_to_tag(domain_)
if not (isinstance(out, common.Field) or is_tuple_of_field(out)):
raise TypeError("'Out' needs to be a located field.")

column_range = None
column: Optional[ColumnDescriptor] = None
if kwargs.get("column_axis") and kwargs["column_axis"].value in domain:
column_axis = kwargs["column_axis"]
column = ColumnDescriptor(column_axis.value, domain[column_axis.value])
del domain[column_axis.value]

@runtime.closure.register(EMBEDDED)
def closure(
domain_: Domain,
sten: Callable[..., Any],
out, #: MutableLocatedField,
ins: list[common.Field],
) -> None:
assert embedded_context.within_valid_context()
offset_provider = embedded_context.offset_provider.get()
_validate_domain(domain_, offset_provider)
domain: dict[Tag, range] = _dimension_to_tag(domain_)
if not (isinstance(out, common.Field) or is_tuple_of_field(out)):
raise TypeError("'Out' needs to be a located field.")

column_range: common.NamedRange | eve.NothingType = eve.NOTHING
if (col_range_placeholder := embedded_context.closure_column_range.get(None)) is not None:
assert (
col_range_placeholder.unit_range.is_empty()
) # check it's just the placeholder with empty range
column_axis = col_range_placeholder.dim
if column_axis is not None and column_axis.value in domain:
column_range = common.NamedRange(
column_axis, common.UnitRange(column.col_range.start, column.col_range.stop)
column_axis,
common.UnitRange(domain[column_axis.value].start, domain[column_axis.value].stop),
)
del domain[column_axis.value]

out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out)
out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out)

def _closure_runner():
# Set context variables before executing the closure
column_range_cvar.set(column_range)
offset_provider_cvar.set(offset_provider)
with embedded_context.new_context(closure_column_range=column_range) as ctx:

def _iterate():
for pos in _domain_iterator(domain):
promoted_ins = [promote_scalars(inp) for inp in ins]
ins_iters = list(
make_in_iterator(inp, pos, column_axis=column.axis if column else None)
make_in_iterator(
inp,
pos,
column_axis=column_range.dim.value
if column_range is not eve.NOTHING
else None,
)
for inp in promoted_ins
)
res = sten(*ins_iters)

if column is None:
if column_range is eve.NOTHING:
assert _is_concrete_position(pos)
out.field_setitem(pos, res)
else:
col_pos = pos.copy()
for k in column.col_range:
col_pos[column.axis] = k
for k in column_range.unit_range:
col_pos[column_range.dim.value] = k
assert _is_concrete_position(col_pos)
out.field_setitem(col_pos, res[k])

ctx = cvars.copy_context()
ctx.run(_closure_runner)
ctx.run(_iterate)


def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
if "offset_provider" not in kwargs:
raise RuntimeError("'offset_provider' not provided.")

context_vars = {"offset_provider": kwargs["offset_provider"]}
if "column_axis" in kwargs:
context_vars["closure_column_range"] = common.NamedRange(
kwargs["column_axis"],
common.UnitRange(0, 0), # empty: indicates column operation, will update later
)

fun(*args)
with embedded_context.new_context(**context_vars) as ctx:
ctx.run(fun, *args)


runtime.fendef_embedded = fendef_embedded
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import contextvars as cvars
import threading
from typing import Any, Callable, Optional

import numpy as np
import pytest

from gt4py.next import common
from gt4py.next.embedded import context as embedded_context
from gt4py.next.iterator import embedded


Expand All @@ -30,8 +30,8 @@ def _run_within_context(
offset_provider: Optional[embedded.OffsetProvider] = None,
) -> Any:
def wrapped_func():
embedded.column_range_cvar.set(column_range)
embedded.offset_provider_cvar.set(offset_provider)
embedded_context.closure_column_range.set(column_range)
embedded_context.offset_provider.set(offset_provider)
func()

cvars.copy_context().run(wrapped_func)
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_func(data_a: int, data_b: int):
assert res.kstart == 1

# Setting an invalid column_range here shouldn't affect other contexts
embedded.column_range_cvar.set(range(2, 999))
embedded_context.closure_column_range.set(range(2, 999))
_run_within_context(
lambda: test_func(2, 3),
column_range=common.NamedRange(
Expand Down

0 comments on commit d5d59d2

Please sign in to comment.