Skip to content

Commit

Permalink
[dace] Enable origin on field arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Sep 11, 2023
1 parent b64fdab commit e277d69
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

import gt4py.next.iterator.ir as itir
from gt4py.next.iterator.embedded import LocatedField, NeighborTableOffsetProvider
from gt4py.next.iterator.embedded import LocatedFieldImpl, NeighborTableOffsetProvider
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.processor_interface import program_executor
Expand All @@ -29,11 +29,17 @@


def convert_arg(arg: Any):
if isinstance(arg, LocatedField):
if isinstance(arg, LocatedFieldImpl):
sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value)
ndim = len(sorted_dims)
dim_indices = [dim[0] for dim in sorted_dims]
return np.moveaxis(np.asarray(arg), range(ndim), dim_indices)
field = np.moveaxis(np.asarray(arg), range(ndim), dim_indices)
if arg.origin:
sorted_origins = [
-arg.origin[dim] if dim in arg.origin else 0 for idx, dim in sorted_dims
]
field = np.roll(field, tuple(sorted_origins), tuple(dim_indices))
return field
return arg


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def basic_stencils(request):
def test_basic_column_stencils(program_processor, lift_mode, basic_stencils):
program_processor, validate = program_processor
stencil, ref_fun, inp_fun = basic_stencils
if program_processor == run_dace_iterator and inp_fun:
pytest.xfail("Not supported in DaCe backend: origin")

shape = [5, 7]
inp = (
Expand All @@ -95,13 +93,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils):

ref = ref_fun(inp)

if (
program_processor == run_dace_iterator
and stencil.__name__ == "shift_stencil"
and inp.origin
):
pytest.xfail("Not supported in DaCe backend: origin")

run_processor(
stencil[{IDim: range(0, shape[0]), KDim: range(0, shape[1])}],
program_processor,
Expand Down Expand Up @@ -330,8 +321,6 @@ def sum_fencil(out, inp0, inp1, k_size):

def test_different_vertical_sizes_with_origin(program_processor):
program_processor, validate = program_processor
if program_processor == run_dace_iterator:
pytest.xfail("Not supported in DaCe backend: origin")

k_size = 10
inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size))
Expand Down

0 comments on commit e277d69

Please sign in to comment.