Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grid manager with backend #609

Merged
merged 8 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def _construct_minimal_decomposition_info(grid: icon.IconGrid):
return decomposition_info

if not grid_functionality[experiment].get(name):
on_gpu = helpers.is_gpu(backend)
gm = grid_utils.get_icon_grid_from_gridfile(experiment, on_gpu)
gm()
gm = grid_utils.get_icon_grid_from_gridfile(experiment, backend)
grid = gm.grid
decomposition_info = _construct_minimal_decomposition_info(grid)
geometry_ = geometry.GridGeometry(
Expand Down
148 changes: 87 additions & 61 deletions model/common/src/icon4py/model/common/grid/grid_manager.py

Large diffs are not rendered by default.

19 changes: 10 additions & 9 deletions model/common/src/icon4py/model/common/test_utils/grid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import functools

import gt4py.next.backend as gtx_backend
import pytest

import icon4py.model.common.grid.grid_manager as gm
Expand All @@ -24,23 +24,24 @@
MCH_CH_R04B09_LEVELS = 65


@functools.cache
def get_icon_grid_from_gridfile(experiment: str, on_gpu: bool = False) -> gm.GridManager:
def get_icon_grid_from_gridfile(
experiment: str, backend: gtx_backend.Backend = None
) -> gm.GridManager:
if experiment == dt_utils.GLOBAL_EXPERIMENT:
return _download_and_load_from_gridfile(
dt_utils.R02B04_GLOBAL,
GLOBAL_GRIDFILE,
num_levels=GLOBAL_NUM_LEVELS,
on_gpu=on_gpu,
limited_area=False,
backend=backend,
)
elif experiment == dt_utils.REGIONAL_EXPERIMENT:
return _download_and_load_from_gridfile(
dt_utils.REGIONAL_EXPERIMENT,
REGIONAL_GRIDFILE,
num_levels=MCH_CH_R04B09_LEVELS,
on_gpu=on_gpu,
limited_area=True,
backend=backend,
)
else:
raise ValueError(f"Unknown experiment: {experiment}")
Expand All @@ -58,23 +59,23 @@ def download_grid_file(file_path: str, filename: str):


def load_grid_from_file(
grid_file: str, num_levels: int, on_gpu: bool, limited_area: bool
grid_file: str, num_levels: int, backend: gtx_backend.Backend, limited_area: bool
) -> gm.GridManager:
manager = gm.GridManager(
gm.ToZeroBasedIndexTransformation(),
str(grid_file),
v_grid.VerticalGridConfig(num_levels=num_levels),
)
manager(on_gpu=on_gpu, limited_area=limited_area)
manager(backend=backend, limited_area=limited_area)
return manager


def _download_and_load_from_gridfile(
file_path: str, filename: str, num_levels: int, on_gpu: bool, limited_area: bool
file_path: str, filename: str, num_levels: int, backend: gtx_backend.Backend, limited_area: bool
) -> gm.GridManager:
grid_file = download_grid_file(file_path, filename)

gm = load_grid_from_file(grid_file, num_levels, on_gpu, limited_area)
gm = load_grid_from_file(grid_file, num_levels, backend, limited_area)
return gm


Expand Down
4 changes: 0 additions & 4 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def is_embedded(backend) -> bool:
return backend is None


def is_gpu(backend) -> bool:
return "gpu" in backend.name if backend else False


def is_roundtrip(backend) -> bool:
return backend.name == "roundtrip" if backend else False

Expand Down
22 changes: 12 additions & 10 deletions model/common/src/icon4py/model/common/test_utils/pytest_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)


DEFAULT_BACKEND = "roundtrip"

