Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make NetCDF file cache handling compatible with dask distributed #2822

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f6a8d4
Add test to reproduce GH 2815
gerritholl Jun 14, 2024
6d31c20
make sure distributed client is local
gerritholl Jun 14, 2024
1e26d1a
Start utility function for distributed friendly
gerritholl Jun 14, 2024
be40c5b
Parameterise test and simplify implementation
gerritholl Jun 14, 2024
cbd00f0
Force shape and dtype. First working prototype.
gerritholl Jun 14, 2024
af4ee66
Add group support and speed up tests
gerritholl Jun 20, 2024
dad3b14
Add partial backward-compatibility fol file handle
gerritholl Jun 20, 2024
fc58ca4
Respect auto_maskandscale with new caching
gerritholl Jun 20, 2024
09c821a
Remove needless except block
gerritholl Jun 20, 2024
4f9c5ed
Test refactoring
gerritholl Jun 20, 2024
ec76fa6
Broaden test match string for test_filenotfound
gerritholl Jun 20, 2024
06d8811
fix docstring example spelling
gerritholl Jul 24, 2024
aaf91b9
Prevent unexpected type promotion in unit test
gerritholl Jul 24, 2024
a2ad42f
Use block info getting a dd-friendly da
gerritholl Jul 24, 2024
9126bbe
Rename to serialisable and remove group argument
gerritholl Jul 25, 2024
5e576f9
Use wrapper class for auto_maskandscale
gerritholl Jul 25, 2024
63e7507
GB -> US spelling
gerritholl Jul 25, 2024
ea04595
Ensure meta dtype
gerritholl Jul 25, 2024
523671a
Merge branch 'main' into bugfix-2815
gerritholl Jul 25, 2024
fde3896
Fix spelling in test
gerritholl Jul 25, 2024
5b137e8
Clarify docstring
gerritholl Jul 26, 2024
c2b1533
Use cache already in scene creation
gerritholl Jul 26, 2024
9fce5a7
Use helper function rather than subclass
gerritholl Jul 26, 2024
4993b65
restore non-cached group retrieval
gerritholl Jul 26, 2024
7c173e7
Merge branch 'main' into bugfix-2815
gerritholl Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions satpy/readers/netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
# satpy. If not, see <http://www.gnu.org/licenses/>.
"""Helpers for reading netcdf-based files."""

import functools
import logging
import warnings

import dask.array as da
import netCDF4
import numpy as np
import xarray as xr

from satpy.readers import open_file_or_filename
from satpy.readers.file_handlers import BaseFileHandler
from satpy.readers.utils import np2str
from satpy.readers.utils import get_serializable_dask_array, np2str
from satpy.utils import get_legacy_chunk_size

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,10 +86,12 @@ class NetCDF4FileHandler(BaseFileHandler):
xarray_kwargs (dict): Addition arguments to `xarray.open_dataset`
cache_var_size (int): Cache variables smaller than this size.
cache_handle (bool): Keep files open for lifetime of filehandler.
Uses xarray.backends.CachingFileManager, which uses a least
recently used cache.
djhoese marked this conversation as resolved.
Show resolved Hide resolved

