Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make NetCDF file cache handling compatible with dask distributed #2822

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7f6a8d4
Add test to reproduce GH 2815
gerritholl Jun 14, 2024
6d31c20
make sure distributed client is local
gerritholl Jun 14, 2024
1e26d1a
Start utility function for distributed friendly
gerritholl Jun 14, 2024
be40c5b
Parameterise test and simplify implementation
gerritholl Jun 14, 2024
cbd00f0
Force shape and dtype. First working prototype.
gerritholl Jun 14, 2024
af4ee66
Add group support and speed up tests
gerritholl Jun 20, 2024
dad3b14
Add partial backward-compatibility fol file handle
gerritholl Jun 20, 2024
fc58ca4
Respect auto_maskandscale with new caching
gerritholl Jun 20, 2024
09c821a
Remove needless except block
gerritholl Jun 20, 2024
4f9c5ed
Test refactoring
gerritholl Jun 20, 2024
ec76fa6
Broaden test match string for test_filenotfound
gerritholl Jun 20, 2024
06d8811
fix docstring example spelling
gerritholl Jul 24, 2024
aaf91b9
Prevent unexpected type promotion in unit test
gerritholl Jul 24, 2024
a2ad42f
Use block info getting a dd-friendly da
gerritholl Jul 24, 2024
9126bbe
Rename to serialisable and remove group argument
gerritholl Jul 25, 2024
5e576f9
Use wrapper class for auto_maskandscale
gerritholl Jul 25, 2024
63e7507
GB -> US spelling
gerritholl Jul 25, 2024
ea04595
Ensure meta dtype
gerritholl Jul 25, 2024
523671a
Merge branch 'main' into bugfix-2815
gerritholl Jul 25, 2024
fde3896
Fix spelling in test
gerritholl Jul 25, 2024
5b137e8
Clarify docstring
gerritholl Jul 26, 2024
c2b1533
Use cache already in scene creation
gerritholl Jul 26, 2024
9fce5a7
Use helper function rather than subclass
gerritholl Jul 26, 2024
4993b65
restore non-cached group retrieval
gerritholl Jul 26, 2024
7c173e7
Merge branch 'main' into bugfix-2815
gerritholl Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions satpy/readers/netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
"""Helpers for reading netcdf-based files."""

import logging
import warnings

import dask.array as da
import netCDF4
import numpy as np
import xarray as xr

from satpy.readers import open_file_or_filename
from satpy.readers.file_handlers import BaseFileHandler
from satpy.readers.utils import np2str
from satpy.readers.utils import get_distributed_friendly_dask_array, np2str
from satpy.utils import get_legacy_chunk_size

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -85,10 +85,12 @@ class NetCDF4FileHandler(BaseFileHandler):
xarray_kwargs (dict): Addition arguments to `xarray.open_dataset`
cache_var_size (int): Cache variables smaller than this size.
cache_handle (bool): Keep files open for lifetime of filehandler.
Uses xarray.backends.CachingFileManager, which uses a least
recently used cache.
djhoese marked this conversation as resolved.
Show resolved Hide resolved

