Skip to content

Commit

Permalink
21929: Adds max_num_cases, skip_reduce_data parameters to set_auto_ab…
Browse files Browse the repository at this point in the history
…lation_params/train, MAJOR (#327)

Also adjusts the return value of `train` to accommodate multiple
possible status messages.
  • Loading branch information
apbassett authored Nov 21, 2024
1 parent cd77b90 commit f436c0c
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
34 changes: 26 additions & 8 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
Persistence,
Precision,
SeriesIDTracking,
TrainStatus,
TabularData2D,
TabularData3D,
TargetedModel,
Expand Down Expand Up @@ -491,9 +492,10 @@ def train( # noqa: C901
progress_callback: t.Optional[Callable] = None,
series: t.Optional[str] = None,
skip_auto_analyze: bool = False,
skip_reduce_data: bool = False,
train_weights_only: bool = False,
validate: bool = True,
):
) -> TrainStatus:
"""
Train one or more cases into a Trainee.
Expand Down Expand Up @@ -552,7 +554,13 @@ def train( # noqa: C901
in cases is applied in order to each of the cases in the series.
skip_auto_analyze : bool, default False
When true, the Trainee will not auto-analyze when appropriate.
Instead, the boolean response will be True if an analyze is needed.
Instead, the return dict will have a "needs_analyze" flag if an
`analyze` is needed.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when
appropriate. Instead, the return dict will have a
"needs_data_reduction" flag if a call to `reduce_data` is
recommended.
train_weights_only : bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -564,9 +572,10 @@ def train( # noqa: C901
Returns
-------
bool
Flag indicating if the Trainee needs to analyze. Only true if
auto-analyze is enabled and the conditions are met.
dict
A dict containing variable keys if there are important messages to
share from the Engine, such as 'needs_analyze` and
'needs_data_reduction'. Otherwise, an empty dict.
"""
trainee_id = self._resolve_trainee(trainee_id).id
feature_attributes = self.resolve_feature_attributes(trainee_id)
Expand Down Expand Up @@ -613,7 +622,7 @@ def train( # noqa: C901
features = internals.get_features_from_data(cases)
serialized_cases = serialize_cases(cases, features, feature_attributes, warn=True) or []

needs_analyze = False
status = {}

if self.configuration.verbose:
print(f'Training session(s) on Trainee with id: {trainee_id}')
Expand Down Expand Up @@ -645,10 +654,15 @@ def train( # noqa: C901
"series": series,
"session": self.active_session.id,
"skip_auto_analyze": skip_auto_analyze,
"skip_reduce_data": skip_reduce_data,
"train_weights_only": train_weights_only,
})

if response and response.get('status') == 'analyze':
needs_analyze = True
status['needs_analyze'] = True
if response and response.get('status') == 'reduce_data':
status['needs_data_reduction'] = True

