Skip to content

Commit

Permalink
Remove wrapper for pystac_client queries
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryPetrochenkov-NOAA committed Feb 29, 2024
1 parent b48ec7d commit a790a75
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 339 deletions.
327 changes: 100 additions & 227 deletions src/gval/utils/loading_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import rioxarray as rxr
import xarray as xr
import numpy as np
from shapely.geometry import MultiPoint, shape
from tempfile import NamedTemporaryFile
from rio_cogeo.cogeo import cog_translate
from rio_cogeo.profiles import cog_profiles
import pystac_client
from pystac.item_collection import ItemCollection

import stackstac

Expand Down Expand Up @@ -324,93 +325,23 @@ def _set_crs(stack: xr.DataArray, band_metadata: list = None) -> Number:
return stack.rio.write_crs(f"EPSG:{band_metadata['epsg'].values}")


def query_stac(
url: str,
collections: str,
time: str,
query: str = None,
max_items: int = None,
intersects: dict = None,
bbox: list = None,
) -> list:
"""Return items from a stac query
Parameters
----------
url : str
Address hosting the STAC API
collections : Union[str, list]
Name of collection to get (currently limited to one)
time : str
Single or range of values to query in the time dimension
bands: list, default = None
Bands to retrieve from service
query : str, default = None
String command to filter data
max_items : int, default = None
The maximum amount of records to retrieve
intersects : dict, default = None
Dictionary representing the type of geometry and its respective coordinates
bbox : list, default = None
Coordinates to filter the spatial range of request
Returns
-------
list
An iterable of STAC items
"""

if not isinstance(collections, Iterable):
collections = [collections]

catalog = pystac_client.Client.open(url)

return catalog.search(
datetime=time,
collections=[collections],
max_items=max_items,
intersects=intersects,
bbox=bbox,
query=query,
).item_collection()


