Skip to content

Commit

Permalink
Merge branch 'main' into hjafari/MIC-5494_drop_python_3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
hussain-jafari authored Nov 12, 2024
2 parents e7e816b + 2d98598 commit becd461
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
**3.0.18 - 11/06/24**

- Fix mypy errors in vivarium/framework/logging/manager.py
- Fix mypy errors in vivarium/framework/results/observations.py

**3.0.17 - 11/04/24**

Expand Down
1 change: 1 addition & 0 deletions docs/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ py:class Timedelta
py:class VectorMapper
py:class ScalarMapper
py:class PandasObject
py:class DataFrameGroupBy
py:exc ResultsConfigurationError
py:exc vivarium.framework.results.exceptions.ResultsConfigurationError

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ exclude = [
'src/vivarium/framework/results/context.py',
'src/vivarium/framework/results/interface.py',
'src/vivarium/framework/results/manager.py',
'src/vivarium/framework/results/observation.py',
'src/vivarium/framework/results/observer.py',
'src/vivarium/framework/state_machine.py',
'src/vivarium/framework/time.py',
Expand Down
79 changes: 44 additions & 35 deletions src/vivarium/framework/results/observation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: ignore-errors
"""
============
Observations
Expand All @@ -19,10 +18,13 @@
"""

from __future__ import annotations

import itertools
from abc import ABC
from collections.abc import Callable
from dataclasses import dataclass
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
from typing import Any

import pandas as pd
from pandas.api.types import CategoricalDtype
Expand Down Expand Up @@ -52,32 +54,31 @@ class BaseObservation(ABC):
when: str
"""Name of the lifecycle phase the observation should happen. Valid values are:
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics"."""
results_initializer: Callable[[Iterable[str], Iterable[Stratification]], pd.DataFrame]
results_initializer: Callable[[set[str], list[Stratification]], pd.DataFrame]
"""Method or function that initializes the raw observation results
prior to starting the simulation. This could return, for example, an empty
DataFrame or one with a complete set of stratifications as the index and
all values set to 0.0."""
results_gatherer: Union[
Callable[[pd.DataFrame, Sequence[str]], pd.DataFrame],
Callable[[pd.DataFrame], pd.DataFrame],
results_gatherer: Callable[
[pd.DataFrame | DataFrameGroupBy[str], tuple[str, ...] | None], pd.DataFrame
]
"""Method or function that gathers the new observation results."""
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]
"""Method or function that updates existing raw observation results with newly
gathered results."""
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame]
"""Method or function that formats the raw observation results."""
stratifications: Optional[Tuple[str]]
stratifications: tuple[str, ...] | None
"""Optional tuple of column names for the observation to stratify by."""
to_observe: Callable[[Event], bool]
"""Method or function that determines whether to perform an observation on this Event."""

