diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index cd7e71588f..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -133,7 +133,7 @@ def _pre_walk_items( yield from _pre_walk_items(child, __key__=key) -def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _pre_walk_values(node: TreeLike) -> Iterable: """Create a pre-order tree traversal iterator of values.""" yield node for child in iter_children_values(node): @@ -153,7 +153,7 @@ def _post_walk_items( yield __key__, node -def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _post_walk_values(node: TreeLike) -> Iterable: """Create a post-order tree traversal iterator of values.""" if (iter_children_values := getattr(node, "iter_children_values", None)) is not None: for child in iter_children_values(): diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d9d3d18213..0033f36cab 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -22,6 +22,7 @@ from gt4py.eve import Coerced, NodeTranslator from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -437,9 +438,12 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains( + node: FencilWithTemporaries, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], +): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): @@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An # cartesian shift dim = offset_provider[offset_name].value consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + + assert new_axis not in consumed_domain.ranges or old_axis == new_axis + + if symbolic_sizes is None: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal( + str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN + ), + ) + else: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(symbolic_sizes[new_axis]), + ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() ) else: raise NotImplementedError @@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator): """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries res = split_closures(node, offset_provider=offset_provider) @@ -581,6 +602,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2e05391634..08897861c2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -81,6 +82,7 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -147,7 +149,9 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes + ) ir = InlineLifts().visit(ir) # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 68627cfd89..d65f67b266 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): axis = offset_provider[offset] if isinstance(axis, gtx.Dimension): continue # Cartesian shifts don’t change the location type - elif isinstance( - axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider) - ): + elif isinstance(axis, Connectivity): assert ( axis.origin_axis.kind == axis.neighbor_axis.kind @@ -964,7 +962,7 @@ def visit_FencilDefinition( def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): try: - child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined] + child_node.annex.type = types[id(child_node)] except KeyError: if not ( isinstance(child_node, ir.SymRef) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py deleted file mode 100644 index 4183f52550..0000000000 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -import gt4py.next.iterator.ir as itir -from gt4py.eve import codegen -from gt4py.eve.exceptions import EveValueError -from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering - - -def _lower( - program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any -): - offset_provider = kwargs.get("offset_provider") - assert isinstance(offset_provider, dict) - if enable_itir_transforms: - program = apply_common_transforms( - program, - lift_mode=kwargs.get("lift_mode"), - offset_provider=offset_provider, - unroll_reduce=do_unroll, - unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements - ) - gtfn_ir = GTFN_lowering.apply( - program, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis"), - ) - return gtfn_ir - - -def generate( - program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any -) -> str: - if kwargs.get("imperative", False): - try: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=False, - **kwargs, - ) - except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to - # gtfn. In this case, just retry with unrolled reductions. - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs) - generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs) - else: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs) - return codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 4abdaa6eea..718fef72af 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,21 +15,24 @@ from __future__ import annotations import dataclasses +import functools import warnings from typing import Any, Final, Optional import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import trees, utils +from gt4py.eve import codegen, trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface -from gt4py.next.program_processors.codegens.gtfn import gtfn_backend +from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering from gt4py.next.type_system import type_specifications as ts, type_translation @@ -54,6 +57,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + symbolic_domain_sizes: Optional[dict[str, str]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -171,6 +175,70 @@ def _process_connectivity_args( return parameters, arg_exprs + def _preprocess_program( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> itir.FencilDefinition: + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + lift_mode = runtime_lift_mode or self.lift_mode + if lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + f"overriden to be {str(runtime_lift_mode)} at runtime." + ) + + if not self.enable_itir_transforms: + return program + + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + lift_mode=lift_mode, + offset_provider=offset_provider, + # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + unconditionally_collapse_tuples=True, + symbolic_domain_sizes=self.symbolic_domain_sizes, + ) + + new_program = apply_common_transforms( + program, unroll_reduce=not self.use_imperative_backend + ) + + if self.use_imperative_backend and any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + + def generate_stencil_source( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> str: + new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + gtfn_ir = GTFN_lowering.apply( + new_program, + offset_provider=offset_provider, + column_axis=column_axis, + ) + + if self.use_imperative_backend: + gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir) + generated_code = GTFNIMCodegen.apply(gtfn_im_ir) + else: + generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") + def __call__( self, inp: stages.ProgramCall, @@ -190,18 +258,6 @@ def __call__( inp.kwargs["offset_provider"] ) - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - runtime_lift_mode = inp.kwargs.pop("lift_mode", None) - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." - ) - # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters backend_arg = self._backend_type() @@ -213,12 +269,11 @@ def __call__( f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});" ) decl_src = cpp_interface.render_function_declaration(function, body=decl_body) - stencil_src = gtfn_backend.generate( + stencil_src = self.generate_stencil_source( program, - enable_itir_transforms=self.enable_itir_transforms, - lift_mode=lift_mode, - imperative=self.use_imperative_backend, - **inp.kwargs, + inp.kwargs["offset_provider"], + inp.kwargs.get("column_axis", None), + inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index f9fa154641..27dec77ed1 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,10 +15,19 @@ from typing import Any from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter +from gt4py.next.program_processors.runners.gtfn import gtfn_executor @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return generate(program, **kwargs) + # TODO(tehrengruber): This is a little ugly. Revisit. + gtfn_translation = gtfn_executor.otf_workflow.translation + assert isinstance(gtfn_translation, GTFNTranslationStep) + return gtfn_translation.generate_stencil_source( + program, + offset_provider=kwargs.get("offset_provider", None), + column_axis=kwargs.get("column_axis", None), + runtime_lift_mode=kwargs.get("lift_mode", None), + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py new file mode 100644 index 0000000000..da0945fe96 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -0,0 +1,113 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from numpy import int32, int64 + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms +from gt4py.next.program_processors import otf_compile_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from tests.next_tests.integration_tests.cases import Case +from tests.next_tests.toy_connectivity import Cell, Edge + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + reduction_setup, +) + + +@pytest.fixture +def run_gtfn_with_temporaries_and_symbolic_sizes(): + return otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, + ) + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.EField: + amul = a * 2 + return amul(E2V[0]) + amul(E2V[1]) + + @gtx.program + def prog( + a: cases.VField, + out: cases.EField, + num_vertices: int32, + num_edges: int64, + num_cells: int32, + ): + testee_op(a, out=out) + + return prog + + +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): + unstructured_case = Case( + run_gtfn_with_temporaries_and_symbolic_sizes, + offset_provider=reduction_setup.offset_provider, + default_sizes={ + Vertex: reduction_setup.num_vertices, + Edge: reduction_setup.num_edges, + Cell: reduction_setup.num_cells, + KDim: reduction_setup.k_levels, + }, + grid_type=common.GridType.UNSTRUCTURED, + ) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, "out")() + + first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] + + cases.verify( + unstructured_case, + testee, + a, + out, + reduction_setup.num_vertices, + reduction_setup.num_edges, + reduction_setup.num_cells, + inout=out, + ref=ref, + ) + + +def test_temporary_symbols(testee, reduction_setup): + itir_with_tmp = apply_common_transforms( + testee.itir, + lift_mode=LiftMode.FORCE_TEMPORARIES, + offset_provider=reduction_setup.offset_provider, + ) + + params = ["num_vertices", "num_edges", "num_cells"] + for param in params: + assert any([param == str(p) for p in itir_with_tmp.fencil.params]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index e851e7b130..5af4605988 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn @fundef @@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): output_file = sys.argv[1] prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 33c7d5baa7..3e8b88ac66 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out): output_file = sys.argv[1] prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py index f7472d4ac3..fdc57449ee 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -47,7 +47,9 @@ def copy_program( output_file = sys.argv[1] prog = copy_program.itir - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index 1dfd74baca..abc3755dca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative E2V = offset("E2V") @@ -92,13 +92,20 @@ def mapped_index(_, __) -> int: output_file = sys.argv[1] imperative = sys.argv[2].lower() == "true" + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), } - generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative) + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 578a19faab..9755774fd0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = generate( + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - lift_mode=LiftMode.SIMPLE_HEURISTIC, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 86c3c98c62..5c2802f90c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -323,7 +323,7 @@ def test_update_cartesian_domains(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) assert actual == expected