diff --git a/xee/ext.py b/xee/ext.py index 2b291d9..b6122e0 100644 --- a/xee/ext.py +++ b/xee/ext.py @@ -21,6 +21,7 @@ import concurrent.futures import functools import importlib +import itertools import math import os import sys @@ -234,7 +235,7 @@ def __init__( x_min_0, y_min_0, x_max_0, y_max_0 = _ee_bounds_to_bounds( self.get_info['bounds'] ) - # TODO(#40): Investigate data discrepancy (off-by-one) issue. + x_min, y_min = self.transform(x_min_0, y_min_0) x_max, y_max = self.transform(x_max_0, y_max_0) self.bounds = x_min, y_min, x_max, y_max @@ -538,6 +539,45 @@ def _get_primary_coordinates(self) -> List[Any]: ] return primary_coords + def _get_tile_from_ee( + self, tile_index: Tuple[Any, Union[str, int]] + ) -> Tuple[slice, np.ndarray]: + """Get a numpy array from EE for a specific bounding box (a 'tile').""" + tile_index, band_id = tile_index + bbox = self.project( + (tile_index[0], 0, tile_index[1], 1) + if band_id == 'longitude' + else (0, tile_index[0], 1, tile_index[1]) + ) + tile_idx = slice(tile_index[0], tile_index[1]) + target_image = ee.Image.pixelLonLat() + return tile_idx, self.image_to_array( + target_image, grid=bbox, dtype=np.float32, bandIds=[band_id] + ) + + def _process_coordinate_data( + self, + tile_count: int, + tile_size: int, + end_point: int, + coordinate_type: str, + ) -> np.ndarray: + """Process coordinate data using multithreading for longitude or latitude.""" + data = [ + (tile_size * i, min(tile_size * (i + 1), end_point)) + for i in range(tile_count) + ] + tiles = [None] * tile_count + with concurrent.futures.ThreadPoolExecutor() as pool: + for i, arr in pool.map( + self._get_tile_from_ee, + list(zip(data, itertools.cycle([coordinate_type]))), + ): + tiles[i] = ( + arr.tolist() if coordinate_type == 'longitude' else arr.tolist()[0] + ) + return np.concatenate(tiles) + def get_variables(self) -> utils.Frozen[str, xarray.Variable]: vars_ = [(name, self.open_store_variable(name)) for name in self._bands()] @@ -553,15 +593,24 @@ def get_variables(self) -> utils.Frozen[str, xarray.Variable]: f'ImageCollection due to: {e}.' ) - lnglat_img = ee.Image.pixelLonLat() - lon_grid = self.project((0, 0, v0.shape[1], 1)) - lat_grid = self.project((0, 0, 1, v0.shape[2])) - lon = self.image_to_array( - lnglat_img, grid=lon_grid, dtype=np.float32, bandIds=['longitude'] + if isinstance(self.chunks, dict): + # when the value of self.chunks = 'auto' or user-defined. + width_chunk = self.chunks['width'] + height_chunk = self.chunks['height'] + else: + # when the value of self.chunks = -1. + width_chunk = v0.shape[1] + height_chunk = v0.shape[2] + + lon_total_tile = math.ceil(v0.shape[1] / width_chunk) + lon = self._process_coordinate_data( + lon_total_tile, width_chunk, v0.shape[1], 'longitude' ) - lat = self.image_to_array( - lnglat_img, grid=lat_grid, dtype=np.float32, bandIds=['latitude'] + lat_total_tile = math.ceil(v0.shape[2] / height_chunk) + lat = self._process_coordinate_data( + lat_total_tile, height_chunk, v0.shape[2], 'latitude' ) + width_coord = np.squeeze(lon) height_coord = np.squeeze(lat)