Skip to content

Commit

Permalink
Adds max_num_cases, skip_reduce_data params
Browse files Browse the repository at this point in the history
  • Loading branch information
apbassett committed Nov 21, 2024
1 parent cd77b90 commit bfc4a4e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
33 changes: 25 additions & 8 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,10 @@ def train( # noqa: C901
progress_callback: t.Optional[Callable] = None,
series: t.Optional[str] = None,
skip_auto_analyze: bool = False,
skip_reduce_data: bool = False,
train_weights_only: bool = False,
validate: bool = True,
):
) -> dict:
"""
Train one or more cases into a Trainee.
Expand Down Expand Up @@ -552,7 +553,12 @@ def train( # noqa: C901
in cases is applied in order to each of the cases in the series.
skip_auto_analyze : bool, default False
When true, the Trainee will not auto-analyze when appropriate.
Instead, the boolean response will be True if an analyze is needed.
Instead, the return dict will have an `output_status` key set to
"analyze" if an analyze is needed.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when appropriate.
Instead, the return dict will have an `output_status` key set to
"reduce_data" if a call to `reduce_data` is recommended.
train_weights_only : bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -564,9 +570,9 @@ def train( # noqa: C901
Returns
-------
bool
Flag indicating if the Trainee needs to analyze. Only true if
auto-analyze is enabled and the conditions are met.
dict
A dict containing `status` and `details` if there are important
messages to share from the Engine. Otherwise, an empty dict.
"""
trainee_id = self._resolve_trainee(trainee_id).id
feature_attributes = self.resolve_feature_attributes(trainee_id)
Expand Down Expand Up @@ -613,7 +619,7 @@ def train( # noqa: C901
features = internals.get_features_from_data(cases)
serialized_cases = serialize_cases(cases, features, feature_attributes, warn=True) or []

needs_analyze = False
status = {}

if self.configuration.verbose:
print(f'Training session(s) on Trainee with id: {trainee_id}')
Expand Down Expand Up @@ -645,10 +651,17 @@ def train( # noqa: C901
"series": series,
"session": self.active_session.id,
"skip_auto_analyze": skip_auto_analyze,
"skip_reduce_data": skip_reduce_data,
"train_weights_only": train_weights_only,
})

if response and response.get('status') == 'analyze':
needs_analyze = True
status['status'] = 'analyze'
status['details'] = 'An analyze is strongly recommended for this trainee.'
elif response and response.get('status') == 'reduce_data':
status['status'] = 'reduce_data'
status['details'] = 'Data reduction via `reduce_data` is recommended for this trainee.'

if batch_scaler is None or gen_batch_size is None:
progress.update(batch_size)
else:
Expand All @@ -663,7 +676,7 @@ def train( # noqa: C901
self._store_session(trainee_id, self.active_session)
self._auto_persist_trainee(trainee_id)

return needs_analyze
return status

def impute(
self,
Expand Down Expand Up @@ -4064,6 +4077,7 @@ def set_auto_ablation_params(
delta_threshold_map: AblationThresholdMap = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
min_num_cases: int = 1_000,
max_num_cases: int = 500_000,
reduce_data_influence_weight_entropy_threshold: float = 0.6,
rel_threshold_map: AblationThresholdMap = None,
relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None,
Expand Down Expand Up @@ -4099,6 +4113,8 @@ def set_auto_ablation_params(
to recompute influence weight entropy.
min_num_cases : int, default 1,000
The threshold of the minimum number of cases at which the model should auto-ablate.
max_num_cases: int, default 500,000
The threshold of the maximum number of cases at which the model should auto-reduce
exact_prediction_features : Optional[List[str]], optional
For each of the features specified, will ablate a case if the prediction matches exactly.
residual_prediction_features : Optional[List[str]], optional
Expand Down Expand Up @@ -4148,6 +4164,7 @@ def set_auto_ablation_params(
delta_threshold_map=delta_threshold_map,
exact_prediction_features=exact_prediction_features,
min_num_cases=min_num_cases,
max_num_cases=max_num_cases,
reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold,
rel_threshold_map=rel_threshold_map,
relative_prediction_threshold_map=relative_prediction_threshold_map,
Expand Down
15 changes: 12 additions & 3 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def train(
progress_callback: t.Optional[Callable] = None,
series: t.Optional[str] = None,
skip_auto_analyze: bool = False,
skip_reduce_data: bool = False,
train_weights_only: bool = False,
validate: bool = True,
):
Expand Down Expand Up @@ -684,6 +685,10 @@ def train(
When true, the Trainee will not auto-analyze when appropriate.
Instead, the 'needs_analyze' property of the Trainee will be
updated.
skip_reduce_data : bool, default False
When true, the Trainee will not call `reduce_data` when appropriate.
Instead, the return dict will have an `output_status` key set to
"reduce_data" if a call to `reduce_data` is recommended.
train_weights_only: bool, default False
When true, and accumulate_weight_feature is provided,
will accumulate all of the cases' neighbor weights instead of
Expand All @@ -694,8 +699,7 @@ def train(
the data and the features dictionary.
"""
if isinstance(self.client, AbstractHowsoClient):
self._needs_analyze = False
needs_analyze = self.client.train(
status = self.client.train(
trainee_id=self.id,
accumulate_weight_feature=accumulate_weight_feature,
batch_size=batch_size,
Expand All @@ -707,10 +711,11 @@ def train(
progress_callback=progress_callback,
series=series,
skip_auto_analyze=skip_auto_analyze,
skip_reduce_data=skip_reduce_data,
train_weights_only=train_weights_only,
validate=validate,
)
self._needs_analyze = needs_analyze
self._needs_analyze = True if status.get('status') == 'analyze' else False
else:
raise AssertionError("Client must have the 'train' method.")

Expand Down Expand Up @@ -755,6 +760,7 @@ def set_auto_ablation_params(
delta_threshold_map: AblationThresholdMap = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
min_num_cases: int = 1_000,
max_num_cases: int = 500_000,
reduce_data_influence_weight_entropy_threshold: float = 0.6,
rel_threshold_map: AblationThresholdMap = None,
relative_prediction_threshold_map: t.Optional[Mapping[str, float]] = None,
Expand Down Expand Up @@ -788,6 +794,8 @@ def set_auto_ablation_params(
to recompute influence weight entropy.
min_num_cases : int, default 1,000
The threshold ofr the minimum number of cases at which the model should auto-ablate.
max_num_cases: int, default 500,000
The threshold of the maximum number of cases at which the model should auto-reduce
exact_prediction_features : Collection of str, optional
For each of the features specified, will ablate a case if the prediction matches exactly.
residual_prediction_features : Collection of str, optional
Expand Down Expand Up @@ -838,6 +846,7 @@ def set_auto_ablation_params(
delta_threshold_map=delta_threshold_map,
exact_prediction_features=exact_prediction_features,
min_num_cases=min_num_cases,
max_num_cases=max_num_cases,
reduce_data_influence_weight_entropy_threshold=reduce_data_influence_weight_entropy_threshold,
rel_threshold_map=rel_threshold_map,
relative_prediction_threshold_map=relative_prediction_threshold_map,
Expand Down

0 comments on commit bfc4a4e

Please sign in to comment.