diff --git a/constraints.txt b/constraints.txt index e84fb5de3..e81402d1a 100644 --- a/constraints.txt +++ b/constraints.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.8 # To update, run: # # pip-compile --output-file=constraints.txt driver/requirements.txt dsl/requirements.txt external/gt4py/setup.cfg fv3core/requirements.txt fv3gfs-physics/requirements.txt pace-util/requirements.txt requirements_dev.txt requirements_docs.txt requirements_lint.txt @@ -34,7 +34,7 @@ attrs==21.2.0 # pytest babel==2.9.1 # via sphinx -backports.entry-points-selectable==1.1.1 +backports-entry-points-selectable==1.1.1 # via virtualenv black==22.3.0 # via @@ -76,8 +76,6 @@ click==8.0.1 # pip-tools cloudpickle==2.0.0 # via dask -cmake==3.22.4 - # via dace commonmark==0.9.1 # via recommonmark coverage==5.5 @@ -88,6 +86,13 @@ cytoolz==0.11.2 # via # gt4py # gt4py (external/gt4py/setup.cfg) +dace==0.14 + # via + # -r driver/requirements.txt + # -r dsl/requirements.txt + # -r fv3core/requirements/requirements_dace.txt + # -r requirements_dev.txt + # pace-dsl dacite==1.6.0 # via # -r driver/requirements.txt @@ -105,8 +110,6 @@ dill==0.3.5.1 # via dace distlib==0.3.2 # via virtualenv -distro==1.7.0 - # via scikit-build docutils==0.16 # via # recommonmark @@ -163,8 +166,6 @@ google-api-core==2.0.0 # via # google-cloud-core # google-cloud-storage -google-auth-oauthlib==0.4.5 - # via gcsfs google-auth==2.0.1 # via # gcsfs @@ -172,6 +173,8 @@ google-auth==2.0.1 # google-auth-oauthlib # google-cloud-core # google-cloud-storage +google-auth-oauthlib==0.4.5 + # via gcsfs google-cloud-core==2.0.0 # via google-cloud-storage google-cloud-storage==1.42.0 @@ -238,15 +241,15 @@ multidict==5.1.0 # via # aiohttp # yarl +mypy==0.790 + # via + # -r fv3gfs-physics/requirements.txt + # -r pace-util/requirements.txt mypy-extensions==0.4.3 # via # black # mypy # typing-inspect -mypy==0.790 - # via - # -r fv3gfs-physics/requirements.txt - # -r pace-util/requirements.txt netcdf4==1.5.7 # via # -r driver/requirements.txt @@ -292,7 +295,6 @@ packaging==21.0 # gt4py # gt4py (external/gt4py/setup.cfg) # pytest - # scikit-build # sphinx # tox pandas==1.3.2 @@ -326,12 +328,12 @@ py==1.10.0 # pytest # pytest-forked # tox -pyasn1-modules==0.2.8 - # via google-auth pyasn1==0.4.8 # via # pyasn1-modules # rsa +pyasn1-modules==0.2.8 + # via google-auth pybind11==2.8.1 # via # gt4py @@ -350,6 +352,21 @@ pygments==2.10.0 # via sphinx pyparsing==2.4.7 # via packaging +pytest==6.2.4 + # via + # -r driver/requirements.txt + # -r fv3core/requirements/requirements_base.txt + # -r requirements_dev.txt + # pytest-cache + # pytest-cov + # pytest-datadir + # pytest-dependency + # pytest-factoryboy + # pytest-forked + # pytest-profiling + # pytest-regressions + # pytest-subtests + # pytest-xdist pytest-cache==1.0 # via -r fv3core/requirements/requirements_base.txt pytest-cov==2.12.1 @@ -378,21 +395,6 @@ pytest-subtests==0.5.0 # -r requirements_dev.txt pytest-xdist==2.3.0 # via -r fv3core/requirements/requirements_base.txt -pytest==6.2.4 - # via - # -r driver/requirements.txt - # -r fv3core/requirements/requirements_base.txt - # -r requirements_dev.txt - # pytest-cache - # pytest-cov - # pytest-datadir - # pytest-dependency - # pytest-factoryboy - # pytest-forked - # pytest-profiling - # pytest-regressions - # pytest-subtests - # pytest-xdist python-dateutil==2.8.2 # via # faker @@ -411,8 +413,6 @@ pyyaml==5.4.1 # pytest-regressions recommonmark==0.7.1 # via -r requirements_docs.txt -requests-oauthlib==1.3.0 - # via google-auth-oauthlib requests==2.26.0 # via # dace @@ -421,10 +421,10 @@ requests==2.26.0 # google-cloud-storage # requests-oauthlib # sphinx +requests-oauthlib==1.3.0 + # via google-auth-oauthlib rsa==4.7.2 # via google-auth -scikit-build==0.15.0 - # via dace scipy==1.7.1 # via # -r fv3core/requirements/requirements_base.txt @@ -447,6 +447,13 @@ snowballstemmer==2.1.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis +sphinx==4.1.2 + # via + # -r requirements_docs.txt + # recommonmark + # sphinx-argparse + # sphinx-gallery + # sphinx-rtd-theme sphinx-argparse==0.3.1 # via -r requirements_docs.txt sphinx-gallery==0.10.1 @@ -455,13 +462,6 @@ sphinx-rtd-theme==0.5.2 # via # -r pace-util/requirements.txt # -r requirements_docs.txt -sphinx==4.1.2 - # via - # -r requirements_docs.txt - # recommonmark - # sphinx-argparse - # sphinx-gallery - # sphinx-rtd-theme sphinxcontrib-applehelp==1.0.2 # via sphinx sphinxcontrib-devhelp==1.0.2 @@ -533,7 +533,6 @@ wheel==0.37.0 # -r pace-util/requirements.txt # astunparse # pip-tools - # scikit-build xarray==0.19.0 # via # -r driver/requirements.txt diff --git a/driver/pace/driver/tools.py b/driver/pace/driver/tools.py new file mode 100644 index 000000000..149baac43 --- /dev/null +++ b/driver/pace/driver/tools.py @@ -0,0 +1,35 @@ +import os +from typing import Optional + +import click + +from pace.dsl.dace.utils import count_memory_from_path + + +# Count the memory from a given SDFG +ACTION_SDFG_MEMORY_COUNT = "sdfg_memory_count" + + +@click.command() +@click.argument( + "action", + required=True, + type=click.Choice([ACTION_SDFG_MEMORY_COUNT]), +) +@click.option( + "--sdfg_path", + type=click.STRING, +) +@click.option("--report_detail", is_flag=True, type=click.BOOL, default=False) +def command_line(action: str, sdfg_path: Optional[str], report_detail: Optional[bool]): + """ + Run tooling. + """ + if action == ACTION_SDFG_MEMORY_COUNT: + if sdfg_path is None or not os.path.exists(sdfg_path): + raise RuntimeError(f"Can't load SDFG {sdfg_path}") + print(count_memory_from_path(sdfg_path, detail_report=report_detail)) + + +if __name__ == "__main__": + command_line() diff --git a/driver/requirements.txt b/driver/requirements.txt index da6e67e40..3833c6e06 100644 --- a/driver/requirements.txt +++ b/driver/requirements.txt @@ -7,4 +7,4 @@ numpy netCDF4 xarray zarr -git+https://github.com/spcl/dace.git@v0.14rc2 +dace>=0.14 diff --git a/dsl/pace/dsl/dace/build.py b/dsl/pace/dsl/dace/build.py index c803d71dd..15b56cf21 100644 --- a/dsl/pace/dsl/dace/build.py +++ b/dsl/pace/dsl/dace/build.py @@ -110,6 +110,7 @@ def get_sdfg_path( """ import os + # TODO: check DaceConfig for cache.strategy == name # Guarding against bad usage of this function if config.get_orchestrate() != DaCeOrchestration.Run: return None diff --git a/dsl/pace/dsl/dace/dace_config.py b/dsl/pace/dsl/dace/dace_config.py index 3fde36388..0d7040c01 100644 --- a/dsl/pace/dsl/dace/dace_config.py +++ b/dsl/pace/dsl/dace/dace_config.py @@ -7,6 +7,12 @@ from pace.util.communicator import CubedSphereCommunicator +# TODO (floriand): Temporary deactivate the distributed compiled +# until we deal with the Grid data inlining during orchestration +# See github issue #301 +DEACTIVATE_DISTRIBUTED_DACE_COMPILE = True + + class DaCeOrchestration(enum.Enum): """ Orchestration mode for DaCe @@ -139,7 +145,12 @@ def __init__( if communicator: self.my_rank = communicator.rank self.rank_size = communicator.comm.Get_size() - self.target_rank = get_target_rank(self.my_rank, communicator.partitioner) + if DEACTIVATE_DISTRIBUTED_DACE_COMPILE: + self.target_rank = communicator.rank + else: + self.target_rank = get_target_rank( + self.my_rank, communicator.partitioner + ) self.layout = communicator.partitioner.layout else: self.my_rank = 0 @@ -170,6 +181,9 @@ def get_backend(self) -> str: def get_orchestrate(self) -> DaCeOrchestration: return self._orchestrate + def get_sync_debug(self) -> bool: + return dace.config.Config.get("compiler", "cuda", "syncdebug") + def as_dict(self) -> Dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), diff --git a/dsl/pace/dsl/dace/orchestration.py b/dsl/pace/dsl/dace/orchestration.py index 437d21d22..1c058f938 100644 --- a/dsl/pace/dsl/dace/orchestration.py +++ b/dsl/pace/dsl/dace/orchestration.py @@ -17,9 +17,13 @@ unblock_waiting_tiles, write_build_info, ) -from pace.dsl.dace.dace_config import DaceConfig, DaCeOrchestration +from pace.dsl.dace.dace_config import ( + DEACTIVATE_DISTRIBUTED_DACE_COMPILE, + DaceConfig, + DaCeOrchestration, +) from pace.dsl.dace.sdfg_opt_passes import splittable_region_expansion -from pace.dsl.dace.utils import DaCeProgress +from pace.dsl.dace.utils import DaCeProgress, count_memory, sdfg_nan_checker from pace.util.mpi import MPI @@ -107,7 +111,10 @@ def _build_sdfg( daceprog: DaceProgram, sdfg: dace.SDFG, config: DaceConfig, args, kwargs ): """Build the .so out of the SDFG on the top tile ranks only""" - is_compiling = determine_compiling_ranks(config) + if DEACTIVATE_DISTRIBUTED_DACE_COMPILE: + is_compiling = True + else: + is_compiling = determine_compiling_ranks(config) if is_compiling: # Make the transients array persistents if config.is_gpu_backend(): @@ -146,11 +153,39 @@ def _build_sdfg( with DaCeProgress(config, "Simplify (2/2)"): sdfg.simplify(validate=False, verbose=True) + # Move all memory that can be into a pool to lower memory pressure. + # Change Persistent memory (sub-SDFG) into Scope and flag it. + with DaCeProgress(config, "Turn Persistents into pooled Scope"): + memory_pooled = 0.0 + for _sd, _aname, arr in sdfg.arrays_recursive(): + if arr.lifetime == dace.AllocationLifetime.Persistent: + arr.pool = True + memory_pooled += arr.total_size * arr.dtype.bytes + arr.lifetime = dace.AllocationLifetime.Scope + memory_pooled = float(memory_pooled) / (1024 * 1024) + DaCeProgress.log( + DaCeProgress.default_prefix(config), + f"Pooled {memory_pooled} mb", + ) + # Compile with DaCeProgress(config, "Codegen & compile"): sdfg.compile() write_build_info(sdfg, config.layout, config.tile_resolution, config._backend) + # Set of debug tools inserted in the SDFG when dace.conf "syncdebug" + # is turned on. + if config.get_sync_debug(): + with DaCeProgress(config, "Debug tooling (NaNChecker)"): + sdfg_nan_checker(sdfg) + + # Printing analysis of the compiled SDFG + with DaCeProgress(config, "Build finished. Running memory static analysis"): + DaCeProgress.log( + DaCeProgress.default_prefix(config), + count_memory(sdfg), + ) + # Compilation done, either exit or scatter/gather and run # DEV NOTE: we explicitly use MPI.COMM_WORLD here because it is # a true multi-machine sync, outside of our own communicator class. @@ -159,13 +194,19 @@ def _build_sdfg( # against scattering when no other ranks are present. if config.get_orchestrate() == DaCeOrchestration.Build: MPI.COMM_WORLD.Barrier() # Protect against early exist which kill SLURM jobs - DaCeProgress.log(config, "Compilation finished and saved, exiting.") + DaCeProgress.log( + DaCeProgress.default_prefix(config), + "Build only, exiting.", + ) exit(0) elif config.get_orchestrate() == DaCeOrchestration.BuildAndRun: MPI.COMM_WORLD.Barrier() if is_compiling: - unblock_waiting_tiles(MPI.COMM_WORLD, sdfg.build_folder) - DaCeProgress.log(config, "Build folder exchanged.") + if not DEACTIVATE_DISTRIBUTED_DACE_COMPILE: + unblock_waiting_tiles(MPI.COMM_WORLD, sdfg.build_folder) + DaCeProgress.log( + DaCeProgress.default_prefix(config), "Build folder exchanged." + ) with DaCeProgress(config, "Run"): res = sdfg(**sdfg_kwargs) res = _download_results_from_dace( @@ -174,9 +215,12 @@ def _build_sdfg( else: source_rank = config.target_rank # wait for compilation to be done - DaCeProgress.log(config, "Rank is not compiling. Waiting for build dir...") + DaCeProgress.log( + DaCeProgress.default_prefix(config), + "Rank is not compiling. Waiting for build dir...", + ) sdfg_path = MPI.COMM_WORLD.recv(source=source_rank) - DaCeProgress.log(config, "Build dir received.") + DaCeProgress.log(DaCeProgress.default_prefix(config), "Build dir received.") daceprog.load_precompiled_sdfg(sdfg_path, *args, **kwargs) with DaCeProgress(config, "Run"): res = _run_sdfg(daceprog, config, args, kwargs) @@ -216,7 +260,10 @@ def _parse_sdfg( """ sdfg_path = get_sdfg_path(daceprog.name, config) if sdfg_path is None: - is_compiling = determine_compiling_ranks(config) + if DEACTIVATE_DISTRIBUTED_DACE_COMPILE: + is_compiling = True + else: + is_compiling = determine_compiling_ranks(config) if not is_compiling: # We can not parse the SDFG since we will load the proper # compiled SDFG from the compiling rank diff --git a/dsl/pace/dsl/dace/sdfg_opt_passes.py b/dsl/pace/dsl/dace/sdfg_opt_passes.py index e1da15cff..08cf4e1fc 100644 --- a/dsl/pace/dsl/dace/sdfg_opt_passes.py +++ b/dsl/pace/dsl/dace/sdfg_opt_passes.py @@ -24,4 +24,4 @@ def splittable_region_expansion(sdfg: dace.SDFG, verbose: bool = False): "K", ] if verbose: - logger.info("Reordered schedule for", node.label) + logger.info(f"Reordered schedule for {node.label}") diff --git a/dsl/pace/dsl/dace/utils.py b/dsl/pace/dsl/dace/utils.py index 3a5efb20f..5fbe18b54 100644 --- a/dsl/pace/dsl/dace/utils.py +++ b/dsl/pace/dsl/dace/utils.py @@ -1,6 +1,7 @@ import logging import time -from typing import List, Tuple +from dataclasses import dataclass, field +from typing import Dict, List, Tuple import dace from dace.transformation.helpers import get_parent_map @@ -12,7 +13,10 @@ # Rough timer & log for major operations of DaCe build stack class DaCeProgress: + """Timer and log to track build progress""" + def __init__(self, config: DaceConfig, label: str): + self.prefix = DaCeProgress.default_prefix(config) self.prefix = f"[{config.get_orchestrate()}]" self.label = label @@ -20,6 +24,10 @@ def __init__(self, config: DaceConfig, label: str): def log(cls, prefix: str, message: str): logger.info(f"{prefix} {message}") + @classmethod + def default_prefix(cls, config: DaceConfig) -> str: + return f"[{config.get_orchestrate()}]" + def __enter__(self): DaCeProgress.log(self.prefix, f"{self.label}...") self.start = time.time() @@ -120,4 +128,119 @@ def evaluate(expr): output_nodes={newnode.data: newnode}, external_edges=True, ) - print(f"Added {len(checks)} NaN checks") + logger.info(f"Added {len(checks)} NaN checks") + + +def _is_ref(sd: dace.sdfg.SDFG, aname: str): + found = False + for node, state in sd.all_nodes_recursive(): + if not isinstance(state, dace.sdfg.SDFGState): + continue + if state.parent is sd: + if isinstance(node, dace.nodes.AccessNode) and aname == node.data: + found = True + break + + return found + + +@dataclass +class ArrayReport: + name: str = "" + total_size_in_bytes: int = 0 + referenced: bool = False + transient: bool = False + pool: bool = False + top_level: bool = False + + +@dataclass +class StorageReport: + name: str = "" + referenced_in_bytes: int = 0 + unreferenced_in_bytes: int = 0 + in_pooled_in_bytes: int = 0 + top_level_in_bytes: int = 0 + details: List[ArrayReport] = field(default_factory=list) + + +def count_memory(sdfg: dace.sdfg.SDFG, detail_report=False) -> str: + allocations: Dict[dace.StorageType, StorageReport] = {} + for storage_type in dace.StorageType: + allocations[storage_type] = StorageReport(name=storage_type) + + for sd, aname, arr in sdfg.arrays_recursive(): + array_size_in_bytes = arr.total_size * arr.dtype.bytes + ref = _is_ref(sd, aname) + + if sd is not sdfg and arr.transient: + if arr.pool: + allocations[arr.storage].in_pooled_in_bytes += array_size_in_bytes + allocations[arr.storage].details.append( + ArrayReport( + name=aname, + total_size_in_bytes=array_size_in_bytes, + referenced=ref, + transient=arr.transient, + pool=arr.pool, + top_level=False, + ) + ) + if ref: + allocations[arr.storage].referenced_in_bytes += array_size_in_bytes + else: + allocations[arr.storage].unreferenced_in_bytes += array_size_in_bytes + + elif sd is sdfg: + if arr.pool: + allocations[arr.storage].in_pooled_in_bytes += array_size_in_bytes + allocations[arr.storage].details.append( + ArrayReport( + name=aname, + total_size_in_bytes=array_size_in_bytes, + referenced=ref, + transient=arr.transient, + pool=arr.pool, + top_level=True, + ) + ) + allocations[arr.storage].top_level_in_bytes += array_size_in_bytes + if ref: + allocations[arr.storage].referenced_in_bytes += array_size_in_bytes + else: + allocations[arr.storage].unreferenced_in_bytes += array_size_in_bytes + + report = f"{sdfg.name}:\n" + for storage, allocs in allocations.items(): + alloc_in_mb = float(allocs.referenced_in_bytes / (1024 * 1024)) + unref_alloc_in_mb = float(allocs.unreferenced_in_bytes / (1024 * 1024)) + in_pooled_in_mb = float(allocs.in_pooled_in_bytes / (1024 * 1024)) + toplvlalloc_in_mb = float(allocs.top_level_in_bytes / (1024 * 1024)) + if alloc_in_mb or toplvlalloc_in_mb > 0: + report += ( + f"{storage}:\n" + f" Alloc ref {alloc_in_mb:.2f} mb\n" + f" Alloc unref {unref_alloc_in_mb:.2f} mb\n" + f" Pooled {in_pooled_in_mb:.2f} mb\n" + f" Top lvl alloc: {toplvlalloc_in_mb:.2f}mb\n" + ) + if detail_report: + report += "\n" + report += " Referenced\tTransient \tPooled\tTotal size(mb)\tName\n" + for detail in allocs.details: + size_in_mb = float(detail.total_size_in_bytes / (1024 * 1024)) + ref_str = " X " if detail.referenced else " " + transient_str = " X " if detail.transient else " " + pooled_str = " X " if detail.pool else " " + report += ( + f" {ref_str}\t{transient_str}" + f"\t {pooled_str}" + f"\t {size_in_mb:.2f}" + f"\t {detail.name}\n" + ) + + return report + + +def count_memory_from_path(sdfg_path: str, detail_report=False) -> str: + return count_memory(dace.SDFG.from_file(sdfg_path), detail_report=detail_report) diff --git a/dsl/pace/dsl/gt4py_utils.py b/dsl/pace/dsl/gt4py_utils.py index f031a6fcb..12b010ffc 100644 --- a/dsl/pace/dsl/gt4py_utils.py +++ b/dsl/pace/dsl/gt4py_utils.py @@ -235,9 +235,6 @@ def make_storage_from_shape( backend: str, dtype: DTypes = np.float64, mask: Optional[Tuple[bool, bool, bool]] = None, - # [TODO]: temporary storage should be lowered properly to DaCe - # and added elsewhere (e.g., remapping) - is_temporary: bool = False, ) -> Field: """Create a new gt4py storage of a given shape filled with zeros. @@ -272,8 +269,6 @@ def make_storage_from_shape( mask=mask, managed_memory=managed_memory, ) - if is_temporary: - storage._istransient = True return storage diff --git a/dsl/requirements.txt b/dsl/requirements.txt index 9cf8f16c6..89186a9fa 100644 --- a/dsl/requirements.txt +++ b/dsl/requirements.txt @@ -1,2 +1,2 @@ # We can only pip-compile external dependencies, and indirectly depend on dace for development -git+https://github.com/spcl/dace.git@v0.14rc2 +dace>=0.14 diff --git a/fv3core/fv3core/stencils/c_sw.py b/fv3core/fv3core/stencils/c_sw.py index d2d1c0064..8927d1630 100644 --- a/fv3core/fv3core/stencils/c_sw.py +++ b/fv3core/fv3core/stencils/c_sw.py @@ -525,7 +525,6 @@ def make_storage(): return utils.make_storage_from_shape( grid_indexing.max_shape, backend=stencil_factory.backend, - is_temporary=False, ) self._tmp_ke = make_storage() diff --git a/fv3core/fv3core/stencils/d_sw.py b/fv3core/fv3core/stencils/d_sw.py index 64f807ea0..4ec15d920 100644 --- a/fv3core/fv3core/stencils/d_sw.py +++ b/fv3core/fv3core/stencils/d_sw.py @@ -710,7 +710,6 @@ def make_storage(): return utils.make_storage_from_shape( self.grid_indexing.max_shape, backend=stencil_factory.backend, - is_temporary=False, ) self._tmp_heat_s = make_storage() @@ -735,7 +734,6 @@ def make_storage(): self._tmp_damp_3d = utils.make_storage_from_shape( (1, 1, self.grid_indexing.domain[2]), backend=stencil_factory.backend, - is_temporary=False, ) self._column_namelist = column_namelist diff --git a/fv3core/fv3core/stencils/delnflux.py b/fv3core/fv3core/stencils/delnflux.py index a46c1f65a..79f5b8028 100644 --- a/fv3core/fv3core/stencils/delnflux.py +++ b/fv3core/fv3core/stencils/delnflux.py @@ -971,19 +971,18 @@ def __init__( k_shape = (1, 1, nk) self._damp_3d = utils.make_storage_from_shape( - k_shape, backend=stencil_factory.backend, is_temporary=False + k_shape, backend=stencil_factory.backend ) # fields must be 3d to assign to them self._fx2 = utils.make_storage_from_shape( - shape, backend=stencil_factory.backend, is_temporary=False + shape, backend=stencil_factory.backend ) self._fy2 = utils.make_storage_from_shape( - shape, backend=stencil_factory.backend, is_temporary=False + shape, backend=stencil_factory.backend ) self._d2 = utils.make_storage_from_shape( grid_indexing.domain_full(), backend=stencil_factory.backend, - is_temporary=False, ) damping_factor_calculation = stencil_factory.from_origin_domain( @@ -1193,12 +1192,6 @@ def __init__( corner_domain = corner_domain[:2] + (nk,) corner_axis_offsets = grid_indexing.axis_offsets(corner_origin, corner_domain) - self._corner_tmp = utils.make_storage_from_shape( - corner_domain, - origin=corner_origin, - backend=stencil_factory.backend, - is_temporary=False, - ) self._copy_corners_x_nord = stencil_factory.from_origin_domain( copy_corners_x_nord, externals={**corner_axis_offsets, **nord_dictionary}, diff --git a/fv3core/fv3core/stencils/fvtp2d.py b/fv3core/fv3core/stencils/fvtp2d.py index 7d2576568..00e5e515b 100644 --- a/fv3core/fv3core/stencils/fvtp2d.py +++ b/fv3core/fv3core/stencils/fvtp2d.py @@ -150,7 +150,6 @@ def make_storage(): idx.max_shape, origin=origin, backend=stencil_factory.backend, - is_temporary=False, ) self._q_advected_y = make_storage() diff --git a/fv3core/fv3core/stencils/pk3_halo.py b/fv3core/fv3core/stencils/pk3_halo.py index 1aa5e5cc1..19510078a 100644 --- a/fv3core/fv3core/stencils/pk3_halo.py +++ b/fv3core/fv3core/stencils/pk3_halo.py @@ -55,7 +55,6 @@ def __init__(self, stencil_factory: StencilFactory): shape_2D, grid_indexing.origin_full(), backend=stencil_factory.backend, - is_temporary=False, ) def __call__(self, pk3: FloatField, delp: FloatField, ptop: float, akap: float): diff --git a/fv3core/fv3core/stencils/riem_solver3.py b/fv3core/fv3core/stencils/riem_solver3.py index a395d7970..fa6e4cbd0 100644 --- a/fv3core/fv3core/stencils/riem_solver3.py +++ b/fv3core/fv3core/stencils/riem_solver3.py @@ -164,7 +164,6 @@ def make_storage(): grid_indexing.max_shape, origin=grid_indexing.origin_compute(), backend=stencil_factory.backend, - is_temporary=False, ) self._tmp_dm = make_storage() diff --git a/fv3core/fv3core/stencils/riem_solver_c.py b/fv3core/fv3core/stencils/riem_solver_c.py index 1cd2e0296..1f47099d2 100644 --- a/fv3core/fv3core/stencils/riem_solver_c.py +++ b/fv3core/fv3core/stencils/riem_solver_c.py @@ -100,7 +100,7 @@ def __init__(self, stencil_factory: StencilFactory, p_fac): def make_storage(): return utils.make_storage_from_shape( - shape, origin, backend=stencil_factory.backend, is_temporary=True + shape, origin, backend=stencil_factory.backend ) self._dm = make_storage() diff --git a/fv3core/requirements/requirements_dace.txt b/fv3core/requirements/requirements_dace.txt index 439b06bf8..9b45fdb87 100644 --- a/fv3core/requirements/requirements_dace.txt +++ b/fv3core/requirements/requirements_dace.txt @@ -1,2 +1,2 @@ # DaCe -git+https://github.com/spcl/dace.git@v0.14rc2 +dace>=0.14 diff --git a/pace-util/pace/util/_timing.py b/pace-util/pace/util/_timing.py index 9519bc25a..d4d08f7f0 100644 --- a/pace-util/pace/util/_timing.py +++ b/pace-util/pace/util/_timing.py @@ -2,6 +2,9 @@ from timeit import default_timer as time from typing import Mapping +from ._optional_imports import cupy as cp +from .utils import GPU_AVAILABLE + class Timer: """Class to accumulate timings for named operations.""" @@ -11,9 +14,15 @@ def __init__(self): self._accumulated_time = {} self._hit_count = {} self._enabled = True + # Check if we have CUDA device and it's ready to + # perform tasks + self._can_time_CUDA = GPU_AVAILABLE def start(self, name: str): """Start timing a given named operation.""" + if self._can_time_CUDA: + cp.cuda.Device(0).synchronize() + cp.cuda.nvtx.RangePush(name) if self._enabled: if name in self._clock_starts: raise ValueError(f"clock already started for '{name}'") @@ -24,6 +33,9 @@ def stop(self, name: str): """Stop timing a given named operation, add the time elapsed to accumulated timing and increase the hit count. """ + if self._can_time_CUDA: + cp.cuda.Device(0).synchronize() + cp.cuda.nvtx.RangePop() if self._enabled: if name not in self._accumulated_time: self._accumulated_time[name] = time() - self._clock_starts.pop(name) diff --git a/requirements_dev.txt b/requirements_dev.txt index 93fa9225b..23187b6ad 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,7 +2,7 @@ pytest pytest-subtests pytest-cov mpi4py -git+https://github.com/spcl/dace.git@v0.14rc2 +dace>=0.14 -e external/gt4py -e pace-util -e stencils