Skip to content

Commit

Permalink
Fix support for computed images by removing a hack to do fast slicing.
Browse files Browse the repository at this point in the history
Add a new `fast_time_slicing` parameter. If True, Xee performs an optimization that makes slicing an ImageCollection across time faster. This optimization loads EE images in a slice by ID, so any modifications to images in a computed ImageCollection will not be reflected.

For those familiar with the code before, the else flow in `_slice_collection` was only entered when images in the collection didn't have IDs. Clearing the image IDs triggered the else block.

Also adds several new warnings:

- if a user enables `fast_time_slicing` but there are no image IDs, and
- if a user is indexing into a very large ImageCollection.

Fixes #88 and #145.

PiperOrigin-RevId: 623815209
  • Loading branch information
naschmitz authored and Xee authors committed Apr 11, 2024
1 parent 7fe930c commit c5bd12e
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 5 deletions.
39 changes: 34 additions & 5 deletions xee/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import functools
import importlib
import itertools
import logging
import math
import os
import sys
Expand Down Expand Up @@ -72,6 +73,12 @@
# trial & error.
REQUEST_BYTE_LIMIT = 2**20 * 48 # 48 MBs

# Xee uses the ee.ImageCollection.toList function for slicing into an
# ImageCollection. This function isn't optimized for large collections. If the
# end index of the slice is beyond 10k, display a warning to the user. This
# value was chosen by trial and error.
_TO_LIST_WARNING_LIMIT = 10000


def _check_request_limit(chunks: Dict[str, int], dtype_size: int, limit: int):
"""Checks that the actual number of bytes exceeds the limit."""
Expand Down Expand Up @@ -153,6 +160,7 @@ def open(
ee_init_if_necessary: bool = False,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
) -> 'EarthEngineStore':
if mode != 'r':
raise ValueError(
Expand All @@ -175,6 +183,7 @@ def open(
ee_init_if_necessary=ee_init_if_necessary,
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
)

def __init__(
Expand All @@ -194,9 +203,11 @@ def __init__(
ee_init_if_necessary: bool = False,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
):
self.ee_init_kwargs = ee_init_kwargs
self.ee_init_if_necessary = ee_init_if_necessary
self.fast_time_slicing = fast_time_slicing

# Initialize executor_kwargs
if executor_kwargs is None:
Expand Down Expand Up @@ -834,15 +845,27 @@ def _slice_collection(self, image_slice: slice) -> ee.Image:
self._ee_init_check()
start, stop, stride = image_slice.indices(self.shape[0])

# If the input images have IDs, just slice them. Otherwise, we need to do
# an expensive `toList()` operation.
if self.store.image_ids:
if self.store.fast_time_slicing and self.store.image_ids:
imgs = self.store.image_ids[start:stop:stride]
else:
if self.store.fast_time_slicing:
logging.warning(
"fast_time_slicing is enabled but ImageCollection images don't have"
' IDs. Reverting to default behavior.'
)
if stop > _TO_LIST_WARNING_LIMIT:
logging.warning(
'Xee is indexing into the ImageCollection beyond %s images. This'
' operation can be slow. To improve performance, consider filtering'
' the ImageCollection prior to using Xee or enabling'
' fast_time_slicing.',
_TO_LIST_WARNING_LIMIT,
)
# TODO(alxr, mahrsee): Find a way to make this case more efficient.
list_range = stop - start
col0 = self.store.image_collection
imgs = col0.toList(list_range, offset=start).slice(0, list_range, stride)
imgs = self.store.image_collection.toList(list_range, offset=start).slice(
0, list_range, stride
)

col = ee.ImageCollection(imgs)

Expand Down Expand Up @@ -1006,6 +1029,7 @@ def open_dataset(
ee_init_kwargs: Optional[Dict[str, Any]] = None,
executor_kwargs: Optional[Dict[str, Any]] = None,
getitem_kwargs: Optional[Dict[str, int]] = None,
fast_time_slicing: bool = False,
) -> xarray.Dataset: # type: ignore
"""Open an Earth Engine ImageCollection as an Xarray Dataset.
Expand Down Expand Up @@ -1084,6 +1108,10 @@ def open_dataset(
- 'max_retries', the maximum number of retry attempts. Defaults to 6.
- 'initial_delay', the initial delay in milliseconds before the first
retry. Defaults to 500.
fast_time_slicing (optional): Whether to perform an optimization that
makes slicing an ImageCollection across time faster. This optimization
loads EE images in a slice by ID, so any modifications to images in a
computed ImageCollection will not be reflected.
Returns:
An xarray.Dataset that streams in remote data from Earth Engine.
"""
Expand Down Expand Up @@ -1114,6 +1142,7 @@ def open_dataset(
ee_init_if_necessary=ee_init_if_necessary,
executor_kwargs=executor_kwargs,
getitem_kwargs=getitem_kwargs,
fast_time_slicing=fast_time_slicing,
)

store_entrypoint = backends_store.StoreBackendEntrypoint()
Expand Down
32 changes: 32 additions & 0 deletions xee/ext_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,38 @@ def test_validate_band_attrs(self):
for _, value in variable.attrs.items():
self.assertIsInstance(value, valid_types)

def test_fast_time_slicing(self):
band = 'temperature_2m'
hourly = (
ee.ImageCollection('ECMWF/ERA5_LAND/HOURLY')
.filterDate('2024-01-01', '2024-01-02')
.select(band)
)
first = hourly.first()
props = ['system:id', 'system:time_start']
fake_collection = ee.ImageCollection(
hourly.toList(count=hourly.size()).replace(
first, ee.Image(0).rename(band).copyProperties(first, props)
)
)

params = dict(
filename_or_obj=fake_collection,
engine=xee.EarthEngineBackendEntrypoint,
geometry=ee.Geometry.BBox(-83.86, 41.13, -76.83, 46.15),
projection=first.projection().atScale(100000),
)

# With slow slicing, the returned data should include the modified image.
slow_slicing = xr.open_dataset(**params)
slow_slicing_data = getattr(slow_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(slow_slicing_data == 0))

# With fast slicing, the returned data should include the original image.
fast_slicing = xr.open_dataset(**params, fast_time_slicing=True)
fast_slicing_data = getattr(fast_slicing[dict(time=0)], band).as_numpy()
self.assertTrue(np.all(fast_slicing_data > 0))

@absltest.skipIf(_SKIP_RASTERIO_TESTS, 'rioxarray module not loaded')
def test_write_projected_dataset_to_raster(self):
# ensure that a projected dataset written to a raster intersects with the
Expand Down

0 comments on commit c5bd12e

Please sign in to comment.