"""

file_handle = None
manager = None

def __init__(self, filename, filename_info, filetype_info,
auto_maskandscale=False, xarray_kwargs=None,
Expand All @@ -99,6 +102,7 @@ def __init__(self, filename, filename_info, filetype_info,
self.file_content = {}
self.cached_file_content = {}
self._use_h5netcdf = False
self._auto_maskandscale = auto_maskandscale
try:
file_handle = self._get_file_handle()
except IOError:
Expand All @@ -118,13 +122,26 @@ def __init__(self, filename, filename_info, filetype_info,
self.collect_cache_vars(cache_var_size)

if cache_handle:
self.file_handle = file_handle
self.manager = xr.backends.CachingFileManager(
functools.partial(_NCDatasetWrapper,
auto_maskandscale=auto_maskandscale),
self.filename, mode="r")
else:
file_handle.close()

def _get_file_handle(self):
return netCDF4.Dataset(self.filename, "r")

@property
def file_handle(self):
"""Backward-compatible way for file handle caching."""
warnings.warn(
"attribute .file_handle is deprecated, use .manager instead",
DeprecationWarning)
if self.manager is None:
return None
return self.manager.acquire()

@staticmethod
def _set_file_handle_auto_maskandscale(file_handle, auto_maskandscale):
if hasattr(file_handle, "set_auto_maskandscale"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that this has to be handled in your PR, but if I remember correctly this Dataset-level set_auto_maskandscale was added to netcdf4-python quite a while ago. It seems error prone and confusing to silently call the method only if it exists and to not log/inform the user that it wasn't used when it was expected. Maybe we should remove this method on the file handler class and always call file_handle.set_auto_maskandscale no matter what. Your wrapper does it already.

Expand Down Expand Up @@ -196,11 +213,8 @@ def _get_required_variable_names(listed_variables, variable_name_replacements):

def __del__(self):
"""Delete the file handler."""
if self.file_handle is not None:
try:
self.file_handle.close()
except RuntimeError: # presumably closed already
pass
if self.manager is not None:
self.manager.close()
djhoese marked this conversation as resolved.
Show resolved Hide resolved

def _collect_global_attrs(self, obj):
"""Collect all the global attributes for the provided file object."""
Expand Down Expand Up @@ -289,8 +303,8 @@ def _get_variable(self, key, val):
group, key = parts
else:
group = None
if self.file_handle is not None:
val = self._get_var_from_filehandle(group, key)
if self.manager is not None:
val = self._get_var_from_manager(group, key)
else:
val = self._get_var_from_xr(group, key)
return val
Expand Down Expand Up @@ -319,18 +333,27 @@ def _get_var_from_xr(self, group, key):
val.load()
return val

def _get_var_from_filehandle(self, group, key):
def _get_var_from_manager(self, group, key):
# Not getting coordinates as this is more work, therefore more
# overhead, and those are not used downstream.

with self.manager.acquire_context() as ds:
if group is not None:
v = ds[group][key]
else:
v = ds[key]
if group is None:
g = self.file_handle
dv = get_serializable_dask_array(
self.manager, key,
chunks=v.shape, dtype=v.dtype)
else:
g = self.file_handle[group]
v = g[key]
dv = get_serializable_dask_array(
self.manager, "/".join([group, key]),
chunks=v.shape, dtype=v.dtype)
attrs = self._get_object_attrs(v)
x = xr.DataArray(
da.from_array(v), dims=v.dimensions, attrs=attrs,
name=v.name)
dv,
dims=v.dimensions, attrs=attrs, name=v.name)
return x

def __contains__(self, item):
Expand Down Expand Up @@ -443,3 +466,24 @@ def _get_attr(self, obj, key):
if self._use_h5netcdf:
return obj.attrs[key]
return super()._get_attr(obj, key)

class _NCDatasetWrapper(netCDF4.Dataset):
"""Wrap netcdf4.Dataset setting auto_maskandscale globally.

Helper class that wraps netcdf4.Dataset while setting extra parameters.
By encapsulating this in a helper class, we can
pass it to CachingFileManager directly. Currently sets
auto_maskandscale globally (for all variables).
"""

def __init__(self, *args, auto_maskandscale=False, **kwargs):
"""Initialise object."""
super().__init__(*args, **kwargs)
self._set_extra_settings(auto_maskandscale=auto_maskandscale)

def _set_extra_settings(self, auto_maskandscale):
"""Set our own custom settings.

Currently only applies set_auto_maskandscale.
"""
self.set_auto_maskandscale(auto_maskandscale)
43 changes: 43 additions & 0 deletions satpy/readers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from shutil import which
from subprocess import PIPE, Popen # nosec

import dask.array as da
import numpy as np
import pyproj
import xarray as xr
Expand Down Expand Up @@ -476,6 +477,48 @@ def remove_earthsun_distance_correction(reflectance, utc_date=None):
return reflectance


def get_serializable_dask_array(manager, varname, chunks, dtype):
"""Construct a serializable dask array from a variable.

