Skip to content

Commit

Permalink
remove redundant choice_idx parameter from recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
kdziedzic68 committed Oct 31, 2024
1 parent 6c1a62c commit 6eba894
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions packages/ragbits-evaluate/src/ragbits/evaluate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def _score(self, pipeline: EvaluationPipeline, dataloader: DataLoader, metrics:
results = event_loop.run_until_complete(evaluator.compute(pipeline=pipeline, dataloader=dataloader, metrics=metrics))
return results["metrics"]

def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial,
ancestors: list[str], choice_idx: int | None = None) -> None:
def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial, ancestors: list[str]) -> None:
"""
Modifies the original dictionary in place, replacing values for keys that contain
'opt_params_range' with random numbers between the specified range [A, B] or for
Expand All @@ -82,7 +81,7 @@ def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial,
for key, value in cfg.items():
if isinstance(value, DictConfig):
if value.get("optimize"):
param_id = f"{'.'.join(ancestors)}.{key}.{str(choice_idx)}"
param_id = f"{'.'.join(ancestors)}.{key}"
choices = value.get("choices")
values_range = value.get("range")
assert not (choices and values_range), "Choices and range cannot be defined in couple"
Expand All @@ -100,15 +99,15 @@ def _set_values_for_optimized_params(self, cfg: DictConfig, trial: optuna.Trial,
choice_idx = trial.suggest_categorical(name=param_id, choices=choices_index)
choice = choices[choice_idx]
if isinstance(choice, DictConfig):
self._set_values_for_optimized_params(choice, trial, ancestors + [key], choice_idx)
self._set_values_for_optimized_params(choice, trial, ancestors + [key, choice_idx])
cfg[key] = choice
choice_idx = None
else:
self._set_values_for_optimized_params(value, trial, ancestors + [key], choice_idx)
self._set_values_for_optimized_params(value, trial, ancestors + [key])
elif isinstance(value, ListConfig):
for param in value:
if isinstance(param, DictConfig):
self._set_values_for_optimized_params(param, trial, ancestors + [key], choice_idx)
self._set_values_for_optimized_params(param, trial, ancestors + [key])



0 comments on commit 6eba894

Please sign in to comment.