def observe(
self,
event: Event,
df: Union[pd.DataFrame, DataFrameGroupBy],
stratifications: Optional[tuple[str, ...]],
) -> Optional[pd.DataFrame]:
df: pd.DataFrame | DataFrameGroupBy[str],
stratifications: tuple[str, ...] | None,
) -> pd.DataFrame | None:
"""Determine whether to observe the given event, and if so, gather the results.
Parameters
Expand All @@ -96,10 +97,7 @@ def observe(
if not self.to_observe(event):
return None
else:
if stratifications is None:
return self.results_gatherer(df)
else:
return self.results_gatherer(df, stratifications)
return self.results_gatherer(df, stratifications)


class UnstratifiedObservation(BaseObservation):
Expand Down Expand Up @@ -140,12 +138,22 @@ def __init__(
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
to_observe: Callable[[Event], bool] = lambda event: True,
):
def _wrap_results_gatherer(
df: pd.DataFrame | DataFrameGroupBy[str], _: tuple[str, ...] | None
) -> pd.DataFrame:
if isinstance(df, DataFrameGroupBy):
raise TypeError(
"Must provide a dataframe to an UnstratifiedObservation. "
f"Provided DataFrameGroupBy instead."
)
return results_gatherer(df)

super().__init__(
name=name,
pop_filter=pop_filter,
when=when,
results_initializer=self.create_empty_df,
results_gatherer=results_gatherer,
results_gatherer=_wrap_results_gatherer,
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=None,
Expand Down Expand Up @@ -214,17 +222,17 @@ def __init__(
when: str,
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
stratifications: Tuple[str, ...],
aggregator_sources: Optional[list[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]],
stratifications: tuple[str, ...],
aggregator_sources: list[str] | None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[float]],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
name=name,
pop_filter=pop_filter,
when=when,
results_initializer=self.create_expanded_df,
results_gatherer=self.get_complete_stratified_results,
results_gatherer=self.get_complete_stratified_results, # type: ignore [arg-type]
results_updater=results_updater,
results_formatter=results_formatter,
stratifications=stratifications,
Expand Down Expand Up @@ -288,8 +296,8 @@ def create_expanded_df(

def get_complete_stratified_results(
self,
pop_groups: DataFrameGroupBy,
stratifications: Tuple[str, ...],
pop_groups: DataFrameGroupBy[str],
stratifications: tuple[str, ...],
) -> pd.DataFrame:
"""Gather results for this observation.
Expand All @@ -313,22 +321,22 @@ def get_complete_stratified_results(

@staticmethod
def _aggregate(
pop_groups: DataFrameGroupBy,
aggregator_sources: Optional[list[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]],
) -> Union[pd.Series, pd.DataFrame]:
pop_groups: DataFrameGroupBy[str],
aggregator_sources: list[str] | None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[float]],
) -> pd.Series[float] | pd.DataFrame:
"""Apply the `aggregator` to the population groups and their
`aggregator_sources` columns.
"""
aggregates = (
pop_groups[aggregator_sources].apply(aggregator).fillna(0.0)
pop_groups[aggregator_sources].apply(aggregator).fillna(0.0) # type: ignore [arg-type]
if aggregator_sources
else pop_groups.apply(aggregator)
else pop_groups.apply(aggregator) # type: ignore [arg-type]
).astype(float)
return aggregates

@staticmethod
def _format(aggregates: Union[pd.Series, pd.DataFrame]) -> pd.DataFrame:
def _format(aggregates: pd.Series[float] | pd.DataFrame) -> pd.DataFrame:
"""Convert the results to a pandas DataFrame if necessary and ensure the
results column name is 'value'.
"""
Expand All @@ -340,10 +348,11 @@ def _format(aggregates: Union[pd.Series, pd.DataFrame]) -> pd.DataFrame:
@staticmethod
def _expand_index(aggregates: pd.DataFrame) -> pd.DataFrame:
"""Include all stratifications in the results by filling missing values with 0."""
if isinstance(aggregates.index, pd.MultiIndex):
full_idx = pd.MultiIndex.from_product(aggregates.index.levels)
else:
full_idx = aggregates.index
full_idx = (
pd.MultiIndex.from_product(aggregates.index.levels)
if isinstance(aggregates.index, pd.MultiIndex)
else aggregates.index
)
aggregates = aggregates.reindex(full_idx).fillna(0.0)
return aggregates

Expand Down Expand Up @@ -384,9 +393,9 @@ def __init__(
pop_filter: str,
when: str,
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
stratifications: Tuple[str, ...],
aggregator_sources: Optional[list[str]],
aggregator: Callable[[pd.DataFrame], Union[float, pd.Series]],
stratifications: tuple[str, ...],
aggregator_sources: list[str] | None,
aggregator: Callable[[pd.DataFrame], float | pd.Series[float]],
to_observe: Callable[[Event], bool] = lambda event: True,
):
super().__init__(
Expand Down

0 comments on commit becd461

Please sign in to comment.