Skip to content

Commit

Permalink
Merge pull request #171 from Furkan-rgb/patch-1
Browse files Browse the repository at this point in the history
Fix shape mismatch error in `CatCMASampler` for categorical problems
  • Loading branch information
HideakiImamura authored Nov 11, 2024
2 parents a243c56 + 14f6d6a commit 2b139de
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions package/samplers/catcma/catcma.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,23 +199,36 @@ def sample_relative(
solution_trials = self._get_solution_trials(completed_trials, optimizer.generation)

if len(solution_trials) >= popsize:
solutions: List[Tuple[np.ndarray, float]] = []
# Calculate the number of categorical variables and maximum number of choices
num_categorical_vars = len(categorical_search_space)
max_num_choices = max(
len(space.choices) for space in categorical_search_space.values()
)

# Prepare solutions list
solutions: List[Tuple[Tuple[np.ndarray, np.ndarray], float]] = []

for t in solution_trials[:popsize]:
assert t.value is not None, "completed trials must have a value"
# Convert Optuna's representation to cmaes.CatCma's internal representation.

# Convert numerical parameters
x = trans.transform({k: t.params[k] for k in numerical_search_space.keys()})

# Convert categorial values to one-hot vectors.
# Example:
# choices = ['a', 'b', 'c']
# value = 'b'
# one_hot_vec = [False, True, False]
c = np.asarray(
[
[c == v for c in categorical_search_space[k].choices]
for k, v in t.params.items()
if k in categorical_search_space.keys()
]
)
c = np.zeros((num_categorical_vars, max_num_choices))

for idx, k in enumerate(categorical_search_space.keys()):
choices = categorical_search_space[k].choices
v = t.params.get(k)
if v is not None:
index = choices.index(v)
c[idx, index] = 1

y = t.value if study.direction == StudyDirection.MINIMIZE else -t.value
solutions.append(((x, c), y)) # type: ignore

Expand Down

0 comments on commit 2b139de

Please sign in to comment.