When we construct a dask array using da.array and use that to create an
gerritholl marked this conversation as resolved.
Show resolved Hide resolved
xarray dataarray, the result is not serializable and dask graphs using
this dataarray cannot be computed when the dask distributed scheduler
is in use. To circumvent this problem, xarray provides the
CachingFileManager. See GH#2815 for more information.

Should have at least one dimension.

Example::

>>> import netCDF4
>>> from xarray.backends import CachingFileManager
>>> cfm = CachingFileManager(netCDF4.Dataset, fn, mode="r")
>>> arr = get_serializable_dask_array(cfm, "my_var")

Args:
manager (xarray.backends.CachingFileManager):
Instance of :class:`~xarray.backends.CachingFileManager` encapsulating the
dataset to be read.
varname (str):
Name of the variable (possibly including a group path).
chunks (tuple):
Chunks to use when creating the dask array.
dtype (dtype):
What dtype to use.
"""
def get_chunk(block_info=None):
arrloc = block_info[None]["array-location"]
with manager.acquire_context() as nc:
var = nc[varname]
return var[tuple(slice(*x) for x in arrloc)]

return da.map_blocks(
get_chunk,
chunks=chunks,
dtype=dtype,
meta=np.array([], dtype=dtype))


class _CalibrationCoefficientParser:
"""Parse user-defined calibration coefficients."""

Expand Down
78 changes: 55 additions & 23 deletions satpy/tests/reader_tests/test_netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Module for testing the satpy.readers.netcdf_utils module."""

import os
import unittest

import numpy as np
import pytest
Expand Down Expand Up @@ -71,13 +70,15 @@ def get_test_content(self, filename, filename_info, filetype_info):
raise NotImplementedError("Fake File Handler subclass must implement 'get_test_content'")


class TestNetCDF4FileHandler(unittest.TestCase):
class TestNetCDF4FileHandler:
"""Test NetCDF4 File Handler Utility class."""

def setUp(self):
@pytest.fixture()
def dummy_nc_file(self, tmp_path):
"""Create a test NetCDF4 file."""
from netCDF4 import Dataset
with Dataset("test.nc", "w") as nc:
fn = tmp_path / "test.nc"
with Dataset(fn, "w") as nc:
# Create dimensions
nc.createDimension("rows", 10)
nc.createDimension("cols", 100)
Expand Down Expand Up @@ -116,17 +117,14 @@ def setUp(self):
d.test_attr_str = "test_string"
d.test_attr_int = 0
d.test_attr_float = 1.2
return fn

def tearDown(self):
"""Remove the previously created test file."""
os.remove("test.nc")

def test_all_basic(self):
def test_all_basic(self, dummy_nc_file):
"""Test everything about the NetCDF4 class."""
import xarray as xr

from satpy.readers.netcdf_utils import NetCDF4FileHandler
file_handler = NetCDF4FileHandler("test.nc", {}, {})
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {})

assert file_handler["/dimension/rows"] == 10
assert file_handler["/dimension/cols"] == 100
Expand Down Expand Up @@ -165,7 +163,7 @@ def test_all_basic(self):
assert file_handler.file_handle is None
assert file_handler["ds2_sc"] == 42

def test_listed_variables(self):
def test_listed_variables(self, dummy_nc_file):
"""Test that only listed variables/attributes area collected."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

Expand All @@ -175,12 +173,12 @@ def test_listed_variables(self):
"attr/test_attr_str",
]
}
file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info)
assert len(file_handler.file_content) == 2
assert "test_group/attr/test_attr_str" in file_handler.file_content
assert "attr/test_attr_str" in file_handler.file_content

def test_listed_variables_with_composing(self):
def test_listed_variables_with_composing(self, dummy_nc_file):
"""Test that composing for listed variables is performed."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

