Skip to content

Commit

Permalink
fext[next]: GTIR embedded backend (not active in tests) (#1702)
Browse files Browse the repository at this point in the history
with features from #1648
  • Loading branch information
havogt authored Oct 23, 2024
1 parent b802010 commit c1106fc
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 41 deletions.
4 changes: 2 additions & 2 deletions docs/user/next/advanced/ToolchainWalkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ pprint.pprint(jit_args)
```

```python
gtx.program_processors.runners.roundtrip.executor(pitir)(*jit_args.args, **jit_args.kwargs)
gtx.program_processors.runners.roundtrip.Roundtrip()(pitir)(*jit_args.args, **jit_args.kwargs)
```

```python
Expand Down Expand Up @@ -290,7 +290,7 @@ assert pitir2 == pitir
#### Pass The result to the compile workflow and execute

```python
example_compiled = gtx.program_processors.runners.roundtrip.executor(pitir2)
example_compiled = gtx.program_processors.runners.roundtrip.Roundtrip()(pitir2)
```

```python
Expand Down
13 changes: 10 additions & 3 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
field_operator_ast as foast,
foast_to_gtir,
foast_to_itir,
past_process_args,
signature,
stages as ffront_stages,
Expand Down Expand Up @@ -560,10 +562,15 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator:
self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type)
)

# TODO(tehrengruber): We can not use transforms from `self.backend` since this can be
# a different backend than the one of the program that calls this field operator. Just use
# the hard-coded lowering until this is cleaned up.
def __gt_itir__(self) -> itir.FunctionDefinition:
return self._frontend_transforms.foast_to_itir(
toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty())
)
return foast_to_itir.foast_to_itir(self.foast_stage)

# FIXME[#1582](tehrengruber): remove after refactoring to GTIR
def __gt_gtir__(self) -> itir.FunctionDefinition:
return foast_to_gtir.foast_to_gtir(self.foast_stage)

def __gt_closure_vars__(self) -> dict[str, Any]:
return self.foast_stage.closure_vars
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def __gt_type__(self) -> ts.CallableType:
def __gt_itir__(self) -> itir.Expr:
return self.foast_to_itir(self.definition)

# FIXME[#1582](tehrengruber): remove after refactoring to GTIR
def __gt_gtir__(self) -> itir.Expr:
# backend should have self.foast_to_itir set to foast_to_gtir
return self.foast_to_itir(self.definition)


@dataclasses.dataclass(frozen=True)
class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]):
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/ffront/gtcallable.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def __gt_itir__(self) -> itir.FunctionDefinition:
"""
...

# FIXME[#1582](tehrengruber): remove after refactoring to GTIR
@abc.abstractmethod
def __gt_gtir__(self) -> itir.FunctionDefinition:
"""
Return iterator IR function definition representing the callable.
Used internally by the Program decorator to populate the function
definitions of the iterator IR.
"""
...

# TODO(tehrengruber): For embedded execution a `__call__` method and for
# "truly" embedded execution arguably also a `from_function` method is
# required. Since field operators currently have a `__gt_type__` with a
Expand Down
15 changes: 11 additions & 4 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,18 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra
gt_callables = transform_utils._filter_closure_vars_by_type(
all_closure_vars, gtcallable.GTCallable
).values()

# FIXME[#1582](tehrengruber): remove after refactoring to GTIR
# TODO(ricoh): The following calls to .__gt_itir__, which will use whatever
# backend is set for each of these field operators (GTCallables). Instead
# we should use the current toolchain to lower these to ITIR. This will require
# making this step aware of the toolchain it is called by (it can be part of multiple).
lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]
# backend is set for each of these field operators (GTCallables). Instead
# we should use the current toolchain to lower these to ITIR. This will require
# making this step aware of the toolchain it is called by (it can be part of multiple).
lowered_funcs = []
for gt_callable in gt_callables:
if to_gtir:
lowered_funcs.append(gt_callable.__gt_gtir__())
else:
lowered_funcs.append(gt_callable.__gt_itir__())

itir_program = ProgramLowering.apply(
inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir
Expand Down
9 changes: 7 additions & 2 deletions src/gt4py/next/iterator/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.iterator.transforms.pass_manager import LiftMode, apply_common_transforms
from gt4py.next.iterator.transforms.pass_manager import (
ITIRTransform,
LiftMode,
apply_common_transforms,
apply_fieldview_transforms,
)


__all__ = ["apply_common_transforms", "LiftMode"]
__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"]
37 changes: 29 additions & 8 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
# SPDX-License-Identifier: BSD-3-Clause

import enum
from typing import Callable, Optional
from typing import Callable, Optional, Protocol

from gt4py.eve import utils as eve_utils
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs
from gt4py.next.iterator.transforms import fencil_to_program, infer_domain, inline_fundefs
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
from gt4py.next.iterator.transforms.constant_folding import ConstantFolding
Expand All @@ -29,6 +30,12 @@
from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce


class ITIRTransform(Protocol):
def __call__(
self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider
) -> itir.Program: ...


@enum.unique
class LiftMode(enum.Enum):
FORCE_INLINE = enum.auto()
Expand Down Expand Up @@ -65,7 +72,7 @@ def _inline_into_scan(ir, *, max_iter=10):
# TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward
# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient.
def apply_common_transforms(
ir: itir.Node,
ir: itir.Program | itir.FencilDefinition,
*,
lift_mode=None,
offset_provider=None,
Expand Down Expand Up @@ -115,10 +122,10 @@ def apply_common_transforms(
# other cases we want it anyway.
force_inline_trivial_lift_args=True,
)
inlined = ConstantFolding.apply(inlined)
inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # still a `itir.Program`
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(
inlined = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program`
inlined,
offset_provider=offset_provider,
# TODO(tehrengruber): disabled since it increases compile-time too much right now
Expand Down Expand Up @@ -167,7 +174,7 @@ def apply_common_transforms(
# larger than the number of closure outputs as given by the unconditional collapse, we can
# only run the unconditional version here instead of in the loop above.
if unconditionally_collapse_tuples:
ir = CollapseTuple.apply(
ir = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program`
ir,
ignore_tuple_size=True,
offset_provider=offset_provider,
Expand All @@ -188,7 +195,7 @@ def apply_common_transforms(
unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider)
if unrolled == ir:
break
ir = unrolled
ir = unrolled # type: ignore[assignment] # still a `itir.Program`
ir = CollapseListGet().visit(ir)
ir = NormalizeShifts().visit(ir)
ir = _inline_lifts(ir, LiftMode.FORCE_INLINE)
Expand All @@ -200,7 +207,7 @@ def apply_common_transforms(
ir = ScanEtaReduction().visit(ir)

if common_subexpression_elimination:
ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program
ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider)
ir = MergeLet().visit(ir)

ir = InlineLambdas.apply(
Expand All @@ -209,3 +216,17 @@ def apply_common_transforms(

assert isinstance(ir, itir.Program)
return ir


def apply_fieldview_transforms(
ir: itir.Program, *, offset_provider: common.OffsetProvider
) -> itir.Program:
ir = inline_fundefs.InlineFundefs().visit(ir)
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
ir = InlineLambdas.apply(ir, opcount_preserving=True)
ir = infer_domain.infer_program(
ir,
offset_provider=offset_provider,
)
ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program`
return ir
2 changes: 1 addition & 1 deletion src/gt4py/next/program_processors/formatters/lisp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ToLispLike(TemplatedGenerator):
)

@classmethod
def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override]
def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override]
transformed = apply_common_transforms(
root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from gt4py._core import definitions as core_defs
from gt4py.next import common, config
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator import ir as itir, transforms as itir_transforms
from gt4py.next.otf import languages, recipes, stages, step_types, workflow
from gt4py.next.otf.binding import interface
from gt4py.next.otf.languages import LanguageSettings
Expand Down Expand Up @@ -46,10 +46,8 @@ def generate_sdfg(
offset_provider: common.OffsetProvider,
column_axis: Optional[common.Dimension],
) -> dace.SDFG:
# TODO(edopao): Call IR transformations and domain inference, finally lower IR to SDFG
raise NotImplementedError

return gtir_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider)
ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider)
return gtir_sdfg.build_sdfg_from_gtir(ir=ir, offset_provider=offset_provider)