def get_stac_data(
url: str,
collection: str,
time: str,
stac_items: ItemCollection,
bands: list = None,
query: str = None,
time_aggregate: str = None,
max_items: int = None,
intersects: dict = None,
bbox: list = None,
resolution: int = None,
nodata_fill: Number = None,
) -> xr.Dataset:
"""
"""Transform STAC Items in to an xarray object
Parameters
----------
url : str
Address hosting the STAC API
collection : str
Name of collection to get (currently limited to one)
time : str
Single or range of values to query in the time dimension
stac_items : ItemCollection
STAC Item Collection returned from pystac client
bands: list, default = None
Bands to retrieve from service
query : str, default = None
String command to filter data
time_aggregate : str, default = None
Method to aggregate multiple time stamps
max_items : int, default = None
The maximum amount of records to retrieve
intersects : dict, default = None
Dictionary representing the type of geometry and its respective coordinates
bbox : list, default = None
Coordinates to filter the spatial range of request
resolution : int, default = 10
Resolution to get data from
nodata_fill : Number, default = None
Expand All @@ -421,84 +352,77 @@ def get_stac_data(
xr.Dataset
Xarray object with resepective STAC API data
Raises
------
ValueError
A valid aggregate must be used for time ranges
"""

with warnings.catch_warnings():
warnings.simplefilter("ignore")

# Call cataloging url, search, and convert to xarray
stac_items = query_stac(
url=url,
time=time,
collections=collection,
max_items=max_items,
intersects=intersects,
bbox=bbox,
query=query,
)

stack = stackstac.stack(stac_items, resolution=resolution)

# Only get unique time indices in case there are duplicates
_, idxs = np.unique(stack.coords["time"], return_index=True)
stack = stack[idxs]

# Aggregate if there is more than one time
if stack.coords["time"].shape[0] > 1:
crs = stack.rio.crs
if time_aggregate == "mean":
stack = stack.mean(dim="time")
stack.attrs["time_aggregate"] = "mean"
elif time_aggregate == "min":
stack = stack.min(dim="time")
stack.attrs["time_aggregate"] = "min"
elif time_aggregate == "max":
stack = stack.max(dim="time")
stack.attrs["time_aggregate"] = "max"
else:
raise ValueError("A valid aggregate must be used for time ranges")

stack.rio.write_crs(crs, inplace=True)
# Only get unique time indices in case there are duplicates
_, idxs = np.unique(stack.coords["time"], return_index=True)
stack = stack[idxs]

# Aggregate if there is more than one time
if stack.coords["time"].shape[0] > 1:
crs = stack.rio.crs
if time_aggregate == "mean":
stack = stack.mean(dim="time")
stack.attrs["time_aggregate"] = "mean"
elif time_aggregate == "min":
stack = stack.min(dim="time")
stack.attrs["time_aggregate"] = "min"
elif time_aggregate == "max":
stack = stack.max(dim="time")
stack.attrs["time_aggregate"] = "max"
else:
stack = stack[0]
stack.attrs["time_aggregate"] = "none"
raise ValueError("A valid aggregate must be used for time ranges")

# Select specific bands
if bands is not None:
bands = [bands] if isinstance(bands, str) else bands
stack = stack.sel({"band": bands})
stack.rio.write_crs(crs, inplace=True)
else:
stack = stack[0]
stack.attrs["time_aggregate"] = "none"

band_metadata = (
stack.coords["raster:bands"] if "raster:bands" in stack.coords else None
)
if "band" in stack.dims:
og_names = [name for name in stack.coords["band"]]
names = [f"band_{x + 1}" for x in range(len(stack.coords["band"]))]
stack = stack.assign_coords({"band": names}).to_dataset(dim="band")
# Select specific bands
if bands is not None:
bands = [bands] if isinstance(bands, str) else bands
stack = stack.sel({"band": bands})

for metadata, var, og_var in zip(band_metadata, stack.data_vars, og_names):
_set_nodata(stack[var], metadata, nodata_fill)
stack[var] = _set_crs(stack[var], band_metadata)
stack[var].attrs["original_name"] = og_var
band_metadata = (
stack.coords["raster:bands"] if "raster:bands" in stack.coords else None
)
if "band" in stack.dims:
og_names = [name for name in stack.coords["band"]]
names = [f"band_{x + 1}" for x in range(len(stack.coords["band"]))]
stack = stack.assign_coords({"band": names}).to_dataset(dim="band")

else:
stack = stack.to_dataset(name="band_1")
_set_nodata(stack["band_1"], band_metadata, nodata_fill)
stack["band_1"] = _set_crs(stack["band_1"])
stack["band_1"].attrs["original_name"] = (
bands[0] if isinstance(bands, list) else bands
)
for metadata, var, og_var in zip(band_metadata, stack.data_vars, og_names):
_set_nodata(stack[var], metadata, nodata_fill)
stack[var] = _set_crs(stack[var], band_metadata)
stack[var].attrs["original_name"] = og_var

else:
stack = stack.to_dataset(name="band_1")
_set_nodata(stack["band_1"], band_metadata, nodata_fill)
stack["band_1"] = _set_crs(stack["band_1"])
stack["band_1"].attrs["original_name"] = (
bands[0] if isinstance(bands, list) else bands
)

return stack
return stack


def _stac_to_df(stac_items: list, assets: list = None) -> pd.DataFrame:
"""Convert a list of stac items to a DataFrame
def stac_to_df(stac_items: ItemCollection, assets: list = None) -> pd.DataFrame:
"""Convert STAC Items in to a DataFrame
Parameters
----------
stac_items: list
List of stac items to create a catalog with
stac_items: ItemCollection
STAC Item Collection returned from pystac client
assets : list, default = None
Assets to keep, (keep all if None)
Expand All @@ -514,101 +438,50 @@ def _stac_to_df(stac_items: list, assets: list = None) -> pd.DataFrame:
"""

dfs, compare_idx = [], 1
for stac_item in stac_items:
map_name, map_id, compare_id = [], [], []
for key, item in stac_item.assets.items():
if assets is None or key in assets:
map_name.append(key)
map_id.append(item.href)
compare_id.append(compare_idx)
item_dfs, compare_idx = [], 1

# Iterate through each STAC Item
for item in stac_items:
item_dict = item.to_dict()
item_columns = {}

# Get columns for all collection level and item level properties
for key, val in item_dict["properties"].items():
if not isinstance(val, list):
if isinstance(val, dict):
for k, v in val.items():
item_columns[k] = [v]
else:
item_columns[key] = [val]

item_columns["bbox"] = MultiPoint(np.array(item_dict["bbox"]).reshape(2, 2)).wkt
item_columns["geometry"] = shape(item_dict["geometry"]).wkt

unique_keys = []
for k, v in item_dict["assets"].items():
for key in v.keys():
if key not in unique_keys:
unique_keys.append(key)

# Create new row for each asset with and assign compare_id and map_id
asset_dfs = []
for k, v in item_dict["assets"].items():
if assets is None or k in assets:
asset_columns = item_columns.copy()

asset_columns[key] = [str(v.get(key, "N/a"))]
asset_columns["compare_id"] = compare_idx
asset_columns["map_id"] = v["href"]
compare_idx += 1
asset_columns["asset"] = [k]
for key in unique_keys:
asset_columns[key] = [str(v.get(key, "N/a"))]

len_assets = len(map_name)

df_contents = {
"collection_id": [stac_item.collection_id] * len_assets,
"item_id": [stac_item.id] * len_assets,
"item_time": [stac_item.get_datetime()] * len_assets,
"create_time": [stac_item.properties["created"]] * len_assets,
"map_id": map_id,
"map_name": map_name,
"compare_id": compare_id,
"coverage_geometry_type": [stac_item.geometry["type"]] * len_assets,
"coverage_geometry_coords": [stac_item.geometry["coordinates"]]
* len_assets,
"coverage_epsg": ["4326"] * len_assets,
"asset_epsg": [stac_item.properties["proj:epsg"]] * len_assets,
}

dfs.append(pd.DataFrame(df_contents))

combined_df = pd.concat(dfs)
if combined_df.empty:
raise ValueError("No entries in DataFrame due to nonexistent asset")

return combined_df


def stac_catalog(
url: str,
collections: Union[str, list],
time: str,
query: str = None,
max_items: int = None,
intersects: dict = None,
bbox: list = None,
assets: list = None,
) -> pd.DataFrame:
"""Create a STAC Catalog from a STAC query
Parameters
----------
url : str
Address hosting the STAC API
collections : Union[str, list]
Name of collection/s to get
time : str
Single or range of values to query in the time dimension
query : str, default = None
String command to filter data
max_items : int, default = None
The maximum amount of records to retrieve
intersects : dict, default = None
Dictionary representing the type of geometry and its respective coordinates
bbox : list, default = None
Coordinates to filter the spatial range of request
assets : list, default = None
Assets to keep, (keep all if None)
Returns
-------
pd.DataFrame
DataFrame representing a catalog based on STAC query
Raises
------
JSONDecodeError
Unable to make STAC query
ValueError
No items returned from query
"""

stac_items = query_stac(
url=url,
time=time,
collections=collections,
max_items=max_items,
intersects=intersects,
bbox=bbox,
query=query,
)
asset_dfs.append(pd.DataFrame(asset_columns))

if len(stac_items) == 0:
raise ValueError("No items returned from query")
item_dfs.append(pd.concat(asset_dfs))

return _stac_to_df(stac_items, assets).reset_index(drop=True)
return pd.concat(item_dfs, ignore_index=True)


def _create_circle_mask(
Expand Down
Loading

0 comments on commit a790a75

Please sign in to comment.