diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 718537713e..a497519dd6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -18,7 +18,7 @@ import numpy as np import gt4py.next.iterator.ir as itir -from gt4py.next.iterator.embedded import LocatedField, NeighborTableOffsetProvider +from gt4py.next.iterator.embedded import LocatedFieldImpl, NeighborTableOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor @@ -29,11 +29,17 @@ def convert_arg(arg: Any): - if isinstance(arg, LocatedField): + if isinstance(arg, LocatedFieldImpl): sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) ndim = len(sorted_dims) dim_indices = [dim[0] for dim in sorted_dims] - return np.moveaxis(np.asarray(arg), range(ndim), dim_indices) + field = np.moveaxis(np.asarray(arg), range(ndim), dim_indices) + if arg.origin: + sorted_origins = [ + -arg.origin[dim] if dim in arg.origin else 0 for idx, dim in sorted_dims + ] + field = np.roll(field, tuple(sorted_origins), tuple(dim_indices)) + return field return arg diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 5970b9a2a9..68060c719d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -82,8 +82,6 @@ def basic_stencils(request): def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils - if program_processor == run_dace_iterator and inp_fun: - pytest.xfail("Not supported in DaCe backend: origin") shape = [5, 7] inp = ( @@ -95,13 +93,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): ref = ref_fun(inp) - if ( - program_processor == run_dace_iterator - and stencil.__name__ == "shift_stencil" - and inp.origin - ): - pytest.xfail("Not supported in DaCe backend: origin") - run_processor( stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}], program_processor, @@ -330,8 +321,6 @@ def sum_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes_with_origin(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: origin") k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size))