Skip to content

Commit

Permalink
move to StencilTest
Browse files Browse the repository at this point in the history
  • Loading branch information
jcanton committed Nov 20, 2024
1 parent 45c16e0 commit 7e7f174
Showing 1 changed file with 56 additions and 54 deletions.
110 changes: 56 additions & 54 deletions model/common/tests/math_tests/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,71 +6,73 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

from icon4py.model.common import dimension as dims
from icon4py.model.common.math.stencils.compute_nabla2_on_cell import compute_nabla2_on_cell
from icon4py.model.common.math.stencils.compute_nabla2_on_cell_k import compute_nabla2_on_cell_k
from icon4py.model.common.test_utils import helpers as test_helpers, reference_funcs
from icon4py.model.common.test_utils.helpers import constant_field, zero_field
from icon4py.model.common.test_utils.helpers import StencilTest, constant_field, zero_field


def test_nabla2_on_cell(
grid,
backend,
):
psi_c = constant_field(
grid,
1.0,
dims.CellDim,
)
geofac_n2s = constant_field(grid, 2.0, dims.CellDim, dims.C2E2CODim)
nabla2_psi_c = zero_field(
grid,
dims.CellDim,
)

compute_nabla2_on_cell.with_backend(backend)(
psi_c=psi_c,
geofac_n2s=geofac_n2s,
nabla2_psi_c=nabla2_psi_c,
horizontal_start=0,
horizontal_end=grid.num_cells,
offset_provider={
"C2E2CO": grid.get_offset_provider("C2E2CO"),
},
)
class TestNabla2OnCell(StencilTest):
PROGRAM = compute_nabla2_on_cell
OUTPUTS = ("nabla2_psi_c",)

nabla2_psi_c_np = reference_funcs.nabla2_on_cell_numpy(
grid, psi_c.asnumpy(), geofac_n2s.asnumpy()
)

assert test_helpers.dallclose(nabla2_psi_c.asnumpy(), nabla2_psi_c_np)
@staticmethod
def reference(
grid,
psi_c: np.array,
geofac_n2s: np.array,
**kwargs,
) -> dict:
nabla2_psi_c_np = reference_funcs.nabla2_on_cell_numpy(
grid, psi_c, geofac_n2s
)
return dict(nabla2_psi_c=nabla2_psi_c_np)

@pytest.fixture
def input_data(self, grid):
psi_c = constant_field(grid, 1.0, dims.CellDim)
geofac_n2s = constant_field(grid, 2.0, dims.CellDim, dims.C2E2CODim)
nabla2_psi_c = zero_field(grid, dims.CellDim)
return dict(
psi_c=psi_c,
geofac_n2s=geofac_n2s,
nabla2_psi_c=nabla2_psi_c,
horizontal_start=0,
horizontal_end=grid.num_cells,
)

def test_nabla2_on_cell_k(
grid,
backend,
):
psi_c = constant_field(grid, 1.0, dims.CellDim, dims.KDim)
geofac_n2s = constant_field(grid, 2.0, dims.CellDim, dims.C2E2CODim)
nabla2_psi_c = zero_field(grid, dims.CellDim, dims.KDim)
class TestNabla2OnCellK(StencilTest):
PROGRAM = compute_nabla2_on_cell_k
OUTPUTS = ("nabla2_psi_c",)

compute_nabla2_on_cell_k.with_backend(backend)(
psi_c=psi_c,
geofac_n2s=geofac_n2s,
nabla2_psi_c=nabla2_psi_c,
horizontal_start=0,
horizontal_end=grid.num_cells,
vertical_start=0,
vertical_end=grid.num_levels,
offset_provider={
"C2E2CO": grid.get_offset_provider("C2E2CO"),
},
)
@staticmethod
def reference(
grid,
psi_c: np.array,
geofac_n2s: np.array,
**kwargs,
) -> dict:
nabla2_psi_c_np = reference_funcs.nabla2_on_cell_k_numpy(
grid, psi_c, geofac_n2s
)
return dict(nabla2_psi_c=nabla2_psi_c_np)

nabla2_psi_c_np = reference_funcs.nabla2_on_cell_k_numpy(
grid, psi_c.asnumpy(), geofac_n2s.asnumpy()
)
@pytest.fixture
def input_data(self, grid):
psi_c = constant_field(grid, 1.0, dims.CellDim, dims.KDim)
geofac_n2s = constant_field(grid, 2.0, dims.CellDim, dims.C2E2CODim)
nabla2_psi_c = zero_field(grid, dims.CellDim, dims.KDim)
return dict(
psi_c=psi_c,
geofac_n2s=geofac_n2s,
nabla2_psi_c=nabla2_psi_c,
horizontal_start=0,
horizontal_end=grid.num_cells,
vertical_start=0,
vertical_end=grid.num_levels,
)

assert test_helpers.dallclose(nabla2_psi_c.asnumpy(), nabla2_psi_c_np)

0 comments on commit 7e7f174

Please sign in to comment.