Skip to content

Commit

Permalink
Adding a function to log or warn if GDAL's cache exceeds node memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
phargogh committed Jan 17, 2024
1 parent d34286d commit 75633cb
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/pygeoprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .geoprocessing import zonal_statistics
from .geoprocessing_core import calculate_slope
from .geoprocessing_core import raster_band_percentile
from .slurm_utils import log_warning_if_gdal_will_exhaust_slurm_memory

try:
__version__ = version('pygeoprocessing')
Expand All @@ -61,9 +62,11 @@
# Thus, the imports are the source of truth for __all__.
__all__ = ('calculate_slope', 'raster_band_percentile',
'ReclassificationMissingValuesError')
exclude_set = {'log_warning_if_gdal_will_exhaust_slurm_memory'}
for attrname in [k for k in locals().keys()]:
try:
if isinstance(getattr(geoprocessing, attrname), types.FunctionType):
if (isinstance(getattr(geoprocessing, attrname), types.FunctionType)
and attrname not in exclude_set):
__all__ += (attrname,)
except AttributeError:
pass
Expand All @@ -75,3 +78,6 @@
UNKNOWN_TYPE = 0
RASTER_TYPE = 1
VECTOR_TYPE = 2

# Check GDAL's cache max vs SLURM memory if we're on slurm.
log_warning_if_gdal_will_exhaust_slurm_memory()
40 changes: 40 additions & 0 deletions src/pygeoprocessing/slurm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os
import warnings

from osgeo import gdal

LOGGER = logging.getLogger(__name__)


def log_warning_if_gdal_will_exhaust_slurm_memory():
slurm_env_vars = set(k for k in os.environ.keys() if k.startswith('SLURM'))
if slurm_env_vars:
gdal_cache_size = gdal.GetCacheMax()
if gdal_cache_size < 100000:
# If the cache size is 100,000 or greater, it's assumed to be in
# bytes. Otherwise, units are interpreted as megabytes.
# See gcore/gdalrasterblock.cpp for reference.
gdal_cache_size_mb = gdal_cache_size
else:
gdal_cache_size_mb = gdal_cache_size * 1024 * 1024

slurm_mem_per_node = os.environ['SLURM_MEM_PER_NODE']
if gdal_cache_size_mb > int(os.environ['SLURM_MEM_PER_NODE']):
message = (
"GDAL's cache max exceeds the memory SLURM has "
"allocated for this node. The process will probably be "
"killed by the kernel's oom-killer. "
f"GDAL_CACHEMAX={gdal_cache_size} (interpreted as "
f"{gdal_cache_size_mb} MB), "
f"SLURM_MEM_PER_NODE={slurm_mem_per_node}")

# If logging is not configured to capture warnings, send the output
# to the usual warnings stream. If logging is configured to
# capture warnings, log the warning as normal.
# This appears to be the easiest way to identify whether we're in a
# logging.captureWarnings(True) block.
if logging._warnings_showwarning is None:
warnings.warn(message)
else:
LOGGER.warning(message)
113 changes: 113 additions & 0 deletions tests/test_slurm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import contextlib
import logging
import logging.handlers
import os
import queue
import unittest
import unittest.mock
import warnings

from osgeo import gdal
from pygeoprocessing import slurm_utils


def mock_env_var(varname, value):
try:
prior_value = os.environ[varname]
except KeyError:
prior_value = None

os.environ[varname] = value
yield
os.environ[varname] = prior_value


class SLURMUtilsTest(unittest.TestCase):
@unittest.mock.patch.dict(os.environ, {"SLURM_MEM_PER_NODE": "128"})
def test_warning_gdal_cachemax_unset_on_slurm(self):
"""PGP.slurm_utils: test warning when GDAL cache not set on slurm."""
for gdal_cachesize in [123456789, # big number of bytes
256]: # megabytes, exceeds slurm
with unittest.mock.patch('osgeo.gdal.GetCacheMax',
lambda: gdal_cachesize):
with warnings.catch_warnings(record=True) as caught_warnings:
slurm_utils.log_warning_if_gdal_will_exhaust_slurm_memory()

self.assertEqual(len(caught_warnings), 1)
caught_message = caught_warnings[0].message.args[0]
self.assertIn("exceeds the memory SLURM has", caught_message)
self.assertIn(f"GDAL_CACHEMAX={gdal_cachesize}",
caught_message)
self.assertIn("SLURM_MEM_PER_NODE=128", caught_message)

@unittest.mock.patch.dict(os.environ, {"SLURM_MEM_PER_NODE": "128"})
def test_logging_gdal_cachemax_unset_on_slurm(self):
"""PGP.slurm_utils: test logs when GDAL cache not set on slurm."""
logging_queue = queue.Queue()
queuehandler = logging.handlers.QueueHandler(logging_queue)
slurm_logger = logging.getLogger('pygeoprocessing.slurm_utils')
slurm_logger.addHandler(queuehandler)

for gdal_cachesize in [123456789, # big number of bytes
256]: # megabytes, exceeds slurm
with unittest.mock.patch('osgeo.gdal.GetCacheMax',
lambda: gdal_cachesize):
try:
logging.captureWarnings(True) # needed for this test
slurm_utils.log_warning_if_gdal_will_exhaust_slurm_memory()
finally:
# Always reset captureWarnings in case of failure so other
# tests don't misbehave.
logging.captureWarnings(False)

caught_warnings = []
while True:
try:
caught_warnings.append(logging_queue.get_nowait())
except queue.Empty:
break

self.assertEqual(len(caught_warnings), 1)
caught_message = caught_warnings[0].msg
self.assertIn("exceeds the memory SLURM has", caught_message)
self.assertIn(f"GDAL_CACHEMAX={gdal_cachesize}",
caught_message)
self.assertIn("SLURM_MEM_PER_NODE=128", caught_message)

slurm_logger.removeHandler(queuehandler)

@unittest.mock.patch.dict(os.environ, {}, clear=True) # clear all env vars
def test_not_on_slurm_no_warnings(self):
"""PGP.slurm_utils: verify no warnings when not on slurm."""
with unittest.mock.patch('osgeo.gdal.GetCacheMax',
lambda: 123456789): # big memory value
with warnings.catch_warnings(record=True) as caught_warnings:
slurm_utils.log_warning_if_gdal_will_exhaust_slurm_memory()

self.assertEqual(caught_warnings, [])

@unittest.mock.patch.dict(os.environ, {}, clear=True) # clear all env vars
def test_not_on_slurm_no_logging(self):
"""PGP.slurm_utils: verify no logging when not on slurm."""
logging_queue = queue.Queue()
queuehandler = logging.handlers.QueueHandler(logging_queue)
slurm_logger = logging.getLogger('pygeoprocessing.slurm_utils')
slurm_logger.addHandler(queuehandler)

with unittest.mock.patch('osgeo.gdal.GetCacheMax',
lambda: 123456789): # big memory value
try:
logging.captureWarnings(True)
slurm_utils.log_warning_if_gdal_will_exhaust_slurm_memory()
finally:
logging.captureWarnings(False)
slurm_logger.removeHandler(queuehandler)

caught_warnings = []
while True:
try:
caught_warnings.append(logging_queue.get_nowait())
except queue.Empty:
break

self.assertEqual(caught_warnings, [])

0 comments on commit 75633cb

Please sign in to comment.