Skip to content

Commit

Permalink
feat: select never selected arm for reward initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryota717 committed Sep 30, 2024
1 parent a1c1784 commit 371556f
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions package/samplers/multi_armed_bandit/multi_armed_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,27 @@ def sample_independent(
param_name: str,
param_distribution: BaseDistribution,
) -> Any:
states = (TrialState.COMPLETE, TrialState.PRUNED)
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)

rewards_by_choice: defaultdict = defaultdict(float)
cnt_by_choice: defaultdict = defaultdict(int)
for t in trials:
rewards_by_choice[t.params[param_name]] += t.value
cnt_by_choice[t.params[param_name]] += 1

# Use never selected arm for initialization like UCB1 algorithm.
# ref. https://github.com/optuna/optunahub-registry/pull/155#discussion_r1780446062
never_selected = [
arm for arm in param_distribution.choices if arm not in rewards_by_choice
]
if never_selected:
return self._rng.rng.choice(never_selected)

# If all arms are selected at least once, select arm by epsilon-greedy.
if self._rng.rng.rand() < self._epsilon:
return self._rng.rng.choice(param_distribution.choices)
else:
states = (TrialState.COMPLETE, TrialState.PRUNED)
trials = study._get_trials(deepcopy=False, states=states, use_cache=True)

rewards_by_choice: defaultdict = defaultdict(float)
cnt_by_choice: defaultdict = defaultdict(int)
for t in trials:
rewards_by_choice[t.params[param_name]] += t.value
cnt_by_choice[t.params[param_name]] += 1

if study.direction == StudyDirection.MINIMIZE:
return min(
param_distribution.choices,
Expand Down

0 comments on commit 371556f

Please sign in to comment.