Skip to content

Commit

Permalink
feat[next]: Embedded field scan (#1365)
Browse files Browse the repository at this point in the history
Adds the scalar scan operator for embedded field view.
  • Loading branch information
havogt authored Dec 12, 2023
1 parent 8e64458 commit a14ad09
Show file tree
Hide file tree
Showing 15 changed files with 372 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ image:
tasks:
- name: Setup venv and dev tools
init: |
ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode
ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode
python -m venv .venv
source .venv/bin/activate
pip install --upgrade pip setuptools wheel
Expand Down
17 changes: 17 additions & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from __future__ import annotations

import functools
import itertools
import operator

from gt4py.eve.extended_typing import Any, Optional, Sequence, cast
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions
Expand Down Expand Up @@ -90,6 +94,19 @@ def _absolute_sub_domain(
return common.Domain(*named_ranges)


def intersect_domains(*domains: common.Domain) -> common.Domain:
return functools.reduce(
operator.and_,
domains,
common.Domain(dims=tuple(), ranges=tuple()),
)


def iterate_domain(domain: common.Domain):
for i in itertools.product(*[list(r) for r in domain.ranges]):
yield tuple(zip(domain.dims, i))


def _expand_ellipsis(
indices: common.RelativeIndexSequence, target_size: int
) -> tuple[common.IntIndex | slice, ...]:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/embedded/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#: Column range used in column mode (`column_axis != None`) in the current embedded iterator
#: closure execution context.
closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range")
closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range")

_undefined_offset_provider: common.OffsetProvider = {}

Expand All @@ -37,7 +37,7 @@
@contextlib.contextmanager
def new_context(
*,
closure_column_range: range | eve.NothingType = eve.NOTHING,
closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING,
offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING,
):
import gt4py.next.embedded.context as this_module
Expand Down
8 changes: 3 additions & 5 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import dataclasses
import functools
import operator
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import ClassVar
Expand Down Expand Up @@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField:
xp = first.__class__.array_ns
op = getattr(xp, array_builtin_name)

domain_intersection = functools.reduce(
operator.and_,
[f.domain for f in fields if common.is_field(f)],
common.Domain(dims=tuple(), ranges=tuple()),
domain_intersection = embedded_common.intersect_domains(
*[f.domain for f in fields if common.is_field(f)]
)

transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = []
for f in fields:
if common.is_field(f):
Expand Down
168 changes: 168 additions & 0 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


_P = ParamSpec("_P")
_R = TypeVar("_R")


@dataclasses.dataclass(frozen=True)
class EmbeddedOperator(Generic[_R, _P]):
fun: Callable[_P, _R]

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
return self.fun(*args, **kwargs)


@dataclasses.dataclass(frozen=True)
class ScanOperator(EmbeddedOperator[_R, _P]):
forward: bool
init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...]
axis: common.Dimension

def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
domain_intersection = _intersect_scan_args(*args, *kwargs.values())
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])

out_domain = common.Domain(
*[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection]
)
if scan_axis not in out_domain.dims:
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

res = _construct_scan_array(out_domain)(self.init)

def scan_loop(hpos):
acc = self.init
for k in scan_range[1] if self.forward else reversed(scan_range[1]):
pos = (*hpos, (scan_axis, k))
new_args = [_tuple_at(pos, arg) for arg in args]
new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()}
acc = self.fun(acc, *new_args, **new_kwargs)
_tuple_assign_value(pos, res, acc)

if len(non_scan_domain) == 0:
# if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop
scan_loop(())
else:
for hpos in embedded_common.iterate_domain(non_scan_domain):
scan_loop(hpos)

return res


def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any):
if "out" in kwargs:
# called from program or direct field_operator as program
offset_provider = kwargs.pop("offset_provider", None)

new_context_kwargs = {}
if embedded_context.within_context():
# called from program
assert offset_provider is None
else:
# field_operator as program
new_context_kwargs["offset_provider"] = offset_provider

out = kwargs.pop("out")
domain = kwargs.pop("domain", None)

flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,))
assert all(f.domain == flattened_out[0].domain for f in flattened_out)

out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain

new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain)

with embedded_context.new_context(**new_context_kwargs) as ctx:
res = ctx.run(op, *args, **kwargs)
_tuple_assign_field(
out,
res,
domain=out_domain,
)
else:
# called from other field_operator
return op(*args, **kwargs)


def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType:
vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL]
assert len(vertical_dim_filtered) <= 1
return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING


def _tuple_assign_field(
target: tuple[common.MutableField | tuple, ...] | common.MutableField,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: common.Domain,
):
@utils.tree_map
def impl(target: common.MutableField, source: common.Field):
target[domain] = source[domain]

impl(target, source)


def _intersect_scan_args(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
) -> common.Domain:
return embedded_common.intersect_domains(
*[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)]
)