"""

file_handle = None
manager = None

def __init__(self, filename, filename_info, filetype_info,
auto_maskandscale=False, xarray_kwargs=None,
Expand All @@ -99,6 +101,7 @@ def __init__(self, filename, filename_info, filetype_info,
self.file_content = {}
self.cached_file_content = {}
self._use_h5netcdf = False
self._auto_maskandscale = auto_maskandscale
try:
file_handle = self._get_file_handle()
except IOError:
Expand All @@ -118,13 +121,24 @@ def __init__(self, filename, filename_info, filetype_info,
self.collect_cache_vars(cache_var_size)

if cache_handle:
self.file_handle = file_handle
self.manager = xr.backends.CachingFileManager(
netCDF4.Dataset, self.filename, mode="r")
else:
file_handle.close()

def _get_file_handle(self):
return netCDF4.Dataset(self.filename, "r")

@property
def file_handle(self):
"""Backward-compatible way for file handle caching."""
warnings.warn(
"attribute .file_handle is deprecated, use .manager instead",
DeprecationWarning)
if self.manager is None:
return None
return self.manager.acquire()

@staticmethod
def _set_file_handle_auto_maskandscale(file_handle, auto_maskandscale):
if hasattr(file_handle, "set_auto_maskandscale"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that this has to be handled in your PR, but if I remember correctly this Dataset-level set_auto_maskandscale was added to netcdf4-python quite a while ago. It seems error prone and confusing to silently call the method only if it exists and to not log/inform the user that it wasn't used when it was expected. Maybe we should remove this method on the file handler class and always call file_handle.set_auto_maskandscale no matter what. Your wrapper does it already.

Expand Down Expand Up @@ -196,11 +210,8 @@ def _get_required_variable_names(listed_variables, variable_name_replacements):

def __del__(self):
"""Delete the file handler."""
if self.file_handle is not None:
try:
self.file_handle.close()
except RuntimeError: # presumably closed already
pass
if self.manager is not None:
self.manager.close()
djhoese marked this conversation as resolved.
Show resolved Hide resolved

def _collect_global_attrs(self, obj):
"""Collect all the global attributes for the provided file object."""
Expand Down Expand Up @@ -289,8 +300,8 @@ def _get_variable(self, key, val):
group, key = parts
else:
group = None
if self.file_handle is not None:
val = self._get_var_from_filehandle(group, key)
if self.manager is not None:
val = self._get_var_from_manager(group, key)
else:
val = self._get_var_from_xr(group, key)
return val
Expand Down Expand Up @@ -319,18 +330,29 @@ def _get_var_from_xr(self, group, key):
val.load()
return val

def _get_var_from_filehandle(self, group, key):
def _get_var_from_manager(self, group, key):
# Not getting coordinates as this is more work, therefore more
# overhead, and those are not used downstream.

with self.manager.acquire_context() as ds:
if group is not None:
v = ds[group][key]
else:
v = ds[key]
if group is None:
g = self.file_handle
dv = get_distributed_friendly_dask_array(
self.manager, key,
chunks=v.shape, dtype=v.dtype,
auto_maskandscale=self._auto_maskandscale)
else:
g = self.file_handle[group]
v = g[key]
dv = get_distributed_friendly_dask_array(
self.manager, key, group=group,
chunks=v.shape, dtype=v.dtype,
auto_maskandscale=self._auto_maskandscale)
attrs = self._get_object_attrs(v)
x = xr.DataArray(
da.from_array(v), dims=v.dimensions, attrs=attrs,
name=v.name)
dv,
dims=v.dimensions, attrs=attrs, name=v.name)
return x

def __contains__(self, item):
Expand Down
51 changes: 51 additions & 0 deletions satpy/readers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from shutil import which
from subprocess import PIPE, Popen # nosec

import dask.array as da
import numpy as np
import pyproj
import xarray as xr
Expand Down Expand Up @@ -474,3 +475,53 @@
with xr.set_options(keep_attrs=True):
reflectance = reflectance / reflectance.dtype.type(sun_earth_dist * sun_earth_dist)
return reflectance


def get_distributed_friendly_dask_array(manager, varname, chunks, dtype,
group="/", auto_maskandscale=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how I feel about this function name. Obviously it makes sense in this PR because it solves this specific problem, but it feels like there is a (shorter) more generic name that gets the point across. Another thing is that distributed_friendly is mentioned here, but that friendliness is a side effect of the "serializable" nature of the way you're accessing the data here, right? get_serializable_dask_array?

I don't feel super strongly about this, but the name was distracting to me so I thought I'd say something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed get_serializable_dask_array.

"""Construct a dask array from a variable for dask distributed.

When we construct a dask array using da.array and use that to create an
gerritholl marked this conversation as resolved.
Show resolved Hide resolved
xarray dataarray, the result is not serialisable and dask graphs using
this dataarray cannot be computed when the dask distributed scheduler
is in use. To circumvent this problem, xarray provides the
CachingFileManager. See GH#2815 for more information.

Should have at least one dimension.

Example::

>>> import NetCDF4
>>> from xarray.backends import CachingFileManager
>>> cfm = CachingFileManager(NetCDF4.Dataset, fn, mode="r")
>>> arr = get_distributed_friendly_dask_array(cfm, "my_var")
gerritholl marked this conversation as resolved.
Show resolved Hide resolved

Args:
manager (xarray.backends.CachingFileManager):
Instance of xarray.backends.CachingFileManager encapsulating the
dataset to be read.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check how the docs render this. If the argument type isn't "clickable" to go directly to the xarray docs for the CFM then we could wrap the mention of it in the description with:

:class:`xarray.backends.CachingFileManager`

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument type was already clickable, but in the description it was not. I have now made it clickable in both cases (screenshot from local doc production):

Bildschirmfoto_2024-07-25_11-10-06

