Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Py2f parallel #457

Open
wants to merge 126 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 116 commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
6deda18
add simple exchange test in test_mpi_decomposition.py, fix parallel d…
halungge Feb 1, 2024
5533ffa
add calculation of mean_cell_area
halungge Feb 2, 2024
90d66a4
Merge branch 'main' into add_dummy_exchange_test
halungge Feb 7, 2024
a5fc37a
WIP
halungge Feb 9, 2024
a4addc2
read grid_root and grid_level from experiment name
halungge Feb 13, 2024
3d92f8a
retract changes in parallel dycore test.
halungge Feb 14, 2024
abededa
formatting
halungge Feb 14, 2024
2508d58
Merge branch 'main' into fix_parallel_tests
halungge Feb 16, 2024
cb7d497
merge main
halungge Feb 21, 2024
58a8d50
merge main
halungge Feb 26, 2024
e08764f
update 2 node data url
halungge Feb 26, 2024
6d4c5b9
fix tests
halungge Feb 26, 2024
d53386f
fix 2 node data url
halungge Feb 27, 2024
6576fa6
Merge branch 'main' into fix_parallel_tests
halungge Feb 27, 2024
53be03b
merge main
halungge Mar 28, 2024
a56b482
pre-commit
halungge Apr 2, 2024
8e74d93
fix mch gridfile name for test
halungge Apr 2, 2024
d6418d8
merge main
halungge Apr 2, 2024
fdd89e9
Merge branch 'main' into fix_parallel_tests
halungge Apr 4, 2024
1e75d55
Merge branch 'main' into fix_parallel_tests
halungge Apr 5, 2024
378c583
updated datasets
halungge Apr 5, 2024
9776040
Merge branch 'main' into fix_parallel_tests
halungge Apr 5, 2024
474ed4d
rename property in icon_grid
halungge Apr 5, 2024
d705d5f
add docstring for parsing for root and levelfrom name
halungge Apr 5, 2024
6d44838
cleanup
halungge Apr 9, 2024
dad2137
Merge branch 'main' into fix_parallel_tests
halungge Apr 9, 2024
0fa8bd0
remove version restrictions for ghex and mpi4py, fix ghex lib name
halungge Apr 9, 2024
d2f5ec3
fix usage of fixtures
halungge Apr 9, 2024
f5293c8
adapt to new ghex interface
halungge Apr 9, 2024
6284fab
Add cProfile to wrapper
samkellerhals Apr 12, 2024
3292041
fix
abishekg7 Apr 12, 2024
0d7e6ae
initial commit
abishekg7 Apr 16, 2024
0fa868e
fix import
abishekg7 Apr 16, 2024
4412b17
Adding dims1,2,3 for indices
abishekg7 Apr 16, 2024
9dd4ad0
replacing Dim names
abishekg7 Apr 16, 2024
306d204
replacing Dim names
abishekg7 Apr 16, 2024
eddd5b4
try
abishekg7 Apr 16, 2024
ec54515
Add CachedProgram
samkellerhals Apr 17, 2024
8bcca18
Conditionally generate host_data
samkellerhals Apr 17, 2024
a5241bf
more acc fixes
samkellerhals Apr 17, 2024
03e4045
Use CachedProgram in diffusion
samkellerhals Apr 17, 2024
7e21089
remove cached program
samkellerhals Apr 18, 2024
09ef7d4
fixes
abishekg7 Apr 18, 2024
ded0d4f
adding halo exchanges to wrapper
abishekg7 Apr 18, 2024
a8aa65e
fixes and running precommit
abishekg7 Apr 18, 2024
e92923c
fixing process props
abishekg7 Apr 18, 2024
e6b453f
Add cached programs and duplicate one program
samkellerhals Apr 18, 2024
ce7169a
exchange
abishekg7 Apr 18, 2024
378e19c
revert changes int processor_props fixture
halungge Apr 19, 2024
a7bf60b
add mpi_skip to non mpi test
halungge Apr 19, 2024
47be422
Load connectivities on device in Grid
samkellerhals Apr 22, 2024
a55c509
use math.prod in unpacking functions
samkellerhals Apr 23, 2024
0b7afb5
Merge branch 'main' of github.com:C2SM/icon4py into py2f-with-optimis…
samkellerhals Apr 23, 2024
74f9b87
remove import for ghex.context, create context from pre-existing comm…
halungge Apr 23, 2024
bc8c6db
merge main and fix adapt construction of CellParams
halungge Apr 23, 2024
17fc611
Py2f cachedprogram fixes (#452)
samkellerhals Apr 24, 2024
ae733b1
Merge branch 'py2f-with-optimisations' of github.com:C2SM/icon4py int…
samkellerhals Apr 24, 2024
b06ee35
Add multi_return profiling test
samkellerhals Apr 24, 2024
53dda46
fix import of roundtrip
halungge Apr 24, 2024
d0995e5
temporarily removing sudo apt update
halungge Apr 24, 2024
e133e75
temporarily removing sudo apt update
halungge Apr 24, 2024
835c1a3
Merge remote-tracking branch 'origin/fix_parallel_tests' into py2f_pa…
abishekg7 Apr 25, 2024
4718d88
local edits
abishekg7 Apr 25, 2024
828f030
Merge branch 'py2f_parallel' of github.com:C2SM/icon4py into py2f_par…
abishekg7 Apr 25, 2024
22563cf
fix imports
abishekg7 Apr 25, 2024
41c7f17
fix
abishekg7 Apr 25, 2024
32c0083
fix
abishekg7 Apr 25, 2024
6c81409
cleanup diffusion.py
samkellerhals Apr 25, 2024
18fab84
cleanup diffusion.py
samkellerhals Apr 25, 2024
fbdcd07
remove conflict
samkellerhals Apr 25, 2024
bfe3b19
udno - temporarily removing sudo apt update
halungge Apr 25, 2024
f105880
Update model/common/src/icon4py/model/common/config.py
halungge Apr 25, 2024
a4c28d3
trying to pass communicator from fortran
abishekg7 Apr 25, 2024
2a1345d
fix import
halungge Apr 25, 2024
c8660c5
fix import
halungge Apr 25, 2024
dffdbc3
fix import
halungge Apr 26, 2024
eb40203
fix import in template.py
samkellerhals Apr 29, 2024
c272075
Merge branch 'fix_roundtrip_import' of github.com:C2SM/icon4py into f…
samkellerhals Apr 29, 2024
6be6136
Merge branch 'fix_roundtrip_import' of github.com:C2SM/icon4py into p…
samkellerhals Apr 29, 2024
82515de
Merge branch 'main' of github.com:C2SM/icon4py into py2f-with-optimis…
samkellerhals Apr 29, 2024
06d2a04
Pass cached connectivities
samkellerhals Apr 29, 2024
ffb27c6
Remove isinstance checks from size passing
samkellerhals Apr 29, 2024
9ea0689
Remove commented out code
samkellerhals Apr 30, 2024
7900cec
Merge branch 'py2f-with-optimisations' of github.com:C2SM/icon4py int…
samkellerhals Apr 30, 2024
07751ec
Remove gridfile
samkellerhals Apr 30, 2024
92e08fc
Adapt CachedProgram to use new workflow interface and don't pass cach…
samkellerhals Apr 30, 2024
8619487
Fix test which uses connecitivities on gpu
samkellerhals Apr 30, 2024
aaa39eb
Move connectivity from device to host in test
samkellerhals May 2, 2024
44cf44e
More connectivity handling from gpu
samkellerhals May 2, 2024
e14ac59
Formatting
samkellerhals May 2, 2024
7ee29e7
rename fixtures
samkellerhals May 2, 2024
e0aca5e
Fix call to extract_connectivities
samkellerhals May 2, 2024
fac70b6
first halo exchange is correct. then Nans
abishekg7 May 3, 2024
0e596fa
add device arg back temporarily
samkellerhals May 3, 2024
f3c49b9
remove comment
samkellerhals May 6, 2024
0fd5f10
Merge branch 'main' into py2f-with-optimisations
samkellerhals May 6, 2024
235dd36
2 core runs seem fine? Need more checks
abishekg7 May 6, 2024
e5c92f2
Merge branch 'py2f-with-optimisations' of github.com:C2SM/icon4py int…
abishekg7 May 6, 2024
9eb2828
fixes running on GPU and CPU
abishekg7 May 17, 2024
c5c868b
cleanup
abishekg7 May 17, 2024
7c0c592
more cleanup + precommit
abishekg7 May 17, 2024
b3bc385
Merge branch 'main' of github.com:C2SM/icon4py into py2f-with-optimis…
samkellerhals May 17, 2024
e4cb5df
fixes
abishekg7 May 17, 2024
29ac395
Merge branch 'py2f-with-optimisations' of github.com:C2SM/icon4py int…
abishekg7 May 17, 2024
b090ebb
precommit
abishekg7 May 17, 2024
567f58d
temp fix
abishekg7 May 17, 2024
382c82d
Merge branch 'main' of github.com:C2SM/icon4py into py2f_parallel
abishekg7 May 30, 2024
f300247
comment try
abishekg7 May 30, 2024
1f92f19
replacing True with limited_area
abishekg7 Jun 7, 2024
7f3c3d8
test
abishekg7 Jun 7, 2024
629f020
temp cpu changes
abishekg7 Jun 10, 2024
2785040
some fixes
abishekg7 Jun 14, 2024
da1ab2c
some cleanup
abishekg7 Jun 14, 2024
6a40ff0
cleanup
abishekg7 Jun 14, 2024
e6be7ba
more tests with cupy indices
abishekg7 Jun 14, 2024
e744c4d
Merge branch 'py2f_parallel' of github.com:C2SM/icon4py into py2f_par…
abishekg7 Jun 14, 2024
e54e2ab
changing indices to xp
abishekg7 Jun 18, 2024
196ae94
Adding xp.asnumpy
abishekg7 Jun 18, 2024
14b187f
using gt4py's asnumpy
abishekg7 Jun 18, 2024
94f9c30
addressing review
abishekg7 Jun 18, 2024
a375271
refactoring utils
abishekg7 Jun 18, 2024
a0df554
concat two dicts in build_array_size_args
abishekg7 Jun 18, 2024
057b79d
concat two dicts in build_array_size_args
abishekg7 Jun 18, 2024
546247b
fix
abishekg7 Jun 18, 2024
45b1d44
cleanup and address review comments
abishekg7 Jun 19, 2024
cdb51fe
refactoring to allow both serial and parallel runs
abishekg7 Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from dataclasses import InitVar, dataclass, field
from enum import Enum
from typing import Final, Optional

from gt4py.next import as_field
from gt4py.next.common import Dimension
from gt4py.next.ffront.fbuiltins import Field, int32
from icon4py.model.common.decomposition.definitions import DecompositionInfo

from icon4py.model.atmosphere.diffusion.diffusion_states import (
DiffusionDiagnosticState,
Expand Down Expand Up @@ -297,7 +297,7 @@ def __post_init__(self, config):
object.__setattr__(
self,
"scaled_nudge_max_coeff",
config.nudge_max_coeff * DEFAULT_PHYSICS_DYNAMICS_TIMESTEP_RATIO,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here saying that ICON already scales this by 5, and that therefore it is the responsibility of the user to set nudge_max_coeff accordingly

config.nudge_max_coeff,
)

def _determine_smagorinski_factor(self, config: DiffusionConfig):
Expand Down Expand Up @@ -370,6 +370,9 @@ def __init__(self, exchange: ExchangeRuntime = SingleNodeExchange()):
self.cell_params: Optional[CellParams] = None
self._horizontal_start_index_w_diffusion: int32 = 0

def set_exchange(self, exchange):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need to set this and make it mutable?

In general, we should discuss if there is a way that your granule interfacing works with having the distinction between the initand __init__ in the python granule.

self._exchange = exchange

def init(
self,
grid: IconGrid,
Expand Down Expand Up @@ -547,9 +550,9 @@ def _sync_cell_fields(self, prognostic_state):
log.debug("communication of prognostic cell fields: theta, w, exner - start")
self._exchange.exchange_and_wait(
CellDim,
prognostic_state.w,
prognostic_state.theta_v,
prognostic_state.exner,
prognostic_state.w.ndarray[0 : self.grid.num_cells, :],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try to get rid of this slices and it did not work. Do you know why or what did not work? We should try to figure this out and handle it differently it is very error prone like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the explicit bounds are needed we could try to push it inside GHexMultiNodeExchange.exchange . What you use here is the "real" num_cells, not nproma, right?

prognostic_state.theta_v.ndarray[0 : self.grid.num_cells, :],
prognostic_state.exner.ndarray[0 : self.grid.num_cells, :],
)
log.debug("communication of prognostic cell fields: theta, w, exner - done")

Expand Down Expand Up @@ -605,7 +608,21 @@ def _do_diffusion_step(
vertex_end_local = self.grid.get_end_index(
VertexDim, HorizontalMarkerIndex.local(VertexDim)
)

vertex_end_halo = self.grid.get_end_index(VertexDim, HorizontalMarkerIndex.halo(VertexDim))

loc_rank = self._exchange.my_rank()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either delete the log statements, or keep them but not commented out. You can switch it off globally.

# log.debug("cell_start_interior for rank",loc_rank," is ..",cell_start_interior)
# log.debug("cell_start_nudging for rank", loc_rank, " is ..", cell_start_nudging)
# log.debug("cell_end_local for rank", loc_rank, " is ..", cell_end_local)
# log.debug("cell_end_halo for rank", loc_rank, " is ..", cell_end_halo)
# log.debug("edge_start_nudging_plus_one for rank", loc_rank, " is ..", edge_start_nudging_plus_one)
# log.debug("edge_start_lb_plus4 for rank", loc_rank, " is ..", edge_start_lb_plus4)
# log.debug("edge_end_local for rank", loc_rank, " is ..", edge_end_local)
# log.debug("edge_end_local_minus2 for rank", loc_rank, " is ..", edge_end_local_minus2)
# log.debug("edge_end_halo for rank", loc_rank, " is ..", edge_end_halo)
# log.debug("vertex_start_lb_plus1 for rank", loc_rank, " is ..", vertex_start_lb_plus1)
# log.debug("vertex_end_local for rank", loc_rank, " is ..", vertex_end_local)
# log.debug("vertex_end_halo for rank", loc_rank, " is ..", vertex_end_halo)
# dtime dependent: enh_smag_factor,
scale_k(self.enh_smag_fac, dtime, self.diff_multfac_smag, offset_provider={})

Expand All @@ -624,10 +641,36 @@ def _do_diffusion_step(
)
log.debug("rbf interpolation 1: end")

# loc_ind_verts=self._exchange._decomposition_info.local_index(VertexDim,DecompositionInfo.EntryType.HALO)
# log.debug("loc_ind_verts rank %s", loc_rank, " loc_ind_verts: %s",loc_ind_verts," shape: %s",loc_ind_verts.shape)
# log.debug("after rbf rank %s", loc_rank, " u_vert max: %s min: %s",
# xp.max(self.u_vert.ndarray[vertex_start_lb_plus1:vertex_end_local, 0:klevels]),
# xp.min(self.u_vert.ndarray[vertex_start_lb_plus1:vertex_end_local, 0:klevels]))
# 2. HALO EXCHANGE -- CALL sync_patch_array_mult u_vert and v_vert
# log.debug("halo....after rbf rank %s", loc_rank, " u_vert max: %s min: %s",
# self.u_vert.ndarray[loc_ind_verts, 0],
# self.u_vert.ndarray[loc_ind_verts, 0])
log.debug("communication rbf extrapolation of vn - start")
self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
log.debug(
"size of u_vert %s v_vert %s", self.u_vert.ndarray.shape, self.v_vert.ndarray.shape
)
log.debug(
"edge_start_lb_plus4 %s edge_end_local_minus2 %s",
edge_start_lb_plus4,
edge_end_local_minus2,
)
self._exchange.exchange_and_wait(
VertexDim,
self.u_vert.ndarray[0 : self.grid.num_vertices, :],
self.v_vert.ndarray[0 : self.grid.num_vertices, :],
)
log.debug("communication rbf extrapolation of vn - end")
# log.debug("after exchange rank %s", loc_rank, " u_vert max: %s min: %s",
# xp.max(self.u_vert.ndarray[vertex_start_lb_plus1:vertex_end_local, 0:klevels]),
# xp.min(self.u_vert.ndarray[vertex_start_lb_plus1:vertex_end_local, 0:klevels]))
# log.debug("halo....after exchange rank %s", loc_rank, " u_vert max: %s min: %s",
# self.u_vert.ndarray[loc_ind_verts, 0],
# self.u_vert.ndarray[loc_ind_verts, 0])

log.debug("running stencil 01(calculate_nabla2_and_smag_coefficients_for_vn): start")
calculate_nabla2_and_smag_coefficients_for_vn(
Expand Down Expand Up @@ -682,10 +725,18 @@ def _do_diffusion_step(

# HALO EXCHANGE IF (discr_vn > 1) THEN CALL sync_patch_array
# TODO (magdalena) move this up and do asynchronous exchange
# loc_ind_edges=self._exchange._decomposition_info.local_index(EdgeDim,DecompositionInfo.EntryType.HALO)
# log.debug("loc_ind_edges rank %s", loc_rank, " loc_ind_edges: %s",loc_ind_edges," shape: %s",loc_ind_edges.shape)
# log.debug("halo..z_nabla2_e..before exchange rank %s", loc_rank, " z_nabla2_e: %s",
# self.z_nabla2_e.ndarray[loc_ind_verts, 0])
if self.config.type_vn_diffu > 1:
log.debug("communication rbf extrapolation of z_nable2_e - start")
self._exchange.exchange_and_wait(EdgeDim, self.z_nabla2_e)
self._exchange.exchange_and_wait(
EdgeDim, self.z_nabla2_e.ndarray[0 : self.grid.num_edges, :]
)
log.debug("communication rbf extrapolation of z_nable2_e - end")
# log.debug("halo..z_nabla2_e..after exchange rank %s", loc_rank, " z_nabla2_e: %s",
# self.z_nabla2_e.ndarray[loc_ind_verts, 0])

log.debug("2nd rbf interpolation: start")
mo_intp_rbf_rbf_vec_interpol_vertex(
Expand All @@ -703,9 +754,19 @@ def _do_diffusion_step(
log.debug("2nd rbf interpolation: end")

# 6. HALO EXCHANGE -- CALL sync_patch_array_mult (Vertex Fields)
# log.debug("halo....after 2nd... rbf rank %s", loc_rank, " u_vert %s",
# self.u_vert.ndarray[loc_ind_verts, 0], " v_vert:",
# self.u_vert.ndarray[loc_ind_verts, 0])
log.debug("communication rbf extrapolation of z_nable2_e - start")
self._exchange.exchange_and_wait(VertexDim, self.u_vert, self.v_vert)
self._exchange.exchange_and_wait(
VertexDim,
self.u_vert.ndarray[0 : self.grid.num_vertices, :],
self.v_vert.ndarray[0 : self.grid.num_vertices, :],
)
log.debug("communication rbf extrapolation of z_nable2_e - end")
# log.debug("halo....after 2nd... after exchange rank %s", loc_rank, " u_vert %s",
# self.u_vert.ndarray[loc_ind_verts, 0], " v_vert:",
# self.u_vert.ndarray[loc_ind_verts, 0])

log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): start")
apply_diffusion_to_vn(
Expand Down Expand Up @@ -734,7 +795,14 @@ def _do_diffusion_step(
)
log.debug("running stencils 04 05 06 (apply_diffusion_to_vn): end")
log.debug("communication of prognistic.vn : start")
handle_edge_comm = self._exchange.exchange(EdgeDim, prognostic_state.vn)
# log.debug("halo..vn..before exchange rank %s", loc_rank, " vn: %s",
# prognostic_state.vn.ndarray[loc_ind_verts, 0])
handle_edge_comm = self._exchange.exchange(
EdgeDim, prognostic_state.vn.ndarray[0 : self.grid.num_edges, :]
)
# handle_edge_comm = self._exchange.exchange_and_wait(EdgeDim, prognostic_state.vn.ndarray[0:self.grid.num_edges,:])
# log.debug("halo..vn..after exchange rank %s", loc_rank, " vn: %s",
# prognostic_state.vn.ndarray[loc_ind_verts, 0])

log.debug(
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): start"
Expand Down
30 changes: 14 additions & 16 deletions model/common/src/icon4py/model/common/decomposition/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from enum import IntEnum
from typing import Any, Protocol

import numpy as np
import numpy.ma as ma
from gt4py.next import Dimension

from icon4py.model.common.settings import xp
from icon4py.model.common.utils import builder


Expand Down Expand Up @@ -72,12 +71,13 @@ class EntryType(IntEnum):
HALO = 2

@builder
def with_dimension(self, dim: Dimension, global_index: np.ndarray, owner_mask: np.ndarray):
masked_global_index = ma.array(global_index, mask=owner_mask)
self._global_index[dim] = masked_global_index
def with_dimension(self, dim: Dimension, global_index: xp.ndarray, owner_mask: xp.ndarray):
self._global_index[dim] = global_index
self._owner_mask[dim] = owner_mask

def __init__(self, klevels: int):
self._global_index = {}
self._owner_mask = {}
self._klevels = klevels

@property
Expand All @@ -90,31 +90,29 @@ def local_index(self, dim: Dimension, entry_type: EntryType = EntryType.ALL):
return self._to_local_index(dim)
case DecompositionInfo.EntryType.HALO:
index = self._to_local_index(dim)
mask = self._global_index[dim].mask
mask = self._owner_mask[dim]
return index[~mask]
case DecompositionInfo.EntryType.OWNED:
index = self._to_local_index(dim)
mask = self._global_index[dim].mask
mask = self._owner_mask[dim]
return index[mask]

def _to_local_index(self, dim):
data = ma.getdata(self._global_index[dim], subok=False)
data = self._global_index[dim]
assert data.ndim == 1
return np.arange(data.shape[0])
return xp.arange(data.shape[0])

def owner_mask(self, dim: Dimension) -> np.ndarray:
return self._global_index[dim].mask
def owner_mask(self, dim: Dimension) -> xp.ndarray:
return self._owner_mask[dim]

def global_index(self, dim: Dimension, entry_type: EntryType = EntryType.ALL):
match entry_type:
case DecompositionInfo.EntryType.ALL:
return ma.getdata(self._global_index[dim], subok=False)
return self._global_index[dim]
case DecompositionInfo.EntryType.OWNED:
global_index = self._global_index[dim]
return ma.getdata(global_index[global_index.mask])
return self._global_index[dim][self._owner_mask[dim]]
case DecompositionInfo.EntryType.HALO:
global_index = self._global_index[dim]
return ma.getdata(global_index[~global_index.mask])
return self._global_index[dim][~self._owner_mask[dim]]
case _:
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,29 @@
from gt4py.next import Dimension, Field

from icon4py.model.common.decomposition.definitions import SingleNodeExchange


try:
import ghex
import mpi4py
from ghex.context import make_context
from ghex.unstructured import (
DomainDescriptor,
HaloGenerator,
make_communication_object,
make_field_descriptor,
make_pattern,
)

mpi4py.rc.initialize = False
mpi4py.rc.finalize = True

except ImportError:
mpi4py = None
ghex = None
unstructured = None
from icon4py.model.common.settings import device


#try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removes ghex and mpi4py being optional dependencies, that is you can run icon4py even if you don't have those library installed. We should keep that feature imho.

import ghex
import mpi4py
from ghex.context import make_context
from ghex.unstructured import (
DomainDescriptor,
HaloGenerator,
make_communication_object,
make_field_descriptor,
make_pattern,
)
from ghex.util import Architecture

mpi4py.rc.initialize = False
mpi4py.rc.finalize = True

# except ImportError:
# mpi4py = None
# ghex = None
# unstructured = None

from icon4py.model.common.decomposition import definitions
from icon4py.model.common.dimension import CellDim, DimensionKind, EdgeDim, VertexDim
Expand All @@ -51,6 +53,11 @@
import mpi4py.MPI


if device.name == "GPU":
ghex_arch = Architecture.GPU
else:
ghex_arch = Architecture.CPU

CommId = Union[int, "mpi4py.MPI.Comm", None]
log = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,8 +107,10 @@ def filter(self, record: logging.LogRecord) -> bool:


@definitions.get_processor_properties.register(definitions.MultiNodeRun)
def get_multinode_properties(s: definitions.MultiNodeRun) -> definitions.ProcessProperties:
return _get_processor_properties(with_mpi=True)
def get_multinode_properties(
s: definitions.MultiNodeRun, comm_id: CommId = None
) -> definitions.ProcessProperties:
return _get_processor_properties(with_mpi=True, comm_id=comm_id)


@dataclass(frozen=True)
Expand Down Expand Up @@ -202,15 +211,16 @@ def exchange(self, dim: definitions.Dimension, *fields: Sequence[Field]):
domain_descriptor = self._domain_descriptors[dim]
assert domain_descriptor is not None, f"domain descriptor for {dim.value} not found"
applied_patterns = [
pattern(make_field_descriptor(domain_descriptor, f.asnumpy())) for f in fields
pattern(make_field_descriptor(domain_descriptor, f, arch=ghex_arch)) for f in fields
]
handle = self._comm.exchange(applied_patterns)
log.info(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.")
log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.")
return MultiNodeResult(handle, applied_patterns)

def exchange_and_wait(self, dim: Dimension, *fields: tuple):
res = self.exchange(dim, *fields)
res.wait()
log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' done.")


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions model/common/src/icon4py/model/common/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
EdgeDim = Dimension("Edge")
CellDim = Dimension("Cell")
VertexDim = Dimension("Vertex")
SingletonDim = Dimension("Singleton")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you need those for? Could you move them it to py2fgen? Whatever has something to do with only the interfacing to fortran should go to tools/py2fgen and not bloat up the model.

SpecialADim = Dimension("SpecialA")
SpecialBDim = Dimension("SpecialB")
SpecialCDim = Dimension("SpecialC")
CellIndexDim = Dimension("CellIndex")
EdgeIndexDim = Dimension("EdgeIndex")
VertexIndexDim = Dimension("VertexIndex")
CEDim = Dimension("CE")
CECDim = Dimension("CEC")
ECDim = Dimension("EC")
Expand Down
7 changes: 0 additions & 7 deletions model/common/src/icon4py/model/common/grid/horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ def end(cls, dim: Dimension) -> int:
return cls._end[dim]


@dataclass(frozen=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for cleaning up...

class HorizontalGridSize:
num_vertices: int
num_edges: int
num_cells: int


class EdgeParams:
def __init__(
self,
Expand Down
4 changes: 3 additions & 1 deletion model/common/src/icon4py/model/common/grid/icon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import cached_property

import numpy as np
from icon4py.model.common.settings import xp
from gt4py.next.common import Dimension, DimensionKind
from gt4py.next.ffront.fbuiltins import int32

Expand Down Expand Up @@ -90,7 +91,8 @@ def __init__(self):

@builder
def with_start_end_indices(
self, dim: Dimension, start_indices: np.ndarray, end_indices: np.ndarray
self, dim: Dimension, start_indices: xp.ndarray, end_indices: xp.ndarray
#self, dim: Dimension, start_indices: np.ndarray, end_indices: np.ndarray
):
self.start_indices[dim] = start_indices.astype(int32)
self.end_indices[dim] = end_indices.astype(int32)
Expand Down
4 changes: 2 additions & 2 deletions model/common/src/icon4py/model/common/grid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np
from gt4py.next import Dimension, NeighborTableOffsetProvider

from icon4py.model.common.settings import xp
Expand All @@ -23,7 +22,8 @@ def neighbortable_offset_provider_for_1d_sparse_fields(
neighbor_axis: Dimension,
has_skip_values: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you need the on_gpu for?

):
table = xp.asarray(np.arange(old_shape[0] * old_shape[1]).reshape(old_shape))

table = xp.asarray(xp.arange(old_shape[0] * old_shape[1]).reshape(old_shape))
return NeighborTableOffsetProvider(
table,
origin_axis,
Expand Down
Loading
Loading