def _construct_scan_array(domain: common.Domain):
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.Field:
return constructors.empty(domain, dtype=type(init))

return impl


def _tuple_assign_value(
pos: Sequence[common.NamedIndex],
target: common.MutableField | tuple[common.MutableField | tuple, ...],
source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...],
) -> None:
@utils.tree_map
def impl(target: common.MutableField, source: core_defs.Scalar):
target[pos] = source

impl(target, source)


def _tuple_at(
pos: Sequence[common.NamedIndex],
field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...],
) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]:
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
assert core_defs.is_scalar_type(res)
return res

return impl(field) # type: ignore[return-value]
95 changes: 34 additions & 61 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators, common, embedded as next_embedded
from gt4py.next import allocators as next_allocators, embedded as next_embedded
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
dialect_ast_enums,
field_operator_ast as foast,
Expand Down Expand Up @@ -550,6 +551,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None
operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(default_factory=dict)

@classmethod
Expand Down Expand Up @@ -586,6 +588,7 @@ def from_function(
definition=definition,
backend=backend,
grid_type=grid_type,
operator_attributes=operator_attributes,
)

def __gt_type__(self) -> ts.CallableType:
Expand Down Expand Up @@ -692,68 +695,38 @@ def __call__(
*args,
**kwargs,
) -> None:
# TODO(havogt): Don't select mode based on existence of kwargs,
# because now we cannot provide nice error messages. E.g. set context var
# if we are reaching this from a program call.
if "out" in kwargs:
out = kwargs.pop("out")
if not next_embedded.context.within_context() and self.backend is not None:
# non embedded execution
offset_provider = kwargs.pop("offset_provider", None)
if self.backend is not None:
# "out" and "offset_provider" -> field_operator as program
# When backend is None, we are in embedded execution and for now
# we disable the program generation since it would involve generating
# Python source code from a PAST node.
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
# deduce argument types
arg_types = []
for arg in args:
arg_types.append(type_translation.from_value(arg))
kwarg_types = {}
for name, arg in kwargs.items():
kwarg_types[name] = type_translation.from_value(arg)

return self.as_program(arg_types, kwarg_types)(
*args, out, offset_provider=offset_provider, **kwargs
)
else:
# "out" -> field_operator called from program in embedded execution or
# field_operator called directly from Python in embedded execution
domain = kwargs.pop("domain", None)
if not next_embedded.context.within_context():
# field_operator from Python in embedded execution
with next_embedded.context.new_context(offset_provider=offset_provider) as ctx:
res = ctx.run(self.definition, *args, **kwargs)
else:
# field_operator from program in embedded execution (offset_provicer is already set)
assert (
offset_provider is None
or next_embedded.context.offset_provider.get() is offset_provider
)
res = self.definition(*args, **kwargs)
_tuple_assign_field(
out, res, domain=None if domain is None else common.domain(domain)
)
return
out = kwargs.pop("out")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
# deduce argument types
arg_types = []
for arg in args:
arg_types.append(type_translation.from_value(arg))
kwarg_types = {}
for name, arg in kwargs.items():
kwarg_types[name] = type_translation.from_value(arg)

return self.as_program(arg_types, kwarg_types)(
*args, out, offset_provider=offset_provider, **kwargs
)
else:
# field_operator called from other field_operator in embedded execution
assert self.backend is None
return self.definition(*args, **kwargs)


def _tuple_assign_field(
target: tuple[common.Field | tuple, ...] | common.Field,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: Optional[common.Domain],
):
if isinstance(target, tuple):
if not isinstance(source, tuple):
raise RuntimeError(f"Cannot assign {source} to {target}.")
for t, s in zip(target, source):
_tuple_assign_field(t, s, domain)
else:
domain = domain or target.domain
target[domain] = source[domain]
if self.operator_attributes is not None and any(
has_scan_op_attribute := [
attribute in self.operator_attributes
for attribute in ["init", "axis", "forward"]
]
):
assert all(has_scan_op_attribute)
forward = self.operator_attributes["forward"]
init = self.operator_attributes["init"]
axis = self.operator_attributes["axis"]
op = embedded_operators.ScanOperator(self.definition, forward, init, axis)
else:
op = embedded_operators.EmbeddedOperator(self.definition)
return embedded_operators.field_operator_call(op, args, kwargs)


@typing.overload
Expand Down
22 changes: 22 additions & 0 deletions src/gt4py/next/field_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2023, ETH Zurich
# All rights reserved.
#
# This file is part of the GT4Py project and the GridTools framework.
# GT4Py is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the
# Free Software Foundation, either version 3 of the License, or any later
# version. See the LICENSE.txt file at the top-level directory of this
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

from gt4py.next import common, utils


@utils.tree_map
def asnumpy(field: common.Field | np.ndarray) -> np.ndarray:
return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition
Loading

0 comments on commit a14ad09

Please sign in to comment.