backends = {
"embedded": None,
"roundtrip": itir_python,
Expand Down Expand Up @@ -96,7 +98,7 @@ def pytest_addoption(parser):
parser.addoption(
"--backend",
action="store",
default="roundtrip",
default=DEFAULT_BACKEND,
help="GT4Py backend to use when executing stencils. Defaults to roundtrip backend, other options include gtfn_cpu, gtfn_gpu, and embedded",
)
except ValueError:
Expand Down Expand Up @@ -140,19 +142,15 @@ def pytest_runtest_setup(item):


def pytest_generate_tests(metafunc):
on_gpu = False
selected_backend = backends[DEFAULT_BACKEND]

# parametrise backend
if "backend" in metafunc.fixturenames:
backend_option = metafunc.config.getoption("backend")
check_backend_validity(backend_option)

if backend_option in gpu_backends:
on_gpu = True

metafunc.parametrize(
"backend", [backends[backend_option]], ids=[f"backend={backend_option}"]
)
selected_backend = backends[backend_option]
metafunc.parametrize("backend", [selected_backend], ids=[f"backend={backend_option}"])

# parametrise grid
if "grid" in metafunc.fixturenames:
Expand All @@ -168,13 +166,17 @@ def pytest_generate_tests(metafunc):
get_icon_grid_from_gridfile,
)

grid_instance = get_icon_grid_from_gridfile(REGIONAL_EXPERIMENT, on_gpu).grid
grid_instance = get_icon_grid_from_gridfile(
REGIONAL_EXPERIMENT, backend=selected_backend
).grid
elif selected_grid_type == "icon_grid_global":
from icon4py.model.common.test_utils.grid_utils import (
get_icon_grid_from_gridfile,
)

grid_instance = get_icon_grid_from_gridfile(GLOBAL_EXPERIMENT, on_gpu).grid
grid_instance = get_icon_grid_from_gridfile(
GLOBAL_EXPERIMENT, backend=selected_backend
).grid
else:
raise ValueError(f"Unknown grid type: {selected_grid_type}")
metafunc.parametrize("grid", [grid_instance], ids=[f"grid={selected_grid_type}"])
Expand Down
4 changes: 2 additions & 2 deletions model/common/tests/grid_tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def construct_decomposition_info(grid: icon.IconGrid) -> definitions.Decompositi
return decomposition_info

def construct_grid_geometry(grid_file: str):
gm = utils.run_grid_manager(grid_file)
gm = utils.run_grid_manager(grid_file, backend=backend)
grid = gm.grid
decomposition_info = construct_decomposition_info(grid)
geometry_source = geometry.GridGeometry(
Expand Down Expand Up @@ -357,7 +357,7 @@ def test_sparse_fields_creator():
],
)
def test_create_auxiliary_orientation_coordinates(backend, grid_savepoint, grid_file):
gm = utils.run_grid_manager(grid_file)
gm = utils.run_grid_manager(grid_file, backend=backend)
grid = gm.grid
coordinates = gm.coordinates

