Skip to content

Commit

Permalink
Merge pull request #78 from lukasValentin/dev
Browse files Browse the repository at this point in the history
Fix of nodata, fill value, mask and data type propagation and enforcement in EOdal
  • Loading branch information
lukasValentin authored Oct 18, 2023
2 parents 239fa0d + 63f9092 commit f0bb31f
Show file tree
Hide file tree
Showing 19 changed files with 262 additions and 96 deletions.
2 changes: 1 addition & 1 deletion eodal/__meta__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
description = "Earth Observation Data Analysis Library" # One-liner
url = "https://github.com/EOA-team/eodal" # your project home-page
license = "GNU General Public License version 3" # See https://choosealicense.com
version = "0.2.2"
version = "0.2.3"
27 changes: 20 additions & 7 deletions eodal/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import eodal
import os
import geopandas as gpd
import numpy as np
import uuid

from pathlib import Path
Expand All @@ -38,9 +39,10 @@

def _get_crs_and_attribs(
in_file: Path, **kwargs
) -> Tuple[GeoInfo, List[Dict[str, Any]]]:
) -> Tuple[GeoInfo, List[Dict[str, Any]], str]:
"""
Returns the ``GeoInfo`` from a multi-band raster dataset
Returns the ``GeoInfo``, attributes and data type from
a multi-band raster dataset.
:param in_file:
raster datasets from which to extract the ``GeoInfo`` and
Expand All @@ -55,7 +57,8 @@ def _get_crs_and_attribs(
ds = RasterCollection.from_multi_band_raster(fpath_raster=in_file, **kwargs)
geo_info = ds[ds.band_names[0]].geo_info
attrs = [ds[x].get_attributes() for x in ds.band_names]
return geo_info, attrs
dtype = ds[ds.band_names[0]].values.dtype
return geo_info, attrs, dtype


def merge_datasets(
Expand Down Expand Up @@ -110,7 +113,7 @@ def merge_datasets(
crs_list = []
attrs_list = []
for dataset in datasets:
geo_info, attrs = _get_crs_and_attribs(in_file=dataset)
geo_info, attrs, dtype = _get_crs_and_attribs(in_file=dataset)
crs_list.append(geo_info.epsg)
attrs_list.append(attrs)

Expand All @@ -130,7 +133,9 @@ def merge_datasets(
# use rasterio merge to get a new raster dataset
dst_kwds = {"QUALITY": "100", "REVERSIBLE": "YES"}
try:
res = merge(datasets=datasets, dst_path=out_file, dst_kwds=dst_kwds, **kwargs)
res = merge(
datasets=datasets, dst_path=out_file, dst_kwds=dst_kwds,
dtype=dtype, **kwargs)
if res is not None:
out_ds, out_transform = res[0], res[1]
except Exception as e:
Expand Down Expand Up @@ -188,16 +193,24 @@ def merge_datasets(
else:
band_alias = f"B{idx+1}"

# get the mask from nodata values
nodata_mask = out_ds[idx, :, :] == nodata
band_arr = np.ma.masked_array(
data=out_ds[idx, :, :],
mask=nodata_mask,
fill_value=nodata
)

raster.add_band(
band_constructor=Band,
band_name=band_name,
values=out_ds[idx, :, :],
values=band_arr,
geo_info=geo_info,
is_tiled=is_tiled,
scale=scale,
offset=offset,
band_alias=band_alias,
unit=unit,
unit=unit
)

# clip raster collection if required to vector_features to keep consistency
Expand Down
56 changes: 41 additions & 15 deletions eodal/core/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,13 @@ def from_rasterio(
nodata, nodata_vals = None, attrs.get("nodatavals", None)
if nodata_vals is not None:
nodata = nodata_vals[band_idx - 1]
# make sure the nodata type matches the datatype of the
# band values
if band_data.dtype.kind not in 'fc' and np.isnan(nodata):
raise TypeError(
f"The datatype of the band data is {band_data.dtype} " +
f"while the nodata value ({nodata}) is float.\n" +
"Please provide an appropriate nodata value")

if masking:
# make sure to set the EPSG code
Expand Down Expand Up @@ -1372,6 +1379,9 @@ def get_meta(self, driver: Optional[str] = "gTiff", **kwargs) -> Dict[str, Any]:
meta["transform"] = self.transform
meta["is_tile"] = self.is_tiled
meta["driver"] = driver
# "compress" as suggested here:
# https://github.com/rasterio/rasterio/discussions/2933#discussioncomment-7208578
meta["compress"] = "DEFLATE"
meta.update(kwargs)

return meta
Expand Down Expand Up @@ -1699,10 +1709,18 @@ def mask(self, mask: np.ndarray, inplace: Optional[bool] = False):
# ignore pixels already masked
if not orig_mask[row, col]:
orig_mask[row, col] = mask[row, col]
# re-use original fill value
fill_value = self.values.fill_value
# update band data array
masked_array = np.ma.MaskedArray(data=self.values.data, mask=orig_mask)
masked_array = np.ma.MaskedArray(
data=self.values.data,
mask=orig_mask,
fill_value=fill_value)
elif self.is_ndarray:
masked_array = np.ma.MaskedArray(data=self.values, mask=mask)
# determine fill value from nodata value
fill_value = self.nodata
masked_array = np.ma.MaskedArray(
data=self.values, mask=mask, fill_value=fill_value)
elif self.is_zarr:
raise NotImplementedError()

Expand Down Expand Up @@ -1890,8 +1908,10 @@ def resample(
out_mask = cv2.resize(in_mask, dim_resampled, cv2.INTER_NEAREST_EXACT)
# convert mask back to boolean array
out_mask = out_mask.astype(bool)
# re-use fill value of the original array
fill_value = self.values.fill_value
# save as masked array
res = np.ma.masked_array(data=res, mask=out_mask)
res = np.ma.masked_array(data=res, mask=out_mask, fill_value=fill_value)

# update the geo_info with new pixel resolution. The upper left x and y
# coordinate must be changed if the pixel coordinates refer to the center
Expand Down Expand Up @@ -1962,6 +1982,7 @@ def reproject(
if self.is_masked_array:
band_data = deepcopy(self.values.data).astype(float)
band_mask = deepcopy(self.values.mask).astype(float)
fill_value = self.values.fill_value
elif self.is_ndarray:
band_data = deepcopy(self.values).astype(float)
elif self.is_zarr:
Expand Down Expand Up @@ -2004,7 +2025,8 @@ def reproject(
# due to the raster alignment
nodata = reprojection_options.get("src_nodata", 0)
out_mask[out_data == nodata] = True
out_data = np.ma.MaskedArray(data=out_data, mask=out_mask)
out_data = np.ma.MaskedArray(
data=out_data, mask=out_mask, fill_value=fill_value)

new_geo_info = GeoInfo.from_affine(affine=out_transform, epsg=target_crs)
if inplace:
Expand Down Expand Up @@ -2192,23 +2214,24 @@ def scale_data(
scale, offset = self.scale, self.offset
if self.is_masked_array:
if pixel_values_to_ignore is None:
scaled_array = scale * (self.values.data + offset)
scaled_array = scale * self.values.data - offset
else:
scaled_array = self.values.data.copy().astype(float)
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] = scale * (
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)]
+ offset
)
scaled_array = np.ma.MaskedArray(data=scaled_array, mask=self.values.mask)
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] = scale * \
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] \
- offset
# reuse fill value
fill_value = self.values.fill_value
scaled_array = np.ma.MaskedArray(
data=scaled_array, mask=self.values.mask, fill_value=fill_value)
elif self.is_ndarray:
if pixel_values_to_ignore is None:
scaled_array = scale * (self.values + offset)
scaled_array = scale * self.values - offset
else:
scaled_array = self.values.copy().astype(float)
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] = scale * (
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)]
+ offset
)
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] = scale * \
scaled_array[~np.isin(scaled_array, pixel_values_to_ignore)] \
- offset
elif self.is_zarr:
raise NotImplementedError()

Expand Down Expand Up @@ -2356,6 +2379,9 @@ def to_rasterio(self, fpath_raster: Path, **kwargs) -> None:
with rio.open(fpath_raster, "w+", **meta) as dst:
# set band name
dst.set_band_description(1, self.band_name)
# set scale and offset
dst._set_all_scales([self.scale])
dst._set_all_offsets([self.offset])
# write band data
if self.is_masked_array:
vals = self.values.data
Expand Down
102 changes: 71 additions & 31 deletions eodal/core/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class allows thereby to handle ``Band`` instances with different spatial referen
from numbers import Number
from pathlib import Path
from rasterio.drivers import driver_from_extension
from rasterio.io import MemoryFile
from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -101,6 +102,7 @@ class allows thereby to handle ``Band`` instances with different spatial referen
from eodal.core.band import Band
from eodal.core.operators import Operator
from eodal.core.spectral_indices import SpectralIndices
from eodal.core.utils import get_highest_dtype
from eodal.utils.constants import ProcessingLevels
from eodal.utils.decorators import check_band_names
from eodal.utils.exceptions import BandNotFoundError
Expand Down Expand Up @@ -1675,6 +1677,7 @@ def to_rasterio(
fpath_raster: Path,
band_selection: Optional[List[str]] = None,
use_band_aliases: Optional[bool] = False,
as_cog: Optional[bool] = False
) -> None:
"""
Writes bands in collection to a raster dataset on disk using
Expand All @@ -1689,14 +1692,30 @@ def to_rasterio(
:param use_band_aliases:
use band aliases instead of band names for setting raster
band descriptions to the output dataset
:param as_cog:
write the raster dataset as cloud-optimized GeoTIFF. This
requires the ``rio-cogeo`` package to be installed. Disabled
by default.
"""
# check output file naming and driver
try:
driver = driver_from_extension(fpath_raster)
except Exception as e:
raise ValueError(
f"Could not determine GDAL driver for " f"{fpath_raster.name}: {e}"
)
# check if COG output is enabled
if as_cog:
try:
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
except ModuleNotFoundError:
raise ModuleNotFoundError(
"rio-cogeo is required for writing cloud-optimized GeoTIFFs\n" +
"Install it with `pip install rio-cogeo`"
)
driver = "GTiff"
else:
# check output file naming and driver
try:
driver = driver_from_extension(fpath_raster)
except Exception as e:
raise ValueError(
f"Could not determine GDAL driver for " f"{fpath_raster.name}: {e}"
)

# check band_selection, if not provided use all available bands
if band_selection is None:
Expand All @@ -1721,46 +1740,67 @@ def to_rasterio(
# check meta and update it with the selected driver for writing the result
meta = deepcopy(self[band_selection[0]].meta)
dtypes = [self[x].values.dtype for x in band_selection]
if len(set(dtypes)) != 1:
UserWarning(
f"Multiple data types found in arrays to write ({set(dtypes)}). "
f"Casting to highest data type"
)

if len(set(dtypes)) == 1:
dtype_str = str(dtypes[0])
else:
# TODO: determine highest dtype
dtype_str = "float32"
# data type checking. We need to get the highest data type
highest_dtype = get_highest_dtype(dtypes)

# update driver, the number of bands and the metadata value
meta.update(
{
"driver": driver,
"count": len(band_selection),
"dtype": dtype_str,
"dtype": str(highest_dtype),
"nodata": self[band_selection[0]].nodata,
"compression": "DEFLATE"
}
)

# open the result dataset and try to write the bands
with rio.open(fpath_raster, "w+", **meta) as dst:
for idx, band_name in enumerate(band_selection):
# check with band name to set
dst.set_band_description(idx + 1, band_name)
# write band data
band_data = self.get_band(band_name).values.astype(dtype_str)
# set masked pixels to nodata
if self[band_name].is_masked_array:
vals = band_data.data
mask = band_data.mask
vals[mask] = self[band_name].nodata
dst.write(band_data, idx + 1)
if as_cog:
with MemoryFile() as memfile:
with memfile.open(**meta) as mem:
# set scales and offsets
scales = [self[band_name].scale for band_name in band_selection]
offsets = [self[band_name].offset for band_name in band_selection]
mem._set_all_scales(scales)
mem._set_all_offsets(offsets)
# polulate data
for idx, band_name in enumerate(band_selection):
# check with band name to set
mem.set_band_description(idx + 1, band_name)
# write band data. Cast to highest data type if necessary.
band_data = self.get_band(band_name).values.astype(highest_dtype)
mem.write(band_data, idx + 1)

# write the COG
dst_profile = cog_profiles.get("deflate")
cog_translate(
mem,
fpath_raster,
dst_profile,
in_memory=True,
quiet=True
)

else:
with rio.open(fpath_raster, "w+", **meta) as dst:
# set scales and offsets
scales = [self[band_name].scale for band_name in band_selection]
offsets = [self[band_name].offset for band_name in band_selection]
dst._set_all_scales(scales)
dst._set_all_offsets(offsets)
# polulate data
for idx, band_name in enumerate(band_selection):
# check with band name to set
dst.set_band_description(idx + 1, band_name)
# write band data. Cast to highest data type if necessary.
band_data = self.get_band(band_name).values.astype(highest_dtype)
dst.write(band_data, idx + 1)

@check_band_names
def to_xarray(self, band_selection: Optional[List[str]] = None) -> xr.DataArray:
"""
Converts bands in collection a ``xarray.DataArray``
Converts bands in collection into a ``xarray.DataArray``
:param band_selection:
selection of bands to process. If not provided uses all
Expand Down
Loading

0 comments on commit f0bb31f

Please sign in to comment.