if batch_scaler is None or gen_batch_size is None:
progress.update(batch_size)
else:
Expand All @@ -663,7 +677,7 @@ def train( # noqa: C901
self._store_session(trainee_id, self.active_session)
self._auto_persist_trainee(trainee_id)

return needs_analyze
return status

def impute(
self,
Expand Down Expand Up @@ -4064,6 +4078,7 @@ def set_auto_ablation_params(
delta_threshold_map: AblationThresholdMap = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
min_num_cases: int = 1_000,
max_num_cases: int = 500_000,
reduce_data_influence_weight_entropy_threshold: float = 0.6,
rel_threshold_map: AblationThresholdMap = None,
relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None,
Expand Down Expand Up @@ -4099,6 +4114,8 @@ def set_auto_ablation_params(
to recompute influence weight entropy.
min_num_cases : int, default 1,000
The threshold of the minimum number of cases at which the model should auto-ablate.
max_num_cases: int, default 500,000
The threshold of the maximum number of cases at which the model should auto-reduce
exact_prediction_features : Optional[List[str]], optional
For each of the features specified, will ablate a case if the prediction matches exactly.
residual_prediction_features : Optional[List[str]], optional
Expand Down Expand Up @@ -4148,6 +4165,7 @@ def set_auto_ablation_params(
delta_threshold_map=delta_threshold_map,
exact_prediction_features=exact_prediction_features,
min_num_cases=min_num_cases,
max_num_cases=max_num_cases,
reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold,
rel_threshold_map=rel_threshold_map,
relative_prediction_threshold_map=relative_prediction_threshold_map,
Expand Down
12 changes: 11 additions & 1 deletion howso/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Literal, Union

from pandas import DataFrame
from typing_extensions import Sequence, TypeAlias, TypedDict
from typing_extensions import NotRequired, Sequence, TypeAlias, TypedDict


class Cases(TypedDict):
Expand Down Expand Up @@ -37,6 +37,16 @@ class Evaluation(TypedDict):
"""A mapping of feature names to lists of values."""


class TrainStatus(TypedDict):
"""Representation of a status output from AbstractHowsoClient.train."""

needs_analyze: NotRequired[bool]
"""Indicates whether the Trainee needs an analyze."""

needs_data_reduction: NotRequired[bool]
"""Indicates whether the Trainee recommends a call to `reduce_data`."""


CaseIndices: TypeAlias = Sequence[tuple[str, int]]
"""Sequence of ``case_indices`` tuples."""

Expand Down
30 changes: 25 additions & 5 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
self._custom_save_path = None
self._calculated_matrices = {}
self._needs_analyze: bool = False
self._needs_data_reduction: bool = False

# Allow passing project id or the project instance
if isinstance(project, BaseProject):
Expand Down Expand Up @@ -311,6 +312,18 @@ def needs_analyze(self) -> bool:
"""
return self._needs_analyze

@property
def needs_data_reduction(self) -> bool:
"""
The flag indicating if the Trainee needs its data reduced.
Returns
-------
bool
A flag indicating if a call to `reduce_data` is recommended.
"""
return self._needs_data_reduction

@property
def calculated_matrices(self) -> dict[str, DataFrame] | None:
"""
Expand Down Expand Up @@ -628,6 +641,7 @@ def train(
progress_callback: t.Optional[Callable] = None,
series: t.Optional[str] = None,
skip_auto_analyze: bool = False,
skip_reduce_data: bool = False,
train_weights_only: bool = False,
validate: bool = True,
):
Expand Down Expand Up @@ -682,8 +696,9 @@ def train(
applied in order to each of the cases in the series.
skip_auto_analyze : bool, default False
When true, the Trainee will not auto-analyze when appropriate.
Instead, the 'needs_analyze' property of the Trainee will be
updated.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when
appropriate.
train_weights_only: bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -694,8 +709,7 @@ def train(
the data and the features dictionary.
"""
if isinstance(self.client, AbstractHowsoClient):
self._needs_analyze = False
needs_analyze = self.client.train(
status = self.client.train(
trainee_id=self.id,
accumulate_weight_feature=accumulate_weight_feature,
batch_size=batch_size,
Expand All @@ -707,10 +721,12 @@ def train(
progress_callback=progress_callback,
series=series,
skip_auto_analyze=skip_auto_analyze,
skip_reduce_data=skip_reduce_data,
train_weights_only=train_weights_only,
validate=validate,
)
self._needs_analyze = needs_analyze
self._needs_analyze = status.get('needs_analyze', False)
self._needs_data_reduction = status.get('needs_data_reduction', False)
else:
raise AssertionError("Client must have the 'train' method.")

Expand Down Expand Up @@ -755,6 +771,7 @@ def set_auto_ablation_params(
delta_threshold_map: AblationThresholdMap = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
min_num_cases: int = 1_000,
max_num_cases: int = 500_000,
reduce_data_influence_weight_entropy_threshold: float = 0.6,
rel_threshold_map: AblationThresholdMap = None,
relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None,
Expand Down Expand Up @@ -788,6 +805,8 @@ def set_auto_ablation_params(
to recompute influence weight entropy.
min_num_cases : int, default 1,000
The threshold ofr the minimum number of cases at which the model should auto-ablate.
max_num_cases: int, default 500,000
The threshold of the maximum number of cases at which the model should auto-reduce
exact_prediction_features : Collection of str, optional
For each of the features specified, will ablate a case if the prediction matches exactly.
residual_prediction_features : Collection of str, optional
Expand Down Expand Up @@ -838,6 +857,7 @@ def set_auto_ablation_params(
delta_threshold_map=delta_threshold_map,
exact_prediction_features=exact_prediction_features,
min_num_cases=min_num_cases,
max_num_cases=max_num_cases,
reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold,
rel_threshold_map=rel_threshold_map,
relative_prediction_threshold_map=relative_prediction_threshold_map,
Expand Down

0 comments on commit f436c0c

Please sign in to comment.