varname (str):
Name of the variable.
chunks (tuple):
Chunks to use when creating the dask array.
dtype (dtype):
What dtype to use.
group (str):
What group to read the variable from.
auto_maskandscale (bool, optional):
Apply automatic masking and scaling. This will only
work if CachingFileManager.acquire returns a handler with a
method set_auto_maskandscale, such as is the case for
NetCDF4.Dataset.
"""
def get_chunk():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chunks is never used here. The current calling from the file handler is accessing the full shape of the variable so this is fine, but only for now. I mean that map_blocks will only ever call this function once. However, if you added a block_info kwarg to the function signature or whatever the map_blocks special keyword argument is, then you could change [:] to access a specific sub-set of the NetCDF file variable and only do a partial load. This should improve performance a lot (I think 🤞) if it was actually used in the file handler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chunks is never used here.

Hm? I'm passing chunks=chunks when I call da.map_blocks. What do you mean, it is never used? Do you mean I could be using chunk-location and num-chunks from a block_info dictionary passed to get_chunk?

The current calling from the file handler is accessing the full shape of the variable so this is fine, but only for now. I mean that map_blocks will only ever call this function once. However, if you added a block_info kwarg to the function signature or whatever the map_blocks special keyword argument is, then you could change [:] to access a specific sub-set of the NetCDF file variable and only do a partial load. This should improve performance a lot (I think 🤞) if it was actually used in the file handler.

I will try to wrap may head around this ☺

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think that's what I'm saying. I think the result of get_chunk() right now is broken for any chunk size other than the full shape of the array because you never do any slicing of the NetCDF variable inside get_chunk(). So, if you had a full array of 100x100 and a chunk size of 50x50, then map_blocks would call this function 4 times ((0-50, 0-50), (0-50, 50-100), (50-100, 0-50), (50-100, 50-100)). BUT each call would return the full variable 100x100. So I think this would be a case where the dask array would say "yeah, I have shape 100x100", but then once you computed it you'd get a 200x200 array back.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it now, I think.

with manager.acquire_context() as nc:
if auto_maskandscale is not None:
nc.set_auto_maskandscale(auto_maskandscale)
return nc["/".join([group, varname])][:]
djhoese marked this conversation as resolved.
Show resolved Hide resolved

return da.map_blocks(
get_chunk,
chunks=chunks,
dtype=dtype,
meta=np.array([]))

Check warning on line 527 in satpy/readers/utils.py

View check run for this annotation

CodeScene Delta Analysis / CodeScene Cloud Delta Analysis (main)

❌ New issue: Excess Number of Function Arguments

get_distributed_friendly_dask_array has 6 arguments, threshold = 4. This function has too many arguments, indicating a lack of encapsulation. Avoid adding more arguments.
gerritholl marked this conversation as resolved.
Show resolved Hide resolved
80 changes: 56 additions & 24 deletions satpy/tests/reader_tests/test_netcdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Module for testing the satpy.readers.netcdf_utils module."""

import os
import unittest

import numpy as np
import pytest
Expand Down Expand Up @@ -71,13 +70,15 @@ def get_test_content(self, filename, filename_info, filetype_info):
raise NotImplementedError("Fake File Handler subclass must implement 'get_test_content'")


class TestNetCDF4FileHandler(unittest.TestCase):
class TestNetCDF4FileHandler:
"""Test NetCDF4 File Handler Utility class."""

def setUp(self):
@pytest.fixture()
def dummy_nc_file(self, tmp_path):
"""Create a test NetCDF4 file."""
from netCDF4 import Dataset
with Dataset("test.nc", "w") as nc:
fn = tmp_path / "test.nc"
with Dataset(fn, "w") as nc:
# Create dimensions
nc.createDimension("rows", 10)
nc.createDimension("cols", 100)
Expand Down Expand Up @@ -116,17 +117,14 @@ def setUp(self):
d.test_attr_str = "test_string"
d.test_attr_int = 0
d.test_attr_float = 1.2
return fn

def tearDown(self):
"""Remove the previously created test file."""
os.remove("test.nc")

def test_all_basic(self):
def test_all_basic(self, dummy_nc_file):
"""Test everything about the NetCDF4 class."""
import xarray as xr

from satpy.readers.netcdf_utils import NetCDF4FileHandler
file_handler = NetCDF4FileHandler("test.nc", {}, {})
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {})

assert file_handler["/dimension/rows"] == 10
assert file_handler["/dimension/cols"] == 100
Expand Down Expand Up @@ -165,7 +163,7 @@ def test_all_basic(self):
assert file_handler.file_handle is None
assert file_handler["ds2_sc"] == 42