Expand Down
8 changes: 4 additions & 4 deletions model/common/tests/grid_tests/test_grid_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def test_grid_file_index_fields(global_grid_file, caplog, icon_grid):
)
def test_grid_manager_eval_v2e(caplog, grid_savepoint, grid_file):
caplog.set_level(logging.DEBUG)
manager = run_grid_manager(grid_file, zero_base)
manager = run_grid_manager(grid_file, transformation=zero_base)
grid = manager.grid
seralized_v2e = grid_savepoint.v2e()
# there are vertices at the boundary of a local domain or at a pentagon point that have less than
Expand All @@ -160,7 +160,7 @@ def test_grid_manager_eval_v2e(caplog, grid_savepoint, grid_file):
)
@pytest.mark.parametrize("dim", [dims.CellDim, dims.EdgeDim, dims.VertexDim])
def test_grid_manager_refin_ctrl(grid_savepoint, grid_file, experiment, dim):
manager = run_grid_manager(grid_file, zero_base)
manager = run_grid_manager(grid_file, transformation=zero_base)
refin_ctrl = manager.refinement
refin_ctrl_serialized = grid_savepoint.refin_ctrl(dim)
assert np.all(
Expand Down Expand Up @@ -450,7 +450,7 @@ def test_gridmanager_given_file_not_found_then_abort():
manager = gm.GridManager(
gm.NoTransformation(), fname, v_grid.VerticalGridConfig(num_levels=80)
)
manager()
manager(None)
assert error.value == 1


Expand Down Expand Up @@ -506,7 +506,7 @@ def test_grid_manager_eval_c2e2c2e(caplog, grid_savepoint, grid_file):
def test_grid_manager_start_end_index(caplog, grid_file, experiment, dim, icon_grid):
caplog.set_level(logging.INFO)
serialized_grid = icon_grid
manager = run_grid_manager(grid_file, zero_base)
manager = run_grid_manager(grid_file, transformation=zero_base)
grid = manager.grid

for domain in utils.global_grid_domains(dim):
Expand Down
2 changes: 1 addition & 1 deletion model/common/tests/grid_tests/test_icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def grid_from_file() -> icon.IconGrid:
manager = gm.GridManager(
gm.ToZeroBasedIndexTransformation(), str(file_name), v_grid.VerticalGridConfig(1)
)
manager()
manager(backend=None)
return manager.grid


Expand Down
14 changes: 10 additions & 4 deletions model/common/tests/grid_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import functools
from pathlib import Path

import gt4py.next as gtx
import gt4py.next.backend as gtx_backend

from icon4py.model.common import dimension as dims
from icon4py.model.common.grid import grid_manager as gm, horizontal as h_grid, vertical as v_grid
from icon4py.model.common.test_utils import datatest_utils as dt_utils
Expand Down Expand Up @@ -97,15 +99,19 @@ def valid_boundary_zones_for_dim(dim: dims.Dimension):
yield from _domain(dim, zones)


@functools.cache
def run_grid_manager(experiment_name: str, num_levels=65, transformation=None) -> gm.GridManager:
def run_grid_manager(
experiment_name: str,
backend: gtx_backend.Backend = gtx.gtfn_cpu,
num_levels=65,
nfarabullini marked this conversation as resolved.
Show resolved Hide resolved
transformation=None,
) -> gm.GridManager:
if transformation is None:
transformation = gm.ToZeroBasedIndexTransformation()
file_name = resolve_file_from_gridfile_name(experiment_name)
with gm.GridManager(
transformation, file_name, v_grid.VerticalGridConfig(num_levels)
) as grid_manager:
grid_manager(limited_area=is_regional(experiment_name))
grid_manager(backend, limited_area=is_regional(experiment_name))
return grid_manager


Expand Down
14 changes: 5 additions & 9 deletions model/common/tests/io_tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
from icon4py.model.common.test_utils import datatest_utils, grid_utils, helpers


# fixing backend to fieldview embedded here.
halungge marked this conversation as resolved.
Show resolved Hide resolved
backend = None
UNLIMITED = None
simple_grid = simple.SimpleGrid()

grid_file = datatest_utils.GRIDS_PATH.joinpath(
datatest_utils.R02B04_GLOBAL, grid_utils.GLOBAL_GRIDFILE
)
global_grid = grid_utils.get_icon_grid_from_gridfile(
datatest_utils.GLOBAL_EXPERIMENT, on_gpu=False
).grid
global_grid = grid_utils.get_icon_grid_from_gridfile(datatest_utils.GLOBAL_EXPERIMENT, backend).grid


def model_state(grid: base.BaseGrid) -> dict[str, xr.DataArray]:
Expand Down Expand Up @@ -177,9 +177,7 @@ def test_io_monitor_write_ugrid_file(test_path):
)
def test_io_monitor_write_and_read_ugrid_dataset(test_path, variables):
path_name = test_path.absolute().as_posix() + "/output"
grid = grid_utils.get_icon_grid_from_gridfile(
datatest_utils.GLOBAL_EXPERIMENT, on_gpu=False
).grid
grid = grid_utils.get_icon_grid_from_gridfile(datatest_utils.GLOBAL_EXPERIMENT, backend).grid
vertical_config = v_grid.VerticalGridConfig(num_levels=grid.num_levels)
vertical_params = v_grid.VerticalGrid(
config=vertical_config,
Expand Down Expand Up @@ -229,9 +227,7 @@ def test_io_monitor_write_and_read_ugrid_dataset(test_path, variables):


def test_fieldgroup_monitor_write_dataset_file_roll(test_path):
grid = grid_utils.get_icon_grid_from_gridfile(
datatest_utils.GLOBAL_EXPERIMENT, on_gpu=False
).grid
grid = grid_utils.get_icon_grid_from_gridfile(datatest_utils.GLOBAL_EXPERIMENT, backend).grid
vertical_config = v_grid.VerticalGridConfig(num_levels=grid.num_levels)
vertical_params = v_grid.VerticalGrid(
config=vertical_config,
Expand Down