Skip to content

Commit

Permalink
add test_to_rio_dataset_nodata_none() and test mask cacheing in test_…
Browse files Browse the repository at this point in the history
…nodata_mask()
  • Loading branch information
dugalh committed Feb 20, 2024
1 parent 54c9d20 commit 989129c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
4 changes: 3 additions & 1 deletion tests/test_fuse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_basic_fusion(
)
),
dict(
driver='GTiff', dtype='uint8', nodata=0,
driver='GTiff', dtype='uint8', nodata=None,
creation_options=dict(
tiled=True, blockxsize=64, blockysize=64, compress='jpeg', interleave='pixel', photometric='ycbcr'
)
Expand All @@ -139,8 +139,10 @@ def test_out_profile(file_rgb_100cm_float, tmp_path: Path, out_profile: Dict):
with raster_fuse:
raster_fuse.process(corr_filename, Model.gain_blk_offset, (3, 3), out_profile=out_profile)
assert (corr_filename.exists())

out_profile.update(**out_profile['creation_options'])
out_profile.pop('creation_options')

with rio.open(file_rgb_100cm_float, 'r') as src_ds, rio.open(corr_filename, 'r') as fuse_ds:
# test output image has been set with out_profile properties
for k, v in out_profile.items():
Expand Down
64 changes: 46 additions & 18 deletions tests/test_raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
import rasterio as rio
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.enums import Resampling, MaskFlags
from rasterio.transform import Affine, from_bounds
from rasterio.windows import Window
from rasterio.warp import transform_bounds
Expand Down Expand Up @@ -72,12 +72,19 @@ def test_nodata_mask(ra_byte):
assert (ra_byte.mask_ra.array == mask).all()
assert ra_byte.mask_ra.transform == ra_byte.transform

# test mask changed after setting altered masked array
array = ra_byte.array.copy()
array[np.divide(mask.shape, 2).astype('int')] = ra_byte.nodata
mask[np.divide(mask.shape, 2).astype('int')] = False
ra_byte.array = array
assert (ra_byte.mask == mask).all()

# test mask unchanged after setting stacked array
ra_byte.array = np.stack((ra_byte.array, ra_byte.array), axis=0)
assert (ra_byte.mask == mask).all()

# test altering the mask
mask[np.divide(mask.shape, 2).astype('int')] = False
mask[(np.divide(mask.shape, 2) + 1).astype('int')] = False
ra_byte.mask = mask
assert (ra_byte.mask == mask).all()

Expand Down Expand Up @@ -192,6 +199,23 @@ def test_to_rio_dataset(ra_byte, tmp_path: Path):
assert (test_array == ra_byte.array).all()


def test_to_rio_dataset_nodata_none(ra_byte, tmp_path: Path):
""" Test writing raster array to dataset with nodata=None writes an internal mask. """
ds_filename = tmp_path.joinpath('temp.tif')
profile = ra_byte.profile
profile.update(nodata=None)
with rio.open(ds_filename, 'w', driver='GTiff', **profile) as ds:
ra_byte.to_rio_dataset(ds)

with rio.open(ds_filename, 'r') as ds:
assert ds.nodata is None
assert ds.mask_flag_enums[0] == [MaskFlags.per_dataset]
test_mask = ds.dataset_mask().astype('bool')
test_array = ds.read(indexes=1)
assert (test_mask == ra_byte.mask).all()
assert (test_array[test_mask] == ra_byte.array[ra_byte.mask]).all()


def test_to_rio_dataset_crop(ra_rgb_byte, tmp_path: Path):
""" Test writing a raster array to a dataset where the dataset & raster array sizes differ. """
ds_filename = tmp_path.joinpath('temp.tif')
Expand Down Expand Up @@ -290,7 +314,7 @@ def test_reprojection(ra_rgb_byte):
('int64', 1, 'float64', float('nan')),
('float32', float('nan'), 'float32', None), # nodata unchanged
])
def test_convert_dtype(profile_100cm_float: dict, src_dtype: str, src_nodata: float, dst_dtype: str, dst_nodata: float):
def test_convert_array_dtype(profile_100cm_float: dict, src_dtype: str, src_nodata: float, dst_dtype: str, dst_nodata: float):
""" Test dtype conversion with combinations covering rounding, clipping (with and w/o type promotion) and
re-masking.
"""
Expand All @@ -307,32 +331,36 @@ def test_convert_dtype(profile_100cm_float: dict, src_dtype: str, src_nodata: fl
array, crs=profile_100cm_float['crs'], transform=profile_100cm_float['transform'], nodata=src_nodata
)

# convert src_ra to dtype
test_ra = src_ra.copy()
test_ra.convert_dtype(dst_dtype, nodata=dst_nodata)
# convert to dtype
src_copy_ra = src_ra.copy()
test_array = src_copy_ra._convert_array_dtype(dst_dtype, nodata=dst_nodata)

# test converting did not change src_copy_ra
assert utils.nan_equals(src_copy_ra.array, src_ra.array).all()
assert (src_copy_ra.mask == src_ra.mask).all()

# create rounded & clipped array in src_dtype to test against
test_array = array
ref_array = array
if np.issubdtype(dst_dtype, np.integer):
test_array = np.clip(np.round(test_array), dst_dtype_info.min, dst_dtype_info.max)
ref_array = np.clip(np.round(ref_array), dst_dtype_info.min, dst_dtype_info.max)
elif np.issubdtype(src_dtype, np.floating):
# don't clip float but set out of range vals to +-inf (as np.astype does)
test_array[test_array < dst_dtype_info.min] = float('-inf')
test_array[test_array > dst_dtype_info.max] = float('inf')
assert np.any(test_array[src_ra.mask] % 1 != 0) # check contains decimals
ref_array[ref_array < dst_dtype_info.min] = float('-inf')
ref_array[ref_array > dst_dtype_info.max] = float('inf')
assert np.any(ref_array[src_ra.mask] % 1 != 0) # check contains decimals

assert test_ra.dtype == dst_dtype
assert test_array.dtype == dst_dtype
if dst_nodata:
assert utils.nan_equals(test_ra.nodata, dst_nodata)
assert np.any(test_ra.mask)
assert np.all(test_ra.mask == src_ra.mask)
test_mask = ~utils.nan_equals(test_array, dst_nodata)
assert np.any(test_mask)
assert (test_mask == src_ra.mask).all()
# use approx test for case of (expected) precision loss e.g. float64->float32 or int64->float32
assert test_ra.array[test_ra.mask] == pytest.approx(test_array[src_ra.mask], rel=1e-6)
assert test_array[src_ra.mask] == pytest.approx(ref_array[src_ra.mask], rel=1e-6)


def test_convert_dtype_error(ra_100cm_float: RasterArray):
def test_convert_array_dtype_error(ra_100cm_float: RasterArray):
""" Test dtype conversion raises an error when the nodata value cannot be cast to the conversion dtype. """
test_ra = ra_100cm_float.copy()
with pytest.raises(ValueError) as ex:
test_ra.convert_dtype('uint8', nodata=float('nan'))
test_ra._convert_array_dtype('uint8', nodata=float('nan'))
assert 'cast' in str(ex.value)

0 comments on commit 989129c

Please sign in to comment.