Skip to content

Commit

Permalink
feat[next]: Pass sizes to temporaries from gt4py program (#1359)
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals authored Jan 18, 2024
1 parent ba36856 commit 49db7ef
Show file tree
Hide file tree
Showing 14 changed files with 264 additions and 128 deletions.
4 changes: 2 additions & 2 deletions src/gt4py/eve/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _pre_walk_items(
yield from _pre_walk_items(child, __key__=key)


def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
def _pre_walk_values(node: TreeLike) -> Iterable:
"""Create a pre-order tree traversal iterator of values."""
yield node
for child in iter_children_values(node):
Expand All @@ -153,7 +153,7 @@ def _post_walk_items(
yield __key__, node


def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
def _post_walk_values(node: TreeLike) -> Iterable:
"""Create a post-order tree traversal iterator of values."""
if (iter_children_values := getattr(node, "iter_children_values", None)) is not None:
for child in iter_children_values():
Expand Down
41 changes: 31 additions & 10 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
Expand Down Expand Up @@ -437,9 +438,12 @@ def _group_offsets(
return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly


def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]):
def update_domains(
node: FencilWithTemporaries,
offset_provider: Mapping[str, Any],
symbolic_sizes: Optional[dict[str, str]],
):
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)

closures: list[ir.StencilClosure] = []
domains = dict[str, ir.FunCall]()
for closure in reversed(node.fencil.closures):
Expand Down Expand Up @@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
# cartesian shift
dim = offset_provider[offset_name].value
consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset)
elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider):
elif isinstance(offset_provider[offset_name], common.Connectivity):
# unstructured shift
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
consumed_domain.ranges.pop(old_axis)
assert new_axis not in consumed_domain.ranges
consumed_domain.ranges[new_axis] = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),

assert new_axis not in consumed_domain.ranges or old_axis == new_axis

