diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py
index cd7e71588f..74c5bd41bb 100644
--- a/src/gt4py/eve/trees.py
+++ b/src/gt4py/eve/trees.py
@@ -133,7 +133,7 @@ def _pre_walk_items(
yield from _pre_walk_items(child, __key__=key)
-def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _pre_walk_values(node: TreeLike) -> Iterable:
"""Create a pre-order tree traversal iterator of values."""
yield node
for child in iter_children_values(node):
@@ -153,7 +153,7 @@ def _post_walk_items(
yield __key__, node
-def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]:
+def _post_walk_values(node: TreeLike) -> Iterable:
"""Create a post-order tree traversal iterator of values."""
if (iter_children_values := getattr(node, "iter_children_values", None)) is not None:
for child in iter_children_values():
diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py
index d9d3d18213..0033f36cab 100644
--- a/src/gt4py/next/iterator/transforms/global_tmps.py
+++ b/src/gt4py/next/iterator/transforms/global_tmps.py
@@ -22,6 +22,7 @@
from gt4py.eve import Coerced, NodeTranslator
from gt4py.eve.traits import SymbolTableTrait
from gt4py.eve.utils import UIDGenerator
+from gt4py.next import common
from gt4py.next.iterator import ir, type_inference
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift
@@ -437,9 +438,12 @@ def _group_offsets(
return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly
-def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]):
+def update_domains(
+ node: FencilWithTemporaries,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
+):
horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider)
-
closures: list[ir.StencilClosure] = []
domains = dict[str, ir.FunCall]()
for closure in reversed(node.fencil.closures):
@@ -479,16 +483,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An
# cartesian shift
dim = offset_provider[offset_name].value
consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset)
- elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider):
+ elif isinstance(offset_provider[offset_name], common.Connectivity):
# unstructured shift
nbt_provider = offset_provider[offset_name]
old_axis = nbt_provider.origin_axis.value
new_axis = nbt_provider.neighbor_axis.value
- consumed_domain.ranges.pop(old_axis)
- assert new_axis not in consumed_domain.ranges
- consumed_domain.ranges[new_axis] = SymbolicRange(
- im.literal("0", ir.INTEGER_INDEX_BUILTIN),
- im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN),
+
+ assert new_axis not in consumed_domain.ranges or old_axis == new_axis
+
+ if symbolic_sizes is None:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.literal(
+ str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN
+ ),
+ )
+ else:
+ new_range = SymbolicRange(
+ im.literal("0", ir.INTEGER_INDEX_BUILTIN),
+ im.ref(symbolic_sizes[new_axis]),
+ )
+ consumed_domain.ranges = dict(
+ (axis, range_) if axis != old_axis else (new_axis, new_range)
+ for axis, range_ in consumed_domain.ranges.items()
)
else:
raise NotImplementedError
@@ -570,7 +587,11 @@ class CreateGlobalTmps(NodeTranslator):
"""
def visit_FencilDefinition(
- self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any]
+ self,
+ node: ir.FencilDefinition,
+ *,
+ offset_provider: Mapping[str, Any],
+ symbolic_sizes: Optional[dict[str, str]],
) -> FencilWithTemporaries:
# Split closures on lifted function calls and introduce temporaries
res = split_closures(node, offset_provider=offset_provider)
@@ -581,6 +602,6 @@ def visit_FencilDefinition(
# Perform an eta-reduction which should put all calls at the highest level of a closure
res = EtaReduction().visit(res)
# Perform a naive extent analysis to compute domain sizes of closures and temporaries
- res = update_domains(res, offset_provider)
+ res = update_domains(res, offset_provider, symbolic_sizes)
# Use type inference to determine the data type of the temporaries
return collect_tmps_info(res, offset_provider=offset_provider)
diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py
index 2e05391634..08897861c2 100644
--- a/src/gt4py/next/iterator/transforms/pass_manager.py
+++ b/src/gt4py/next/iterator/transforms/pass_manager.py
@@ -13,6 +13,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import enum
+from typing import Optional
from gt4py.next.iterator import ir
from gt4py.next.iterator.transforms import simple_inline_heuristic
@@ -81,6 +82,7 @@ def apply_common_transforms(
common_subexpression_elimination=True,
force_inline_lambda_args=False,
unconditionally_collapse_tuples=False,
+ symbolic_domain_sizes: Optional[dict[str, str]] = None,
):
if lift_mode is None:
lift_mode = LiftMode.FORCE_INLINE
@@ -147,7 +149,9 @@ def apply_common_transforms(
if lift_mode != LiftMode.FORCE_INLINE:
assert offset_provider is not None
- ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider)
+ ir = CreateGlobalTmps().visit(
+ ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes
+ )
ir = InlineLifts().visit(ir)
# If after creating temporaries, the scan is not at the top, we inline.
# The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it.
diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py
index 68627cfd89..d65f67b266 100644
--- a/src/gt4py/next/iterator/type_inference.py
+++ b/src/gt4py/next/iterator/type_inference.py
@@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints):
axis = offset_provider[offset]
if isinstance(axis, gtx.Dimension):
continue # Cartesian shifts don’t change the location type
- elif isinstance(
- axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider)
- ):
+ elif isinstance(axis, Connectivity):
assert (
axis.origin_axis.kind
== axis.neighbor_axis.kind
@@ -964,7 +962,7 @@ def visit_FencilDefinition(
def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None:
for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES):
try:
- child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined]
+ child_node.annex.type = types[id(child_node)]
except KeyError:
if not (
isinstance(child_node, ir.SymRef)
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
deleted file mode 100644
index 4183f52550..0000000000
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# GT4Py - GridTools Framework
-#
-# Copyright (c) 2014-2023, ETH Zurich
-# All rights reserved.
-#
-# This file is part of the GT4Py project and the GridTools framework.
-# GT4Py is free software: you can redistribute it and/or modify it under
-# the terms of the GNU General Public License as published by the
-# Free Software Foundation, either version 3 of the License, or any later
-# version. See the LICENSE.txt file at the top-level directory of this
-# distribution for a copy of the license or check .
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-from typing import Any
-
-import gt4py.next.iterator.ir as itir
-from gt4py.eve import codegen
-from gt4py.eve.exceptions import EveValueError
-from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms
-from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
-from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
-from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
-
-
-def _lower(
- program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any
-):
- offset_provider = kwargs.get("offset_provider")
- assert isinstance(offset_provider, dict)
- if enable_itir_transforms:
- program = apply_common_transforms(
- program,
- lift_mode=kwargs.get("lift_mode"),
- offset_provider=offset_provider,
- unroll_reduce=do_unroll,
- unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
- )
- gtfn_ir = GTFN_lowering.apply(
- program,
- offset_provider=offset_provider,
- column_axis=kwargs.get("column_axis"),
- )
- return gtfn_ir
-
-
-def generate(
- program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any
-) -> str:
- if kwargs.get("imperative", False):
- try:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=False,
- **kwargs,
- )
- except EveValueError:
- # if we don't unroll, there may be lifts left in the itir which can't be lowered to
- # gtfn. In this case, just retry with unrolled reductions.
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs)
- generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs)
- else:
- gtfn_ir = _lower(
- program=program,
- enable_itir_transforms=enable_itir_transforms,
- do_unroll=True,
- **kwargs,
- )
- generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs)
- return codegen.format_source("cpp", generated_code, style="LLVM")
diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
index 4abdaa6eea..718fef72af 100644
--- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
+++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
@@ -15,21 +15,24 @@
from __future__ import annotations
import dataclasses
+import functools
import warnings
from typing import Any, Final, Optional
import numpy as np
from gt4py._core import definitions as core_defs
-from gt4py.eve import trees, utils
+from gt4py.eve import codegen, trees, utils
from gt4py.next import common
from gt4py.next.common import Connectivity, Dimension
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
-from gt4py.next.iterator.transforms import LiftMode
+from gt4py.next.iterator.transforms import LiftMode, pass_manager
from gt4py.next.otf import languages, stages, step_types, workflow
from gt4py.next.otf.binding import cpp_interface, interface
-from gt4py.next.program_processors.codegens.gtfn import gtfn_backend
+from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen
+from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering
+from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -54,6 +57,7 @@ class GTFNTranslationStep(
use_imperative_backend: bool = False
lift_mode: Optional[LiftMode] = None
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
+ symbolic_domain_sizes: Optional[dict[str, str]] = None
def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
@@ -171,6 +175,70 @@ def _process_connectivity_args(
return parameters, arg_exprs
+ def _preprocess_program(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> itir.FencilDefinition:
+ # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
+ # to the interface of all (or at least all of concern) backends, but instead should be
+ # configured in the backend itself (like it is here), until then we respect the argument
+ # here and warn the user if it differs from the one configured.
+ lift_mode = runtime_lift_mode or self.lift_mode
+ if lift_mode != self.lift_mode:
+ warnings.warn(
+ f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
+ f"overriden to be {str(runtime_lift_mode)} at runtime."
+ )
+
+ if not self.enable_itir_transforms:
+ return program
+
+ apply_common_transforms = functools.partial(
+ pass_manager.apply_common_transforms,
+ lift_mode=lift_mode,
+ offset_provider=offset_provider,
+ # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements
+ unconditionally_collapse_tuples=True,
+ symbolic_domain_sizes=self.symbolic_domain_sizes,
+ )
+
+ new_program = apply_common_transforms(
+ program, unroll_reduce=not self.use_imperative_backend
+ )
+
+ if self.use_imperative_backend and any(
+ node.id == "neighbors"
+ for node in new_program.pre_walk_values().if_isinstance(itir.SymRef)
+ ):
+ # if we don't unroll, there may be lifts left in the itir which can't be lowered to
+ # gtfn. In this case, just retry with unrolled reductions.
+ new_program = apply_common_transforms(program, unroll_reduce=True)
+
+ return new_program
+
+ def generate_stencil_source(
+ self,
+ program: itir.FencilDefinition,
+ offset_provider: dict[str, Connectivity | Dimension],
+ column_axis: Optional[common.Dimension],
+ runtime_lift_mode: Optional[LiftMode] = None,
+ ) -> str:
+ new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode)
+ gtfn_ir = GTFN_lowering.apply(
+ new_program,
+ offset_provider=offset_provider,
+ column_axis=column_axis,
+ )
+
+ if self.use_imperative_backend:
+ gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir)
+ generated_code = GTFNIMCodegen.apply(gtfn_im_ir)
+ else:
+ generated_code = GTFNCodegen.apply(gtfn_ir)
+ return codegen.format_source("cpp", generated_code, style="LLVM")
+
def __call__(
self,
inp: stages.ProgramCall,
@@ -190,18 +258,6 @@ def __call__(
inp.kwargs["offset_provider"]
)
- # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
- # to the interface of all (or at least all of concern) backends, but instead should be
- # configured in the backend itself (like it is here), until then we respect the argument
- # here and warn the user if it differs from the one configured.
- runtime_lift_mode = inp.kwargs.pop("lift_mode", None)
- lift_mode = runtime_lift_mode or self.lift_mode
- if runtime_lift_mode != self.lift_mode:
- warnings.warn(
- f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but "
- "overriden to be {str(runtime_lift_mode)} at runtime."
- )
-
# combine into a format that is aligned with what the backend expects
parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters
backend_arg = self._backend_type()
@@ -213,12 +269,11 @@ def __call__(
f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});"
)
decl_src = cpp_interface.render_function_declaration(function, body=decl_body)
- stencil_src = gtfn_backend.generate(
+ stencil_src = self.generate_stencil_source(
program,
- enable_itir_transforms=self.enable_itir_transforms,
- lift_mode=lift_mode,
- imperative=self.use_imperative_backend,
- **inp.kwargs,
+ inp.kwargs["offset_provider"],
+ inp.kwargs.get("column_axis", None),
+ inp.kwargs.get("lift_mode", None),
)
source_code = interface.format_source(
self._language_settings(),
diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py
index f9fa154641..27dec77ed1 100644
--- a/src/gt4py/next/program_processors/formatters/gtfn.py
+++ b/src/gt4py/next/program_processors/formatters/gtfn.py
@@ -15,10 +15,19 @@
from typing import Any
from gt4py.next.iterator import ir as itir
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep
from gt4py.next.program_processors.processor_interface import program_formatter
+from gt4py.next.program_processors.runners.gtfn import gtfn_executor
@program_formatter
def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
- return generate(program, **kwargs)
+ # TODO(tehrengruber): This is a little ugly. Revisit.
+ gtfn_translation = gtfn_executor.otf_workflow.translation
+ assert isinstance(gtfn_translation, GTFNTranslationStep)
+ return gtfn_translation.generate_stencil_source(
+ program,
+ offset_provider=kwargs.get("offset_provider", None),
+ column_axis=kwargs.get("column_axis", None),
+ runtime_lift_mode=kwargs.get("lift_mode", None),
+ )
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
new file mode 100644
index 0000000000..da0945fe96
--- /dev/null
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py
@@ -0,0 +1,113 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+from numpy import int32, int64
+
+from gt4py import next as gtx
+from gt4py.next import common
+from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
+from gt4py.next.program_processors import otf_compile_executor
+from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries
+from tests.next_tests.integration_tests.cases import Case
+from tests.next_tests.toy_connectivity import Cell, Edge
+
+from next_tests.integration_tests import cases
+from next_tests.integration_tests.cases import E2V, KDim, Vertex, cartesian_case, unstructured_case
+from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
+ reduction_setup,
+)
+
+
+@pytest.fixture
+def run_gtfn_with_temporaries_and_symbolic_sizes():
+ return otf_compile_executor.OTFBackend(
+ executor=otf_compile_executor.OTFCompileExecutor(
+ name="run_gtfn_with_temporaries_and_sizes",
+ otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
+ translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(
+ symbolic_domain_sizes={
+ "Cell": "num_cells",
+ "Edge": "num_edges",
+ "Vertex": "num_vertices",
+ },
+ ),
+ ),
+ ),
+ allocator=run_gtfn_with_temporaries.allocator,
+ )
+
+
+@pytest.fixture
+def testee():
+ @gtx.field_operator
+ def testee_op(a: cases.VField) -> cases.EField:
+ amul = a * 2
+ return amul(E2V[0]) + amul(E2V[1])
+
+ @gtx.program
+ def prog(
+ a: cases.VField,
+ out: cases.EField,
+ num_vertices: int32,
+ num_edges: int64,
+ num_cells: int32,
+ ):
+ testee_op(a, out=out)
+
+ return prog
+
+
+def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup):
+ unstructured_case = Case(
+ run_gtfn_with_temporaries_and_symbolic_sizes,
+ offset_provider=reduction_setup.offset_provider,
+ default_sizes={
+ Vertex: reduction_setup.num_vertices,
+ Edge: reduction_setup.num_edges,
+ Cell: reduction_setup.num_cells,
+ KDim: reduction_setup.k_levels,
+ },
+ grid_type=common.GridType.UNSTRUCTURED,
+ )
+
+ a = cases.allocate(unstructured_case, testee, "a")()
+ out = cases.allocate(unstructured_case, testee, "out")()
+
+ first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1])
+ ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs]
+
+ cases.verify(
+ unstructured_case,
+ testee,
+ a,
+ out,
+ reduction_setup.num_vertices,
+ reduction_setup.num_edges,
+ reduction_setup.num_cells,
+ inout=out,
+ ref=ref,
+ )
+
+
+def test_temporary_symbols(testee, reduction_setup):
+ itir_with_tmp = apply_common_transforms(
+ testee.itir,
+ lift_mode=LiftMode.FORCE_TEMPORARIES,
+ offset_provider=reduction_setup.offset_provider,
+ )
+
+ params = ["num_vertices", "num_edges", "num_cells"]
+ for param in params:
+ assert any([param == str(p) for p in itir_with_tmp.fencil.params])
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
index e851e7b130..5af4605988 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
@fundef
@@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp):
output_file = sys.argv[1]
prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False)
- generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
index 33c7d5baa7..3e8b88ac66 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py
@@ -18,7 +18,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out):
output_file = sys.argv[1]
prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False)
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
index f7472d4ac3..fdc57449ee 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py
@@ -18,7 +18,7 @@
import gt4py.next as gtx
from gt4py.next import Field, field_operator, program
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -47,7 +47,9 @@ def copy_program(
output_file = sys.argv[1]
prog = copy_program.itir
- generated_code = generate(prog, offset_provider={})
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider={}, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
index 1dfd74baca..abc3755dca 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.builtins import *
from gt4py.next.iterator.runtime import closure, fundef, offset
from gt4py.next.iterator.tracing import trace_fencil_definition
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative
E2V = offset("E2V")
@@ -92,13 +92,20 @@ def mapped_index(_, __) -> int:
output_file = sys.argv[1]
imperative = sys.argv[2].lower() == "true"
+ if imperative:
+ backend = run_gtfn_imperative
+ else:
+ backend = run_gtfn
+
# prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils
prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False)
offset_provider = {
"V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True),
"E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False),
}
- generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative)
+ generated_code = backend.executor.otf_workflow.translation.generate_stencil_source(
+ prog, offset_provider=offset_provider, column_axis=None
+ )
with open(output_file, "w+") as output:
output.write(generated_code)
diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
index 578a19faab..9755774fd0 100644
--- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
+++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py
@@ -19,7 +19,7 @@
from gt4py.next.iterator.runtime import closure, fundef
from gt4py.next.iterator.tracing import trace_fencil_definition
from gt4py.next.iterator.transforms import LiftMode
-from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate
+from gt4py.next.program_processors.runners.gtfn import run_gtfn
IDim = gtx.Dimension("IDim")
@@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x):
prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False)
offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")}
- generated_code = generate(
+ generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source(
prog,
offset_provider=offset_provider,
- lift_mode=LiftMode.SIMPLE_HEURISTIC,
+ runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC,
column_axis=KDim,
)
diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
index 86c3c98c62..5c2802f90c 100644
--- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
+++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py
@@ -323,7 +323,7 @@ def test_update_cartesian_domains():
for a, s in (("JDim", "j"), ("KDim", "k"))
],
)
- actual = update_domains(testee, {"I": gtx.Dimension("IDim")})
+ actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None)
assert actual == expected