Expand All @@ -199,7 +197,7 @@ def test_listed_variables_with_composing(self):
],
}
}
file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info)
assert len(file_handler.file_content) == 3
assert "test_group/ds1_f/attr/test_attr_str" in file_handler.file_content
assert "test_group/ds1_i/attr/test_attr_str" in file_handler.file_content
Expand All @@ -208,10 +206,10 @@ def test_listed_variables_with_composing(self):
assert not any("another_parameter" in var for var in file_handler.file_content)
assert "test_group/attr/test_attr_str" in file_handler.file_content

def test_caching(self):
def test_caching(self, dummy_nc_file):
"""Test that caching works as intended."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler
h = NetCDF4FileHandler("test.nc", {}, {}, cache_var_size=1000,
h = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_var_size=1000,
cache_handle=True)
assert h.file_handle is not None
assert h.file_handle.isopen()
Expand All @@ -226,8 +224,6 @@ def test_caching(self):
np.testing.assert_array_equal(
h["ds2_f"],
np.arange(10. * 100).reshape((10, 100)))
h.__del__()
assert not h.file_handle.isopen()

def test_filenotfound(self):
"""Test that error is raised when file not found."""
Expand All @@ -237,21 +233,21 @@ def test_filenotfound(self):
with pytest.raises(IOError, match=".*(No such file or directory|Unknown file format).*"):
NetCDF4FileHandler("/thisfiledoesnotexist.nc", {}, {})

def test_get_and_cache_npxr_is_xr(self):
def test_get_and_cache_npxr_is_xr(self, dummy_nc_file):
"""Test that get_and_cache_npxr() returns xr.DataArray."""
import xarray as xr

from satpy.readers.netcdf_utils import NetCDF4FileHandler
file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True)

data = file_handler.get_and_cache_npxr("test_group/ds1_f")
assert isinstance(data, xr.DataArray)

def test_get_and_cache_npxr_data_is_cached(self):
def test_get_and_cache_npxr_data_is_cached(self, dummy_nc_file):
"""Test that the data are cached when get_and_cache_npxr() is called."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True)
data = file_handler.get_and_cache_npxr("test_group/ds1_f")

# Delete the dataset from the file content dict, it should be available from the cache
Expand All @@ -265,7 +261,6 @@ class TestNetCDF4FsspecFileHandler:

def test_default_to_netcdf4_lib(self):
"""Test that the NetCDF4 backend is used by default."""
import os
import tempfile

import h5py
Expand Down Expand Up @@ -393,3 +388,40 @@ def test_get_data_as_xarray_scalar_h5netcdf(tmp_path):
res = get_data_as_xarray(fid["test_data"])
np.testing.assert_equal(res.data, np.array(data))
assert res.attrs == NC_ATTRS


@pytest.fixture()
def dummy_nc(tmp_path):
"""Fixture to create a dummy NetCDF file and return its path."""
import xarray as xr

fn = tmp_path / "sjaunja.nc"
ds = xr.Dataset(data_vars={"kaitum": (["x"], np.arange(10))})
ds.to_netcdf(fn)
return fn


def test_caching_distributed(dummy_nc):
"""Test that the distributed scheduler works with file handle caching.

This is a test for GitHub issue 2815.
"""
from dask.distributed import Client

from satpy.readers.netcdf_utils import NetCDF4FileHandler

fh = NetCDF4FileHandler(dummy_nc, {}, {}, cache_handle=True)

def doubler(x):
return x * 2

# As documented in GH issue 2815, using dask distributed with the file
# handle cacher might fail in non-trivial ways, such as giving incorrect
# results. Testing map_blocks is one way to reproduce the problem
# reliably, even though the problem also manifests itself (in different
# ways) without map_blocks.


with Client():
dask_doubler = fh["kaitum"].map_blocks(doubler)
dask_doubler.compute()
Loading