if symbolic_sizes is None:
new_range = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.literal(
str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN
),
)
else:
new_range = SymbolicRange(
im.literal("0", ir.INTEGER_INDEX_BUILTIN),
im.ref(symbolic_sizes[new_axis]),
)
consumed_domain.ranges = dict(
(axis, range_) if axis != old_axis else (new_axis, new_range)
for axis, range_ in consumed_domain.ranges.items()
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator):
"""

def visit_FencilDefinition(
self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any]
self,
node: ir.FencilDefinition,
*,
offset_provider: Mapping[str, Any],
symbolic_sizes: Optional[dict[str, str]],
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider)
Expand All @@ -581,6 +602,6 @@ def visit_FencilDefinition(
# Perform an eta-reduction which should put all calls at the highest level of a closure
res = EtaReduction().visit(res)
# Perform a naive extent analysis to compute domain sizes of closures and temporaries
res = update_domains(res, offset_provider)
res = update_domains(res, offset_provider, symbolic_sizes)
# Use type inference to determine the data type of the temporaries
return collect_tmps_info(res, offset_provider=offset_provider)
6 changes: 5 additions & 1 deletion src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import enum
from typing import Optional

from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import simple_inline_heuristic
Expand Down Expand Up @@ -81,6 +82,7 @@ def apply_common_transforms(
common_subexpression_elimination=True,
force_inline_lambda_args=False,
unconditionally_collapse_tuples=False,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
Expand Down Expand Up @@ -147,7 +149,9 @@ def apply_common_transforms(

if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
ir = CreateGlobalTmps().visit(
ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes
)
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
Expand Down
6 changes: 2 additions & 4 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints):
axis = offset_provider[offset]
if isinstance(axis, gtx.Dimension):
continue # Cartesian shifts don’t change the location type
elif isinstance(
axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider)
):
elif isinstance(axis, Connectivity):
assert (
axis.origin_axis.kind
== axis.neighbor_axis.kind
Expand Down Expand Up @@ -964,7 +962,7 @@ def visit_FencilDefinition(
def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None:
for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES):
try:
child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined]
child_node.annex.type = types[id(child_node)]
except KeyError:
if not (
isinstance(child_node, ir.SymRef)
Expand Down
77 changes: 0 additions & 77 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py

This file was deleted.

95 changes: 75 additions & 20 deletions src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@
from __future__ import annotations

import dataclasses
import functools
import warnings
from typing import Any, Final, Optional

import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.eve import trees, utils
from gt4py.eve import codegen, trees, utils
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import LiftMode
from gt4py.next.iterator.transforms import LiftMode, pass_manager
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
from gt4py.next.program_processors.codegens.gtfn import gtfn_backend
from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
from gt4py.next.type_system import type_specifications as ts, type_translation


Expand All @@ -54,6 +57,7 @@ class GTFNTranslationStep(
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
symbolic_domain_sizes: Optional[dict[str, str]] = None

def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
Expand Down Expand Up @@ -171,6 +175,70 @@ def _process_connectivity_args(

return parameters, arg_exprs

def _preprocess_program(
self,
program: itir.FencilDefinition,
offset_provider: dict[str, Connectivity | Dimension],
runtime_lift_mode: Optional[LiftMode] = None,
) -> itir.FencilDefinition:
# TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
# to the interface of all (or at least all of concern) backends, but instead should be
# configured in the backend itself (like it is here), until then we respect the argument
# here and warn the user if it differs from the one configured.
lift_mode = runtime_lift_mode or self.lift_mode
if lift_mode != self.lift_mode:
warnings.warn(
f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
f"overriden to be {str(runtime_lift_mode)} at runtime."
)

if not self.enable_itir_transforms:
return program

apply_common_transforms = functools.partial(
pass_manager.apply_common_transforms,
lift_mode=lift_mode,
offset_provider=offset_provider,
# sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
unconditionally_collapse_tuples=True,
symbolic_domain_sizes=self.symbolic_domain_sizes,
)

new_program = apply_common_transforms(
program, unroll_reduce=not self.use_imperative_backend
)

if self.use_imperative_backend and any(
node.id == "neighbors"
for node in new_program.pre_walk_values().if_isinstance(itir.SymRef)
):
# if we don't unroll, there may be lifts left in the itir which can't be lowered to
# gtfn. In this case, just retry with unrolled reductions.
new_program = apply_common_transforms(program, unroll_reduce=True)

return new_program

def generate_stencil_source(
self,
program: itir.FencilDefinition,
offset_provider: dict[str, Connectivity | Dimension],
column_axis: Optional[common.Dimension],
runtime_lift_mode: Optional[LiftMode] = None,
) -> str:
new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode)
gtfn_ir = GTFN_lowering.apply(
new_program,
offset_provider=offset_provider,
column_axis=column_axis,
)

if self.use_imperative_backend:
gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir)
generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
else:
generated_code = GTFNCodegen.apply(gtfn_ir)
return codegen.format_source("cpp", generated_code, style="LLVM")

def __call__(
self,
inp: stages.ProgramCall,
Expand All @@ -190,18 +258,6 @@ def __call__(
inp.kwargs["offset_provider"]
)

# TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
# to the interface of all (or at least all of concern) backends, but instead should be
# configured in the backend itself (like it is here), until then we respect the argument
# here and warn the user if it differs from the one configured.
runtime_lift_mode = inp.kwargs.pop("lift_mode", None)
lift_mode = runtime_lift_mode or self.lift_mode
if runtime_lift_mode != self.lift_mode:
warnings.warn(
f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
"overriden to be {str(runtime_lift_mode)} at runtime."
)

# combine into a format that is aligned with what the backend expects
parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters
backend_arg = self._backend_type()
Expand All @@ -213,12 +269,11 @@ def __call__(
f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});"
)
decl_src = cpp_interface.render_function_declaration(function, body=decl_body)
stencil_src = gtfn_backend.generate(
stencil_src = self.generate_stencil_source(
program,
enable_itir_transforms=self.enable_itir_transforms,
lift_mode=lift_mode,
imperative=self.use_imperative_backend,
**inp.kwargs,
inp.kwargs["offset_provider"],
inp.kwargs.get("column_axis", None),
inp.kwargs.get("lift_mode", None),
)
source_code = interface.format_source(
self._language_settings(),
Expand Down
13 changes: 11 additions & 2 deletions src/gt4py/next/program_processors/formatters/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
from typing import Any

from gt4py.next.iterator import ir as itir
from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep
from gt4py.next.program_processors.processor_interface import program_formatter
from gt4py.next.program_processors.runners.gtfn import gtfn_executor


@program_formatter
def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
return generate(program, **kwargs)
# TODO(tehrengruber): This is a little ugly. Revisit.
gtfn_translation = gtfn_executor.otf_workflow.translation
assert isinstance(gtfn_translation, GTFNTranslationStep)
return gtfn_translation.generate_stencil_source(
program,
offset_provider=kwargs.get("offset_provider", None),
column_axis=kwargs.get("column_axis", None),
runtime_lift_mode=kwargs.get("lift_mode", None),
)
Loading

0 comments on commit 49db7ef

Please sign in to comment.