diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index b82dea1a2f..a5a63cb56c 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -247,7 +247,7 @@ pprint.pprint(jit_args) ``` ```python -gtx.program_processors.runners.roundtrip.executor(pitir)(*jit_args.args, **jit_args.kwargs) +gtx.program_processors.runners.roundtrip.Roundtrip()(pitir)(*jit_args.args, **jit_args.kwargs) ``` ```python @@ -290,7 +290,7 @@ assert pitir2 == pitir #### Pass The result to the compile workflow and execute ```python -example_compiled = gtx.program_processors.runners.roundtrip.executor(pitir2) +example_compiled = gtx.program_processors.runners.roundtrip.Roundtrip()(pitir2) ``` ```python diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 52fe8d8116..dc2421e1d2 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,6 +34,8 @@ from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, + foast_to_gtir, + foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -560,10 +562,15 @@ def with_grid_type(self, grid_type: GridType) -> FieldOperator: self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) + # TODO(tehrengruber): We can not use transforms from `self.backend` since this can be + # a different backend than the one of the program that calls this field operator. Just use + # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return self._frontend_transforms.foast_to_itir( - toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) - ) + return foast_to_itir.foast_to_itir(self.foast_stage) + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.FunctionDefinition: + return foast_to_gtir.foast_to_gtir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 0844f63286..312ac686a2 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -45,6 +45,11 @@ def __gt_type__(self) -> ts.CallableType: def __gt_itir__(self) -> itir.Expr: return self.foast_to_itir(self.definition) + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.Expr: + # backend should have self.foast_to_itir set to foast_to_gtir + return self.foast_to_itir(self.definition) + @dataclasses.dataclass(frozen=True) class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): diff --git a/src/gt4py/next/ffront/gtcallable.py b/src/gt4py/next/ffront/gtcallable.py index beaebb3a5a..cdfb23910e 100644 --- a/src/gt4py/next/ffront/gtcallable.py +++ b/src/gt4py/next/ffront/gtcallable.py @@ -52,6 +52,16 @@ def __gt_itir__(self) -> itir.FunctionDefinition: """ ... + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + @abc.abstractmethod + def __gt_gtir__(self) -> itir.FunctionDefinition: + """ + Return iterator IR function definition representing the callable. + Used internally by the Program decorator to populate the function + definitions of the iterator IR. + """ + ... + # TODO(tehrengruber): For embedded execution a `__call__` method and for # "truly" embedded execution arguably also a `from_function` method is # required. Since field operators currently have a `__gt_type__` with a diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a20c517cce..14d705576e 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -80,11 +80,18 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra gt_callables = transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).values() + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR # TODO(ricoh): The following calls to .__gt_itir__, which will use whatever - # backend is set for each of these field operators (GTCallables). Instead - # we should use the current toolchain to lower these to ITIR. This will require - # making this step aware of the toolchain it is called by (it can be part of multiple). - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + # backend is set for each of these field operators (GTCallables). Instead + # we should use the current toolchain to lower these to ITIR. This will require + # making this step aware of the toolchain it is called by (it can be part of multiple). + lowered_funcs = [] + for gt_callable in gt_callables: + if to_gtir: + lowered_funcs.append(gt_callable.__gt_gtir__()) + else: + lowered_funcs.append(gt_callable.__gt_itir__()) itir_program = ProgramLowering.apply( inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 58678cfc9c..6f9651a397 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -6,7 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator.transforms.pass_manager import LiftMode, apply_common_transforms +from gt4py.next.iterator.transforms.pass_manager import ( + ITIRTransform, + LiftMode, + apply_common_transforms, + apply_fieldview_transforms, +) -__all__ = ["apply_common_transforms", "LiftMode"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "LiftMode", "ITIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index b3bb7bc6e1..7c35d552dc 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -7,11 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause import enum -from typing import Callable, Optional +from typing import Callable, Optional, Protocol from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms import fencil_to_program, infer_domain, inline_fundefs from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -29,6 +30,12 @@ from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce +class ITIRTransform(Protocol): + def __call__( + self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + ) -> itir.Program: ... + + @enum.unique class LiftMode(enum.Enum): FORCE_INLINE = enum.auto() @@ -65,7 +72,7 @@ def _inline_into_scan(ir, *, max_iter=10): # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Node, + ir: itir.Program | itir.FencilDefinition, *, lift_mode=None, offset_provider=None, @@ -115,10 +122,10 @@ def apply_common_transforms( # other cases we want it anyway. force_inline_trivial_lift_args=True, ) - inlined = ConstantFolding.apply(inlined) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # still a `itir.Program` # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( + inlined = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` inlined, offset_provider=offset_provider, # TODO(tehrengruber): disabled since it increases compile-time too much right now @@ -167,7 +174,7 @@ def apply_common_transforms( # larger than the number of closure outputs as given by the unconditional collapse, we can # only run the unconditional version here instead of in the loop above. if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( + ir = CollapseTuple.apply( # type: ignore[assignment] # still a `itir.Program` ir, ignore_tuple_size=True, offset_provider=offset_provider, @@ -188,7 +195,7 @@ def apply_common_transforms( unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) if unrolled == ir: break - ir = unrolled + ir = unrolled # type: ignore[assignment] # still a `itir.Program` ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) @@ -200,7 +207,7 @@ def apply_common_transforms( ir = ScanEtaReduction().visit(ir) if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program + ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) ir = MergeLet().visit(ir) ir = InlineLambdas.apply( @@ -209,3 +216,17 @@ def apply_common_transforms( assert isinstance(ir, itir.Program) return ir + + +def apply_fieldview_transforms( + ir: itir.Program, *, offset_provider: common.OffsetProvider +) -> itir.Program: + ir = inline_fundefs.InlineFundefs().visit(ir) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + ) + ir = CollapseTuple.apply(ir, offset_provider=offset_provider) # type: ignore[assignment] # type is still `itir.Program` + return ir diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py index c477795c34..7b722a7c1a 100644 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ b/src/gt4py/next/program_processors/formatters/lisp.py @@ -50,7 +50,7 @@ class ToLispLike(TemplatedGenerator): ) @classmethod - def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override] + def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] transformed = apply_common_transforms( root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index ffc33a9f25..f2953eb05f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -17,7 +17,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common, config -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import languages, recipes, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings @@ -46,10 +46,8 @@ def generate_sdfg( offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> dace.SDFG: - # TODO(edopao): Call IR transformations and domain inference, finally lower IR to SDFG - raise NotImplementedError - - return gtir_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) + ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + return gtir_sdfg.build_sdfg_from_gtir(ir=ir, offset_provider=offset_provider) def __call__( self, inp: stages.CompilableProgram diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 93e6d09c5b..57785ceb33 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import functools import importlib.util import pathlib import tempfile @@ -20,7 +21,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_specifications as ts @@ -90,11 +91,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Node, + ir: itir.Program | itir.FencilDefinition, debug: bool, - lift_mode: itir_transforms.LiftMode, use_embedded: bool, offset_provider: dict[str, common.Connectivity | common.Dimension], + transforms: itir_transforms.ITIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -102,7 +103,6 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. - lift_mode: Change the way lifted function calls are evaluated. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -110,15 +110,13 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash((ir, transforms, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) - ir = itir_transforms.apply_common_transforms( - ir, lift_mode=lift_mode, offset_provider=offset_provider - ) + ir = transforms(ir, offset_provider=offset_provider) program = EmbeddedDSL.apply(ir) @@ -187,9 +185,9 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]): debug: Optional[bool] = None - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None + transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -198,8 +196,8 @@ def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: inp.data, offset_provider=inp.args.offset_provider, debug=debug, - lift_mode=self.lift_mode, use_embedded=self.use_embedded, + transforms=self.transforms, ) def decorated_fencil( @@ -224,28 +222,38 @@ def decorated_fencil( return decorated_fencil -executor = Roundtrip() -executor_with_temporaries = Roundtrip(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES) - default = next_backend.Backend( name="roundtrip", - executor=executor, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, lift_mode=itir_transforms.LiftMode.FORCE_INLINE + ) + ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) with_temporaries = next_backend.Backend( name="roundtrip_with_temporaries", - executor=executor_with_temporaries, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, + lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES, + ) + ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) + gtir = next_backend.Backend( name="roundtrip_gtir", - executor=executor, + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() + ), ), )