Skip to content

Commit

Permalink
20936: Updates set_auto_ablation_params, MAJOR (#256)
Browse files Browse the repository at this point in the history
  • Loading branch information
howsoRes authored Aug 7, 2024
1 parent 592730c commit 889f9a4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
20 changes: 10 additions & 10 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3697,7 +3697,6 @@ def set_auto_analyze_params(
analyze_threshold: t.Optional[int] = None,
*,
analysis_sub_model_size: t.Optional[int] = None,
auto_analyze_limit_size: t.Optional[int] = None,
analyze_growth_factor: t.Optional[float] = None,
action_features: t.Optional[Collection[str]] = None,
bypass_calculate_feature_residuals: t.Optional[bool] = None,
Expand Down Expand Up @@ -3730,9 +3729,6 @@ def set_auto_analyze_params(
analyze_threshold : int, optional
The threshold for the number of cases at which the model should be
re-analyzed.
auto_analyze_limit_size : int, optional
The size of of the model at which to stop doing auto-analysis.
Value of 0 means no limit.
analyze_growth_factor : float, optional
The factor by which to increase the analyze threshold every time
the model grows to the current threshold size.
Expand Down Expand Up @@ -3799,7 +3795,6 @@ def set_auto_analyze_params(
'auto_optimize_enabled': 'auto_analyze_enabled',
'optimize_threshold': 'analyze_threshold',
'optimize_growth_factor': 'analyze_growth_factor',
'auto_optimize_limit_size': 'auto_analyze_limit_size',
}
analyze_deprecated_params = {
'bypass_hyperparameter_optimization': 'bypass_hyperparameter_analysis',
Expand All @@ -3818,8 +3813,6 @@ def set_auto_analyze_params(
analyze_threshold = kwargs[old_param]
elif old_param == 'optimize_growth_factor':
analyze_growth_factor = kwargs[old_param]
elif old_param == 'auto_optimize_limit_size':
auto_analyze_limit_size = kwargs[old_param]

del kwargs[old_param]
warnings.warn(
Expand Down Expand Up @@ -3859,7 +3852,6 @@ def set_auto_analyze_params(
self.execute(trainee_id, "set_auto_analyze_params", {
"auto_analyze_enabled": auto_analyze_enabled,
"analyze_threshold": analyze_threshold,
"auto_analyze_limit_size": auto_analyze_limit_size,
"analyze_growth_factor": analyze_growth_factor,
"action_features": action_features,
"context_features": context_features,
Expand Down Expand Up @@ -3904,7 +3896,9 @@ def set_auto_ablation_params(
trainee_id: str,
auto_ablation_enabled: bool = False,
*,
ablated_cases_distribution_batch_size: int = 100,
auto_ablation_weight_feature: str = ".case_weight",
batch_size: int = 2_000,
conviction_lower_threshold: t.Optional[float] = None,
conviction_upper_threshold: t.Optional[float] = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
Expand Down Expand Up @@ -3932,8 +3926,13 @@ def set_auto_ablation_params(
The ID of the Trainee to set auto ablation parameters for.
auto_ablation_enabled : bool, default False
When True, the :meth:`train` method will ablate cases that meet the set criteria.
ablated_cases_distribution_batch_size: int, default 100
Number of cases in a batch to distribute ablated cases' influence weights.
auto_ablation_weight_feature : str, default ".case_weight"
The weight feature that should be accumulated to when cases are ablated.
batch_size: number, default 2,000
Number of cases in a batch to consider for ablation prior to training and
to recompute influence weight entropy.
minimum_model_size : int, default 1,000
The threshold of the minimum number of cases at which the model should auto-ablate.
influence_weight_entropy_threshold : float, default 0.6
Expand All @@ -3956,8 +3955,10 @@ def set_auto_ablation_params(
"""
trainee_id = self._resolve_trainee(trainee_id).id
params = dict(
ablated_cases_distribution_batch_size=ablated_cases_distribution_batch_size,
auto_ablation_enabled=auto_ablation_enabled,
auto_ablation_weight_feature=auto_ablation_weight_feature,
batch_size=batch_size,
minimum_model_size=minimum_model_size,
influence_weight_entropy_threshold=influence_weight_entropy_threshold,
exact_prediction_features=exact_prediction_features,
Expand Down Expand Up @@ -4869,8 +4870,7 @@ def set_params(self, trainee_id: str, params: Mapping):
deprecated_params = {
'auto_optimize_enabled': 'auto_analyze_enabled',
'optimize_threshold': 'analyze_threshold',
'optimize_growth_factor': 'analyze_growth_factor',
'auto_optimize_limit_size': 'auto_analyze_limit_size',
'optimize_growth_factor': 'analyze_growth_factor'
}

# replace any old params with new params and remove old param
Expand Down
17 changes: 10 additions & 7 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,9 @@ def set_auto_ablation_params(
self,
auto_ablation_enabled: bool = False,
*,
ablated_cases_distribution_batch_size: int = 100,
auto_ablation_weight_feature: str = ".case_weight",
batch_size: int = 2_000,
conviction_lower_threshold: t.Optional[float] = None,
conviction_upper_threshold: t.Optional[float] = None,
exact_prediction_features: t.Optional[Collection[str]] = None,
Expand All @@ -728,8 +730,13 @@ def set_auto_ablation_params(
----------
auto_ablation_enabled : bool, default False
When True, the :meth:`train` method will ablate cases that meet the set criteria.
ablated_cases_distribution_batch_size: int, default 100
Number of cases in a batch to distribute ablated cases' influence weights.
auto_ablation_weight_feature : str, default ".case_weight"
The weight feature that should be accumulated to when cases are ablated.
batch_size: number, default 2,000
Number of cases in a batch to consider for ablation prior to training and
to recompute influence weight entropy.
minimum_model_size : int, default 1,000
The threshold ofr the minimum number of cases at which the model should auto-ablate.
influence_weight_entropy_threshold : float, default 0.6
Expand All @@ -753,8 +760,10 @@ def set_auto_ablation_params(
if isinstance(self.client, AbstractHowsoClient):
self.client.set_auto_ablation_params(
trainee_id=self.id,
ablated_cases_distribution_batch_size=ablated_cases_distribution_batch_size,
auto_ablation_enabled=auto_ablation_enabled,
auto_ablation_weight_feature=auto_ablation_weight_feature,
batch_size=batch_size,
minimum_model_size=minimum_model_size,
influence_weight_entropy_threshold=influence_weight_entropy_threshold,
exact_prediction_features=exact_prediction_features,
Expand Down Expand Up @@ -826,7 +835,6 @@ def set_auto_analyze_params(
auto_analyze_enabled: bool = False,
analyze_threshold: t.Optional[int] = None,
*,
auto_analyze_limit_size: t.Optional[int] = None,
analyze_growth_factor: t.Optional[float] = None,
**kwargs,
) -> None:
Expand All @@ -847,9 +855,6 @@ def set_auto_analyze_params(
analyze_threshold : int, optional
The threshold for the number of cases at which the model should be
re-analyzed.
auto_analyze_limit_size : int, optional
The size of the model at which to stop doing auto-analysis. Value of
0 means no limit.
analyze_growth_factor : float, optional
The factor by which to increase the analysis threshold every
time the model grows to the current threshold size.
Expand All @@ -860,7 +865,6 @@ def set_auto_analyze_params(
self.client.set_auto_analyze_params(
trainee_id=self.id,
auto_analyze_enabled=auto_analyze_enabled,
auto_analyze_limit_size=auto_analyze_limit_size,
analyze_growth_factor=analyze_growth_factor,
analyze_threshold=analyze_threshold,
**kwargs,
Expand Down Expand Up @@ -3163,8 +3167,7 @@ def set_params(self, params: Mapping[str, t.Any]):
},
"auto_analyze_enabled": False,
"analyze_threshold": 100,
"analyze_growth_factor": 7.389,
"auto_analyze_limit_size": 100000
"analyze_growth_factor": 7.389
}
"""
if isinstance(self.client, AbstractHowsoClient):
Expand Down

0 comments on commit 889f9a4

Please sign in to comment.