diff --git a/package/samplers/catcma/catcma.py b/package/samplers/catcma/catcma.py index dd335ef8..371ee9ea 100644 --- a/package/samplers/catcma/catcma.py +++ b/package/samplers/catcma/catcma.py @@ -199,23 +199,36 @@ def sample_relative( solution_trials = self._get_solution_trials(completed_trials, optimizer.generation) if len(solution_trials) >= popsize: - solutions: List[Tuple[np.ndarray, float]] = [] + # Calculate the number of categorical variables and maximum number of choices + num_categorical_vars = len(categorical_search_space) + max_num_choices = max( + len(space.choices) for space in categorical_search_space.values() + ) + + # Prepare solutions list + solutions: List[Tuple[Tuple[np.ndarray, np.ndarray], float]] = [] + for t in solution_trials[:popsize]: assert t.value is not None, "completed trials must have a value" # Convert Optuna's representation to cmaes.CatCma's internal representation. + + # Convert numerical parameters x = trans.transform({k: t.params[k] for k in numerical_search_space.keys()}) + # Convert categorial values to one-hot vectors. # Example: # choices = ['a', 'b', 'c'] # value = 'b' # one_hot_vec = [False, True, False] - c = np.asarray( - [ - [c == v for c in categorical_search_space[k].choices] - for k, v in t.params.items() - if k in categorical_search_space.keys() - ] - ) + c = np.zeros((num_categorical_vars, max_num_choices)) + + for idx, k in enumerate(categorical_search_space.keys()): + choices = categorical_search_space[k].choices + v = t.params.get(k) + if v is not None: + index = choices.index(v) + c[idx, index] = 1 + y = t.value if study.direction == StudyDirection.MINIMIZE else -t.value solutions.append(((x, c), y)) # type: ignore