From 6606406a61cc81cb9dda52dc192dc7d0ca2e1ef5 Mon Sep 17 00:00:00 2001 From: atmorling Date: Thu, 28 Nov 2024 22:03:43 +0200 Subject: [PATCH] use patrol_id as groupby in get_patrol_obs (#340) --- ecoscope/io/async_earthranger.py | 8 ++++---- ecoscope/io/earthranger.py | 1 + tests/test_asyncearthranger_io.py | 1 + tests/test_earthranger_io.py | 1 + 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ecoscope/io/async_earthranger.py b/ecoscope/io/async_earthranger.py index b78c1d53..4c02fc1e 100644 --- a/ecoscope/io/async_earthranger.py +++ b/ecoscope/io/async_earthranger.py @@ -419,10 +419,7 @@ async def get_patrol_observations_with_patrol_filter( relocations : ecoscope.base.Relocations """ observations = ecoscope.base.Relocations() - df_pt = None - - if include_patrol_details: - df_pt = await self.get_patrol_types_dataframe() + df_pt = await self.get_patrol_types_dataframe() if include_patrol_details else None tasks = [] async for patrol in self.get_patrols(since=since, until=until, patrol_type=patrol_type, status=status): @@ -432,6 +429,9 @@ async def get_patrol_observations_with_patrol_filter( observations = await asyncio.gather(*tasks) observations = pd.concat(observations) + if include_patrol_details: + observations["groupby_col"] = observations["patrol_id"] + return observations async def _get_observations_by_patrol(self, patrol, relocations=True, tz="UTC", patrol_types=None, **kwargs): diff --git a/ecoscope/io/earthranger.py b/ecoscope/io/earthranger.py index 4149e321..fa30d8d3 100644 --- a/ecoscope/io/earthranger.py +++ b/ecoscope/io/earthranger.py @@ -895,6 +895,7 @@ def get_patrol_observations(self, patrols_df, include_patrol_details=False, **kw ) ) if len(observation) > 0: + observation["groupby_col"] = patrol["id"] observations.append(observation) except Exception as e: print( diff --git a/tests/test_asyncearthranger_io.py b/tests/test_asyncearthranger_io.py index c0961d1f..caa06937 100644 --- a/tests/test_asyncearthranger_io.py +++ b/tests/test_asyncearthranger_io.py @@ -241,6 +241,7 @@ async def test_get_patrol_observations_with_patrol_details( assert not observations.empty assert set(observations.columns) == set(get_patrol_observations_fields).union(get_patrol_details_fields) assert type(observations["fixtime"] == pd.Timestamp) + pd.testing.assert_series_equal(observations["patrol_id"], observations["groupby_col"], check_names=False) @pytest.mark.asyncio diff --git a/tests/test_earthranger_io.py b/tests/test_earthranger_io.py index 13cb1fc2..e04d0acb 100644 --- a/tests/test_earthranger_io.py +++ b/tests/test_earthranger_io.py @@ -297,6 +297,7 @@ def test_get_patrol_observations_with_patrol_details(er_io): assert not observations.empty assert "patrol_id" in observations.columns assert "patrol_title" in observations.columns + pd.testing.assert_series_equal(observations["patrol_id"], observations["groupby_col"], check_names=False) def test_users(er_io):