def test_listed_variables(self):
def test_listed_variables(self, dummy_nc_file):
"""Test that only listed variables/attributes area collected."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

Expand All @@ -175,12 +173,12 @@ def test_listed_variables(self):
"attr/test_attr_str",
]
}
file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info)
assert len(file_handler.file_content) == 2
assert "test_group/attr/test_attr_str" in file_handler.file_content
assert "attr/test_attr_str" in file_handler.file_content

def test_listed_variables_with_composing(self):
def test_listed_variables_with_composing(self, dummy_nc_file):
"""Test that composing for listed variables is performed."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

Expand All @@ -199,7 +197,7 @@ def test_listed_variables_with_composing(self):
],
}
}
file_handler = NetCDF4FileHandler("test.nc", {}, filetype_info)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, filetype_info)
assert len(file_handler.file_content) == 3
assert "test_group/ds1_f/attr/test_attr_str" in file_handler.file_content
assert "test_group/ds1_i/attr/test_attr_str" in file_handler.file_content
Expand All @@ -208,10 +206,10 @@ def test_listed_variables_with_composing(self):
assert not any("another_parameter" in var for var in file_handler.file_content)
assert "test_group/attr/test_attr_str" in file_handler.file_content

def test_caching(self):
def test_caching(self, dummy_nc_file):
"""Test that caching works as intended."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler
h = NetCDF4FileHandler("test.nc", {}, {}, cache_var_size=1000,
h = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_var_size=1000,
cache_handle=True)
assert h.file_handle is not None
assert h.file_handle.isopen()
Expand All @@ -226,31 +224,29 @@ def test_caching(self):
np.testing.assert_array_equal(
h["ds2_f"],
np.arange(10. * 100).reshape((10, 100)))
h.__del__()
assert not h.file_handle.isopen()

def test_filenotfound(self):
"""Test that error is raised when file not found."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

with pytest.raises(IOError, match=".*No such file or directory.*"):
with pytest.raises(IOError, match=".* file .*"):
NetCDF4FileHandler("/thisfiledoesnotexist.nc", {}, {})

def test_get_and_cache_npxr_is_xr(self):
def test_get_and_cache_npxr_is_xr(self, dummy_nc_file):
"""Test that get_and_cache_npxr() returns xr.DataArray."""
import xarray as xr

from satpy.readers.netcdf_utils import NetCDF4FileHandler
file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True)

data = file_handler.get_and_cache_npxr("test_group/ds1_f")
assert isinstance(data, xr.DataArray)

def test_get_and_cache_npxr_data_is_cached(self):
def test_get_and_cache_npxr_data_is_cached(self, dummy_nc_file):
"""Test that the data are cached when get_and_cache_npxr() is called."""
from satpy.readers.netcdf_utils import NetCDF4FileHandler

file_handler = NetCDF4FileHandler("test.nc", {}, {}, cache_handle=True)
file_handler = NetCDF4FileHandler(dummy_nc_file, {}, {}, cache_handle=True)
data = file_handler.get_and_cache_npxr("test_group/ds1_f")

# Delete the dataset from the file content dict, it should be available from the cache
Expand All @@ -264,7 +260,6 @@ class TestNetCDF4FsspecFileHandler:

def test_default_to_netcdf4_lib(self):
"""Test that the NetCDF4 backend is used by default."""
import os
import tempfile

import h5py
Expand Down Expand Up @@ -392,3 +387,40 @@ def test_get_data_as_xarray_scalar_h5netcdf(tmp_path):
res = get_data_as_xarray(fid["test_data"])
np.testing.assert_equal(res.data, np.array(data))
assert res.attrs == NC_ATTRS


@pytest.fixture()
def dummy_nc(tmp_path):
"""Fixture to create a dummy NetCDF file and return its path."""
import xarray as xr

fn = tmp_path / "sjaunja.nc"
ds = xr.Dataset(data_vars={"kaitum": (["x"], np.arange(10))})
ds.to_netcdf(fn)
return fn


def test_caching_distributed(dummy_nc):
"""Test that the distributed scheduler works with file handle caching.

This is a test for GitHub issue 2815.
"""
from dask.distributed import Client

from satpy.readers.netcdf_utils import NetCDF4FileHandler

fh = NetCDF4FileHandler(dummy_nc, {}, {}, cache_handle=True)

def doubler(x):
return x * 2

# As documented in GH issue 2815, using dask distributed with the file
# handle cacher might fail in non-trivial ways, such as giving incorrect
# results. Testing map_blocks is one way to reproduce the problem
# reliably, even though the problem also manifests itself (in different
# ways) without map_blocks.


with Client():
dask_doubler = fh["kaitum"].map_blocks(doubler)
dask_doubler.compute()
Loading
Loading