Skip to content

Commit

Permalink
feat: working test suite for s2s python client
Browse files Browse the repository at this point in the history
  • Loading branch information
Gabe-Levin committed Jan 19, 2025
1 parent 573e585 commit 689eb18
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 0 deletions.
107 changes: 107 additions & 0 deletions space2stats_client/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json

import geopandas as gpd
import pandas as pd
import pytest
from pystac import Catalog, Collection, Item
from shapely.geometry import Polygon

MOCK_GEOMETRY = Polygon([[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]])
MOCK_GDF = gpd.GeoDataFrame(
{"name": ["Test Area"]}, geometry=[MOCK_GEOMETRY], crs="EPSG:4326"
)


@pytest.fixture
def sample_geodataframe():
"""Create a sample GeoDataFrame for testing."""
geometry = Polygon([[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]])
return gpd.GeoDataFrame({"id": [1], "geometry": [geometry]}, crs="EPSG:4326")


@pytest.fixture
def mock_catalog(mocker):
"""Mock STAC catalog responses."""
mock_item = mocker.Mock(spec=Item)
mock_item.id = "test_dataset"
mock_item.properties = {
"name": "Test Dataset",
"description": "Test Description",
"source_data": "Test Source",
"table:columns": [
{"name": "sum_pop_2020", "description": "Population count 2020"},
{
"name": "sum_pop_f_10_2020",
"description": "Female population count 2020",
},
],
}

mock_collection = mocker.Mock(spec=Collection)
mock_collection.id = "test_dataset"
mock_collection.get_item.return_value = mock_item

mock_catalog = mocker.Mock(spec=Catalog)
mock_catalog.get_all_items.return_value = iter([mock_item])
mock_catalog.get_collections.return_value = iter([mock_collection])

return mock_catalog


@pytest.fixture
def mock_api_response(mocker, mock_catalog):
"""Mock API responses for testing."""

def mock_response(*args, **kwargs):
mock = mocker.Mock()
mock.status_code = 200

if "geoboundaries.org" in str(args[0]):
mock.json.return_value = {
"gjDownloadURL": "https://example.com/boundary.geojson"
}
elif "topics" in str(args[0]):
mock.json.return_value = [
{
"id": "test_dataset",
"name": "Test Dataset",
"description": "Test Description",
"source_data": "Test Source",
"variables": {"sum_pop_2020": "Population count 2020"},
}
]
elif "properties" in str(args[0]):
mock.json.return_value = {
"name": "Test Dataset",
"description": "Test Description",
"variables": {
"sum_pop_2020": "Population count 2020",
"sum_pop_f_10_2020": "Female population count 2020",
},
}
elif "fields" in str(args[0]):
mock.json.return_value = ["sum_pop_2020", "sum_pop_f_10_2020"]
elif "summary" in str(args[0]):
mock.json.return_value = [
{
"hex_id": "862a1070fffffff",
"sum_pop_2020": 1000,
"sum_pop_f_10_2020": 500,
}
]
elif "aggregate" in str(args[0]):
mock.json.return_value = {"sum_pop_2020": 5000, "sum_pop_f_10_2020": 2500}

return mock

# Mock requests
mocker.patch("requests.get", side_effect=mock_response)
mocker.patch("requests.post", side_effect=mock_response)

# Use pre-created GeoDataFrame
mocker.patch("geopandas.read_file", return_value=MOCK_GDF)

# Mock STAC catalog
mocker.patch("pystac.Catalog.from_file", return_value=mock_catalog)

return mock_response
101 changes: 101 additions & 0 deletions space2stats_client/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import geopandas as gpd
import pandas as pd
import pytest

from space2stats_client import Space2StatsClient


def test_client_initialization():
"""Test that the client initializes with correct endpoints."""
client = Space2StatsClient()
assert client.base_url == "https://space2stats.ds.io"
assert client.summary_endpoint == f"{client.base_url}/summary"
assert client.aggregation_endpoint == f"{client.base_url}/aggregate"
assert client.fields_endpoint == f"{client.base_url}/fields"


def test_get_topics(mock_api_response):
"""Test that get_topics returns expected DataFrame."""
client = Space2StatsClient()
topics = client.get_topics()
assert isinstance(topics, pd.DataFrame)
assert "name" in topics.columns
assert "description" in topics.columns
assert "source_data" in topics.columns


def test_get_fields(mock_api_response):
"""Test that get_fields returns list of available fields."""
client = Space2StatsClient()
fields = client.get_fields()
assert isinstance(fields, list)
assert "sum_pop_2020" in fields
assert "sum_pop_f_10_2020" in fields


def test_get_properties(mock_api_response):
"""Test that get_properties returns DataFrame with variable descriptions."""
client = Space2StatsClient()
properties = client.get_properties("test_dataset")
assert "name" in properties
assert "description" in properties
assert "sum_pop_2020" in properties["name"].values


def test_fetch_admin_boundaries(mock_api_response):
"""Test fetching admin boundaries."""
client = Space2StatsClient()
boundaries = client.fetch_admin_boundaries("USA", "ADM1")
assert isinstance(boundaries, gpd.GeoDataFrame)
assert "geometry" in boundaries.columns


def test_get_summary(mock_api_response, sample_geodataframe):
"""Test get_summary with sample data."""
client = Space2StatsClient()
result = client.get_summary(
gdf=sample_geodataframe, spatial_join_method="centroid", fields=["sum_pop_2020"]
)
assert isinstance(result, pd.DataFrame)
assert "hex_id" in result.columns
assert "sum_pop_2020" in result.columns


def test_get_aggregate(mock_api_response, sample_geodataframe):
"""Test get_aggregate with sample data."""
client = Space2StatsClient()
result = client.get_aggregate(
gdf=sample_geodataframe,
spatial_join_method="centroid",
fields=["sum_pop_2020"],
aggregation_type="sum",
)
assert isinstance(result, pd.DataFrame)
assert "sum_pop_2020" in result


def test_invalid_spatial_join_method(sample_geodataframe):
"""Test that invalid spatial join method raises ValueError."""
client = Space2StatsClient()
with pytest.raises(
Exception, match="Input should be 'touches', 'centroid' or 'within'"
):
client.get_summary(
gdf=sample_geodataframe,
spatial_join_method="invalid",
fields=["population"],
)


def test_invalid_aggregation_type(sample_geodataframe):
"""Test that invalid aggregation type raises ValueError."""
client = Space2StatsClient()
with pytest.raises(
Exception, match="Input should be 'sum', 'avg', 'count', 'max' or 'min'"
):
client.get_aggregate(
gdf=sample_geodataframe,
spatial_join_method="centroid",
fields=["population"],
aggregation_type="invalid",
)

0 comments on commit 689eb18

Please sign in to comment.