diff --git a/tests/fixtures/catalog.json b/tests/fixtures/catalog.json index 292a6e6..6d3c3a4 100644 --- a/tests/fixtures/catalog.json +++ b/tests/fixtures/catalog.json @@ -321,6 +321,23 @@ "rel": "self", "type": "application/json", "href": "https://stac.endpoint.io/collections" + }, + { + "rel": "data", + "type": "application/json", + "href": "https://stac.endpoint.io/collections" + }, + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": "https://stac.endpoint.io/aggregate" + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": "https://stac.endpoint.io/aggregations" } ] } diff --git a/tests/test_advanced_pystac_client.py b/tests/test_advanced_pystac_client.py new file mode 100644 index 0000000..29af8b5 --- /dev/null +++ b/tests/test_advanced_pystac_client.py @@ -0,0 +1,119 @@ +"""Test Advanced PySTAC client.""" +import json +import os +from unittest.mock import MagicMock, patch + +import pytest + +from titiler.pystac import AdvancedClient + +catalog_json = os.path.join(os.path.dirname(__file__), "fixtures", "catalog.json") + + +@pytest.fixture +def mock_stac_io(): + """STAC IO mock""" + return MagicMock() + + +@pytest.fixture +def client(mock_stac_io): + """STAC client mock""" + client = AdvancedClient(id="pystac-client", description="pystac-client") + + with open(catalog_json, "r") as f: + catalog = json.loads(f.read()) + client.open = MagicMock() + client.open.return_value = catalog + client._collections_href = MagicMock() + client._collections_href.return_value = "http://example.com/collections" + + client._stac_io = mock_stac_io + return client + + +def test_get_supported_aggregations(client, mock_stac_io): + """Test supported STAC aggregation methods""" + mock_stac_io.read_json.return_value = { + "aggregations": [{"name": "aggregation1"}, {"name": "aggregation2"}] + } + supported_aggregations = client.get_supported_aggregations() + assert supported_aggregations == ["aggregation1", "aggregation2"] + + +@patch( + "titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation_unsupported(supported_aggregations, client): + """Test handling of unsupported aggregation types""" + collection_id = "sentinel-2-l2a" + aggregation = "unsupported-aggregation" + + with pytest.warns( + UserWarning, match="Aggregation type unsupported-aggregation is not supported" + ): + aggregation_data = client.get_aggregation(collection_id, aggregation) + assert aggregation_data == [] + + +@patch( + "titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation(supported_aggregations, client, mock_stac_io): + """Test handling aggregation response""" + collection_id = "sentinel-2-l2a" + aggregation = "datetime_frequency" + aggregation_params = {"datetime_frequency_interval": "day"} + + mock_stac_io.read_json.return_value = { + "aggregations": [ + { + "name": "datetime_frequency", + "buckets": [ + { + "key": "2023-12-11T00:00:00.000Z", + "data_type": "frequency_distribution", + "frequency": 1, + "to": None, + "from": None, + } + ], + }, + { + "name": "unusable_aggregation", + "buckets": [ + { + "key": "2023-12-11T00:00:00.000Z", + } + ], + }, + ] + } + + aggregation_data = client.get_aggregation( + collection_id, aggregation, aggregation_params + ) + assert aggregation_data[0]["key"] == "2023-12-11T00:00:00.000Z" + assert aggregation_data[0]["data_type"] == "frequency_distribution" + assert aggregation_data[0]["frequency"] == 1 + assert len(aggregation_data) == 1 + + +@patch( + "titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", + return_value=["datetime_frequency"], +) +def test_get_aggregation_no_response(supported_aggregations, client, mock_stac_io): + """Test handling of no aggregation response""" + collection_id = "sentinel-2-l2a" + aggregation = "datetime_frequency" + aggregation_params = {"datetime_frequency_interval": "day"} + + mock_stac_io.read_json.return_value = [] + + aggregation_data = client.get_aggregation( + collection_id, aggregation, aggregation_params + ) + assert aggregation_data == [] diff --git a/titiler/pystac/__init__.py b/titiler/pystac/__init__.py new file mode 100644 index 0000000..f398c4b --- /dev/null +++ b/titiler/pystac/__init__.py @@ -0,0 +1,7 @@ +"""titiler.pystac""" + +__all__ = [ + "AdvancedClient", +] + +from titiler.pystac.advanced_client import AdvancedClient diff --git a/titiler/pystac/advanced_client.py b/titiler/pystac/advanced_client.py new file mode 100644 index 0000000..24f77f5 --- /dev/null +++ b/titiler/pystac/advanced_client.py @@ -0,0 +1,87 @@ +""" +This module provides an advanced client for interacting with STAC (SpatioTemporal Asset Catalog) APIs. + +The `AdvancedClient` class extends the basic functionality of the `pystac.Client` to include +methods for retrieving and aggregating data from STAC collections. +""" + +import warnings +from typing import Optional +from urllib.parse import urlencode + +import pystac +from pystac_client import Client + + +class AdvancedClient(Client): + """AdvancedClient extends the basic functionality of the pystac.Client class.""" + + def get_aggregation( + self, + collection_id: str, + aggregation: str, + aggregation_params: Optional[dict] = None, + ) -> list[dict]: + """Perform an aggregation on a STAC collection. + + Args: + collection_id (str): The ID of the collection to aggregate. + aggregation (str): The aggregation type to perform. + aggregation_params (Optional[dict], optional): Additional parameters for the aggregation. Defaults to None. + Returns: + List[str]: The aggregation response. + """ + assert self._stac_io is not None + + if aggregation not in self.get_supported_aggregations(): + warnings.warn( + f"Aggregation type {aggregation} is not supported", stacklevel=1 + ) + return [] + + # Construct the URL for aggregation + url = ( + self._collections_href(collection_id) + + f"/aggregate?aggregations={aggregation}" + ) + if aggregation_params: + params = urlencode(aggregation_params) + url += f"&{params}" + + aggregation_response = self._stac_io.read_json(url) + + if not aggregation_response: + return [] + + aggregation_data = [] + for agg in aggregation_response["aggregations"]: + if agg["name"] == aggregation: + aggregation_data = agg["buckets"] + + return aggregation_data + + def get_supported_aggregations(self) -> list[str]: + """Get the supported aggregation types. + + Returns: + List[str]: The supported aggregations. + """ + response = self._stac_io.read_json(self.get_aggregations_link()) + aggregations = response.get("aggregations", []) + return [agg["name"] for agg in aggregations] + + def get_aggregations_link(self) -> Optional[pystac.Link]: + """Returns this client's aggregations link. + + Returns: + Optional[pystac.Link]: The aggregations link, or None if there is not one found. + """ + return next( + ( + link + for link in self.links + if link.rel == "aggregations" + and link.media_type == pystac.MediaType.JSON + ), + None, + ) diff --git a/titiler/stacapi/factory.py b/titiler/stacapi/factory.py index 854c25f..d9a569e 100644 --- a/titiler/stacapi/factory.py +++ b/titiler/stacapi/factory.py @@ -19,7 +19,6 @@ from morecantile import tms as morecantile_tms from morecantile.defaults import TileMatrixSets from pydantic import conint -from pystac_client import Client from pystac_client.stac_api_io import StacApiIO from rasterio.transform import xy as rowcol_to_coords from rasterio.warp import transform as transform_points @@ -45,6 +44,7 @@ from titiler.core.resources.responses import GeoJSONResponse, XMLResponse from titiler.core.utils import render_image from titiler.mosaic.factory import PixelSelectionParams +from titiler.pystac import AdvancedClient from titiler.stacapi.backend import STACAPIBackend from titiler.stacapi.dependencies import APIParams, STACApiParams, STACSearchParams from titiler.stacapi.models import FeatureInfo, LayerDict @@ -568,7 +568,7 @@ def get_layer_from_collections( # noqa: C901 ), headers=headers, ) - catalog = Client.open(url, stac_io=stac_api_io) + catalog = AdvancedClient.open(url, stac_io=stac_api_io) layers: Dict[str, LayerDict] = {} for collection in catalog.get_collections(): @@ -580,6 +580,7 @@ def get_layer_from_collections( # noqa: C901 tilematrixsets = render.pop("tilematrixsets", None) output_format = render.pop("format", None) + aggregation = render.pop("aggregation", None) _ = render.pop("minmax_zoom", None) # Not Used _ = render.pop("title", None) # Not Used @@ -643,6 +644,20 @@ def get_layer_from_collections( # noqa: C901 "values" ] ] + elif aggregation and aggregation["name"] == "datetime_frequency": + datetime_aggregation = catalog.get_aggregation( + collection_id=collection.id, + aggregation="datetime_frequency", + aggregation_params=aggregation["params"], + ) + layer["time"] = [ + python_datetime.datetime.strptime( + t["key"], + "%Y-%m-%dT%H:%M:%S.000Z", + ).strftime("%Y-%m-%d") + for t in datetime_aggregation + if t["frequency"] > 0 + ] elif intervals := temporal_extent.intervals: start_date = intervals[0][0] end_date = (