From b24812ff74d1353a2d56d3cffb86952298836f04 Mon Sep 17 00:00:00 2001 From: Aaron Defazio Date: Thu, 5 Sep 2024 15:20:04 +0000 Subject: [PATCH 01/37] BN Fixes --- algorithmic_efficiency/pytorch_utils.py | 14 ++++++++------ .../librispeech_pytorch/models.py | 11 ++++++----- .../librispeech_pytorch/models.py | 10 +++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 4f6c254bd..2e5828912 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,7 +57,6 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() - def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -67,10 +66,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer, ) if isinstance(module, bn_layers): if not update_batch_norm: - module.eval() - module.momentum_backup = module.momentum + if not hasattr(module, 'momentum_backup'): + module.momentum_backup = module.momentum + # module.momentum can be float or torch.Tensor. - module.momentum = 0. * module.momentum_backup + if torch.is_tensor(module.momentum_backup): + module.momentum = torch.zeros_like(module.momentum_backup) + else: + module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup - module.track_running_stats = update_batch_norm + module.momentum = module.momentum_backup \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 502cb093e..cab73df4a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -40,7 +40,7 @@ class ConformerConfig: time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True input_dropout_rate: float = 0.1 - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True attention_temperature: float = 1.0 @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings): mean = (masked_inp).sum(dim=(0, 1)) / count var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: mean = self.running_mean var = self.running_var diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a5ee3fa0a..bdf556f1c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -36,7 +36,7 @@ class DeepspeechConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 # If None, defaults to 0.1. input_dropout_rate: Optional[float] = 0.1 @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings): sum_ = dist_nn.all_reduce(sum_) var = sum_ / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() else: mean = self.running_mean var = self.running_var From e6c2106c2460d0149235dd4eccfd4017b0952734 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Sep 2024 15:30:15 +0200 Subject: [PATCH 02/37] added prepare_for_eval, eval only if is_time_remaining --- algorithmic_efficiency/spec.py | 14 +- submission_runner.py | 203 ++++++++++++++++------------- submissions/template/submission.py | 20 +++ 3 files changed, 149 insertions(+), 88 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..792093a2e 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,7 +406,19 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] - +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..5df1f05ff 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,6 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, + prepare_for_eval: spec.PrepareForEvalFn, hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, @@ -335,7 +336,7 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -370,101 +371,128 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) - # Use 3x the runtime budget for the self-tuning ruleset. - max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 3 * workload.max_allowed_runtime_sec) - train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): - with profiler.profile('Evaluation'): + + # Prepare for evaluation (timed). + with profiler.profile('Prepare for eval'): del batch - _reset_cuda_mem() - - try: - eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) - # Check if targets reached. - # Note that this is one of the stopping conditions for the length of - # a training run. To score the run we only consider the time - # to validation target retrospectively. - train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) - train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) - goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - # Save last eval time. - eval_end_time = get_time() - train_state['last_eval_time'] = eval_end_time - - # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time - - # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time - latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] - latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] - latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] - time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') - eval_results.append((global_step, latest_eval_result)) - - logging_start_time = get_time() - - if log_dir is not None and RANK == 0: - metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, - ) - if save_checkpoints: - checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) + + # Check if time is remaining, + # use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) + train_state['is_time_remaining'] = ( + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) - logging_end_time = get_time() - train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + # Eval if time is remaining (untimed). + if train_state['is_time_remaining']: + with profiler.profile('Evaluation'): _reset_cuda_mem() - except RuntimeError as e: - logging.exception(f'Eval step {global_step} error.\n') - if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + try: + eval_start_time = get_time() + latest_eval_result = workload.eval_model(global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) + # Check if targets reached. + # Note that this is one of the stopping conditions for the length of + # a training run. To score the run we only consider the time + # to validation target retrospectively. + train_state['validation_goal_reached'] = ( + workload.has_reached_validation_target(latest_eval_result) or + train_state['validation_goal_reached']) + train_state['test_goal_reached'] = ( + workload.has_reached_test_target(latest_eval_result) or + train_state['test_goal_reached']) + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) + # Save last eval time. + eval_end_time = get_time() + train_state['last_eval_time'] = eval_end_time + + # Accumulate eval time. + train_state[ + 'accumulated_eval_time'] += eval_end_time - eval_start_time + + # Add times to eval results for logging. + latest_eval_result['score'] = ( + train_state['accumulated_submission_time']) + latest_eval_result[ + 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['accumulated_submission_time'] = train_state[ + 'accumulated_submission_time'] + latest_eval_result['accumulated_eval_time'] = train_state[ + 'accumulated_eval_time'] + latest_eval_result['accumulated_logging_time'] = train_state[ + 'accumulated_logging_time'] + time_since_start = latest_eval_result['total_duration'] + logging.info(f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}') + eval_results.append((global_step, latest_eval_result)) + + logging_start_time = get_time() + + if log_dir is not None and RANK == 0: + metrics_logger.append_scalar_metrics( + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, + ) + if save_checkpoints: + checkpoint_utils.save_checkpoint( + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS + .save_intermediate_checkpoints) + + logging_end_time = get_time() + train_state['accumulated_logging_time'] += ( + logging_end_time - logging_start_time) + _reset_cuda_mem() + except RuntimeError as e: + logging.exception(f'Eval step {global_step} error.\n') + if 'out of memory' in str(e): + logging.warning('Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') + _reset_cuda_mem() + train_state['last_step_end_time'] = get_time() metrics = {'eval_results': eval_results, 'global_step': global_step} @@ -518,6 +546,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection + prepare_for_eval = submission_module.prepare_for_eval try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: @@ -589,7 +618,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..848d8af44 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -42,6 +42,26 @@ def update_params(workload: spec.Workload, pass +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + # batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state + """ + pass + + def get_batch_size(workload_name): """ Gets batch size for workload. From 8bad99d663f34ce4b6b6c4a2a40b828e19fc3a5b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sat, 14 Sep 2024 18:11:43 +0200 Subject: [PATCH 03/37] added prepare_for_eval to all submissions --- .../external_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../cifar/cifar_jax/submission.py | 25 +++++++++++++++++-- .../cifar/cifar_pytorch/submission.py | 21 ++++++++++++++++ .../mnist/mnist_jax/submission.py | 21 ++++++++++++++++ .../mnist/mnist_pytorch/submission.py | 21 ++++++++++++++++ .../adafactor/jax/submission.py | 21 ++++++++++++++++ .../adafactor/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/adamw/jax/submission.py | 21 ++++++++++++++++ .../adamw/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/lamb/jax/submission.py | 21 ++++++++++++++++ .../lamb/pytorch/submission.py | 21 ++++++++++++++++ .../momentum/jax/submission.py | 21 ++++++++++++++++ .../momentum/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/nadamw/jax/submission.py | 21 ++++++++++++++++ .../nadamw/pytorch/submission.py | 21 ++++++++++++++++ .../nesterov/jax/submission.py | 21 ++++++++++++++++ .../nesterov/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/jax/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/shampoo/jax/submission.py | 21 ++++++++++++++++ .../jax_submission_base.py | 21 ++++++++++++++++ .../pytorch_submission_base.py | 21 ++++++++++++++++ 29 files changed, 611 insertions(+), 2 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..5f203c5c6 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..32f4e830e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..ba56cd99f 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..e2c44d9c1 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..502b7e5b4 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..8bc2eed95 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..bbf548ccb 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..992f769f3 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..b2256fc5a 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -108,8 +108,6 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, @@ -134,6 +132,29 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. def data_selection(workload: spec.Workload, input_queue: Iterator[Dict[str, spec.Tensor]], optimizer_state: spec.OptimizerState, diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..b55c31afc 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -96,6 +96,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..f09886215 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -106,6 +106,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), updated_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..8b5151c77 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -72,6 +72,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..ed2ee371f 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..5f6540020 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -265,6 +265,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..5d2107ba6 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..2b42bb5a4 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -125,6 +125,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..e08d5b433 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -165,6 +165,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..da5865087 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -258,6 +258,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..1ab362dd6 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..999321bd5 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..5f203c5c6 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..ba56cd99f 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..20109a9e3 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..b4b8b77af 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..9f12c4f3f 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -244,6 +244,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..cf5e49f4f 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -216,6 +216,27 @@ def _loss_fn(params, update_batch_norm=True): return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..b596f0bdc 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -160,6 +160,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..31e8a8850 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -109,3 +109,24 @@ def update_params(workload: spec.Workload, 'grad_norm': grad_norm[0], }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..549d2dc58 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -89,3 +89,24 @@ def update_params(workload: spec.Workload, grad_norm.item()) return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) From 1c7d51c0eb2cf64295f030b6ef0566bcd24b01cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:08:48 +0200 Subject: [PATCH 04/37] fix formatting --- submission_runner.py | 25 ++++++++++++++----------- submissions/template/submission.py | 1 - 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 5df1f05ff..a711be9ac 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -336,7 +336,8 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) + data_select_rng, update_rng, prep_eval_rng, eval_rng = \ + prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -414,12 +415,12 @@ def train_once( try: eval_start_time = get_time() latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) # Check if targets reached. # Note that this is one of the stopping conditions for the length of # a training run. To score the run we only consider the time @@ -454,7 +455,7 @@ def train_once( 'accumulated_logging_time'] time_since_start = latest_eval_result['total_duration'] logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') + f'\tStep: {global_step}, \t{latest_eval_result}') eval_results.append((global_step, latest_eval_result)) logging_start_time = get_time() @@ -489,8 +490,9 @@ def train_once( except RuntimeError as e: logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + logging.warning( + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') _reset_cuda_mem() train_state['last_step_end_time'] = get_time() @@ -618,7 +620,8 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, prepare_for_eval, + update_params, data_selection, + prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 848d8af44..445e1f7cd 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -47,7 +47,6 @@ def prepare_for_eval(workload: spec.Workload, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, - # batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], From 21a580b56c1f19cff11b13b62d4fceb1dc003f29 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:18:03 +0200 Subject: [PATCH 05/37] fix formatting --- algorithmic_efficiency/spec.py | 3 ++- submission_runner.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 792093a2e..25bd7b6d0 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -418,7 +418,8 @@ def init_optimizer_state(workload: Workload, int, RandomState ], - UpdateReturn] + UpdateReturn] + # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index a711be9ac..632cb450b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -620,7 +620,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, From 420b583f8bd60ca13b6b7cf9a7d0b8211d5c904b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 12:48:13 +0200 Subject: [PATCH 06/37] updated documentation --- DOCUMENTATION.md | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..586e03d8c 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload). Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally. -The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. +The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. ##### Fixed functions @@ -218,9 +218,35 @@ def update_params( - Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques. - The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort. - Cannot replace the model parameters with pre-trained ones. -- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. - Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`. + +###### Prepare for evaluation function + +```python +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState +) -> (updated_optimizer_state, updated_variables, updated_model_state) +``` + +- Arguments are the same of `update_param`, with the only exception of `batch`. +- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section). + - The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time. + - The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call. +- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. +- Allowed to update model state and model parameters. +- Allowed to update state for the optimizer. +- Cannot replace the model parameters with pre-trained ones. + ###### Data selection ```python @@ -250,7 +276,8 @@ def data_selection( In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms. -Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. +Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. +The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. #### Valid submissions From e09bbf594150dae74b186ee354daa23d3f29de25 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 17:39:20 +0200 Subject: [PATCH 07/37] add `train_state` to all instances of `update_params', passing it by (shallow) copy in submission_runner --- DOCUMENTATION.md | 1 + algorithmic_efficiency/spec.py | 2 ++ .../external_tuning/jax_nadamw_full_budget.py | 2 ++ .../external_tuning/jax_nadamw_target_setting.py | 2 ++ .../external_tuning/pytorch_nadamw_full_budget.py | 4 +++- .../external_tuning/pytorch_nadamw_target_setting.py | 4 +++- .../self_tuning/jax_nadamw_full_budget.py | 2 ++ .../self_tuning/jax_nadamw_target_setting.py | 2 ++ .../self_tuning/pytorch_nadamw_full_budget.py | 4 +++- .../self_tuning/pytorch_nadamw_target_setting.py | 4 +++- .../development_algorithms/cifar/cifar_jax/submission.py | 4 +++- .../development_algorithms/cifar/cifar_pytorch/submission.py | 4 +++- .../development_algorithms/mnist/mnist_jax/submission.py | 4 +++- .../development_algorithms/mnist/mnist_pytorch/submission.py | 4 +++- .../paper_baselines/adafactor/jax/submission.py | 4 +++- .../paper_baselines/adafactor/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/adamw/jax/submission.py | 4 +++- .../paper_baselines/adamw/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/lamb/jax/submission.py | 4 +++- .../paper_baselines/lamb/pytorch/submission.py | 4 +++- .../paper_baselines/momentum/jax/submission.py | 4 +++- .../paper_baselines/momentum/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/nadamw/jax/submission.py | 2 ++ .../paper_baselines/nadamw/pytorch/submission.py | 4 +++- .../paper_baselines/nesterov/jax/submission.py | 4 +++- .../paper_baselines/nesterov/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/sam/jax/submission.py | 4 +++- .../paper_baselines/sam/pytorch/submission.py | 4 +++- .../paper_baselines/shampoo/jax/submission.py | 4 +++- .../target_setting_algorithms/jax_submission_base.py | 4 +++- .../target_setting_algorithms/pytorch_submission_base.py | 4 +++- submission_runner.py | 1 + submissions/template/submission.py | 3 ++- 33 files changed, 88 insertions(+), 25 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..8207691d6 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -199,6 +199,7 @@ def update_params( batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..7a16f0040 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -401,6 +401,7 @@ def init_optimizer_state(workload: Workload, Dict[str, Tensor], LossType, OptimizerState, + Dict[str, Any], List[Tuple[int, float]], int, RandomState @@ -422,6 +423,7 @@ def update_params(workload: Workload, batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState) -> UpdateReturn: diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..63cf25fe5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..ab0ee82b1 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..c85cc6dd3 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..bb1278911 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..f6ada3c8e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..9c7f66c43 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..2af6d548a 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..2e2385e29 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..89aeac238 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,6 +118,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -125,6 +126,7 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del global_step + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..bcdab6fc3 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -61,6 +61,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -68,6 +69,7 @@ def update_params(workload: spec.Workload, del current_params_types del hyperparameters del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..01a266eaf 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -83,12 +83,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del global_step diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..e72a9d823 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any import torch @@ -40,6 +40,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -47,6 +48,7 @@ def update_params(workload: spec.Workload, del hyperparameters del loss_type del current_params_types + del train_state del eval_results del global_step diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..ea440cce7 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..30d6942e4 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -198,12 +198,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..935d0d0ca 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..ddd17b3b2 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -59,12 +59,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..3944d6483 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -126,12 +126,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..20bc80d23 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -197,12 +197,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..b2db0c728 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -152,12 +152,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..533e9fed4 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import optax @@ -75,12 +75,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..63cf25fe5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..c85cc6dd3 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..f79bc34b4 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -152,12 +152,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..330e344c1 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import optax @@ -75,12 +75,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..5448ff1f2 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Any from flax import jax_utils import jax @@ -205,12 +205,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..967d53549 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -139,12 +139,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..104ae3ce3 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -121,12 +121,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..e66b1ab23 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any import jax from jax import lax @@ -77,12 +77,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..c031f3ac4 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any from absl import logging import torch @@ -20,12 +20,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..aef7fafb0 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -357,6 +357,7 @@ def train_once( batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state=train_state.copy(), eval_results=eval_results, global_step=global_step, rng=update_rng) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..fb9b1cad1 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from algorithmic_efficiency import spec @@ -30,6 +30,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: From f15e227ee96ca181d670d6dd06bada647986c9ee Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 17:44:41 +0200 Subject: [PATCH 08/37] update DOCS --- DOCUMENTATION.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 8207691d6..8722a441e 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -213,6 +213,7 @@ def update_params( - The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used. - Allowed to update state for the optimizer. - Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state). +- The submission can access the elapsed training time and get further information about the evaluation through `train_state`. - The submission can access the target evaluation metric via the `workload` variable. - **A call to this function will be considered a step** - The time between a call to this function and the next call to this function will be considered the per-step time. From 107c6b6e2ad3312d77ad4e99034f17a31f2967c6 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 18:05:01 +0200 Subject: [PATCH 09/37] update test --- tests/reference_algorithm_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..938c4fa11 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -471,6 +471,7 @@ def _test_submission(workload_name, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state={}, eval_results=[], global_step=global_step, rng=update_rng) From d4ad0eb06df5f323a4d383e70715eeb99181d294 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:02:05 +0200 Subject: [PATCH 10/37] fix linting --- reference_algorithms/paper_baselines/momentum/jax/submission.py | 2 +- .../paper_baselines/momentum/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/nesterov/jax/submission.py | 2 +- .../paper_baselines/nesterov/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/sam/pytorch/submission.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index b2db0c728..dc101896b 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 533e9fed4..52aba82bf 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f79bc34b4..e47c7fa0c 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 330e344c1..442949866 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 967d53549..15b6b6858 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import torch From cb7e162230d6ca3849a183435d05c6f802498de1 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:03:42 +0200 Subject: [PATCH 11/37] fix isort --- .../external_tuning/pytorch_nadamw_full_budget.py | 2 +- .../external_tuning/pytorch_nadamw_target_setting.py | 2 +- .../self_tuning/pytorch_nadamw_full_budget.py | 2 +- .../self_tuning/pytorch_nadamw_target_setting.py | 2 +- .../development_algorithms/cifar/cifar_jax/submission.py | 2 +- .../development_algorithms/cifar/cifar_pytorch/submission.py | 2 +- .../development_algorithms/mnist/mnist_jax/submission.py | 2 +- .../development_algorithms/mnist/mnist_pytorch/submission.py | 2 +- .../paper_baselines/adafactor/jax/submission.py | 2 +- .../paper_baselines/adafactor/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/adamw/jax/submission.py | 2 +- .../paper_baselines/adamw/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/lamb/jax/submission.py | 2 +- reference_algorithms/paper_baselines/lamb/pytorch/submission.py | 2 +- .../paper_baselines/nadamw/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/shampoo/jax/submission.py | 2 +- submissions/template/submission.py | 2 +- 17 files changed, 17 insertions(+), 17 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index c85cc6dd3..72a3bf289 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index bb1278911..934538b63 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 2af6d548a..f968d4abf 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 2e2385e29..14c22141c 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 89aeac238..7e41e9fd7 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index bcdab6fc3..81110bae6 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 01a266eaf..3f75c9904 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index e72a9d823..d326f4035 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple import torch diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index ea440cce7..39cf3d4f9 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 30d6942e4..880f9168d 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 935d0d0ca..06eeacb39 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index ddd17b3b2..0710fb9a0 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 3944d6483..891da63be 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 20bc80d23..7886dc75d 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index c85cc6dd3..72a3bf289 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 104ae3ce3..e853a821b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/submissions/template/submission.py b/submissions/template/submission.py index fb9b1cad1..9bfb23367 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from algorithmic_efficiency import spec From 1f59285fa1ae8eb8e4cce10cba4db486bf49f8e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:05:43 +0200 Subject: [PATCH 12/37] fix import sort --- reference_algorithms/paper_baselines/sam/jax/submission.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 2 +- .../target_setting_algorithms/pytorch_submission_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5448ff1f2..95bea68aa 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index e66b1ab23..a98d134fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index c031f3ac4..586429e37 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple from absl import logging import torch From f574bf04dda725f790bb6ffaf2ca62b260b132d8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 01:01:14 +0000 Subject: [PATCH 13/37] add use_running_average_bn arg for jax --- .../workloads/cifar/cifar_jax/models.py | 9 +++- .../workloads/cifar/cifar_jax/workload.py | 9 ++-- .../imagenet_resnet/imagenet_jax/models.py | 9 +++- .../imagenet_resnet/imagenet_jax/workload.py | 9 ++-- .../librispeech_jax/models.py | 49 ++++++++++++------- .../librispeech_jax/workload.py | 9 ++-- 6 files changed, 63 insertions(+), 31 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..09338ca82 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..019dde38c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +120,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..2e680cbd9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..46168c2a0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +158,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..077ff0f89 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,7 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +462,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train): keepdims=True) var = sum_vv / count_v - - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value - + + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +588,7 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +631,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..6c55acfb0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +119,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +128,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue( From 7ca8365a7aba462181736b8d39382162c9bb1ad6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:01:58 +0000 Subject: [PATCH 14/37] formatting --- algorithmic_efficiency/pytorch_utils.py | 3 +- .../workloads/cifar/cifar_jax/models.py | 6 +-- .../workloads/cifar/cifar_jax/workload.py | 3 +- .../imagenet_resnet/imagenet_jax/models.py | 4 +- .../imagenet_resnet/imagenet_jax/workload.py | 3 +- .../librispeech_jax/models.py | 49 +++++++++++++------ .../librispeech_jax/workload.py | 3 +- .../librispeech_pytorch/models.py | 2 +- 8 files changed, 49 insertions(+), 24 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 2e5828912..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() + def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer, else: module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup \ No newline at end of file + module.momentum = module.momentum_backup diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 09338ca82..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -31,10 +31,10 @@ def __call__(self, update_batch_norm: bool = True, use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - - # Preserve default behavior for backwards compatibility + + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 019dde38c..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -111,7 +111,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 2e680cbd9..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -88,9 +88,9 @@ def __call__(self, use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - # Preserve default behavior for backwards compatibility + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 46168c2a0..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -149,7 +149,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 077ff0f89..db92f56d4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,7 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if use_running_average_bn: + if use_running_average_bn: mean = self.ra_mean.value var = self.ra_var.value @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag keepdims=True) var = sum_vv / count_v - + if update_batch_norm: self.ra_mean.value = momentum * \ self.ra_mean.value + (1 - momentum) * mean self.ra_var.value = momentum * \ self.ra_var.value + (1 - momentum) * var - + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -631,12 +648,12 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs @@ -673,7 +690,11 @@ def __call__(self, # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 6c55acfb0..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -108,7 +108,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index cab73df4a..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings): self.momentum) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( self.momentum) * var.detach() - + else: mean = self.running_mean var = self.running_var From 39132387e411a5e869a98c5b57fbd1e7b0d12194 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:37:04 +0000 Subject: [PATCH 15/37] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db92f56d4..a7f786c32 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,7 +616,12 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) From baac0a452871ef5a940c07dbfd64f6d3b9c5427d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:26:42 +0000 Subject: [PATCH 16/37] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index a7f786c32..cb6287c5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,10 +616,10 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, use_running_average ) From 087fd5c1e8400d1ab162cbf79e0fad6a828dae5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:58:10 +0000 Subject: [PATCH 17/37] debugging --- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..3caf151ab 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,6 +113,7 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN + print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, From c5c36c291f2c2a5a21bc0b60961a7016039e93ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:30:40 +0000 Subject: [PATCH 18/37] add seperate model_fn for deepspeech jax without use_running_average_bn --- .../librispeech_jax/workload.py | 1 - .../librispeech_jax/workload.py | 31 +++++++++++++++++++ tests/reference_algorithm_tests.py | 1 + 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 3caf151ab..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,7 +113,6 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN - print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8473fac0f..c81b1b0b4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,6 +55,37 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..5e563d2f9 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -408,6 +408,7 @@ def _test_submission(workload_name, workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) + print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) From 783aab4a3c2952823290c3e3881b0e423231a2ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:32:36 +0000 Subject: [PATCH 19/37] fix syntax error --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c81b1b0b4..05fdf90e7 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -56,7 +56,7 @@ def init_model_fn( params = jax_utils.replicate(params) return params, model_state - def model_fn( + def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], From 28e7e21a334001f3e62ce15875f3126af0affbd6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:37:29 +0000 Subject: [PATCH 20/37] fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 05fdf90e7..2d46960ed 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional +from typing import Optional, Dict, Tuple from flax import jax_utils import jax From b063f9f7fa5288736c006ff111454e77869ada8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:43:44 +0000 Subject: [PATCH 21/37] fix import order --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2d46960ed..e5030f426 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional, Dict, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils import jax From 894cd872aa07d97e2a4fe0ce9e5a0dcdb790bfb8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:53:11 +0000 Subject: [PATCH 22/37] formatting --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index e5030f426..a0db6d607 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,7 +55,7 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state - + def model_fn( self, params: spec.ParameterContainer, From d9c4ee9d3a85f55e069db21b39feaf216ee9d42d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 18 Oct 2024 17:19:41 +0200 Subject: [PATCH 23/37] add prepare_for_eval to spec.py --- algorithmic_efficiency/spec.py | 43 ++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 25bd7b6d0..b8be5fcaa 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,19 +406,6 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] -PrepareForEvalFn = Callable[[ - Workload, - ParameterContainer, - ParameterTypeTree, - ModelAuxiliaryState, - Hyperparameters, - LossType, - OptimizerState, - List[Tuple[int, float]], - int, - RandomState -], - UpdateReturn] # Each call to this function is considered a "step". @@ -442,6 +429,36 @@ def update_params(workload: Workload, pass +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] + + +# Prepare model and optimizer for evaluation. +def prepare_for_eval(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + DataSelectionFn = Callable[[ Workload, Iterator[Dict[str, Any]], From 9caedc5570550708aba7d2695e15b2480ca7cf0f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 11:48:35 +0200 Subject: [PATCH 24/37] make prepare_for_eval backward compatible --- submission_runner.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 632cb450b..3ef30ffba 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -378,25 +378,27 @@ def train_once( workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). - with profiler.profile('Prepare for eval'): - del batch - prepare_for_eval_start_time = get_time() - optimizer_state, model_params, model_state = prepare_for_eval( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=prep_eval_rng) - prepare_for_eval_end_time = get_time() - - # Update sumbission time. - train_state['accumulated_submission_time'] += ( - prepare_for_eval_end_time - prepare_for_eval_start_time) + if prepare_for_eval is not None: + + with profiler.profile('Prepare for eval'): + del batch + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) # Check if time is remaining, # use 3x the runtime budget for the self-tuning ruleset. @@ -548,7 +550,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection - prepare_for_eval = submission_module.prepare_for_eval + prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None) try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: From 4d74d2ccee73ae6096a9fceff6a7b60c80f8f5a7 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 12:00:29 +0200 Subject: [PATCH 25/37] optional prepare_for_eval arg --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 3ef30ffba..c396cb027 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,7 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, - prepare_for_eval: spec.PrepareForEvalFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, From ce8eb182043258fc2d7823d84bcc4591441dc159 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 10:50:06 +0200 Subject: [PATCH 26/37] ensure backward compatibility --- algorithmic_efficiency/spec.py | 8 ++++---- submission_runner.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 7a16f0040..7bc86b505 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -401,10 +401,10 @@ def init_optimizer_state(workload: Workload, Dict[str, Tensor], LossType, OptimizerState, - Dict[str, Any], List[Tuple[int, float]], int, - RandomState + RandomState, + Optional[Dict[str, Any]] ], UpdateReturn] @@ -423,10 +423,10 @@ def update_params(workload: Workload, batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: RandomState) -> UpdateReturn: + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass diff --git a/submission_runner.py b/submission_runner.py index aef7fafb0..1a66acc58 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,11 +17,13 @@ import datetime import gc import importlib +from inspect import signature import itertools import json import os import struct import time +from types import MappingProxyType from typing import Any, Dict, Optional, Tuple from absl import app @@ -273,6 +275,10 @@ def train_once( hyperparameters, opt_init_rng) logging.info('Initializing metrics bundle.') + + # Check if 'train_state' is in the function signature + needs_train_state = 'train_state' in signature(update_params).parameters + # Bookkeeping. train_state = { 'validation_goal_reached': False, @@ -357,10 +363,11 @@ def train_once( batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, - train_state=train_state.copy(), eval_results=eval_results, global_step=global_step, - rng=update_rng) + rng=update_rng, + **({'train_state': MappingProxyType(train_state)} + if needs_train_state else {})) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 From 5a06a0dc670014db1eeebb1ec1960e984e379db0 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:38:05 +0200 Subject: [PATCH 27/37] adding train_state to all submissions --- .../external_tuning/jax_nadamw_full_budget.py | 5 +++-- .../external_tuning/jax_nadamw_target_setting.py | 5 +++-- .../external_tuning/pytorch_nadamw_full_budget.py | 7 ++++--- .../external_tuning/pytorch_nadamw_target_setting.py | 7 ++++--- .../self_tuning/jax_nadamw_full_budget.py | 5 +++-- .../self_tuning/jax_nadamw_target_setting.py | 5 +++-- .../self_tuning/pytorch_nadamw_full_budget.py | 7 ++++--- .../self_tuning/pytorch_nadamw_target_setting.py | 7 ++++--- .../development_algorithms/cifar/cifar_jax/submission.py | 7 ++++--- .../cifar/cifar_pytorch/submission.py | 7 ++++--- .../development_algorithms/mnist/mnist_jax/submission.py | 7 ++++--- .../mnist/mnist_pytorch/submission.py | 7 ++++--- .../paper_baselines/adafactor/jax/submission.py | 7 ++++--- .../paper_baselines/adafactor/pytorch/submission.py | 7 ++++--- .../paper_baselines/adamw/jax/submission.py | 7 ++++--- .../paper_baselines/adamw/pytorch/submission.py | 7 ++++--- .../paper_baselines/lamb/jax/submission.py | 7 ++++--- .../paper_baselines/lamb/pytorch/submission.py | 7 ++++--- .../paper_baselines/momentum/jax/submission.py | 5 +++-- .../paper_baselines/momentum/pytorch/submission.py | 5 +++-- .../paper_baselines/nadamw/jax/submission.py | 5 +++-- .../paper_baselines/nadamw/pytorch/submission.py | 7 ++++--- .../paper_baselines/nesterov/jax/submission.py | 5 +++-- .../paper_baselines/nesterov/pytorch/submission.py | 5 +++-- reference_algorithms/paper_baselines/sam/jax/submission.py | 5 +++-- .../paper_baselines/sam/pytorch/submission.py | 5 +++-- .../paper_baselines/shampoo/jax/submission.py | 7 ++++--- .../target_setting_algorithms/jax_submission_base.py | 5 +++-- .../target_setting_algorithms/pytorch_submission_base.py | 5 +++-- submissions/template/submission.py | 7 ++++--- 30 files changed, 107 insertions(+), 77 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 63cf25fe5..b390639f3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index ab0ee82b1..88725d5c3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 72a3bf289..3fc054984 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 934538b63..f218184d7 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index f6ada3c8e..14bca5730 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 9c7f66c43..4e1e523a2 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f968d4abf..076658093 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 14c22141c..d9dde586e 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 7e41e9fd7..abb598fd4 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 81110bae6..def94296b 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -61,10 +61,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 3f75c9904..4fd7d2212 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -83,10 +83,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index d326f4035..c14de49ab 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -40,10 +40,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 39cf3d4f9..ce4bfebb0 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 880f9168d..17c5d8a03 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -198,10 +198,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 06eeacb39..793a3f1de 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 0710fb9a0..225924b98 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -59,10 +59,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 891da63be..63b0cb219 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -126,10 +126,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7886dc75d..7c545d7ab 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -197,10 +197,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index dc101896b..b173ba8ba 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 52aba82bf..c063f0a64 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 63cf25fe5..b390639f3 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 72a3bf289..3fc054984 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index e47c7fa0c..35ef2bfa8 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 442949866..0b7cc570b 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 95bea68aa..da2208519 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -205,10 +205,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 15b6b6858..a793673f9 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -139,10 +139,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index e853a821b..504dff0d1 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -121,10 +121,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index a98d134fc..999422fb0 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -77,10 +77,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 586429e37..92f222a18 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -20,10 +20,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 9bfb23367..b8a394322 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from algorithmic_efficiency import spec @@ -30,10 +30,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) From 86114ef970c832ea5b8ed15c47856e6d6d325df3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:47:47 +0200 Subject: [PATCH 28/37] fix missing import Optional --- reference_algorithms/paper_baselines/momentum/jax/submission.py | 2 +- .../paper_baselines/momentum/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/nesterov/jax/submission.py | 2 +- .../paper_baselines/nesterov/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/sam/pytorch/submission.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 2 +- .../target_setting_algorithms/pytorch_submission_base.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index b173ba8ba..346abe652 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index c063f0a64..090a8bc01 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 35ef2bfa8..fa5329778 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 0b7cc570b..ce0854f7d 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index a793673f9..e9c9c9bc4 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import torch diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 999422fb0..6914da94e 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 92f222a18..606253e32 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from absl import logging import torch From 1965241b5bc1995206569ca7f786f8f8e098a7ed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:52:50 +0200 Subject: [PATCH 29/37] fix yapf --- .../external_tuning/jax_nadamw_full_budget.py | 26 +++++++++---------- .../jax_nadamw_target_setting.py | 26 +++++++++---------- .../pytorch_nadamw_full_budget.py | 26 +++++++++---------- .../pytorch_nadamw_target_setting.py | 26 +++++++++---------- .../self_tuning/jax_nadamw_full_budget.py | 26 +++++++++---------- .../self_tuning/jax_nadamw_target_setting.py | 26 +++++++++---------- .../self_tuning/pytorch_nadamw_full_budget.py | 26 +++++++++---------- .../pytorch_nadamw_target_setting.py | 26 +++++++++---------- .../cifar/cifar_jax/submission.py | 26 +++++++++---------- .../cifar/cifar_pytorch/submission.py | 26 +++++++++---------- .../mnist/mnist_jax/submission.py | 26 +++++++++---------- .../mnist/mnist_pytorch/submission.py | 26 +++++++++---------- .../adafactor/jax/submission.py | 26 +++++++++---------- .../adafactor/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/adamw/jax/submission.py | 26 +++++++++---------- .../adamw/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/lamb/jax/submission.py | 26 +++++++++---------- .../lamb/pytorch/submission.py | 26 +++++++++---------- .../momentum/jax/submission.py | 26 +++++++++---------- .../momentum/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/nadamw/jax/submission.py | 26 +++++++++---------- .../nadamw/pytorch/submission.py | 26 +++++++++---------- .../nesterov/jax/submission.py | 26 +++++++++---------- .../nesterov/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/sam/jax/submission.py | 26 +++++++++---------- .../paper_baselines/sam/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/shampoo/jax/submission.py | 26 +++++++++---------- .../jax_submission_base.py | 26 +++++++++---------- .../pytorch_submission_base.py | 26 +++++++++---------- submissions/template/submission.py | 26 +++++++++---------- 30 files changed, 390 insertions(+), 390 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index b390639f3..a235c50cd 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 88725d5c3..06413f681 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 3fc054984..0e654d43c 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index f218184d7..dd0b8b076 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 14bca5730..a9f048f03 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -264,19 +264,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 4e1e523a2..4d3d2b341 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -264,19 +264,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 076658093..5a5319957 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index d9dde586e..699b11268 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index abb598fd4..97d6df9f1 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index def94296b..853064957 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -53,19 +53,19 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 4fd7d2212..6d05954a1 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -75,19 +75,19 @@ def loss_fn(params): return new_optimizer_state, updated_params, new_model_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index c14de49ab..d27d7f742 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -32,19 +32,19 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index ce4bfebb0..efe238f26 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 17c5d8a03..377468612 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -190,19 +190,19 @@ def step(self, closure=None): return loss -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 793a3f1de..31e0a6801 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 225924b98..27ceaeef7 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -51,19 +51,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 63b0cb219..be13ab540 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -118,19 +118,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7c545d7ab..d3b491e75 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -189,19 +189,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 346abe652..3eef23942 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -144,19 +144,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 090a8bc01..cf474ebdd 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -67,19 +67,19 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index b390639f3..a235c50cd 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 3fc054984..0e654d43c 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index fa5329778..553b3e478 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -144,19 +144,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index ce0854f7d..ba8c69e6c 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -67,19 +67,19 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index da2208519..b5c7069cb 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -197,19 +197,19 @@ def _loss_fn(params, update_batch_norm=True): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index e9c9c9bc4..b69945d51 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -131,19 +131,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 504dff0d1..8f0b311a0 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -113,19 +113,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 6914da94e..51b20181b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -69,19 +69,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 606253e32..6203c58b3 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -12,19 +12,19 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submissions/template/submission.py b/submissions/template/submission.py index b8a394322..ab98c9958 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -22,19 +22,19 @@ def init_optimizer_state(workload: spec.Workload, pass -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) From 8cc4f4a0278406fb3b2ad5a3f9f5d4b5fd329daf Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:26:53 +0530 Subject: [PATCH 30/37] default dropout rates for workloads are added --- DOCUMENTATION.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..851d85dbc 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,6 +400,22 @@ Submissions will be scored based on their performance on the [fixed workload](#f Furthermore, a less computationally expensive subset of the fixed workloads is collected with the [qualification set](#qualification-set). Submitters without enough compute resources to self-report on the full set of fixed and held-out workloads can instead self-report on this smaller qualification set. Well-performing submissions can thereby qualify for computational resources provided by sponsors of the benchmark to be scored on the full benchmark set. +#### Default Dropout Values for Different Workloads: + +| Workload | Dropout Values | +|------------------------|------------------------------------------------------------------------------------------------------| +| cifar | dropout not used | +| criteo 1tb | dropout_rate: 0.0 | +| fastmri | dropout_rate: 0.0 | +| imagenet_resnet | dropout not used | +| imagenet_vit | dropout_rate: 0.0 | +| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | +| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | +| mnist | dropout not used | +| ogbg | dropout_rate: 0.1 | +| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | + + NOTE: Submitters are no longer required to self-report results for AlgoPerf competition v0.5. #### Fixed workloads From a6fc879e119cc805bafd98ecd086b1243b9a42c7 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:45:21 +0530 Subject: [PATCH 31/37] adding the dropout info in fixed workload section --- DOCUMENTATION.md | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 851d85dbc..2decbcb46 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,22 +400,6 @@ Submissions will be scored based on their performance on the [fixed workload](#f Furthermore, a less computationally expensive subset of the fixed workloads is collected with the [qualification set](#qualification-set). Submitters without enough compute resources to self-report on the full set of fixed and held-out workloads can instead self-report on this smaller qualification set. Well-performing submissions can thereby qualify for computational resources provided by sponsors of the benchmark to be scored on the full benchmark set. -#### Default Dropout Values for Different Workloads: - -| Workload | Dropout Values | -|------------------------|------------------------------------------------------------------------------------------------------| -| cifar | dropout not used | -| criteo 1tb | dropout_rate: 0.0 | -| fastmri | dropout_rate: 0.0 | -| imagenet_resnet | dropout not used | -| imagenet_vit | dropout_rate: 0.0 | -| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | -| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | -| mnist | dropout not used | -| ogbg | dropout_rate: 0.1 | -| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | - - NOTE: Submitters are no longer required to self-report results for AlgoPerf competition v0.5. #### Fixed workloads @@ -433,6 +417,23 @@ The currently eight fixed workloads are: | **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | | **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | +#### Default Dropout Values for Different Workloads: + +| Workload | Dropout Values | +|------------------------|------------------------------------------------------------------------------------------------------| +| cifar | dropout not used | +| criteo 1tb | dropout_rate: 0.0 | +| fastmri | dropout_rate: 0.0 | +| imagenet_resnet | dropout not used | +| imagenet_vit | dropout_rate: 0.0 | +| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | +| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | +| mnist | dropout not used | +| ogbg | dropout_rate: 0.1 | +| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | + + + #### Randomized workloads In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration. From 19838992f8edb766860f655670215c037ddcc834 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:47:07 +0530 Subject: [PATCH 32/37] removing bold headings --- DOCUMENTATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 2decbcb46..0c9c429c6 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -417,7 +417,7 @@ The currently eight fixed workloads are: | **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | | **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | -#### Default Dropout Values for Different Workloads: +Default Dropout Values for Different Workloads: | Workload | Dropout Values | |------------------------|------------------------------------------------------------------------------------------------------| From 76b084b556af6bd58d1fbf40d5215cce510146b9 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:55:52 +0530 Subject: [PATCH 33/37] fix: removed cifar10 and mnist --- DOCUMENTATION.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 0c9c429c6..990656d38 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -421,19 +421,15 @@ Default Dropout Values for Different Workloads: | Workload | Dropout Values | |------------------------|------------------------------------------------------------------------------------------------------| -| cifar | dropout not used | | criteo 1tb | dropout_rate: 0.0 | | fastmri | dropout_rate: 0.0 | | imagenet_resnet | dropout not used | | imagenet_vit | dropout_rate: 0.0 | | librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | | librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | -| mnist | dropout not used | | ogbg | dropout_rate: 0.1 | | wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | - - #### Randomized workloads In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration. From f72028f1236ab562a6d1d50f43ec099de9325bbb Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 19 Nov 2024 21:07:46 +0530 Subject: [PATCH 34/37] fix: ran yapf for passing the checks --- scoring/score_submissions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 59295b686..8cc06b15f 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -211,7 +211,8 @@ def main(_): verbosity=0, self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict, - output_dir=FLAGS.output_dir,) + output_dir=FLAGS.output_dir, + ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From f4c17f0223c2c50c8d4fbdf16cb6271c31d9c989 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 21 Nov 2024 09:33:38 -0800 Subject: [PATCH 35/37] Update README.md fix pytorch command --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a1f10a33..516c8eb1b 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \ + --submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \ --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ``` From bb2b361c3bda4085c45134de79a4ebf25f1f64f3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Nov 2024 23:38:41 +0000 Subject: [PATCH 36/37] formatting --- scoring/performance_profile.py | 2 +- scoring/score_submissions.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 0d5ca9770..f4f2d5679 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -321,7 +321,7 @@ def compute_performance_profiles(submissions, df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS] # Sort workloads alphabetically (for better display) df = df.reindex(sorted(df.columns), axis=1) - + # Save time to target dataframe df.to_csv(os.path.join(output_dir, 'time_to_targets.csv')) # For each held-out workload set to inf if the base workload is inf or nan diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 59295b686..8cc06b15f 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -211,7 +211,8 @@ def main(_): verbosity=0, self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict, - output_dir=FLAGS.output_dir,) + output_dir=FLAGS.output_dir, + ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From d8f07b73c7a6b0d049c513e4b846696ed7df1da8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 22 Nov 2024 22:15:01 +0530 Subject: [PATCH 37/37] fix: triggering the checks again --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 0d5ca9770..f4f2d5679 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -321,7 +321,7 @@ def compute_performance_profiles(submissions, df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS] # Sort workloads alphabetically (for better display) df = df.reindex(sorted(df.columns), axis=1) - + # Save time to target dataframe df.to_csv(os.path.join(output_dir, 'time_to_targets.csv')) # For each held-out workload set to inf if the base workload is inf or nan