diff --git a/mosartwmpy/utilities/bil_to_parquet.py b/mosartwmpy/utilities/bil_to_parquet.py index 44ea93e..701571d 100644 --- a/mosartwmpy/utilities/bil_to_parquet.py +++ b/mosartwmpy/utilities/bil_to_parquet.py @@ -1,22 +1,18 @@ import click -import numpy as np import geopandas as gpd import json -from matplotlib import pyplot -import matplotlib as plt import pandas as pd -import pyarrow as pa -import pyarrow.parquet as pq import rasterio from rasterio.io import MemoryFile from rasterio.mask import mask +from rasterio.merge import merge from shapely.geometry import box import xarray as xr @click.command() @click.option( '--grid-path', - default='../../input/domains/mosart.nc', + default='input/domains/mosart.nc', type=click.Path( file_okay=True, dir_okay=False, @@ -28,7 +24,8 @@ ) @click.option( '--bil-elevation-path', - default='../../input/elevation/na_dem_30s_bil/na_dem_30s.bil', + default=['input/elevation/na_dem_30s_bil/na_dem_30s.bil', 'input/elevation/ca_dem_30s_bil/ca_dem_30s.bil'], + multiple=True, type=click.Path( file_okay=True, dir_okay=False, @@ -36,11 +33,11 @@ resolve_path=True, ), prompt='What is the path to the .bil elevation file?', - help="""Path to the .bil elevation file.""", + help="""Path to one or more .bil elevation file(s).""", ) @click.option( '--parquet-elevation-path', - default='../../input/elevation/na_dem_30s.parquet', + default='input/elevation/na_dem_30s.parquet', type=click.Path( file_okay=True, dir_okay=False, @@ -58,19 +55,34 @@ def bil_to_parquet( grid_longitude_key='lon', grid_latitude_key='lat', ): - """Convert a bil file into a parquet file.""" + """Convert one or more bil file(s) into a parquet file.""" domain = xr.open_dataset(grid_path) grid_resolution = domain[grid_latitude_key][1] - domain[grid_latitude_key][0] - ID = domain['ID'].to_numpy().flatten() - - # Import bil elevation file and trim to domain. - bil = rasterio.open(bil_elevation_path) - bil = cropToDomain(bil, domain, grid_longitude_key, grid_latitude_key, grid_resolution, bil_elevation_path[:-4] + '_cropped.bil') + ID = domain['ID'].values.flatten() - # Resample data to same resolution as grid. - scale_factor = bil.res[0]/grid_resolution - avg_downsampled_bil = bil.read( + # Combine all input bil files. + merged_bil = None + for bil in bil_elevation_path: + if merged_bil is None: + merged_bil = rasterio.open(bil) + continue + + bil = rasterio.open(bil) + merged_bil, transform = merge([bil, merged_bil]) + merged_bil = return_in_memory(merged_bil, bil.crs, transform) + + merged_bil = avg_resample(merged_bil, grid_resolution) + merged_bil = crop_to_domain(merged_bil, domain, grid_longitude_key, grid_latitude_key, grid_resolution) + + df = pd.DataFrame(merged_bil.read(1).flatten()) + df.columns = df.columns.astype(str) + df.to_parquet(parquet_elevation_path) + +def avg_resample(bil, grid_resolution): + scale_factor = bil.res[0] / grid_resolution + + avg_sampled_bil = bil.read( out_shape=( bil.count, int(bil.height * scale_factor), @@ -78,35 +90,27 @@ def bil_to_parquet( ), resampling=rasterio.enums.Resampling.average ) + transform = bil.transform * bil.transform.scale( + (bil.width / avg_sampled_bil.shape[-1]), + (bil.height / avg_sampled_bil.shape[-2]) + ) + return return_in_memory(avg_sampled_bil, bil.crs, transform) - # Write as parquet file. - df = pd.DataFrame(avg_downsampled_bil.flatten(), ID) - df.columns = df.columns.astype(str) - df.to_parquet(parquet_elevation_path) - - -def cropToDomain(bil, domain, grid_latitude_key, grid_longitude_key, grid_resolution, cropped_output_path): +def crop_to_domain(bil, domain, grid_latitude_key, grid_longitude_key, grid_resolution): xmin, ymin, xmax, ymax = domain[grid_latitude_key].min().min().item(0), domain[grid_longitude_key].min().min().item(0), domain[grid_latitude_key].max().max().item(0), domain[grid_longitude_key].max().max().item(0) bbox = box(xmin, ymin, xmax + grid_resolution, ymax + grid_resolution) - if bbox == bil.bounds: - return bil - geo = gpd.GeoDataFrame({'geometry': bbox}, index=[0], crs=bil.crs) coords = [json.loads(geo.to_json())['features'][0]['geometry']] - out_img, out_transform = mask(dataset=bil, shapes=coords, crop=True) - out_meta = bil.meta.copy() - out_meta.update({"driver": "GTiff", - "height": out_img.shape[1], - "width": out_img.shape[2], - "transform": out_transform, - "crs": bil.crs}) - - with MemoryFile() as memfile: - with memfile.open(**out_meta) as dataset: # Open as DatasetWriter - dataset.write(out_img) - del out_img - return memfile.open() + cropped, transform = mask(dataset=bil, shapes=coords, crop=True, nodata=-999) + + return return_in_memory(cropped, bil.crs, transform) + +def return_in_memory(array, crs, transform): + memfile = MemoryFile() + dataset = memfile.open(driver='GTiff', height=array.shape[-2], width=array.shape[-1], count=1, crs=crs, transform=transform, dtype=array.dtype) + dataset.write(array) + return dataset if __name__ == '__main__': bil_to_parquet() diff --git a/mosartwmpy/utilities/test_bil_to_parquet.py b/mosartwmpy/utilities/test_bil_to_parquet.py new file mode 100644 index 0000000..5b4ef55 --- /dev/null +++ b/mosartwmpy/utilities/test_bil_to_parquet.py @@ -0,0 +1,75 @@ +from affine import Affine +from dataclasses import dataclass +from mosartwmpy.utilities.bil_to_parquet import avg_resample, crop_to_domain, return_in_memory +import numpy as np +from rasterio.crs import CRS +from shapely.geometry import box +import unittest +import xarray as xr + +class TestAvgResample(unittest.TestCase): + def setUp(self): + self.crs = CRS.from_epsg(4326) + self.transform = Affine.identity() + + def test_avg_resample(self): + @dataclass + class TestCase: + name: str + scale: float + input: np.ndarray + expected: np.ndarray + + testcases = [ + TestCase(name="rescale_to_same_size", scale=1, input=np.zeros((1, 4, 4)), expected=np.zeros((1, 4, 4))), + TestCase(name="rescale_double", scale=.5, input=np.zeros((1, 4, 4)), expected=np.zeros((1, 8, 8))), + TestCase(name="rescale_half", scale=2, input=np.zeros((1, 4, 4)), expected=np.zeros((1, 2, 2))), + TestCase(name="rescale_test_average_positive", scale=2, input=np.array([[[1, 2, 1, 2, 1, 2], [1, 2, 1, 2, 1, 2]]], dtype=float), expected=np.array([[[1.5, 1.5, 1.5]]])), + TestCase(name="rescale_test_average_negative", scale=2, input=np.array([[[-1, -2, -1, -2, -1, -2], [-1, -2, -1, -2, -1, -2]]], dtype=float), expected=np.array([[[-1.5, -1.5, -1.5]]])), + TestCase(name="rescale_test_average_zero", scale=2, input=np.array([[[-1, -2, -1, -2, -1, -2], [1, 2, 1, 2, 1, 2]]], dtype=float), expected=np.array([[[0, 0, 0]]])), + ] + + for case in testcases: + bil = return_in_memory(case.input, self.crs, self.transform) + actual = avg_resample(bil, case.scale).read() + self.assertTrue(np.array_equiv(actual, case.expected)) + +class TestCropToDomain(unittest.TestCase): + def setUp(self): + self.crs = CRS.from_epsg(4326) + self.transform = Affine.identity() + self.grid_latitude_key = 'lat' + self.grid_longitude_key = 'lon' + self.grid_resolution = .125 + + def test_crop_to_domain(self): + @dataclass + class TestCase: + name: str + resolution: float + input_bounds: list + crop_bounds: list + expected: list + + testcases = [ + TestCase(name='identity', resolution=1, input_bounds=[0, 0, 4, 4], crop_bounds=[0, 0, 4, 4], expected=[0, 4, 4 ,0]), + TestCase(name='crop_smaller', resolution=1, input_bounds=[0, 0, 20, 20], crop_bounds=[5, 5, 10, 10], expected=[5, 11, 11, 5]), + TestCase(name='smaller_than_crop', resolution=1, input_bounds=[0, 0, 1, 1], crop_bounds=[0, 0, 5, 5], expected=[0, 1, 1, 0]), + ] + + for case in testcases: + bil = return_in_memory(np.zeros((1, case.input_bounds[2], case.input_bounds[3])), self.crs, self.transform) + lon = [case.crop_bounds[0], case.crop_bounds[2]] + lat = [case.crop_bounds[1], case.crop_bounds[3]] + domain = xr.Dataset( + coords=dict( + lon=(np.linspace(min(lon), max(lon), int((max(lon)-min(lon))/case.resolution))), + lat=(np.linspace(min(lat), max(lat), int((max(lat)-min(lat))/case.resolution))), + ) + ) + + actual = crop_to_domain(bil, domain, self.grid_latitude_key, self.grid_longitude_key, case.resolution) + self.assertEqual([actual.bounds[0], actual.bounds[1], actual.bounds[2], actual.bounds[3]], case.expected) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file