diff --git a/howso/client/base.py b/howso/client/base.py index 0a9c2ff7..d4a7527b 100644 --- a/howso/client/base.py +++ b/howso/client/base.py @@ -55,6 +55,7 @@ Persistence, Precision, SeriesIDTracking, + TrainStatus, TabularData2D, TabularData3D, TargetedModel, @@ -491,9 +492,10 @@ def train( # noqa: C901 progress_callback: t.Optional[Callable] = None, series: t.Optional[str] = None, skip_auto_analyze: bool = False, + skip_reduce_data: bool = False, train_weights_only: bool = False, validate: bool = True, - ): + ) -> TrainStatus: """ Train one or more cases into a Trainee. @@ -552,7 +554,13 @@ def train( # noqa: C901 in cases is applied in order to each of the cases in the series. skip_auto_analyze : bool, default False When true, the Trainee will not auto-analyze when appropriate. - Instead, the boolean response will be True if an analyze is needed. + Instead, the return dict will have a "needs_analyze" flag if an + `analyze` is needed. + skip_reduce_data : bool, default False + When true, the Trainee will not call `reduce_data` when + appropriate. Instead, the return dict will have a + "needs_data_reduction" flag if a call to `reduce_data` is + recommended. train_weights_only : bool, default False When true, and accumulate_weight_feature is provided, will accumulate all of the cases' neighbor weights instead of @@ -564,9 +572,10 @@ def train( # noqa: C901 Returns ------- - bool - Flag indicating if the Trainee needs to analyze. Only true if - auto-analyze is enabled and the conditions are met. + dict + A dict containing variable keys if there are important messages to + share from the Engine, such as 'needs_analyze` and + 'needs_data_reduction'. Otherwise, an empty dict. """ trainee_id = self._resolve_trainee(trainee_id).id feature_attributes = self.resolve_feature_attributes(trainee_id) @@ -613,7 +622,7 @@ def train( # noqa: C901 features = internals.get_features_from_data(cases) serialized_cases = serialize_cases(cases, features, feature_attributes, warn=True) or [] - needs_analyze = False + status = {} if self.configuration.verbose: print(f'Training session(s) on Trainee with id: {trainee_id}') @@ -645,10 +654,15 @@ def train( # noqa: C901 "series": series, "session": self.active_session.id, "skip_auto_analyze": skip_auto_analyze, + "skip_reduce_data": skip_reduce_data, "train_weights_only": train_weights_only, }) + if response and response.get('status') == 'analyze': - needs_analyze = True + status['needs_analyze'] = True + if response and response.get('status') == 'reduce_data': + status['needs_data_reduction'] = True + if batch_scaler is None or gen_batch_size is None: progress.update(batch_size) else: @@ -663,7 +677,7 @@ def train( # noqa: C901 self._store_session(trainee_id, self.active_session) self._auto_persist_trainee(trainee_id) - return needs_analyze + return status def impute( self, @@ -4064,6 +4078,7 @@ def set_auto_ablation_params( delta_threshold_map: AblationThresholdMap = None, exact_prediction_features: t.Optional[Collection[str]] = None, min_num_cases: int = 1_000, + max_num_cases: int = 500_000, reduce_data_influence_weight_entropy_threshold: float = 0.6, rel_threshold_map: AblationThresholdMap = None, relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None, @@ -4099,6 +4114,8 @@ def set_auto_ablation_params( to recompute influence weight entropy. min_num_cases : int, default 1,000 The threshold of the minimum number of cases at which the model should auto-ablate. + max_num_cases: int, default 500,000 + The threshold of the maximum number of cases at which the model should auto-reduce exact_prediction_features : Optional[List[str]], optional For each of the features specified, will ablate a case if the prediction matches exactly. residual_prediction_features : Optional[List[str]], optional @@ -4148,6 +4165,7 @@ def set_auto_ablation_params( delta_threshold_map=delta_threshold_map, exact_prediction_features=exact_prediction_features, min_num_cases=min_num_cases, + max_num_cases=max_num_cases, reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold, rel_threshold_map=rel_threshold_map, relative_prediction_threshold_map=relative_prediction_threshold_map, diff --git a/howso/client/typing.py b/howso/client/typing.py index 797231d3..1135f72f 100644 --- a/howso/client/typing.py +++ b/howso/client/typing.py @@ -4,7 +4,7 @@ from typing import Any, Literal, Union from pandas import DataFrame -from typing_extensions import Sequence, TypeAlias, TypedDict +from typing_extensions import NotRequired, Sequence, TypeAlias, TypedDict class Cases(TypedDict): @@ -37,6 +37,16 @@ class Evaluation(TypedDict): """A mapping of feature names to lists of values.""" +class TrainStatus(TypedDict): + """Representation of a status output from AbstractHowsoClient.train.""" + + needs_analyze: NotRequired[bool] + """Indicates whether the Trainee needs an analyze.""" + + needs_data_reduction: NotRequired[bool] + """Indicates whether the Trainee recommends a call to `reduce_data`.""" + + CaseIndices: TypeAlias = Sequence[tuple[str, int]] """Sequence of ``case_indices`` tuples.""" diff --git a/howso/engine/trainee.py b/howso/engine/trainee.py index 5ce3ba74..6af27434 100644 --- a/howso/engine/trainee.py +++ b/howso/engine/trainee.py @@ -150,6 +150,7 @@ def __init__( self._custom_save_path = None self._calculated_matrices = {} self._needs_analyze: bool = False + self._needs_data_reduction: bool = False # Allow passing project id or the project instance if isinstance(project, BaseProject): @@ -311,6 +312,18 @@ def needs_analyze(self) -> bool: """ return self._needs_analyze + @property + def needs_data_reduction(self) -> bool: + """ + The flag indicating if the Trainee needs its data reduced. + + Returns + ------- + bool + A flag indicating if a call to `reduce_data` is recommended. + """ + return self._needs_data_reduction + @property def calculated_matrices(self) -> dict[str, DataFrame] | None: """ @@ -628,6 +641,7 @@ def train( progress_callback: t.Optional[Callable] = None, series: t.Optional[str] = None, skip_auto_analyze: bool = False, + skip_reduce_data: bool = False, train_weights_only: bool = False, validate: bool = True, ): @@ -682,8 +696,9 @@ def train( applied in order to each of the cases in the series. skip_auto_analyze : bool, default False When true, the Trainee will not auto-analyze when appropriate. - Instead, the 'needs_analyze' property of the Trainee will be - updated. + skip_reduce_data : bool, default False + When true, the Trainee will not call `reduce_data` when + appropriate. train_weights_only: bool, default False When true, and accumulate_weight_feature is provided, will accumulate all of the cases' neighbor weights instead of @@ -694,8 +709,7 @@ def train( the data and the features dictionary. """ if isinstance(self.client, AbstractHowsoClient): - self._needs_analyze = False - needs_analyze = self.client.train( + status = self.client.train( trainee_id=self.id, accumulate_weight_feature=accumulate_weight_feature, batch_size=batch_size, @@ -707,10 +721,12 @@ def train( progress_callback=progress_callback, series=series, skip_auto_analyze=skip_auto_analyze, + skip_reduce_data=skip_reduce_data, train_weights_only=train_weights_only, validate=validate, ) - self._needs_analyze = needs_analyze + self._needs_analyze = status.get('needs_analyze', False) + self._needs_data_reduction = status.get('needs_data_reduction', False) else: raise AssertionError("Client must have the 'train' method.") @@ -755,6 +771,7 @@ def set_auto_ablation_params( delta_threshold_map: AblationThresholdMap = None, exact_prediction_features: t.Optional[Collection[str]] = None, min_num_cases: int = 1_000, + max_num_cases: int = 500_000, reduce_data_influence_weight_entropy_threshold: float = 0.6, rel_threshold_map: AblationThresholdMap = None, relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None, @@ -788,6 +805,8 @@ def set_auto_ablation_params( to recompute influence weight entropy. min_num_cases : int, default 1,000 The threshold ofr the minimum number of cases at which the model should auto-ablate. + max_num_cases: int, default 500,000 + The threshold of the maximum number of cases at which the model should auto-reduce exact_prediction_features : Collection of str, optional For each of the features specified, will ablate a case if the prediction matches exactly. residual_prediction_features : Collection of str, optional @@ -838,6 +857,7 @@ def set_auto_ablation_params( delta_threshold_map=delta_threshold_map, exact_prediction_features=exact_prediction_features, min_num_cases=min_num_cases, + max_num_cases=max_num_cases, reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold, rel_threshold_map=rel_threshold_map, relative_prediction_threshold_map=relative_prediction_threshold_map,