Skip to content

Commit

Permalink
updated mpi doptuna
Browse files Browse the repository at this point in the history
  • Loading branch information
Deathn0t committed Feb 21, 2024
1 parent d1a8349 commit 7b60adf
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions deephyper_benchmark/search/_mpi_doptuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from deephyper.core.utils._timeout import terminate_on_timeout # noqa: E402
from deephyper.evaluator import RunningJob # noqa: E402
from deephyper.search import Search # noqa: E402
from deephyper.skopt.moo import non_dominated_set


def optuna_suggest_from_hp(trial, cs_hp):
Expand Down Expand Up @@ -284,7 +285,7 @@ def objective_wrapper(trial):
constraints = []
for i, lbi in enumerate(self._moo_lower_bounds):
if lbi is not None and type(output["objective"][i]) is not str:
ci = -(output["objective"][i] - lbi) # <= 0
ci = -(output["objective"][i] - lbi) # <= 0
constraints.append(ci)
trial.set_user_attr("constraints", tuple(constraints))

Expand Down Expand Up @@ -356,8 +357,31 @@ def optimize_wrapper(duration):

df_results = pd.DataFrame([t.user_attrs["results"] for t in all_trials])
df_path = os.path.join(self._log_dir, "results.csv")

# Check if Multi-Objective Optimization was performed to save the pareto front
objective_columns = [
col for col in df_results.columns if col.startswith("objective")
]

if len(objective_columns) > 1:
if pd.api.types.is_string_dtype(df_results[objective_columns[0]]):
mask_no_failures = ~df_results[objective_columns[0]].str.startswith(
"F"
)
else:
mask_no_failures = np.ones(len(df_results), dtype=bool)
objectives = -df_results.loc[
mask_no_failures, objective_columns
].values.astype(float)
mask_pareto_front = non_dominated_set(objectives)
df_results["pareto_efficient"] = False
df_results.loc[mask_no_failures, "pareto_efficient"] = mask_pareto_front

df_results.to_csv(df_path, index=False)

self.extend_results_with_pareto_efficient(df_path)
if self.comm:
self.comm.Barrier()

df_results = pd.read_csv(df_path)

return df_results

0 comments on commit 7b60adf

Please sign in to comment.