diff --git a/howso/client/base.py b/howso/client/base.py index 0e60c0cc..771d2131 100644 --- a/howso/client/base.py +++ b/howso/client/base.py @@ -553,12 +553,13 @@ def train( # noqa: C901 in cases is applied in order to each of the cases in the series. skip_auto_analyze : bool, default False When true, the Trainee will not auto-analyze when appropriate. - Instead, the return dict will have a `status` key set to - "analyze" if an analyze is needed. + Instead, the return dict will have a "needs_analyze" flag if an + `analyze` is needed. skip_reduce_data : bool, default False When true, the Trainee will not call `reduce_data` when - appropriate. Instead, the return dict will have a `status` key set - to "reduce_data" if a call to `reduce_data` is recommended. + appropriate. Instead, the return dict will have a + "needs_data_reduction" flag if a call to `reduce_data` is + recommended. train_weights_only : bool, default False When true, and accumulate_weight_feature is provided, will accumulate all of the cases' neighbor weights instead of @@ -571,8 +572,9 @@ def train( # noqa: C901 Returns ------- dict - A dict containing `status` and `details` if there are important - messages to share from the Engine. Otherwise, an empty dict. + A dict containing variable keys if there are important messages to + share from the Engine, such as 'needs_analyze` and + 'needs_data_reduction'. Otherwise, an empty dict. """ trainee_id = self._resolve_trainee(trainee_id).id feature_attributes = self.resolve_feature_attributes(trainee_id) @@ -656,11 +658,9 @@ def train( # noqa: C901 }) if response and response.get('status') == 'analyze': - status['status'] = 'analyze' - status['details'] = 'An analyze is strongly recommended for this trainee.' - elif response and response.get('status') == 'reduce_data': - status['status'] = 'reduce_data' - status['details'] = 'Data reduction via `reduce_data` is recommended for this trainee.' + status['needs_analyze'] = True + if response and response.get('status') == 'reduce_data': + status['needs_data_reduction'] = True if batch_scaler is None or gen_batch_size is None: progress.update(batch_size) diff --git a/howso/engine/trainee.py b/howso/engine/trainee.py index b5bdc0bb..6af27434 100644 --- a/howso/engine/trainee.py +++ b/howso/engine/trainee.py @@ -150,6 +150,7 @@ def __init__( self._custom_save_path = None self._calculated_matrices = {} self._needs_analyze: bool = False + self._needs_data_reduction: bool = False # Allow passing project id or the project instance if isinstance(project, BaseProject): @@ -311,6 +312,18 @@ def needs_analyze(self) -> bool: """ return self._needs_analyze + @property + def needs_data_reduction(self) -> bool: + """ + The flag indicating if the Trainee needs its data reduced. + + Returns + ------- + bool + A flag indicating if a call to `reduce_data` is recommended. + """ + return self._needs_data_reduction + @property def calculated_matrices(self) -> dict[str, DataFrame] | None: """ @@ -683,12 +696,9 @@ def train( applied in order to each of the cases in the series. skip_auto_analyze : bool, default False When true, the Trainee will not auto-analyze when appropriate. - Instead, the return dict will have a `status` key set to - "analyze" if an analyze is needed. skip_reduce_data : bool, default False When true, the Trainee will not call `reduce_data` when - appropriate. Instead, the return dict will have a `status` key set - to "reduce_data" if a call to `reduce_data` is recommended. + appropriate. train_weights_only: bool, default False When true, and accumulate_weight_feature is provided, will accumulate all of the cases' neighbor weights instead of @@ -715,7 +725,8 @@ def train( train_weights_only=train_weights_only, validate=validate, ) - self._needs_analyze = True if status.get('status') == 'analyze' else False + self._needs_analyze = status.get('needs_analyze', False) + self._needs_data_reduction = status.get('needs_data_reduction', False) else: raise AssertionError("Client must have the 'train' method.")