diff --git a/package/samplers/multi_armed_bandit/multi_armed_bandit.py b/package/samplers/multi_armed_bandit/multi_armed_bandit.py index 871fd766..6c6a06be 100644 --- a/package/samplers/multi_armed_bandit/multi_armed_bandit.py +++ b/package/samplers/multi_armed_bandit/multi_armed_bandit.py @@ -37,18 +37,27 @@ def sample_independent( param_name: str, param_distribution: BaseDistribution, ) -> Any: + states = (TrialState.COMPLETE, TrialState.PRUNED) + trials = study._get_trials(deepcopy=False, states=states, use_cache=True) + + rewards_by_choice: defaultdict = defaultdict(float) + cnt_by_choice: defaultdict = defaultdict(int) + for t in trials: + rewards_by_choice[t.params[param_name]] += t.value + cnt_by_choice[t.params[param_name]] += 1 + + # Use never selected arm for initialization like UCB1 algorithm. + # ref. https://github.com/optuna/optunahub-registry/pull/155#discussion_r1780446062 + never_selected = [ + arm for arm in param_distribution.choices if arm not in rewards_by_choice + ] + if never_selected: + return self._rng.rng.choice(never_selected) + + # If all arms are selected at least once, select arm by epsilon-greedy. if self._rng.rng.rand() < self._epsilon: return self._rng.rng.choice(param_distribution.choices) else: - states = (TrialState.COMPLETE, TrialState.PRUNED) - trials = study._get_trials(deepcopy=False, states=states, use_cache=True) - - rewards_by_choice: defaultdict = defaultdict(float) - cnt_by_choice: defaultdict = defaultdict(int) - for t in trials: - rewards_by_choice[t.params[param_name]] += t.value - cnt_by_choice[t.params[param_name]] += 1 - if study.direction == StudyDirection.MINIMIZE: return min( param_distribution.choices,