Skip to content

Commit

Permalink
Adjust base train() return values
Browse files Browse the repository at this point in the history
  • Loading branch information
apbassett committed Nov 21, 2024
1 parent 5248f96 commit 77baf92
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
22 changes: 11 additions & 11 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,12 +553,13 @@ def train( # noqa: C901
in cases is applied in order to each of the cases in the series.
skip_auto_analyze : bool, default False
When true, the Trainee will not auto-analyze when appropriate.
Instead, the return dict will have a `status` key set to
"analyze" if an analyze is needed.
Instead, the return dict will have a "needs_analyze" flag if an
`analyze` is needed.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when
appropriate. Instead, the return dict will have a `status` key set
to "reduce_data" if a call to `reduce_data` is recommended.
appropriate. Instead, the return dict will have a
"needs_data_reduction" flag if a call to `reduce_data` is
recommended.
train_weights_only : bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -571,8 +572,9 @@ def train( # noqa: C901
Returns
-------
dict
A dict containing `status` and `details` if there are important
messages to share from the Engine. Otherwise, an empty dict.
A dict containing variable keys if there are important messages to
share from the Engine, such as 'needs_analyze` and
'needs_data_reduction'. Otherwise, an empty dict.
"""
trainee_id = self._resolve_trainee(trainee_id).id
feature_attributes = self.resolve_feature_attributes(trainee_id)
Expand Down Expand Up @@ -656,11 +658,9 @@ def train( # noqa: C901
})

if response and response.get('status') == 'analyze':
status['status'] = 'analyze'
status['details'] = 'An analyze is strongly recommended for this trainee.'
elif response and response.get('status') == 'reduce_data':
status['status'] = 'reduce_data'
status['details'] = 'Data reduction via `reduce_data` is recommended for this trainee.'
status['needs_analyze'] = True
if response and response.get('status') == 'reduce_data':
status['needs_data_reduction'] = True

if batch_scaler is None or gen_batch_size is None:
progress.update(batch_size)
Expand Down
21 changes: 16 additions & 5 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
self._custom_save_path = None
self._calculated_matrices = {}
self._needs_analyze: bool = False
self._needs_data_reduction: bool = False

# Allow passing project id or the project instance
if isinstance(project, BaseProject):
Expand Down Expand Up @@ -311,6 +312,18 @@ def needs_analyze(self) -> bool:
"""
return self._needs_analyze

@property
def needs_data_reduction(self) -> bool:
"""
The flag indicating if the Trainee needs its data reduced.
Returns
-------
bool
A flag indicating if a call to `reduce_data` is recommended.
"""
return self._needs_data_reduction

@property
def calculated_matrices(self) -> dict[str, DataFrame] | None:
"""
Expand Down Expand Up @@ -683,12 +696,9 @@ def train(
applied in order to each of the cases in the series.
skip_auto_analyze : bool, default False
When true, the Trainee will not auto-analyze when appropriate.
Instead, the return dict will have a `status` key set to
"analyze" if an analyze is needed.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when
appropriate. Instead, the return dict will have a `status` key set
to "reduce_data" if a call to `reduce_data` is recommended.
appropriate.
train_weights_only: bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -715,7 +725,8 @@ def train(
train_weights_only=train_weights_only,
validate=validate,
)
self._needs_analyze = True if status.get('status') == 'analyze' else False
self._needs_analyze = status.get('needs_analyze', False)
self._needs_data_reduction = status.get('needs_data_reduction', False)
else:
raise AssertionError("Client must have the 'train' method.")

Expand Down

0 comments on commit 77baf92

Please sign in to comment.