diff --git a/grudge/discretization.py b/grudge/discretization.py index dc08cd11f..0a9703e3c 100644 --- a/grudge/discretization.py +++ b/grudge/discretization.py @@ -7,13 +7,7 @@ .. autofunction:: make_discretization_collection .. currentmodule:: grudge.discretization - -.. autofunction:: relabel_partitions - -Internal things that are visble due to type annotations -------------------------------------------------------- - -.. class:: _InterPartitionConnectionPair +.. autoclass:: PartID """ __copyright__ = """ @@ -41,13 +35,14 @@ THE SOFTWARE. """ -from typing import Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any +from typing import Sequence, Mapping, Optional, Union, Tuple, TYPE_CHECKING, Any from pytools import memoize_method, single_valued +from dataclasses import dataclass, replace + from grudge.dof_desc import ( VTAG_ALL, - BTAG_MULTIVOL_PARTITION, DD_VOLUME_ALL, DISCR_TAG_BASE, DISCR_TAG_MODAL, @@ -70,8 +65,7 @@ make_face_restriction, DiscretizationConnection ) -from meshmode.mesh import ( - InterPartitionAdjacencyGroup, Mesh, BTAG_PARTITION, BoundaryTag) +from meshmode.mesh import Mesh, BTAG_PARTITION from meshmode.dof_array import DOFArray from warnings import warn @@ -80,6 +74,89 @@ import mpi4py.MPI +@dataclass(frozen=True) +class PartID: + """Unique identifier for a piece of a partitioned mesh. + + .. attribute:: volume_tag + + The volume of the part. + + .. attribute:: rank + + The (optional) MPI rank of the part. + + """ + volume_tag: VolumeTag + rank: Optional[int] = None + + +# {{{ part ID normalization + +def _normalize_mesh_part_ids( + mesh: Mesh, + volume_tags: Sequence[VolumeTag], + mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None): + """Convert a mesh's configuration-dependent "part ID" into a fixed type.""" + from numbers import Integral + if VTAG_ALL not in volume_tags: + # Multi-volume + if mpi_communicator is not None: + # Accept PartID + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Accept PartID or volume tag + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif mesh_part_id in volume_tags: + return PartID(mesh_part_id) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Single-volume + if mpi_communicator is not None: + # Accept PartID or rank + def as_part_id(mesh_part_id): + if isinstance(mesh_part_id, PartID): + return mesh_part_id + elif isinstance(mesh_part_id, Integral): + return PartID(VTAG_ALL, int(mesh_part_id)) + else: + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + else: + # Shouldn't be called + def as_part_id(mesh_part_id): + raise TypeError(f"Unable to convert {mesh_part_id} to PartID.") + + facial_adjacency_groups = mesh.facial_adjacency_groups + + new_facial_adjacency_groups = [] + + from meshmode.mesh import InterPartAdjacencyGroup + for grp_list in facial_adjacency_groups: + new_grp_list = [] + for fagrp in grp_list: + if isinstance(fagrp, InterPartAdjacencyGroup): + part_id = as_part_id(fagrp.part_id) + new_fagrp = replace( + fagrp, + boundary_tag=BTAG_PARTITION(part_id), + part_id=part_id) + else: + new_fagrp = fagrp + new_grp_list.append(new_fagrp) + new_facial_adjacency_groups.append(new_grp_list) + + return mesh.copy(facial_adjacency_groups=new_facial_adjacency_groups) + +# }}} + + # {{{ discr_tag_to_group_factory normalization def _normalize_discr_tag_to_group_factory( @@ -133,6 +210,11 @@ class DiscretizationCollection: (volume, interior facets, boundaries) and associated element groups. + .. note:: + + Do not call the constructor directly. Use + :func:`make_discretization_collection` instead. + .. autoattribute:: dim .. autoattribute:: ambient_dim .. autoattribute:: real_dtype @@ -160,8 +242,9 @@ def __init__(self, array_context: ArrayContext, discr_tag_to_group_factory: Optional[ Mapping[DiscretizationTag, ElementGroupFactory]] = None, mpi_communicator: Optional["mpi4py.MPI.Intracomm"] = None, - inter_partition_connections: Optional[ - Mapping[BoundaryDomainTag, DiscretizationConnection]] = None + inter_part_connections: Optional[ + Mapping[Tuple[PartID, PartID], + DiscretizationConnection]] = None, ) -> None: """ :arg discr_tag_to_group_factory: A mapping from discretization tags @@ -202,15 +285,19 @@ def __init__(self, array_context: ArrayContext, from meshmode.discretization import Discretization - # {{{ deprecated backward compatibility yuck - if isinstance(volume_discrs, Mesh): + # {{{ deprecated backward compatibility yuck + warn("Calling the DiscretizationCollection constructor directly " "is deprecated, call make_discretization_collection " "instead. This will stop working in 2023.", DeprecationWarning, stacklevel=2) mesh = volume_discrs + + mesh = _normalize_mesh_part_ids( + mesh, [VTAG_ALL], mpi_communicator=mpi_communicator) + discr_tag_to_group_factory = _normalize_discr_tag_to_group_factory( dim=mesh.dim, discr_tag_to_group_factory=discr_tag_to_group_factory, @@ -224,30 +311,30 @@ def __init__(self, array_context: ArrayContext, del mesh - if inter_partition_connections is not None: - raise TypeError("may not pass inter_partition_connections when " + if inter_part_connections is not None: + raise TypeError("may not pass inter_part_connections when " "DiscretizationCollection constructor is called in " "legacy mode") - self._inter_partition_connections = \ - _set_up_inter_partition_connections( + self._inter_part_connections = \ + _set_up_inter_part_connections( array_context=self._setup_actx, mpi_communicator=mpi_communicator, volume_discrs=volume_discrs, base_group_factory=( discr_tag_to_group_factory[DISCR_TAG_BASE])) - else: - if inter_partition_connections is None: - raise TypeError("inter_partition_connections must be passed when " - "DiscretizationCollection constructor is called in " - "'modern' mode") - - self._inter_partition_connections = inter_partition_connections + # }}} + else: assert discr_tag_to_group_factory is not None self._discr_tag_to_group_factory = discr_tag_to_group_factory - # }}} + if inter_part_connections is None: + raise TypeError("inter_part_connections must be passed when " + "DiscretizationCollection constructor is called in " + "'modern' mode") + + self._inter_part_connections = inter_part_connections self._volume_discrs = volume_discrs @@ -729,104 +816,67 @@ def normal(self, dd): # {{{ distributed/multi-volume setup -def _check_btag(tag: BoundaryTag) -> Union[BTAG_MULTIVOL_PARTITION, BTAG_PARTITION]: - if isinstance(tag, BTAG_MULTIVOL_PARTITION): - return tag - - elif isinstance(tag, BTAG_PARTITION): - return tag - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(tag)}'") - - -def _remote_rank_from_btag(btag: BoundaryTag) -> Optional[int]: - if isinstance(btag, BTAG_PARTITION): - return btag.part_nr - - elif isinstance(btag, BTAG_MULTIVOL_PARTITION): - return btag.other_rank - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(btag)}'") - - -def _flip_dtag( - self_rank: Optional[int], - domain_tag: BoundaryDomainTag, - ) -> BoundaryDomainTag: - if isinstance(domain_tag.tag, BTAG_PARTITION): - assert self_rank is not None - return BoundaryDomainTag( - BTAG_PARTITION(self_rank), domain_tag.volume_tag) - - elif isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION): - return BoundaryDomainTag( - BTAG_MULTIVOL_PARTITION( - other_rank=None if domain_tag.tag.other_rank is None else self_rank, - other_volume_tag=domain_tag.volume_tag), - domain_tag.tag.other_volume_tag) - - else: - raise TypeError("unexpected type of inter-partition boundary tag " - f"'{type(domain_tag.tag)}'") - - -def _set_up_inter_partition_connections( +def _set_up_inter_part_connections( array_context: ArrayContext, mpi_communicator: Optional["mpi4py.MPI.Intracomm"], volume_discrs: Mapping[VolumeTag, Discretization], - base_group_factory: ElementGroupFactory, + base_group_factory: ElementGroupFactory, ) -> Mapping[ - BoundaryDomainTag, + Tuple[PartID, PartID], DiscretizationConnection]: - from meshmode.distributed import (get_inter_partition_tags, + from meshmode.distributed import (get_connected_parts, make_remote_group_infos, InterRankBoundaryInfo, MPIBoundaryCommSetupHelper) - inter_part_tags = { - BoundaryDomainTag(_check_btag(btag), discr_vol_tag) - for discr_vol_tag, volume_discr in volume_discrs.items() - for btag in get_inter_partition_tags(volume_discr.mesh)} + rank = mpi_communicator.Get_rank() if mpi_communicator is not None else None + + # Save boundary restrictions as they're created to avoid potentially creating + # them twice in the loop below + cached_part_bdry_restrictions: Mapping[ + Tuple[PartID, PartID], + DiscretizationConnection] = {} + + def get_part_bdry_restriction(self_part_id, other_part_id): + cached_result = cached_part_bdry_restrictions.get( + (self_part_id, other_part_id), None) + if cached_result is not None: + return cached_result + return cached_part_bdry_restrictions.setdefault( + (self_part_id, other_part_id), + make_face_restriction( + array_context, volume_discrs[self_part_id.volume_tag], + base_group_factory, + boundary_tag=BTAG_PARTITION(other_part_id))) inter_part_conns: Mapping[ - BoundaryDomainTag, + Tuple[PartID, PartID], DiscretizationConnection] = {} - if inter_part_tags: - local_boundary_restrictions = { - domain_tag: make_face_restriction( - array_context, volume_discrs[domain_tag.volume_tag], - base_group_factory, boundary_tag=domain_tag.tag) - for domain_tag in inter_part_tags} + irbis = [] - irbis = [] + for vtag, volume_discr in volume_discrs.items(): + part_id = PartID(vtag, rank) + connected_part_ids = get_connected_parts(volume_discr.mesh) + for connected_part_id in connected_part_ids: + bdry_restr = get_part_bdry_restriction( + self_part_id=part_id, other_part_id=connected_part_id) - for domain_tag in inter_part_tags: - assert isinstance( - domain_tag.tag, (BTAG_PARTITION, BTAG_MULTIVOL_PARTITION)) - - other_rank = _remote_rank_from_btag(domain_tag.tag) - btag_restr = local_boundary_restrictions[domain_tag] - - if other_rank is None: + if connected_part_id.rank == rank: # {{{ rank-local interface between multiple volumes - assert isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION) + connected_bdry_restr = get_part_bdry_restriction( + self_part_id=connected_part_id, other_part_id=part_id) from meshmode.discretization.connection import \ make_partition_connection - remote_dtag = _flip_dtag(None, domain_tag) - inter_part_conns[domain_tag] = make_partition_connection( + inter_part_conns[connected_part_id, part_id] = \ + make_partition_connection( array_context, - local_bdry_conn=btag_restr, - remote_bdry_discr=( - local_boundary_restrictions[remote_dtag].to_discr), + local_bdry_conn=bdry_restr, + remote_bdry_discr=connected_bdry_restr.to_discr, remote_group_infos=make_remote_group_infos( - array_context, remote_dtag.tag, btag_restr)) + array_context, part_id, connected_bdry_restr)) # }}} else: @@ -838,27 +888,25 @@ def _set_up_inter_partition_connections( irbis.append( InterRankBoundaryInfo( - local_btag=domain_tag.tag, - local_part_id=domain_tag, - remote_part_id=_flip_dtag( - mpi_communicator.rank, domain_tag), - remote_rank=other_rank, - local_boundary_connection=btag_restr)) + local_part_id=part_id, + remote_part_id=connected_part_id, + remote_rank=connected_part_id.rank, + local_boundary_connection=bdry_restr)) # }}} - if irbis: - assert mpi_communicator is not None + if irbis: + assert mpi_communicator is not None - with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, - irbis, base_group_factory) as bdry_setup_helper: - while True: - conns = bdry_setup_helper.complete_some() - if not conns: - # We're done. - break + with MPIBoundaryCommSetupHelper(mpi_communicator, array_context, + irbis, base_group_factory) as bdry_setup_helper: + while True: + conns = bdry_setup_helper.complete_some() + if not conns: + # We're done. + break - inter_part_conns.update(conns) + inter_part_conns.update(conns) return inter_part_conns @@ -942,6 +990,7 @@ def make_discretization_collection( from pytools import single_valued, is_single_valued + assert len(volumes) > 0 assert is_single_valued(mesh_or_discr.ambient_dim for mesh_or_discr in volumes.values()) @@ -953,62 +1002,30 @@ def make_discretization_collection( del order + mpi_communicator = getattr(array_context, "mpi_communicator", None) + + if any( + isinstance(mesh_or_discr, Discretization) + for mesh_or_discr in volumes.values()): + raise NotImplementedError("Doesn't work at the moment") + volume_discrs = { - vtag: ( - Discretization( - array_context, mesh_or_discr, - discr_tag_to_group_factory[DISCR_TAG_BASE]) - if isinstance(mesh_or_discr, Mesh) else mesh_or_discr) - for vtag, mesh_or_discr in volumes.items()} + vtag: Discretization( + array_context, + _normalize_mesh_part_ids( + mesh, volumes.keys(), mpi_communicator=mpi_communicator), + discr_tag_to_group_factory[DISCR_TAG_BASE]) + for vtag, mesh in volumes.items()} return DiscretizationCollection( array_context=array_context, volume_discrs=volume_discrs, discr_tag_to_group_factory=discr_tag_to_group_factory, - inter_partition_connections=_set_up_inter_partition_connections( + inter_part_connections=_set_up_inter_part_connections( array_context=array_context, - mpi_communicator=getattr( - array_context, "mpi_communicator", None), + mpi_communicator=mpi_communicator, volume_discrs=volume_discrs, - base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE], - )) - -# }}} - - -# {{{ relabel_partitions - -def relabel_partitions(mesh: Mesh, - self_rank: int, - part_nr_to_rank_and_vol_tag: Mapping[int, Tuple[int, VolumeTag]]) -> Mesh: - """Given a partitioned mesh (which includes :class:`meshmode.mesh.BTAG_PARTITION` - boundary tags), relabel those boundary tags into - :class:`grudge.dof_desc.BTAG_MULTIVOL_PARTITION` tags, which map each - of the incoming partitions onto a combination of rank and volume tag, - given by *part_nr_to_rank_and_vol_tag*. - """ - - def _new_btag(btag: BoundaryTag) -> BTAG_MULTIVOL_PARTITION: - if not isinstance(btag, BTAG_PARTITION): - raise TypeError("unexpected inter-partition boundary tags of type " - f"'{type(btag)}', expected BTAG_PARTITION") - - rank, vol_tag = part_nr_to_rank_and_vol_tag[btag.part_nr] - return BTAG_MULTIVOL_PARTITION( - other_rank=(None if rank == self_rank else rank), - other_volume_tag=vol_tag) - - assert mesh.facial_adjacency_groups is not None - - from dataclasses import replace - return mesh.copy(facial_adjacency_groups=[ - [ - replace(fagrp, - boundary_tag=_new_btag(fagrp.boundary_tag)) - if isinstance(fagrp, InterPartitionAdjacencyGroup) - else fagrp - for fagrp in grp_fagrp_list] - for grp_fagrp_list in mesh.facial_adjacency_groups]) + base_group_factory=discr_tag_to_group_factory[DISCR_TAG_BASE])) # }}} diff --git a/grudge/dof_desc.py b/grudge/dof_desc.py index e3015d1d8..cf285a30e 100644 --- a/grudge/dof_desc.py +++ b/grudge/dof_desc.py @@ -8,8 +8,6 @@ :mod:`grudge`-specific boundary tags ------------------------------------ -.. autoclass:: BTAG_MULTIVOL_PARTITION - Domain tags ----------- @@ -111,24 +109,6 @@ class VTAG_ALL: # noqa: N801 # }}} -# {{{ partition identifier - -@dataclass(init=True, eq=True, frozen=True) -class BTAG_MULTIVOL_PARTITION: # noqa: N801 - """ - .. attribute:: other_rank - - An integer, or *None*. If *None*, this marks a partition boundary - to another volume on the same rank. - - .. attribute:: other_volume_tag - """ - other_rank: Optional[int] - other_volume_tag: "VolumeTag" - -# }}} - - # {{{ domain tag @dataclass(frozen=True, eq=True) @@ -401,17 +381,17 @@ def _normalize_domain_and_discr_tag( if _contextual_volume_tag is None: _contextual_volume_tag = VTAG_ALL - if domain in [DTAG_SCALAR, "scalar"]: + if domain == "scalar": domain = DTAG_SCALAR - elif isinstance(domain, (BoundaryDomainTag, VolumeDomainTag)): + elif isinstance(domain, (ScalarDomainTag, BoundaryDomainTag, VolumeDomainTag)): pass - elif domain == "vol": + elif domain in [VTAG_ALL, "vol"]: domain = DTAG_VOLUME_ALL elif domain in [FACE_RESTR_ALL, "all_faces"]: - domain = BoundaryDomainTag(FACE_RESTR_ALL) + domain = BoundaryDomainTag(FACE_RESTR_ALL, _contextual_volume_tag) elif domain in [FACE_RESTR_INTERIOR, "int_faces"]: - domain = BoundaryDomainTag(FACE_RESTR_INTERIOR) - elif isinstance(domain, (BTAG_PARTITION, BTAG_MULTIVOL_PARTITION)): + domain = BoundaryDomainTag(FACE_RESTR_INTERIOR, _contextual_volume_tag) + elif isinstance(domain, BTAG_PARTITION): domain = BoundaryDomainTag(domain, _contextual_volume_tag) elif domain in [BTAG_ALL, BTAG_REALLY_ALL, BTAG_NONE]: domain = BoundaryDomainTag(domain, _contextual_volume_tag) diff --git a/grudge/dt_utils.py b/grudge/dt_utils.py index a108cddbd..817828234 100644 --- a/grudge/dt_utils.py +++ b/grudge/dt_utils.py @@ -82,7 +82,7 @@ def characteristic_lengthscales( node distance on the reference cell (see :func:`dt_non_geometric_factors`), and :math:`r_D` is the inradius of the cell (see :func:`dt_geometric_factors`). - :returns: a frozen :class:`~meshmode.dof_array.DOFArray` containing a + :returns: a :class:`~meshmode.dof_array.DOFArray` containing a characteristic lengthscale for each element, at each nodal location. .. note:: @@ -94,7 +94,7 @@ def characteristic_lengthscales( methods has been used as a guide. Any concrete time integrator will likely require scaling of the values returned by this routine. """ - @memoize_in(dcoll, (characteristic_lengthscales, + @memoize_in(dcoll, (characteristic_lengthscales, dd, "compute_characteristic_lengthscales")) def _compute_characteristic_lengthscales(): return actx.freeze( diff --git a/grudge/eager.py b/grudge/eager.py index 400fee355..2175592d4 100644 --- a/grudge/eager.py +++ b/grudge/eager.py @@ -47,14 +47,14 @@ def __init__(self, *args, **kwargs): def project(self, src, tgt, vec): return op.project(self, src, tgt, vec) - def grad(self, vec): - return op.local_grad(self, vec) + def grad(self, *args): + return op.local_grad(self, *args) - def d_dx(self, xyz_axis, vec): - return op.local_d_dx(self, xyz_axis, vec) + def d_dx(self, xyz_axis, *args): + return op.local_d_dx(self, xyz_axis, *args) - def div(self, vecs): - return op.local_div(self, vecs) + def div(self, *args): + return op.local_div(self, *args) def weak_grad(self, *args): return op.weak_local_grad(self, *args) @@ -68,8 +68,8 @@ def weak_div(self, *args): def mass(self, *args): return op.mass(self, *args) - def inverse_mass(self, vec): - return op.inverse_mass(self, vec) + def inverse_mass(self, *args): + return op.inverse_mass(self, *args) def face_mass(self, *args): return op.face_mass(self, *args) @@ -89,5 +89,6 @@ def nodal_max(self, dd, vec): interior_trace_pair = op.interior_trace_pair cross_rank_trace_pairs = op.cross_rank_trace_pairs +inter_volume_trace_pairs = op.inter_volume_trace_pairs # vim: foldmethod=marker diff --git a/grudge/op.py b/grudge/op.py index 967a66ac3..862b201dd 100644 --- a/grudge/op.py +++ b/grudge/op.py @@ -81,6 +81,7 @@ DiscretizationFaceAxisTag) from grudge.discretization import DiscretizationCollection +from grudge.dof_desc import as_dofdesc from pytools import keyed_memoize_in from pytools.obj_array import make_obj_array @@ -88,7 +89,10 @@ import numpy as np import grudge.dof_desc as dof_desc -from grudge.dof_desc import DD_VOLUME_ALL, FACE_RESTR_ALL +from grudge.dof_desc import ( + DD_VOLUME_ALL, FACE_RESTR_ALL, DISCR_TAG_BASE, + DOFDesc, VolumeDomainTag +) from grudge.interpolation import interp from grudge.projection import project @@ -249,14 +253,14 @@ def get_ref_derivative_mats(grp): def _strong_scalar_grad(dcoll, dd_in, vec): - assert dd_in == dof_desc.as_dofdesc(DD_VOLUME_ALL) + assert isinstance(dd_in.domain_tag, VolumeDomainTag) from grudge.geometry import inverse_surface_metric_derivative_mat - discr = dcoll.discr_from_dd(DD_VOLUME_ALL) + discr = dcoll.discr_from_dd(dd_in) actx = vec.array_context - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) return _gradient_kernel(actx, discr, discr, _reference_derivative_matrices, inverse_jac_mat, vec, @@ -264,7 +268,7 @@ def _strong_scalar_grad(dcoll, dd_in, vec): def local_grad( - dcoll: DiscretizationCollection, vec, *, nested=False) -> ArrayOrContainer: + dcoll: DiscretizationCollection, *args, nested=False) -> ArrayOrContainer: r"""Return the element-local gradient of a function :math:`f` represented by *vec*: @@ -273,15 +277,26 @@ def local_grad( \nabla|_E f = \left( \partial_x|_E f, \partial_y|_E f, \partial_z|_E f \right) + May be called with ``(vec)`` or ``(dd_in, vec)``. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. + :arg dd_in: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. :arg nested: return nested object arrays instead of a single multidimensional array if *vec* is non-scalar. :returns: an object array (possibly nested) of :class:`~meshmode.dof_array.DOFArray`\ s or :class:`~arraycontext.ArrayContainer` of object arrays. """ - dd_in = DD_VOLUME_ALL + if len(args) == 1: + vec, = args + dd_in = DD_VOLUME_ALL + elif len(args) == 2: + dd_in, vec = args + else: + raise TypeError("invalid number of arguments") + from grudge.tools import rec_map_subarrays return rec_map_subarrays( partial(_strong_scalar_grad, dcoll, dd_in), @@ -290,7 +305,7 @@ def local_grad( def local_d_dx( - dcoll: DiscretizationCollection, xyz_axis, vec) -> ArrayOrContainer: + dcoll: DiscretizationCollection, xyz_axis, *args) -> ArrayOrContainer: r"""Return the element-local derivative along axis *xyz_axis* of a function :math:`f` represented by *vec*: @@ -298,22 +313,34 @@ def local_d_dx( \frac{\partial f}{\partial \lbrace x,y,z\rbrace}\Big|_E + May be called with ``(vec)`` or ``(dd, vec)``. + :arg xyz_axis: an integer indicating the axis along which the derivative is taken. + :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. :returns: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. """ + if len(args) == 1: + vec, = args + dd = DD_VOLUME_ALL + elif len(args) == 2: + dd, vec = args + else: + raise TypeError("invalid number of arguments") + if not isinstance(vec, DOFArray): - return map_array_container(partial(local_d_dx, dcoll, xyz_axis), vec) + return map_array_container(partial(local_d_dx, dcoll, xyz_axis, dd), vec) - discr = dcoll.discr_from_dd(DD_VOLUME_ALL) + discr = dcoll.discr_from_dd(dd) actx = vec.array_context from grudge.geometry import inverse_surface_metric_derivative_mat - inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, - _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) + inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd, + _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) return _single_axis_derivative_kernel( actx, discr, discr, @@ -321,7 +348,7 @@ def local_d_dx( metric_in_matvec=False) -def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainer: +def local_div(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: r"""Return the element-local divergence of the vector function :math:`\mathbf{f}` represented by *vecs*: @@ -329,6 +356,10 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainer: \nabla|_E \cdot \mathbf{f} = \sum_{i=1}^d \partial_{x_i}|_E \mathbf{f}_i + May be called with ``(vec)`` or ``(dd, vec)``. + + :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. :arg vecs: an object array of :class:`~meshmode.dof_array.DOFArray`\s or an :class:`~arraycontext.ArrayContainer` object @@ -337,13 +368,21 @@ def local_div(dcoll: DiscretizationCollection, vecs) -> ArrayOrContainer: :returns: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. """ + if len(args) == 1: + vec, = args + dd = DD_VOLUME_ALL + elif len(args) == 2: + dd, vec = args + else: + raise TypeError("invalid number of arguments") + from grudge.tools import rec_map_subarrays return rec_map_subarrays( lambda vec: sum( - local_d_dx(dcoll, i, vec_i) + local_d_dx(dcoll, i, dd, vec_i) for i, vec_i in enumerate(vec)), (dcoll.ambient_dim,), (), - vecs, scalar_cls=DOFArray) + vec, scalar_cls=DOFArray) # }}} @@ -396,8 +435,9 @@ def get_ref_stiffness_transpose_mat(out_grp, in_grp): def _weak_scalar_grad(dcoll, dd_in, vec): from grudge.geometry import inverse_surface_metric_derivative_mat + dd_in = as_dofdesc(dd_in) in_discr = dcoll.discr_from_dd(dd_in) - out_discr = dcoll.discr_from_dd(DD_VOLUME_ALL) + out_discr = dcoll.discr_from_dd(dd_in.with_discr_tag(DISCR_TAG_BASE)) actx = vec.array_context inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, @@ -493,8 +533,9 @@ def weak_local_d_dx(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: from grudge.geometry import inverse_surface_metric_derivative_mat + dd_in = as_dofdesc(dd_in) in_discr = dcoll.discr_from_dd(dd_in) - out_discr = dcoll.discr_from_dd(DD_VOLUME_ALL) + out_discr = dcoll.discr_from_dd(dd_in.with_discr_tag(DISCR_TAG_BASE)) actx = vec.array_context inverse_jac_mat = inverse_surface_metric_derivative_mat(actx, dcoll, dd=dd_in, @@ -633,7 +674,7 @@ def mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: *vec* being an :class:`~arraycontext.ArrayContainer`, the mass operator is applied component-wise. - May be called with ``(vec)`` or ``(dd, vec)``. + May be called with ``(vec)`` or ``(dd_in, vec)``. Specifically, this function applies the mass matrix elementwise on a vector of coefficients :math:`\mathbf{f}` via: @@ -645,7 +686,7 @@ def mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: where :math:`\phi_i` are local polynomial basis functions on :math:`E`. - :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + :arg dd_in: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. Defaults to the base volume discretization if not provided. :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. @@ -655,13 +696,15 @@ def mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: if len(args) == 1: vec, = args - dd = dof_desc.DD_VOLUME_ALL + dd_in = dof_desc.DD_VOLUME_ALL elif len(args) == 2: - dd, vec = args + dd_in, vec = args else: raise TypeError("invalid number of arguments") - return _apply_mass_operator(dcoll, DD_VOLUME_ALL, dd, vec) + dd_out = dd_in.with_discr_tag(DISCR_TAG_BASE) + + return _apply_mass_operator(dcoll, dd_out, dd_in, vec) # }}} @@ -719,7 +762,7 @@ def _apply_inverse_mass_operator( return DOFArray(actx, data=tuple(group_data)) -def inverse_mass(dcoll: DiscretizationCollection, vec) -> ArrayOrContainer: +def inverse_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: r"""Return the action of the DG mass matrix inverse on a vector (or vectors) of :class:`~meshmode.dof_array.DOFArray`\ s, *vec*. In the case of *vec* being an :class:`~arraycontext.ArrayContainer`, @@ -749,15 +792,24 @@ def inverse_mass(dcoll: DiscretizationCollection, vec) -> ArrayOrContainer: where :math:`\widehat{\mathbf{M}}` is the reference mass matrix on :math:`\widehat{E}`. + May be called with ``(vec)`` or ``(dd, vec)``. + :arg vec: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` of them. + :arg dd: a :class:`~grudge.dof_desc.DOFDesc`, or a value convertible to one. + Defaults to the base volume discretization if not provided. :returns: a :class:`~meshmode.dof_array.DOFArray` or an :class:`~arraycontext.ArrayContainer` like *vec*. """ + if len(args) == 1: + vec, = args + dd = DD_VOLUME_ALL + elif len(args) == 2: + dd, vec = args + else: + raise TypeError("invalid number of arguments") - return _apply_inverse_mass_operator( - dcoll, DD_VOLUME_ALL, DD_VOLUME_ALL, vec - ) + return _apply_inverse_mass_operator(dcoll, dd, dd, vec) # }}} @@ -855,21 +907,25 @@ def get_ref_face_mass_mat(face_grp, vol_grp): return get_ref_face_mass_mat(face_element_group, vol_element_group) -def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd, vec): +def _apply_face_mass_operator(dcoll: DiscretizationCollection, dd_in, vec): if not isinstance(vec, DOFArray): return map_array_container( - partial(_apply_face_mass_operator, dcoll, dd), vec + partial(_apply_face_mass_operator, dcoll, dd_in), vec ) from grudge.geometry import area_element - volm_discr = dcoll.discr_from_dd(DD_VOLUME_ALL) - face_discr = dcoll.discr_from_dd(dd) + dd_out = DOFDesc( + VolumeDomainTag(dd_in.domain_tag.volume_tag), + DISCR_TAG_BASE) + + volm_discr = dcoll.discr_from_dd(dd_out) + face_discr = dcoll.discr_from_dd(dd_in) dtype = vec.entry_dtype actx = vec.array_context assert len(face_discr.groups) == len(volm_discr.groups) - surf_area_elements = area_element(actx, dcoll, dd=dd, + surf_area_elements = area_element(actx, dcoll, dd=dd_in, _use_geoderiv_connection=actx.supports_nonscalar_broadcasting) return DOFArray( @@ -906,7 +962,7 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: *vec* being an arbitrary :class:`~arraycontext.ArrayContainer`, the face mass operator is applied component-wise. - May be called with ``(vec)`` or ``(dd, vec)``. + May be called with ``(vec)`` or ``(dd_in, vec)``. Specifically, this function applies the face mass matrix elementwise on a vector of coefficients :math:`\mathbf{f}` as the sum of contributions for @@ -937,13 +993,13 @@ def face_mass(dcoll: DiscretizationCollection, *args) -> ArrayOrContainer: if len(args) == 1: vec, = args - dd = DD_VOLUME_ALL.trace(FACE_RESTR_ALL) + dd_in = DD_VOLUME_ALL.trace(FACE_RESTR_ALL) elif len(args) == 2: - dd, vec = args + dd_in, vec = args else: raise TypeError("invalid number of arguments") - return _apply_face_mass_operator(dcoll, dd, vec) + return _apply_face_mass_operator(dcoll, dd_in, vec) # }}} diff --git a/grudge/projection.py b/grudge/projection.py index ba8d4bc3c..e21e02295 100644 --- a/grudge/projection.py +++ b/grudge/projection.py @@ -37,7 +37,11 @@ from arraycontext import ArrayOrContainer from grudge.discretization import DiscretizationCollection -from grudge.dof_desc import as_dofdesc, VolumeDomainTag, ConvertibleToDOFDesc +from grudge.dof_desc import ( + as_dofdesc, + VolumeDomainTag, + BoundaryDomainTag, + ConvertibleToDOFDesc) from numbers import Number @@ -64,6 +68,8 @@ def project( contextual_volume_tag = None if isinstance(src_dofdesc.domain_tag, VolumeDomainTag): contextual_volume_tag = src_dofdesc.domain_tag.tag + elif isinstance(src_dofdesc.domain_tag, BoundaryDomainTag): + contextual_volume_tag = src_dofdesc.domain_tag.volume_tag tgt_dofdesc = as_dofdesc(tgt, _contextual_volume_tag=contextual_volume_tag) diff --git a/grudge/reductions.py b/grudge/reductions.py index ab106c8c4..6087b5725 100644 --- a/grudge/reductions.py +++ b/grudge/reductions.py @@ -344,7 +344,7 @@ def _apply_elementwise_reduction( ) ) else: - @memoize_in(actx, (_apply_elementwise_reduction, + @memoize_in(actx, (_apply_elementwise_reduction, dd, "elementwise_%s_prg" % op_name)) def elementwise_prg(): # FIXME: This computes the reduction value redundantly for each diff --git a/grudge/shortcuts.py b/grudge/shortcuts.py index 0aca64a58..e6e62cc55 100644 --- a/grudge/shortcuts.py +++ b/grudge/shortcuts.py @@ -20,6 +20,8 @@ THE SOFTWARE. """ +from grudge.dof_desc import DD_VOLUME_ALL + from pytools import memoize_in @@ -76,11 +78,14 @@ def set_up_rk4(field_var_name, dt, fields, rhs, t_start=0.0): return dt_stepper -def make_visualizer(dcoll, vis_order=None, **kwargs): +def make_visualizer(dcoll, vis_order=None, volume_dd=None, **kwargs): from meshmode.discretization.visualization import make_visualizer + if volume_dd is None: + volume_dd = DD_VOLUME_ALL + return make_visualizer( dcoll._setup_actx, - dcoll.discr_from_dd("vol"), vis_order, **kwargs) + dcoll.discr_from_dd(volume_dd), vis_order, **kwargs) def make_boundary_visualizer(dcoll, vis_order=None, **kwargs): diff --git a/grudge/trace_pair.py b/grudge/trace_pair.py index 3db3517a6..a20a17643 100644 --- a/grudge/trace_pair.py +++ b/grudge/trace_pair.py @@ -55,7 +55,7 @@ from warnings import warn -from typing import List, Hashable, Optional, Type, Any, Sequence +from typing import List, Hashable, Optional, Tuple, Type, Any, Sequence, Mapping from pytools.persistent_dict import KeyBuilder @@ -64,10 +64,9 @@ ArrayContext, with_container_arithmetic, dataclass_array_container, - get_container_context_recursively, - flatten, to_numpy, - unflatten, from_numpy, - flat_size_and_dtype, + get_container_context_recursively_opt, + to_numpy, + from_numpy, ArrayOrContainer ) @@ -76,9 +75,8 @@ from numbers import Number from pytools import memoize_on_first_arg -from pytools.obj_array import obj_array_vectorize -from grudge.discretization import DiscretizationCollection, _remote_rank_from_btag +from grudge.discretization import DiscretizationCollection, PartID from grudge.projection import project from meshmode.mesh import BTAG_PARTITION @@ -88,7 +86,7 @@ import grudge.dof_desc as dof_desc from grudge.dof_desc import ( DOFDesc, DD_VOLUME_ALL, FACE_RESTR_INTERIOR, DISCR_TAG_BASE, - VolumeTag, VolumeDomainTag, BoundaryDomainTag, BTAG_MULTIVOL_PARTITION, + VolumeTag, VolumeDomainTag, BoundaryDomainTag, ConvertibleToDOFDesc, ) @@ -262,7 +260,7 @@ def bv_trace_pair( DeprecationWarning, stacklevel=2) dd = dof_desc.as_dofdesc(dd) return bdry_trace_pair( - dcoll, dd, project(dcoll, "vol", dd, interior), exterior) + dcoll, dd, project(dcoll, dd.domain_tag.volume_tag, dd, interior), exterior) # }}} @@ -297,16 +295,22 @@ def local_interior_trace_pair( interior = project(dcoll, volume_dd, trace_dd, vec) - def get_opposite_trace(el): - if isinstance(el, Number): - return el + opposite_face_conn = dcoll.opposite_face_connection(trace_dd.domain_tag) + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary else: - assert isinstance(trace_dd.domain_tag, BoundaryDomainTag) - return dcoll.opposite_face_connection(trace_dd.domain_tag)(el) + return opposite_face_conn(ary) - e = obj_array_vectorize(get_opposite_trace, interior) + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + exterior = rec_map_array_container( + get_opposite_trace, + interior, + leaf_class=DOFArray) - return TracePair(trace_dd, interior=interior, exterior=e) + return TracePair(trace_dd, interior=interior, exterior=exterior) def interior_trace_pair(dcoll: DiscretizationCollection, vec) -> TracePair: @@ -364,57 +368,93 @@ def interior_trace_pairs(dcoll: DiscretizationCollection, vec, *, def local_inter_volume_trace_pairs( dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainer, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainer, - ) -> ArrayOrContainer: - if not isinstance(self_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("self_volume_dd must describe a volume") - if not isinstance(other_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("other_volume_dd must describe a volume") - if self_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized self DOFDesc, got '{self_volume_dd}'") - if other_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized other DOFDesc, got '{other_volume_dd}'") - - self_btag = BTAG_MULTIVOL_PARTITION( - other_rank=None, - other_volume_tag=other_volume_dd.domain_tag.tag) - other_btag = BTAG_MULTIVOL_PARTITION( - other_rank=None, - other_volume_tag=self_volume_dd.domain_tag.tag) - - self_trace_dd = self_volume_dd.trace(self_btag) - other_trace_dd = other_volume_dd.trace(other_btag) - - # FIXME: In all likelihood, these traces will be reevaluated from - # the other side, which is hard to prevent given the interface we - # have. Lazy eval will hopefully collapse those redundant evaluations... - self_trace = project( - dcoll, self_volume_dd, self_trace_dd, self_ary) - other_trace = project( - dcoll, other_volume_dd, other_trace_dd, other_ary) - - other_to_self = dcoll._inter_partition_connections[ - BoundaryDomainTag(self_btag, self_volume_dd.domain_tag.tag)] - - def get_opposite_trace(el): - if isinstance(el, Number): - return el - else: - return other_to_self(el) + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]] + ) -> Mapping[Tuple[DOFDesc, DOFDesc], TracePair]: + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") + + rank = ( + dcoll.mpi_communicator.Get_rank() + if dcoll.mpi_communicator is not None + else None) + + result: Mapping[Tuple[DOFDesc, DOFDesc], TracePair] = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + from meshmode.mesh import mesh_has_boundary + if not mesh_has_boundary( + dcoll.discr_from_dd(vol_dd_pair[0]).mesh, + BTAG_PARTITION(PartID(vol_dd_pair[1].domain_tag.tag, rank))): + continue + + directional_vol_dd_pairs = [ + (vol_dd_pair[1], vol_dd_pair[0]), + (vol_dd_pair[0], vol_dd_pair[1])] + + trace_dd_pair = tuple( + self_vol_dd.trace( + BTAG_PARTITION( + PartID(other_vol_dd.domain_tag.tag, rank))) + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs) + + # Pre-compute the projections out here to avoid doing it twice inside + # the loop below + trace_data = { + trace_dd: project(dcoll, vol_dd, trace_dd, vol_data) + for vol_dd, trace_dd, vol_data in zip( + vol_dd_pair, trace_dd_pair, vol_data_pair)} + + for other_vol_dd, self_vol_dd in directional_vol_dd_pairs: + self_part_id = PartID(self_vol_dd.domain_tag.tag, rank) + other_part_id = PartID(other_vol_dd.domain_tag.tag, rank) + + self_trace_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + other_trace_dd = other_vol_dd.trace(BTAG_PARTITION(self_part_id)) + + self_trace_data = trace_data[self_trace_dd] + unswapped_other_trace_data = trace_data[other_trace_dd] + + other_to_self = dcoll._inter_part_connections[ + other_part_id, self_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return other_to_self(ary) # noqa: B023 + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + other_trace_data = rec_map_array_container( + get_opposite_trace, + unswapped_other_trace_data, + leaf_class=DOFArray) + + result[other_vol_dd, self_vol_dd] = TracePair( + self_trace_dd, + interior=self_trace_data, + exterior=other_trace_data) - return TracePair( - self_trace_dd, - interior=self_trace, - exterior=obj_array_vectorize(get_opposite_trace, other_trace)) + return result def inter_volume_trace_pairs(dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainer, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainer, - comm_tag: Hashable = None) -> List[ArrayOrContainer]: + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]], + comm_tag: Hashable = None) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: """ Note that :func:`local_inter_volume_trace_pairs` provides the rank-local contributions if those are needed in isolation. Similarly, @@ -423,13 +463,21 @@ def inter_volume_trace_pairs(dcoll: DiscretizationCollection, """ # TODO documentation - return ( - [local_inter_volume_trace_pairs(dcoll, - self_volume_dd, self_ary, other_volume_dd, other_ary)] - + cross_rank_inter_volume_trace_pairs(dcoll, - self_volume_dd, self_ary, other_volume_dd, other_ary, - comm_tag=comm_tag) - ) + result: Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]] = {} + + local_tpairs = local_inter_volume_trace_pairs(dcoll, pairwise_volume_data) + cross_rank_tpairs = cross_rank_inter_volume_trace_pairs( + dcoll, pairwise_volume_data, comm_tag=comm_tag) + + for directional_vol_dd_pair, tpair in local_tpairs.items(): + result[directional_vol_dd_pair] = [tpair] + + for directional_vol_dd_pair, tpairs in cross_rank_tpairs.items(): + result.setdefault(directional_vol_dd_pair, []).extend(tpairs) + + return result # }}} @@ -442,29 +490,17 @@ def update_for_type(self, key_hash, key: Type[Any]): @memoize_on_first_arg -def _remote_inter_partition_tags( +def _connected_parts( dcoll: DiscretizationCollection, self_volume_tag: VolumeTag, - other_volume_tag: Optional[VolumeTag] = None - ) -> Sequence[BoundaryDomainTag]: - if other_volume_tag is None: - other_volume_tag = self_volume_tag - - result: List[BoundaryDomainTag] = [] - for domain_tag in dcoll._inter_partition_connections: - if isinstance(domain_tag.tag, BTAG_PARTITION): - if domain_tag.volume_tag == self_volume_tag: - result.append(domain_tag) - - elif isinstance(domain_tag.tag, BTAG_MULTIVOL_PARTITION): - if (domain_tag.tag.other_rank is not None - and domain_tag.volume_tag == self_volume_tag - and domain_tag.tag.other_volume_tag == other_volume_tag): - result.append(domain_tag) - - else: - raise AssertionError("unexpected inter-partition tag type encountered: " - f"'{domain_tag.tag}'") + other_volume_tag: VolumeTag + ) -> Sequence[PartID]: + result: List[PartID] = [ + connected_part_id + for connected_part_id, part_id in dcoll._inter_part_connections.keys() + if ( + part_id.volume_tag == self_volume_tag + and connected_part_id.volume_tag == other_volume_tag)] return result @@ -506,22 +542,31 @@ class _RankBoundaryCommunicationEager: def __init__(self, actx: ArrayContext, dcoll: DiscretizationCollection, - domain_tag: BoundaryDomainTag, - *, local_bdry_data: ArrayOrContainer, - send_data: ArrayOrContainer, + *, + local_part_id: PartID, + remote_part_id: PartID, + local_bdry_data: ArrayOrContainer, + remote_bdry_data_template: ArrayOrContainer, comm_tag: Optional[Hashable] = None): comm = dcoll.mpi_communicator assert comm is not None - remote_rank = _remote_rank_from_btag(domain_tag.tag) + remote_rank = remote_part_id.rank assert remote_rank is not None self.dcoll = dcoll self.array_context = actx - self.domain_tag = domain_tag - self.bdry_discr = dcoll.discr_from_dd(domain_tag) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_part_id.volume_tag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) self.local_bdry_data = local_bdry_data + self.remote_bdry_data_template = remote_bdry_data_template self.comm_tag = self.base_comm_tag comm_tag = _sym_tag_to_num_tag(comm_tag) @@ -534,33 +579,73 @@ def __init__(self, # requests is complete, however it is not clear that this is documented # behavior. We hold on to the buffer (via the instance attribute) # as well, just in case. - self.send_data_np = to_numpy(flatten(send_data, actx), actx) - self.send_req = comm.Isend(self.send_data_np, - remote_rank, - tag=self.comm_tag) + self.send_reqs = [] + self.send_data = [] + + def send_single_array(key, local_subary): + if not isinstance(local_subary, Number): + local_subary_np = to_numpy(local_subary, actx) + self.send_reqs.append( + comm.Isend(local_subary_np, remote_rank, tag=self.comm_tag)) + self.send_data.append(local_subary_np) + return local_subary + + self.recv_reqs = [] + self.recv_data = {} + + def recv_single_array(key, remote_subary_template): + if not isinstance(remote_subary_template, Number): + remote_subary_np = np.empty( + remote_subary_template.shape, + remote_subary_template.dtype) + self.recv_reqs.append( + comm.Irecv(remote_subary_np, remote_rank, tag=self.comm_tag)) + self.recv_data[key] = remote_subary_np + return remote_subary_template - recv_size, recv_dtype = flat_size_and_dtype(local_bdry_data) - self.recv_data_np = np.empty(recv_size, recv_dtype) - self.recv_req = comm.Irecv(self.recv_data_np, remote_rank, tag=self.comm_tag) + from arraycontext.container.traversal import rec_keyed_map_array_container + rec_keyed_map_array_container(send_single_array, local_bdry_data) + rec_keyed_map_array_container(recv_single_array, remote_bdry_data_template) def finish(self): - # Wait for the nonblocking receive request to complete before + from mpi4py import MPI + + # Wait for the nonblocking receive requests to complete before # accessing the data - self.recv_req.Wait() + MPI.Request.waitall(self.recv_reqs) + + def finish_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + return from_numpy(self.recv_data[key], self.array_context) + + from arraycontext.container.traversal import rec_keyed_map_array_container + unswapped_remote_bdry_data = rec_keyed_map_array_container( + finish_single_array, self.remote_bdry_data_template) + + remote_to_local = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] - recv_data_flat = from_numpy( - self.recv_data_np, self.array_context) - unswapped_remote_bdry_data = unflatten(self.local_bdry_data, - recv_data_flat, self.array_context) - bdry_conn = self.dcoll._inter_partition_connections[self.domain_tag] - remote_bdry_data = bdry_conn(unswapped_remote_bdry_data) + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) - # Complete the nonblocking send request associated with communicating - # `self.local_bdry_data_np` - self.send_req.Wait() + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + unswapped_remote_bdry_data, + leaf_class=DOFArray) + + # Complete the nonblocking send requests + MPI.Request.waitall(self.send_reqs) return TracePair( - DOFDesc(self.domain_tag, DISCR_TAG_BASE), + self.local_bdry_dd, interior=self.local_bdry_data, exterior=remote_bdry_data) @@ -573,62 +658,110 @@ class _RankBoundaryCommunicationLazy: def __init__(self, actx: ArrayContext, dcoll: DiscretizationCollection, - domain_tag: BoundaryDomainTag, *, + local_part_id: PartID, + remote_part_id: PartID, local_bdry_data: ArrayOrContainer, - send_data: ArrayOrContainer, + remote_bdry_data_template: ArrayOrContainer, comm_tag: Optional[Hashable] = None) -> None: if comm_tag is None: raise ValueError("lazy communication requires 'comm_tag' to be supplied") - self.dcoll = dcoll - self.array_context = actx - self.bdry_discr = dcoll.discr_from_dd(domain_tag) - self.domain_tag = domain_tag - - remote_rank = _remote_rank_from_btag(domain_tag.tag) + remote_rank = remote_part_id.rank assert remote_rank is not None - self.local_bdry_data = local_bdry_data + self.dcoll = dcoll + self.array_context = actx + self.local_bdry_dd = DOFDesc( + BoundaryDomainTag( + BTAG_PARTITION(remote_part_id), + volume_tag=local_part_id.volume_tag), + DISCR_TAG_BASE) + self.bdry_discr = dcoll.discr_from_dd(self.local_bdry_dd) + self.local_part_id = local_part_id + self.remote_part_id = remote_part_id + + from pytato import ( + make_distributed_recv, + make_distributed_send, + DistributedSendRefHolder) + + # TODO: This currently assumes that local_bdry_data and + # remote_bdry_data_template have the same structure. This is not true + # in general. Find a way to staple the sends appropriately when the number + # of recvs is not equal to the number of sends + # FIXME: Overly restrictive (just needs to be the same structure) + assert type(local_bdry_data) == type(remote_bdry_data_template) + + sends = {} + + def send_single_array(key, local_subary): + if isinstance(local_subary, Number): + return + else: + ary_tag = (comm_tag, key) + sends[key] = make_distributed_send( + local_subary, dest_rank=remote_rank, comm_tag=ary_tag) + + def recv_single_array(key, remote_subary_template): + if isinstance(remote_subary_template, Number): + # NOTE: Assumes that the same number is passed on every rank + return remote_subary_template + else: + ary_tag = (comm_tag, key) + return DistributedSendRefHolder( + sends[key], + make_distributed_recv( + src_rank=remote_rank, comm_tag=ary_tag, + shape=remote_subary_template.shape, + dtype=remote_subary_template.dtype, + axes=remote_subary_template.axes)) from arraycontext.container.traversal import rec_keyed_map_array_container - key_to_send_subary = {} - - def store_send_subary(key, send_subary): - key_to_send_subary[key] = send_subary - return send_subary - rec_keyed_map_array_container(store_send_subary, send_data) - - from pytato import make_distributed_recv, staple_distributed_send - - def communicate_single_array(key, local_bdry_subary): - ary_tag = (comm_tag, key) - return staple_distributed_send( - key_to_send_subary[key], dest_rank=remote_rank, comm_tag=ary_tag, - stapled_to=make_distributed_recv( - src_rank=remote_rank, comm_tag=ary_tag, - shape=local_bdry_subary.shape, - dtype=local_bdry_subary.dtype, - axes=local_bdry_subary.axes)) + rec_keyed_map_array_container(send_single_array, local_bdry_data) + self.local_bdry_data = local_bdry_data - self.remote_data = rec_keyed_map_array_container( - communicate_single_array, self.local_bdry_data) + self.unswapped_remote_bdry_data = rec_keyed_map_array_container( + recv_single_array, remote_bdry_data_template) def finish(self): - bdry_conn = self.dcoll._inter_partition_connections[self.domain_tag] + remote_to_local = self.dcoll._inter_part_connections[ + self.remote_part_id, self.local_part_id] + + def get_opposite_trace(ary): + if isinstance(ary, Number): + return ary + else: + return remote_to_local(ary) + + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + remote_bdry_data = rec_map_array_container( + get_opposite_trace, + self.unswapped_remote_bdry_data, + leaf_class=DOFArray) return TracePair( - DOFDesc(self.domain_tag, DISCR_TAG_BASE), + self.local_bdry_dd, interior=self.local_bdry_data, - exterior=bdry_conn(self.remote_data)) + exterior=remote_bdry_data) # }}} # {{{ cross_rank_trace_pairs +def _replace_dof_arrays(array_container, dof_array): + from arraycontext import rec_map_array_container + from meshmode.dof_array import DOFArray + return rec_map_array_container( + lambda x: dof_array if isinstance(x, DOFArray) else x, + array_container, + leaf_class=DOFArray) + + def cross_rank_trace_pairs( dcoll: DiscretizationCollection, ary: ArrayOrContainer, tag: Hashable = None, @@ -637,9 +770,9 @@ def cross_rank_trace_pairs( r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. For each partition boundary, the field data values in *ary* are - communicated to/from the neighboring partition. Presumably, this - communication is MPI (but strictly speaking, may not be, and this - routine is agnostic to the underlying communication). + communicated to/from the neighboring part. Presumably, this communication + is MPI (but strictly speaking, may not be, and this routine is agnostic to + the underlying communication). For each face on each partition boundary, a :class:`TracePair` is created with the locally, and @@ -684,44 +817,74 @@ def cross_rank_trace_pairs( # }}} - comm_bdtags = _remote_inter_partition_tags( - dcoll, self_volume_tag=volume_dd.domain_tag.tag) + if dcoll.mpi_communicator is None: + return [] + + rank = dcoll.mpi_communicator.Get_rank() + + local_part_id = PartID(volume_dd.domain_tag.tag, rank) + + connected_part_ids = _connected_parts( + dcoll, self_volume_tag=volume_dd.domain_tag.tag, + other_volume_tag=volume_dd.domain_tag.tag) + + remote_part_ids = [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] # This asserts that there is only one data exchange per rank, so that # there is no risk of mismatched data reaching the wrong recipient. # (Since we have only a single tag.) - assert len(comm_bdtags) == len( - {_remote_rank_from_btag(bdtag.tag) for bdtag in comm_bdtags}) + assert len(remote_part_ids) == len({part_id.rank for part_id in remote_part_ids}) - if isinstance(ary, Number): - # NOTE: Assumes that the same number is passed on every rank - return [TracePair(DOFDesc(bdtag, DISCR_TAG_BASE), interior=ary, exterior=ary) - for bdtag in comm_bdtags] + actx = get_container_context_recursively_opt(ary) - actx = get_container_context_recursively(ary) - assert actx is not None + if actx is None: + # NOTE: Assumes that the same number is passed on every rank + return [ + TracePair( + volume_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=ary, exterior=ary) + for remote_part_id in remote_part_ids] from grudge.array_context import MPIPytatoArrayContextBase if isinstance(actx, MPIPytatoArrayContextBase): - rbc = _RankBoundaryCommunicationLazy + rbc_class = _RankBoundaryCommunicationLazy else: - rbc = _RankBoundaryCommunicationEager + rbc_class = _RankBoundaryCommunicationEager + + rank_bdry_communicators = [] - def start_comm(bdtag): - local_bdry_data = project( - dcoll, - DOFDesc(VolumeDomainTag(bdtag.volume_tag), DISCR_TAG_BASE), - DOFDesc(bdtag, DISCR_TAG_BASE), - ary) + for remote_part_id in remote_part_ids: + bdry_dd = volume_dd.trace(BTAG_PARTITION(remote_part_id)) - return rbc(actx, dcoll, bdtag, - local_bdry_data=local_bdry_data, - send_data=local_bdry_data, - comm_tag=comm_tag) + local_bdry_data = project(dcoll, volume_dd, bdry_dd, ary) - rank_bdry_communcators = [start_comm(bdtag) for bdtag in comm_bdtags] - return [rc.finish() for rc in rank_bdry_communcators] + from arraycontext import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + remote_bdry_zeros = tag_axes( + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + dcoll._inter_part_connections[ + remote_part_id, local_part_id].from_discr.zeros(actx)) + + remote_bdry_data_template = _replace_dof_arrays( + local_bdry_data, remote_bdry_zeros) + + rank_bdry_communicators.append( + rbc_class(actx, dcoll, + local_part_id=local_part_id, + remote_part_id=remote_part_id, + local_bdry_data=local_bdry_data, + remote_bdry_data_template=remote_bdry_data_template, + comm_tag=comm_tag)) + + return [rbc.finish() for rbc in rank_bdry_communicators] # }}} @@ -730,10 +893,13 @@ def start_comm(bdtag): def cross_rank_inter_volume_trace_pairs( dcoll: DiscretizationCollection, - self_volume_dd: DOFDesc, self_ary: ArrayOrContainer, - other_volume_dd: DOFDesc, other_ary: ArrayOrContainer, + pairwise_volume_data: Mapping[ + Tuple[DOFDesc, DOFDesc], + Tuple[ArrayOrContainer, ArrayOrContainer]], *, comm_tag: Hashable = None, - ) -> List[TracePair]: + ) -> Mapping[ + Tuple[DOFDesc, DOFDesc], + List[TracePair]]: # FIXME: Should this interface take in boundary data instead? # TODO: Docs r"""Get a :class:`list` of *ary* trace pairs for each partition boundary. @@ -747,60 +913,109 @@ def cross_rank_inter_volume_trace_pairs( """ # {{{ process arguments - if not isinstance(self_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("self_volume_dd must describe a volume") - if not isinstance(other_volume_dd.domain_tag, VolumeDomainTag): - raise ValueError("other_volume_dd must describe a volume") - if self_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized self DOFDesc, got '{self_volume_dd}'") - if other_volume_dd.discretization_tag != DISCR_TAG_BASE: - raise TypeError( - f"expected a base-discretized other DOFDesc, got '{other_volume_dd}'") + for vol_dd_pair in pairwise_volume_data.keys(): + for vol_dd in vol_dd_pair: + if not isinstance(vol_dd.domain_tag, VolumeDomainTag): + raise ValueError( + "pairwise_volume_data keys must describe volumes, " + f"got '{vol_dd}'") + if vol_dd.discretization_tag != DISCR_TAG_BASE: + raise ValueError( + "expected base-discretized DOFDesc in pairwise_volume_data, " + f"got '{vol_dd}'") # }}} - comm_bdtags = _remote_inter_partition_tags( - dcoll, - self_volume_tag=self_volume_dd.domain_tag.tag, - other_volume_tag=other_volume_dd.domain_tag.tag) - - # This asserts that there is only one data exchange per rank, so that - # there is no risk of mismatched data reaching the wrong recipient. - # (Since we have only a single tag.) - assert len(comm_bdtags) == len( - {_remote_rank_from_btag(bdtag.tag) for bdtag in comm_bdtags}) - - actx = get_container_context_recursively(self_ary) - assert actx is not None + if dcoll.mpi_communicator is None: + return {} + + rank = dcoll.mpi_communicator.Get_rank() + + for vol_data_pair in pairwise_volume_data.values(): + for vol_data in vol_data_pair: + actx = get_container_context_recursively_opt(vol_data) + if actx is not None: + break + if actx is not None: + break + + def get_remote_connected_parts(local_vol_dd, remote_vol_dd): + connected_part_ids = _connected_parts( + dcoll, self_volume_tag=local_vol_dd.domain_tag.tag, + other_volume_tag=remote_vol_dd.domain_tag.tag) + return [ + part_id + for part_id in connected_part_ids + if part_id.rank != rank] + + if actx is None: + # NOTE: Assumes that the same number is passed on every rank for a + # given volume + return { + (remote_vol_dd, local_vol_dd): [ + TracePair( + local_vol_dd.trace(BTAG_PARTITION(remote_part_id)), + interior=local_vol_ary, exterior=remote_vol_ary) + for remote_part_id in get_remote_connected_parts( + local_vol_dd, remote_vol_dd)] + for (remote_vol_dd, local_vol_dd), (remote_vol_ary, local_vol_ary) + in pairwise_volume_data.items()} from grudge.array_context import MPIPytatoArrayContextBase if isinstance(actx, MPIPytatoArrayContextBase): - rbc = _RankBoundaryCommunicationLazy + rbc_class = _RankBoundaryCommunicationLazy else: - rbc = _RankBoundaryCommunicationEager - - def start_comm(bdtag): - assert isinstance(bdtag.tag, BTAG_MULTIVOL_PARTITION) - self_volume_dd = DOFDesc( - VolumeDomainTag(bdtag.volume_tag), DISCR_TAG_BASE) - other_volume_dd = DOFDesc( - VolumeDomainTag(bdtag.tag.other_volume_tag), DISCR_TAG_BASE) - - local_bdry_data = project(dcoll, self_volume_dd, bdtag, self_ary) - send_data = project(dcoll, other_volume_dd, - BTAG_MULTIVOL_PARTITION( - other_rank=bdtag.tag.other_rank, - other_volume_tag=bdtag.volume_tag), other_ary) - - return rbc(actx, dcoll, bdtag, - local_bdry_data=local_bdry_data, - send_data=send_data, - comm_tag=comm_tag) - - rank_bdry_communcators = [start_comm(bdtag) for bdtag in comm_bdtags] - return [rc.finish() for rc in rank_bdry_communcators] + rbc_class = _RankBoundaryCommunicationEager + + rank_bdry_communicators = {} + + for vol_dd_pair, vol_data_pair in pairwise_volume_data.items(): + directional_volume_data = { + (vol_dd_pair[0], vol_dd_pair[1]): (vol_data_pair[0], vol_data_pair[1]), + (vol_dd_pair[1], vol_dd_pair[0]): (vol_data_pair[1], vol_data_pair[0])} + + for dd_pair, data_pair in directional_volume_data.items(): + other_vol_dd, self_vol_dd = dd_pair + other_vol_data, self_vol_data = data_pair + + self_part_id = PartID(self_vol_dd.domain_tag.tag, rank) + other_part_ids = get_remote_connected_parts(self_vol_dd, other_vol_dd) + + rbcs = [] + + for other_part_id in other_part_ids: + self_bdry_dd = self_vol_dd.trace(BTAG_PARTITION(other_part_id)) + self_bdry_data = project( + dcoll, self_vol_dd, self_bdry_dd, self_vol_data) + + from arraycontext import tag_axes + from meshmode.transform_metadata import ( + DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) + other_bdry_zeros = tag_axes( + actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + dcoll._inter_part_connections[ + other_part_id, self_part_id].from_discr.zeros(actx)) + + other_bdry_data_template = _replace_dof_arrays( + other_vol_data, other_bdry_zeros) + + rbcs.append( + rbc_class(actx, dcoll, + local_part_id=self_part_id, + remote_part_id=other_part_id, + local_bdry_data=self_bdry_data, + remote_bdry_data_template=other_bdry_data_template, + comm_tag=comm_tag)) + + rank_bdry_communicators[other_vol_dd, self_vol_dd] = rbcs + + return { + directional_vol_dd_pair: [rbc.finish() for rbc in rbcs] + for directional_vol_dd_pair, rbcs in rank_bdry_communicators.items()} # }}} diff --git a/requirements.txt b/requirements.txt index 6d8841e9a..2107e5aeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/arraycontext.git#egg=arraycontext -git+https://github.com/inducer/meshmode.git@generic-part-bdry#egg=meshmode +git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle diff --git a/test/test_grudge.py b/test/test_grudge.py index a0bb9ac18..819752098 100644 --- a/test/test_grudge.py +++ b/test/test_grudge.py @@ -1084,22 +1084,16 @@ def test_multi_volume(actx_factory): nelements_per_axis=(8,)*dim, order=4) meg, = mesh.groups - part_per_element = ( - mesh.vertices[0, meg.vertex_indices[:, 0]] > 0).astype(np.int32) + x = mesh.vertices[0, meg.vertex_indices] + x_elem_avg = np.sum(x, axis=1)/x.shape[1] + volume_to_elements = { + 0: np.where(x_elem_avg <= 0)[0], + 1: np.where(x_elem_avg > 0)[0]} from meshmode.mesh.processing import partition_mesh - from grudge.discretization import relabel_partitions - parts = { - i: relabel_partitions( - partition_mesh(mesh, part_per_element, i)[0], - self_rank=0, - part_nr_to_rank_and_vol_tag={ - 0: (0, 0), - 1: (0, 1), - }) - for i in range(2)} - - make_discretization_collection(actx, parts, order=4) + volume_to_mesh = partition_mesh(mesh, volume_to_elements) + + make_discretization_collection(actx, volume_to_mesh, order=4) # }}}