From 6eba8940381c3134114bd665a927a683bfd889aa Mon Sep 17 00:00:00 2001 From: kdziedzic68 Date: Thu, 31 Oct 2024 13:57:53 +0100 Subject: [PATCH] remove redundant choice_idx parameter from recursion --- .../src/ragbits/evaluate/optimizer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/packages/ragbits-evaluate/src/ragbits/evaluate/optimizer.py b/packages/ragbits-evaluate/src/ragbits/evaluate/optimizer.py index 0a3fe15a..5bd7e83e 100644 --- a/packages/ragbits-evaluate/src/ragbits/evaluate/optimizer.py +++ b/packages/ragbits-evaluate/src/ragbits/evaluate/optimizer.py @@ -72,8 +72,7 @@ def _score(self, pipeline: EvaluationPipeline, dataloader: DataLoader, metrics: results = event_loop.run_until_complete(evaluator.compute(pipeline=pipeline, dataloader=dataloader, metrics=metrics)) return results["metrics"] - def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, - ancestors: list[str], choice_idx: int | None = None) -> None: + def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, ancestors: list[str]) -> None: """ Modifies the original dictionary in place, replacing values for keys that contain 'opt_params_range' with random numbers between the specified range [A, B] or for @@ -82,7 +81,7 @@ def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, for key, value in cfg.items(): if isinstance(value, DictConfig): if value.get("optimize"): - param_id = f"{'.'.join(ancestors)}.{key}.{str(choice_idx)}" + param_id = f"{'.'.join(ancestors)}.{key}" choices = value.get("choices") values_range = value.get("range") assert not (choices and values_range), "Choices and range cannot be defined in couple" @@ -100,15 +99,15 @@ def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, choice_idx = trial.suggest_categorical(name=param_id, choices=choices_index) choice = choices[choice_idx] if isinstance(choice, DictConfig): - self._set_values_for_optimized_params(choice, trial, ancestors + [key], choice_idx) + self._set_values_for_optimized_params(choice, trial, ancestors + [key, choice_idx]) cfg[key] = choice choice_idx = None else: - self._set_values_for_optimized_params(value, trial, ancestors + [key], choice_idx) + self._set_values_for_optimized_params(value, trial, ancestors + [key]) elif isinstance(value, ListConfig): for param in value: if isinstance(param, DictConfig): - self._set_values_for_optimized_params(param, trial, ancestors + [key], choice_idx) + self._set_values_for_optimized_params(param, trial, ancestors + [key])