def __call__(
self, inp: stages.CompilableProgram
Expand Down
40 changes: 24 additions & 16 deletions src/gt4py/next/program_processors/runners/roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import dataclasses
import functools
import importlib.util
import pathlib
import tempfile
Expand All @@ -20,7 +21,7 @@
from gt4py.eve import codegen
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
from gt4py.next import allocators as next_allocators, backend as next_backend, common, config
from gt4py.next.ffront import foast_to_gtir, past_to_itir
from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir
from gt4py.next.iterator import ir as itir, transforms as itir_transforms
from gt4py.next.otf import stages, workflow
from gt4py.next.type_system import type_specifications as ts
Expand Down Expand Up @@ -90,35 +91,32 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str:


def fencil_generator(
ir: itir.Node,
ir: itir.Program | itir.FencilDefinition,
debug: bool,
lift_mode: itir_transforms.LiftMode,
use_embedded: bool,
offset_provider: dict[str, common.Connectivity | common.Dimension],
transforms: itir_transforms.ITIRTransform,
) -> stages.CompiledProgram:
"""
Generate a directly executable fencil from an ITIR node.
Arguments:
ir: The iterator IR (ITIR) node.
debug: Keep module source containing fencil implementation.
lift_mode: Change the way lifted function calls are evaluated.
use_embedded: Directly use builtins from embedded backend instead of
generic dispatcher. Gives faster performance and is easier
to debug.
offset_provider: A mapping from offset names to offset providers.
"""
# TODO(tehrengruber): just a temporary solution until we have a proper generic
# caching mechanism
cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items())))
cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items())))
if cache_key in _FENCIL_CACHE:
if debug:
print(f"Using cached fencil for key {cache_key}")
return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key])

ir = itir_transforms.apply_common_transforms(
ir, lift_mode=lift_mode, offset_provider=offset_provider
)
ir = transforms(ir, offset_provider=offset_provider)

program = EmbeddedDSL.apply(ir)

Expand Down Expand Up @@ -187,9 +185,9 @@ def fencil_generator(
@dataclasses.dataclass(frozen=True)
class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]):
debug: Optional[bool] = None
lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE
use_embedded: bool = True
dispatch_backend: Optional[next_backend.Backend] = None
transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms`

def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram:
debug = config.DEBUG if self.debug is None else self.debug
Expand All @@ -198,8 +196,8 @@ def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram:
inp.data,
offset_provider=inp.args.offset_provider,
debug=debug,
lift_mode=self.lift_mode,
use_embedded=self.use_embedded,
transforms=self.transforms,
)

def decorated_fencil(
Expand All @@ -224,28 +222,38 @@ def decorated_fencil(
return decorated_fencil


executor = Roundtrip()
executor_with_temporaries = Roundtrip(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES)

default = next_backend.Backend(
name="roundtrip",
executor=executor,
executor=Roundtrip(
transforms=functools.partial(
itir_transforms.apply_common_transforms, lift_mode=itir_transforms.LiftMode.FORCE_INLINE
)
),
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
transforms=next_backend.DEFAULT_TRANSFORMS,
)
with_temporaries = next_backend.Backend(
name="roundtrip_with_temporaries",
executor=executor_with_temporaries,
executor=Roundtrip(
transforms=functools.partial(
itir_transforms.apply_common_transforms,
lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES,
)
),
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
transforms=next_backend.DEFAULT_TRANSFORMS,
)


gtir = next_backend.Backend(
name="roundtrip_gtir",
executor=executor,
executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later...
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
transforms=next_backend.Transforms(
past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True),
foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True),
field_view_op_to_prog=foast_to_past.operator_to_program_factory(
foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory()
),
),
)

0 comments on commit c1106fc

Please sign in to comment.