From 4ff613453089083094da6d911f2a4081f59d96a2 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Sun, 16 Oct 2022 21:33:59 -0300 Subject: [PATCH 01/32] Create new metrics for rank and crowding --- pymoo/algorithms/moo/nsga2.py | 110 +---- pymoo/cython/mnn.pyx | 397 ++++++++++++++++++ pymoo/cython/pruning_cd.pyx | 312 ++++++++++++++ .../survival/rank_and_crowding/__init__.py | 1 + .../survival/rank_and_crowding/classes.py | 206 +++++++++ .../survival/rank_and_crowding/metrics.py | 193 +++++++++ pymoo/util/function_loader.py | 12 + pymoo/util/mnn.py | 67 +++ pymoo/util/pruning_cd.py | 89 ++++ tests/algorithms/test_nsga2.py | 2 +- tests/misc/test_crowding_distance.py | 2 +- 11 files changed, 1292 insertions(+), 99 deletions(-) create mode 100644 pymoo/cython/mnn.pyx create mode 100644 pymoo/cython/pruning_cd.pyx create mode 100644 pymoo/operators/survival/rank_and_crowding/__init__.py create mode 100644 pymoo/operators/survival/rank_and_crowding/classes.py create mode 100644 pymoo/operators/survival/rank_and_crowding/metrics.py create mode 100644 pymoo/util/mnn.py create mode 100644 pymoo/util/pruning_cd.py diff --git a/pymoo/algorithms/moo/nsga2.py b/pymoo/algorithms/moo/nsga2.py index 8da9d6715..eb8f7b511 100644 --- a/pymoo/algorithms/moo/nsga2.py +++ b/pymoo/algorithms/moo/nsga2.py @@ -1,18 +1,17 @@ import numpy as np +import warnings from pymoo.algorithms.base.genetic import GeneticAlgorithm -from pymoo.core.survival import Survival from pymoo.docs import parse_doc_string from pymoo.operators.crossover.sbx import SBX from pymoo.operators.mutation.pm import PM +from pymoo.operators.survival.rank_and_crowding import RankAndCrowding from pymoo.operators.sampling.rnd import FloatRandomSampling from pymoo.operators.selection.tournament import compare, TournamentSelection from pymoo.termination.default import DefaultMultiObjectiveTermination from pymoo.util.display.multi import MultiObjectiveOutput from pymoo.util.dominator import Dominator -from pymoo.util.misc import find_duplicates, has_feasible -from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting -from pymoo.util.randomized_argsort import randomized_argsort +from pymoo.util.misc import has_feasible # --------------------------------------------------------------------------------------------------------- @@ -68,47 +67,14 @@ def binary_tournament(pop, P, algorithm, **kwargs): # --------------------------------------------------------------------------------------------------------- -class RankAndCrowdingSurvival(Survival): - - def __init__(self, nds=None) -> None: - super().__init__(filter_infeasible=True) - self.nds = nds if nds is not None else NonDominatedSorting() - - def _do(self, problem, pop, *args, n_survive=None, **kwargs): - - # get the objective space values and objects - F = pop.get("F").astype(float, copy=False) - - # the final indices of surviving individuals - survivors = [] - - # do the non-dominated sorting until splitting front - fronts = self.nds.do(F, n_stop_if_ranked=n_survive) - - for k, front in enumerate(fronts): - - # calculate the crowding distance of the front - crowding_of_front = calc_crowding_distance(F[front, :]) - - # save rank and crowding in the individual class - for j, i in enumerate(front): - pop[i].set("rank", k) - pop[i].set("crowding", crowding_of_front[j]) - - # current front sorted by crowding distance if splitting - if len(survivors) + len(front) > n_survive: - I = randomized_argsort(crowding_of_front, order='descending', method='numpy') - I = I[:(n_survive - len(survivors))] - - # otherwise take the whole front unsorted - else: - I = np.arange(len(front)) - - # extend the survivors by all or selected individuals - survivors.extend(front[I]) - - return pop[survivors] - +class RankAndCrowdingSurvival(RankAndCrowding): + + def __init__(self, nds=None, crowding_func="cd"): + warnings.warn( + "RankAndCrowdingSurvival is deprecated and will be removed in version 0.8.*; use RankAndCrowding operator instead, which supports several and custom crowding diversity metrics.", + DeprecationWarning, 2 + ) + super().__init__(nds, crowding_func) # ========================================================================================================= # Implementation @@ -123,9 +89,10 @@ def __init__(self, selection=TournamentSelection(func_comp=binary_tournament), crossover=SBX(eta=15, prob=0.9), mutation=PM(eta=20), - survival=RankAndCrowdingSurvival(), + survival=RankAndCrowding(), output=MultiObjectiveOutput(), **kwargs): + super().__init__( pop_size=pop_size, sampling=sampling, @@ -147,55 +114,4 @@ def _set_optimum(self, **kwargs): self.opt = self.pop[self.pop.get("rank") == 0] -def calc_crowding_distance(F, filter_out_duplicates=True): - n_points, n_obj = F.shape - - if n_points <= 2: - return np.full(n_points, np.inf) - - else: - - if filter_out_duplicates: - # filter out solutions which are duplicates - duplicates get a zero finally - is_unique = np.where(np.logical_not(find_duplicates(F, epsilon=1e-32)))[0] - else: - # set every point to be unique without checking it - is_unique = np.arange(n_points) - - # index the unique points of the array - _F = F[is_unique] - - # sort each column and get index - I = np.argsort(_F, axis=0, kind='mergesort') - - # sort the objective space values for the whole matrix - _F = _F[I, np.arange(n_obj)] - - # calculate the distance from each point to the last and next - dist = np.row_stack([_F, np.full(n_obj, np.inf)]) - np.row_stack([np.full(n_obj, -np.inf), _F]) - - # calculate the norm for each objective - set to NaN if all values are equal - norm = np.max(_F, axis=0) - np.min(_F, axis=0) - norm[norm == 0] = np.nan - - # prepare the distance to last and next vectors - dist_to_last, dist_to_next = dist, np.copy(dist) - dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm - - # if we divide by zero because all values in one columns are equal replace by none - dist_to_last[np.isnan(dist_to_last)] = 0.0 - dist_to_next[np.isnan(dist_to_next)] = 0.0 - - # sum up the distance to next and last and norm by objectives - also reorder from sorted list - J = np.argsort(I, axis=0) - _cd = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj - - # save the final vector which sets the crowding distance for duplicates to zero to be eliminated - crowding = np.zeros(n_points) - crowding[is_unique] = _cd - - # crowding[np.isinf(crowding)] = 1e+14 - return crowding - - parse_doc_string(NSGA2.__init__) diff --git a/pymoo/cython/mnn.pyx b/pymoo/cython/mnn.pyx new file mode 100644 index 000000000..7ae772e84 --- /dev/null +++ b/pymoo/cython/mnn.pyx @@ -0,0 +1,397 @@ +# distutils: language = c++ +# cython: language_level=2, boundscheck=False, wraparound=False, cdivision=True + +# This was implemented using the full distances matrix +# Other strategies can be more efficient depending on the population size and number of objectives +# This approach was the most promising for N = 3 +# I believe for a large number of objectives M, some strategy based on upper bounds for distances would be helpful +# Those interested in contributing please contact me at bruscalia12@gmail.com + + +import numpy as np + +from libcpp cimport bool +from libcpp.vector cimport vector +from libcpp.set cimport set as cpp_set + + +cdef extern from "math.h": + double HUGE_VAL + + +def calc_mnn(double[:, :] X, int n_remove=0): + + cdef: + int N, M, n + cpp_set[int] extremes + vector[int] extremes_min, extremes_max + + N = X.shape[0] + M = X.shape[1] + + if n_remove <= (N - M): + if n_remove < 0: + n_remove = 0 + else: + pass + else: + n_remove = N - M + + extremes_min = c_get_argmin(X) + extremes_max = c_get_argmax(X) + + extremes = cpp_set[int]() + + for n in extremes_min: + extremes.insert(n) + + for n in extremes_max: + extremes.insert(n) + + X = c_normalize_array(X, extremes_max, extremes_min) + + return c_calc_mnn(X, n_remove, N, M, extremes) + + +def calc_2nn(double[:, :] X, int n_remove=0): + + cdef: + int N, M, n + cpp_set[int] extremes + vector[int] extremes_min, extremes_max + + N = X.shape[0] + M = X.shape[1] + + if n_remove <= (N - M): + if n_remove < 0: + n_remove = 0 + else: + pass + else: + n_remove = N - M + + extremes_min = c_get_argmin(X) + extremes_max = c_get_argmax(X) + + extremes = cpp_set[int]() + + for n in extremes_min: + extremes.insert(n) + + for n in extremes_max: + extremes.insert(n) + + X = c_normalize_array(X, extremes_max, extremes_min) + + M = 2 + + return c_calc_mnn(X, n_remove, N, M, extremes) + + +cdef c_calc_mnn(double[:, :] X, int n_remove, int N, int M, cpp_set[int] extremes): + + cdef: + int n, mm, i, j, n_removed, k, MM + double dij + cpp_set[int] calc_items + cpp_set[int] H + double[:, :] D + double[:] d + int[:, :] Mnn + + #Define items to calculate distances + calc_items = cpp_set[int]() + for n in range(N): + calc_items.insert(n) + for n in extremes: + calc_items.erase(n) + + #Define remaining items to evaluate + H = cpp_set[int]() + for n in range(N): + H.insert(n) + + #Instantiate distances array + _D = np.empty((N, N), dtype=np.double) + D = _D[:, :] + + #Shape of X + MM = X.shape[1] + + #Fill values on D + for i in range(N - 1): + D[i, i] = 0.0 + + for j in range(i + 1, N): + + dij = 0 + for mm in range(MM): + dij = dij + (X[j, mm] - X[i, mm]) * (X[j, mm] - X[i, mm]) + + D[i, j] = dij + D[j, i] = D[i, j] + + D[N-1, N-1] = 0.0 + + #Initialize + n_removed = 0 + + #Initialize neighbors and distances + # _Mnn = np.full((N, M), -1, dtype=np.intc) + _Mnn = np.argpartition(D, range(1, M+1), axis=1)[:, 1:M+1].astype(np.intc) + dd = np.full((N,), HUGE_VAL, dtype=np.double) + + Mnn = _Mnn[:, :] + d = dd[:] + + #Obtain distance metrics + c_calc_d(d, Mnn, D, calc_items, M) + + #While n_remove not acheived (no need to recalculate if only one item should be removed) + while n_removed < (n_remove - 1): + + #Obtain element to drop + k = c_get_drop(d, H) + H.erase(k) + + #Update index + n_removed = n_removed + 1 + + #Get items to be recalculated + calc_items = c_get_calc_items(Mnn, H, k, M) + for n in extremes: + calc_items.erase(n) + + #Fill in neighbors and distance matrix + c_calc_mnn_iter( + X, + Mnn, + D, + N, M, + calc_items, + H + ) + + #Obtain distance metrics + c_calc_d(d, Mnn, D, calc_items, M) + + return dd + + +cdef c_calc_mnn_iter( + double[:, :] X, + int[:, :] Mnn, + double[:, :] D, + int N, int M, + cpp_set[int] calc_items, + cpp_set[int] H + ): + + cdef: + int i, j, m + + #Iterate over items to calculate + for i in calc_items: + + #Iterate over elements in X + for j in H: + + #Go to next if same element + if (j == i): + continue + + #Replace at least the last neighbor + elif (D[i, j] <= D[i, Mnn[i, M-1]]) or (Mnn[i, M-1] == -1): + + #Iterate over current values + for m in range(M): + + #Set to current if unassigned + if (Mnn[i, m] == -1): + + #Set last neighbor to index + Mnn[i, m] = j + break + + #Break if checking already corresponding index + elif (j == Mnn[i, m]): + break + + #Distance satisfies condition + elif (D[i, j] <= D[i, Mnn[i, m]]): + + #Replace higher values + Mnn[i, m + 1:] = Mnn[i, m:-1] + + #Replace current value + Mnn[i, m] = j + break + + +#Calculate crowding metric +cdef c_calc_d(double[:] d, int[:, :] Mnn, double[:, :] D, cpp_set[int] calc_items, int M): + + cdef: + int i, m + + for i in calc_items: + + d[i] = 1 + for m in range(M): + d[i] = d[i] * D[i, Mnn[i, m]] + + +#Returns indexes of items to be recalculated after removal +cdef cpp_set[int] c_get_calc_items( + int[:, :] Mnn, + cpp_set[int] H, + int k, int M): + + cdef: + int i, m + cpp_set[int] calc_items + + calc_items = cpp_set[int]() + + for i in H: + + for m in range(M): + + if Mnn[i, m] == k: + + Mnn[i, m:-1] = Mnn[i, m + 1:] + Mnn[i, M-1] = -1 + + calc_items.insert(i) + + return calc_items + + +#Returns elements to remove based on crowding metric d and heap of remaining elements H +cdef int c_get_drop(double[:] d, cpp_set[int] H): + + cdef: + int i, min_i + double min_d + + min_d = HUGE_VAL + min_i = 0 + + for i in H: + + if d[i] <= min_d: + min_d = d[i] + min_i = i + + return min_i + + +#Elements in condensed matrix +cdef int c_square_to_condensed(int i, int j, int N): + + cdef int _i = i + + if i < j: + i = j + j = _i + + return N * j - j * (j + 1) // 2 + i - 1 - j + + +#Returns vector of positions of minimum values along axis 0 of a 2d memoryview +cdef vector[int] c_get_argmin(double[:, :] X): + + cdef: + int N, M, min_i, n, m + double min_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + min_i = 0 + min_val = X[0, m] + + for n in range(N): + + if X[n, m] < min_val: + + min_i = n + min_val = X[n, m] + + indexes.push_back(min_i) + + return indexes + + +#Returns vector of positions of maximum values along axis 0 of a 2d memoryview +cdef vector[int] c_get_argmax(double[:, :] X): + + cdef: + int N, M, max_i, n, m + double max_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + max_i = 0 + max_val = X[0, m] + + for n in range(N): + + if X[n, m] > max_val: + + max_i = n + max_val = X[n, m] + + indexes.push_back(max_i) + + return indexes + + +#Performs normalization of a 2d memoryview +cdef double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): + + cdef: + int N = X.shape[0] + int M = X.shape[1] + int n, m, l, u + double l_val, u_val, diff_val + vector[double] min_vals, max_vals + + min_vals = vector[double]() + max_vals = vector[double]() + + m = 0 + for u in extremes_max: + u_val = X[u, m] + max_vals.push_back(u_val) + m = m + 1 + + m = 0 + for l in extremes_min: + l_val = X[l, m] + min_vals.push_back(l_val) + m = m + 1 + + for m in range(M): + + diff_val = max_vals[m] - min_vals[m] + if diff_val == 0.0: + diff_val = 1.0 + + for n in range(N): + + X[n, m] = (X[n, m] - min_vals[m]) / diff_val + + return X \ No newline at end of file diff --git a/pymoo/cython/pruning_cd.pyx b/pymoo/cython/pruning_cd.pyx new file mode 100644 index 000000000..505c29d56 --- /dev/null +++ b/pymoo/cython/pruning_cd.pyx @@ -0,0 +1,312 @@ +# distutils: language = c++ +# cython: language_level=2, boundscheck=False, wraparound=False, cdivision=True + +import numpy as np + +from libcpp cimport bool +from libcpp.vector cimport vector +from libcpp.set cimport set as cpp_set + + +cdef extern from "math.h": + double HUGE_VAL + + +#Python definition +def calc_pcd(double[:, :] X, int n_remove=0): + + cdef: + int N, M, n + cpp_set[int] extremes + vector[int] extremes_min, extremes_max + int[:, :] I + + N = X.shape[0] + M = X.shape[1] + + if n_remove <= (N - M): + if n_remove < 0: + n_remove = 0 + else: + pass + else: + n_remove = N - M + + extremes_min = c_get_argmin(X) + extremes_max = c_get_argmax(X) + + extremes = cpp_set[int]() + + for n in extremes_min: + extremes.insert(n) + + for n in extremes_max: + extremes.insert(n) + + _I = np.argsort(X, axis=0, kind='mergesort').astype(np.intc) + I = _I[:, :] + + X = c_normalize_array(X, extremes_max, extremes_min) + + return c_calc_pcd(X, I, n_remove, N, M, extremes) + + +#Returns crowding metrics with recursive elimination +cdef c_calc_pcd(double[:, :] X, int[:, :] I, int n_remove, int N, int M, cpp_set[int] extremes): + + cdef: + int n, n_removed, k + cpp_set[int] calc_items + cpp_set[int] H + double[:, :] D + double[:] d + + #Define items to calculate distances + calc_items = cpp_set[int]() + for n in range(N): + calc_items.insert(n) + for n in extremes: + calc_items.erase(n) + + #Define remaining items to evaluate + H = cpp_set[int]() + for n in range(N): + H.insert(n) + + #Initialize + n_removed = 0 + + #Initialize neighbors and distances + _D = np.full((N, M), HUGE_VAL, dtype=np.double) + dd = np.full((N,), HUGE_VAL, dtype=np.double) + + D = _D[:, :] + d = dd[:] + + #Fill in neighbors and distance matrix + c_calc_pcd_iter( + X, + I, + D, + N, M, + calc_items, + ) + + #Obtain distance metrics + c_calc_d(d, D, calc_items, M) + + #While n_remove not acheived + while n_removed < (n_remove - 1): + + #Obtain element to drop + k = c_get_drop(d, H) + H.erase(k) + + #Update index + n_removed = n_removed + 1 + + #Get items to be recalculated + calc_items = c_get_calc_items(I, k, M, N) + for n in extremes: + calc_items.erase(n) + + #Fill in neighbors and distance matrix + c_calc_pcd_iter( + X, + I, + D, + N, M, + calc_items, + ) + + #Obtain distance metrics + c_calc_d(d, D, calc_items, M) + + return dd + + +#Iterate +cdef c_calc_pcd_iter( + double[:, :] X, + int[:, :] I, + double[:, :] D, + int N, int M, + cpp_set[int] calc_items, + ): + + cdef: + int i, m, n, l, u + + #Iterate over items to calculate + for i in calc_items: + + #Iterate over elements in X + for m in range(M): + + for n in range(N): + + if i == I[n, m]: + + l = I[n - 1, m] + u = I[n + 1, m] + + D[i, m] = (X[u, m] - X[l, m]) / M + + +#Calculate crowding metric +cdef c_calc_d(double[:] d, double[:, :] D, cpp_set[int] calc_items, int M): + + cdef: + int i, m + + for i in calc_items: + + d[i] = 0 + for m in range(M): + d[i] = d[i] + D[i, m] + + +#Returns indexes of items to be recalculated after removal +cdef cpp_set[int] c_get_calc_items( + int[:, :] I, + int k, int M, int N + ): + + cdef: + int n, m + cpp_set[int] calc_items + + calc_items = cpp_set[int]() + + #Iterate over all elements in I + for m in range(M): + + for n in range(N): + + if I[n, m] == k: + + #Add to set of items to be recalculated + calc_items.insert(I[n - 1, m]) + calc_items.insert(I[n + 1, m]) + + #Remove element from sorted array + I[n:-1, m] = I[n + 1:, m] + + return calc_items + + +#Returns elements to remove based on crowding metric d and heap of remaining elements H +cdef int c_get_drop(double[:] d, cpp_set[int] H): + + cdef: + int i, min_i + double min_d + + min_d = HUGE_VAL + min_i = 0 + + for i in H: + + if d[i] <= min_d: + min_d = d[i] + min_i = i + + return min_i + + +#Returns vector of positions of minimum values along axis 0 of a 2d memoryview +cdef vector[int] c_get_argmin(double[:, :] X): + + cdef: + int N, M, min_i, n, m + double min_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + min_i = 0 + min_val = X[0, m] + + for n in range(N): + + if X[n, m] < min_val: + + min_i = n + min_val = X[n, m] + + indexes.push_back(min_i) + + return indexes + + +#Returns vector of positions of maximum values along axis 0 of a 2d memoryview +cdef vector[int] c_get_argmax(double[:, :] X): + + cdef: + int N, M, max_i, n, m + double max_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + max_i = 0 + max_val = X[0, m] + + for n in range(N): + + if X[n, m] > max_val: + + max_i = n + max_val = X[n, m] + + indexes.push_back(max_i) + + return indexes + + +#Performs normalization of a 2d memoryview +cdef double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): + + cdef: + int N = X.shape[0] + int M = X.shape[1] + int n, m, l, u + double l_val, u_val, diff_val + vector[double] min_vals, max_vals + + min_vals = vector[double]() + max_vals = vector[double]() + + m = 0 + for u in extremes_max: + u_val = X[u, m] + max_vals.push_back(u_val) + m = m + 1 + + m = 0 + for l in extremes_min: + l_val = X[l, m] + min_vals.push_back(l_val) + m = m + 1 + + for m in range(M): + + diff_val = max_vals[m] - min_vals[m] + if diff_val == 0.0: + diff_val = 1.0 + + for n in range(N): + + X[n, m] = (X[n, m] - min_vals[m]) / diff_val + + return X \ No newline at end of file diff --git a/pymoo/operators/survival/rank_and_crowding/__init__.py b/pymoo/operators/survival/rank_and_crowding/__init__.py new file mode 100644 index 000000000..9a7419410 --- /dev/null +++ b/pymoo/operators/survival/rank_and_crowding/__init__.py @@ -0,0 +1 @@ +from pymoo.operators.survival.rank_and_crowding.classes import RankAndCrowding, ConstrRankAndCrowding \ No newline at end of file diff --git a/pymoo/operators/survival/rank_and_crowding/classes.py b/pymoo/operators/survival/rank_and_crowding/classes.py new file mode 100644 index 000000000..516a0dce3 --- /dev/null +++ b/pymoo/operators/survival/rank_and_crowding/classes.py @@ -0,0 +1,206 @@ +import numpy as np +from pymoo.util.randomized_argsort import randomized_argsort +from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting +from pymoo.core.survival import Survival, split_by_feasibility +from pymoo.core.population import Population +from pymoo.operators.survival.rank_and_crowding.metrics import get_crowding_function + + +class RankAndCrowding(Survival): + + def __init__(self, nds=None, crowding_func="cd"): + """ + A generalization of the NSGA-II survival operator that ranks individuals by dominance criteria + and sorts the last front by some user-specified crowding metric. The default is NSGA-II's crowding distances + although others might be more effective. + + For many-objective problems, try using 'mnn' or '2nn'. + + For Bi-objective problems, 'pcd' is very effective. + + Parameters + ---------- + nds : str or None, optional + Pymoo type of non-dominated sorting. Defaults to None. + + crowding_func : str or callable, optional + Crowding metric. Options are: + + - 'cd': crowding distances + - 'pcd' or 'pruned-cd': pruned crowding distances + - 'ce': crowding entropy + - 'mnn': M-Neaest Neighbors + - '2nn': 2-Neaest Neighbors + + If callable, it has the form ``fun(F, filter_out_duplicates=None, n_remove=None, **kwargs)`` + in which F (n, m) and must return metrics in a (n,) array. + + The options 'pcd', 'cd', and 'ce' are recommended for two-objective problems, whereas 'mnn' and '2nn' for many objective. + When using 'pcd', 'mnn', or '2nn', individuals are already eliminated in a 'single' manner. + Due to Cython implementation, they are as fast as the corresponding 'cd', 'mnn-fast', or '2nn-fast', + although they can singnificantly improve diversity of solutions. + Defaults to 'cd'. + """ + + crowding_func_ = get_crowding_function(crowding_func) + + super().__init__(filter_infeasible=True) + self.nds = nds if nds is not None else NonDominatedSorting() + self.crowding_func = crowding_func_ + + def _do(self, + problem, + pop, + *args, + n_survive=None, + **kwargs): + + # get the objective space values and objects + F = pop.get("F").astype(float, copy=False) + + # the final indices of surviving individuals + survivors = [] + + # do the non-dominated sorting until splitting front + fronts = self.nds.do(F, n_stop_if_ranked=n_survive) + + for k, front in enumerate(fronts): + + # current front sorted by crowding distance if splitting + while len(survivors) + len(front) > n_survive: + + #Define how many will be removed + n_remove = len(survivors) + len(front) - n_survive + + # re-calculate the crowding distance of the front + crowding_of_front = \ + self.crowding_func.do( + F[front, :], + n_remove=n_remove + ) + + I = randomized_argsort(crowding_of_front, order='descending', method='numpy') + + I = I[:-n_remove] + front = front[I] + + # otherwise take the whole front unsorted + else: + # calculate the crowding distance of the front + crowding_of_front = \ + self.crowding_func.do( + F[front, :], + n_remove=0 + ) + + # save rank and crowding in the individual class + for j, i in enumerate(front): + pop[i].set("rank", k) + pop[i].set("crowding", crowding_of_front[j]) + + # extend the survivors by all or selected individuals + survivors.extend(front) + + return pop[survivors] + + +class ConstrRankAndCrowding(Survival): + + def __init__(self, nds=None, crowding_func="cd"): + """ + The Rank and Crowding survival approach for handling constraints proposed on + GDE3 by Kukkonen, S. & Lampinen, J. (2005). + + Parameters + ---------- + nds : str or None, optional + Pymoo type of non-dominated sorting. Defaults to None. + + crowding_func : str or callable, optional + Crowding metric. Options are: + + - 'cd': crowding distances + - 'pcd' or 'pruned-cd': pruned crowding distances + - 'ce': crowding entropy + - 'mnn': M-Neaest Neighbors + - '2nn': 2-Neaest Neighbors + + If callable, it has the form ``fun(F, filter_out_duplicates=None, n_remove=None, **kwargs)`` + in which F (n, m) and must return metrics in a (n,) array. + + The options 'pcd', 'cd', and 'ce' are recommended for two-objective problems, whereas 'mnn' and '2nn' for many objective. + When using 'pcd', 'mnn', or '2nn', individuals are already eliminated in a 'single' manner. + Due to Cython implementation, they are as fast as the corresponding 'cd', 'mnn-fast', or '2nn-fast', + although they can singnificantly improve diversity of solutions. + Defaults to 'cd'. + """ + + super().__init__(filter_infeasible=False) + self.nds = nds if nds is not None else NonDominatedSorting() + self.ranking = RankAndCrowding(nds=nds, crowding_func=crowding_func) + + def _do(self, + problem, + pop, + *args, + n_survive=None, + **kwargs): + + if n_survive is None: + n_survive = len(pop) + + n_survive = min(n_survive, len(pop)) + + #If the split should be done beforehand + if problem.n_constr > 0: + + #Split by feasibility + feas, infeas = split_by_feasibility(pop, eps=0.0, sort_infeasbible_by_cv=True) + + #Obtain len of feasible + n_feas = len(feas) + + #Assure there is at least_one survivor + if n_feas == 0: + survivors = Population() + else: + survivors = self.ranking.do(problem, pop[feas], *args, n_survive=min(len(feas), n_survive), **kwargs) + + #Calculate how many individuals are still remaining to be filled up with infeasible ones + n_remaining = n_survive - len(survivors) + + #If infeasible solutions need to be added + if n_remaining > 0: + + #Constraints to new ranking + G = pop[infeas].get("G") + G = np.maximum(G, 0) + + #Fronts in infeasible population + infeas_fronts = self.nds.do(G, n_stop_if_ranked=n_remaining) + + #Iterate over fronts + for k, front in enumerate(infeas_fronts): + + #Save ranks + pop[infeas][front].set("cv_rank", k) + + #Current front sorted by CV + if len(survivors) + len(front) > n_survive: + + #Obtain CV of front + CV = pop[infeas][front].get("CV").flatten() + I = randomized_argsort(CV, order='ascending', method='numpy') + I = I[:(n_survive - len(survivors))] + + #Otherwise take the whole front unsorted + else: + I = np.arange(len(front)) + + # extend the survivors by all or selected individuals + survivors = Population.merge(survivors, pop[infeas][front[I]]) + + else: + survivors = self.ranking.do(problem, pop, *args, n_survive=n_survive, **kwargs) + + return survivors diff --git a/pymoo/operators/survival/rank_and_crowding/metrics.py b/pymoo/operators/survival/rank_and_crowding/metrics.py new file mode 100644 index 000000000..751f4fe73 --- /dev/null +++ b/pymoo/operators/survival/rank_and_crowding/metrics.py @@ -0,0 +1,193 @@ +import numpy as np +from scipy.spatial.distance import pdist, squareform +from pymoo.util.misc import find_duplicates +from pymoo.util.function_loader import load_function + + +def get_crowding_function(label): + + if label == "cd": + fun = FunctionalDiversity(calc_crowding_distance, filter_out_duplicates=False) + elif (label == "pcd") or (label == "pruning-cd"): + fun = FunctionalDiversity(load_function("calc_pcd"), filter_out_duplicates=True) + elif label == "ce": + fun = FunctionalDiversity(calc_crowding_entropy, filter_out_duplicates=True) + elif label == "mnn": + fun = FunctionalDiversity(load_function("calc_mnn"), filter_out_duplicates=True) + elif label == "2nn": + fun = FunctionalDiversity(load_function("calc_2nn"), filter_out_duplicates=True) + elif hasattr(label, "__call__"): + fun = FunctionalDiversity(label, filter_out_duplicates=True) + else: + raise KeyError("Crwoding function not defined") + return fun + + +class CrowdingDiversity: + + def do(self, F, n_remove=0): + #Converting types Python int to Cython int would fail in some cases converting to long instead + n_remove = np.intc(n_remove) + F = np.array(F, dtype=np.double) + return self._do(F, n_remove=n_remove) + + def _do(self, F, n_remove=None): + pass + + +class FunctionalDiversity(CrowdingDiversity): + + def __init__(self, function=None, filter_out_duplicates=True): + self.function = function + self.filter_out_duplicates = filter_out_duplicates + super().__init__() + + def _do(self, F, **kwargs): + + n_points, n_obj = F.shape + + if n_points <= F.shape[1]: + return np.full(n_points, np.inf) + + else: + + if self.filter_out_duplicates: + # filter out solutions which are duplicates - duplicates get a zero finally + is_unique = np.where(np.logical_not(find_duplicates(F, epsilon=1e-32)))[0] + else: + # set every point to be unique without checking it + is_unique = np.arange(n_points) + + # index the unique points of the array + _F = F[is_unique] + + _d = self.function(_F, **kwargs) + + d = np.zeros(n_points) + d[is_unique] = _d + + return d + + +def calc_crowding_distance(F, **kwargs): + n_points, n_obj = F.shape + + # sort each column and get index + I = np.argsort(F, axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + F = F[I, np.arange(n_obj)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([F, np.full(n_obj, np.inf)]) - np.row_stack([np.full(n_obj, -np.inf), F]) + + # calculate the norm for each objective - set to NaN if all values are equal + norm = np.max(F, axis=0) - np.min(F, axis=0) + norm[norm == 0] = np.nan + + # prepare the distance to last and next vectors + dist_to_last, dist_to_next = dist, np.copy(dist) + dist_to_last, dist_to_next = dist_to_last[:-1] / norm, dist_to_next[1:] / norm + + # if we divide by zero because all values in one columns are equal replace by none + dist_to_last[np.isnan(dist_to_last)] = 0.0 + dist_to_next[np.isnan(dist_to_next)] = 0.0 + + # sum up the distance to next and last and norm by objectives - also reorder from sorted list + J = np.argsort(I, axis=0) + cd = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj + + return cd + + +def calc_crowding_entropy(F, **kwargs): + """Wang, Y.-N., Wu, L.-H. & Yuan, X.-F., 2010. Multi-objective self-adaptive differential + evolution with elitist archive and crowding entropy-based diversity measure. + Soft Comput., 14(3), pp. 193-209. + + Parameters + ---------- + F : 2d array like + Objective functions. + + Returns + ------- + ce : 1d array + Crowding Entropies + """ + n_points, n_obj = F.shape + + # sort each column and get index + I = np.argsort(F, axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + F = F[I, np.arange(n_obj)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([F, np.full(n_obj, np.inf)]) - np.row_stack([np.full(n_obj, -np.inf), F]) + + # calculate the norm for each objective - set to NaN if all values are equal + norm = np.max(F, axis=0) - np.min(F, axis=0) + norm[norm == 0] = np.nan + + # prepare the distance to last and next vectors + dl = dist.copy()[:-1] + du = dist.copy()[1:] + + #Fix nan + dl[np.isnan(dl)] = 0.0 + du[np.isnan(du)] = 0.0 + + #Total distance + cd = dl + du + + #Get relative positions + pl = (dl[1:-1] / cd[1:-1]) + pu = (du[1:-1] / cd[1:-1]) + + #Entropy + entropy = np.row_stack([np.full(n_obj, np.inf), + -(pl * np.log2(pl) + pu * np.log2(pu)), + np.full(n_obj, np.inf)]) + + #Crowding entropy + J = np.argsort(I, axis=0) + _cej = cd[J, np.arange(n_obj)] * entropy[J, np.arange(n_obj)] / norm + _cej[np.isnan(_cej)] = 0.0 + ce = _cej.sum(axis=1) + + return ce + + +def calc_mnn_fast(F, **kwargs): + return _calc_mnn_fast(F, F.shape[1], **kwargs) + + +def calc_2nn_fast(F, **kwargs): + return _calc_mnn_fast(F, 2, **kwargs) + + +def _calc_mnn_fast(F, n_neighbors, **kwargs): + + # calculate the norm for each objective - set to NaN if all values are equal + norm = np.max(F, axis=0) - np.min(F, axis=0) + norm[norm == 0] = 1.0 + + # F normalized + F = (F - F.min(axis=0)) / norm + + # Distances pairwise (Inefficient) + D = squareform(pdist(F, metric="sqeuclidean")) + + # M neighbors + M = F.shape[1] + _D = np.partition(D, range(1, M+1), axis=1)[:, 1:M+1] + + # Metric d + d = np.prod(_D, axis=1) + + # Set top performers as np.inf + _extremes = np.concatenate((np.argmin(F, axis=0), np.argmax(F, axis=0))) + d[_extremes] = np.inf + + return d diff --git a/pymoo/util/function_loader.py b/pymoo/util/function_loader.py index 085aa41d9..c68282632 100644 --- a/pymoo/util/function_loader.py +++ b/pymoo/util/function_loader.py @@ -4,6 +4,7 @@ def get_functions(): + from pymoo.util.nds.fast_non_dominated_sort import fast_non_dominated_sort from pymoo.util.nds.efficient_non_dominated_sort import efficient_non_dominated_sort from pymoo.util.nds.tree_based_non_dominated_sort import tree_based_non_dominated_sort @@ -11,6 +12,8 @@ def get_functions(): from pymoo.util.misc import calc_perpendicular_distance from pymoo.util.hv import hv from pymoo.util.stochastic_ranking import stochastic_ranking + from pymoo.util.mnn import calc_mnn, calc_2nn + from pymoo.util.pruning_cd import calc_pcd FUNCTIONS = { "fast_non_dominated_sort": { @@ -34,6 +37,15 @@ def get_functions(): "hv": { "python": hv, "cython": "pymoo.cython.hv" }, + "calc_mnn": { + "python": calc_mnn, "cython": "pymoo.cython.mnn" + }, + "calc_2nn": { + "python": calc_2nn, "cython": "pymoo.cython.mnn" + }, + "calc_pcd": { + "python": calc_pcd, "cython": "pymoo.cython.pruning_cd" + }, } diff --git a/pymoo/util/mnn.py b/pymoo/util/mnn.py new file mode 100644 index 000000000..53a641ace --- /dev/null +++ b/pymoo/util/mnn.py @@ -0,0 +1,67 @@ +import numpy as np +from scipy.spatial.distance import pdist, squareform + +def calc_mnn(X, n_remove=0): + return calc_mnn_base(X, n_remove=n_remove, twonn=False) + +def calc_2nn(X, n_remove=0): + return calc_mnn_base(X, n_remove=n_remove, twonn=True) + +def calc_mnn_base(X, n_remove=0, twonn=False): + + N = X.shape[0] + M = X.shape[1] + + if n_remove <= (N - M): + if n_remove < 0: + n_remove = 0 + else: + pass + else: + n_remove = N - M + + if twonn: + M = 2 + + extremes_min = np.argmin(X, axis=0) + extremes_max = np.argmax(X, axis=0) + + min_vals = np.min(X, axis=0) + max_vals = np.max(X, axis=0) + + extremes = np.concatenate((extremes_min, extremes_max)) + + X = (X - min_vals) / (max_vals - min_vals) + + H = np.arange(N) + + D = squareform(pdist(X, metric="sqeuclidean")) + Dnn = np.partition(D, range(1, M+1), axis=1)[:, 1:M+1] + d = np.product(Dnn, axis=1) + d[extremes] = np.inf + + n_removed = 0 + + #While n_remove not acheived + while n_removed < (n_remove - 1): + + #Obtain element to drop + _d = d[H] + _k = np.argmin(_d) + k = H[_k] + H = H[H != k] + + #Update index + n_removed = n_removed + 1 + if n_removed == n_remove: + break + + else: + + D[:, k] = np.inf + Dnn[H] = np.partition(D[H], range(1, M+1), axis=1)[:, 1:M+1] + d[H] = np.product(Dnn[H], axis=1) + d[extremes] = np.inf + + return d + diff --git a/pymoo/util/pruning_cd.py b/pymoo/util/pruning_cd.py new file mode 100644 index 000000000..4407a1d81 --- /dev/null +++ b/pymoo/util/pruning_cd.py @@ -0,0 +1,89 @@ +import numpy as np + +def calc_pcd(X, n_remove=0): + + N = X.shape[0] + M = X.shape[1] + + if n_remove <= (N - M): + if n_remove < 0: + n_remove = 0 + else: + pass + else: + n_remove = N - M + + extremes_min = np.argmin(X, axis=0) + extremes_max = np.argmax(X, axis=0) + + min_vals = np.min(X, axis=0) + max_vals = np.max(X, axis=0) + + extremes = np.concatenate((extremes_min, extremes_max)) + + X = (X - min_vals) / (max_vals - min_vals) + + H = np.arange(N) + d = np.full(N, np.inf) + + I = np.argsort(X, axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + _X = X[I, np.arange(M)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([_X, np.full(M, np.inf)]) - np.row_stack([np.full(M, -np.inf), _X]) + + # prepare the distance to last and next vectors + dist_to_last, dist_to_next = dist, np.copy(dist) + dist_to_last, dist_to_next = dist_to_last[:-1], dist_to_next[1:] + + # if we divide by zero because all values in one columns are equal replace by none + dist_to_last[np.isnan(dist_to_last)] = 0.0 + dist_to_next[np.isnan(dist_to_next)] = 0.0 + + # sum up the distance to next and last and norm by objectives - also reorder from sorted list + J = np.argsort(I, axis=0) + _d = np.sum(dist_to_last[J, np.arange(M)] + dist_to_next[J, np.arange(M)], axis=1) + d[H] = _d + d[extremes] = np.inf + + n_removed = 0 + + #While n_remove not acheived + while n_removed < (n_remove - 1): + + #Obtain element to drop + _d = d[H] + _k = np.argmin(_d) + k = H[_k] + + H = H[H != k] + + #Update index + n_removed = n_removed + 1 + + I = np.argsort(X[H].copy(), axis=0, kind='mergesort') + + # sort the objective space values for the whole matrix + _X = X[H].copy()[I, np.arange(M)] + + # calculate the distance from each point to the last and next + dist = np.row_stack([_X, np.full(M, np.inf)]) - np.row_stack([np.full(M, -np.inf), _X]) + + # prepare the distance to last and next vectors + dist_to_last, dist_to_next = dist, np.copy(dist) + dist_to_last, dist_to_next = dist_to_last[:-1], dist_to_next[1:] + + # if we divide by zero because all values in one columns are equal replace by none + dist_to_last[np.isnan(dist_to_last)] = 0.0 + dist_to_next[np.isnan(dist_to_next)] = 0.0 + + # sum up the distance to next and last and norm by objectives - also reorder from sorted list + J = np.argsort(I, axis=0) + _d = np.sum(dist_to_last[J, np.arange(M)] + dist_to_next[J, np.arange(M)], axis=1) + d[H] = _d + d[extremes] = np.inf + + return d + diff --git a/tests/algorithms/test_nsga2.py b/tests/algorithms/test_nsga2.py index ce692b9c2..262dc94fa 100644 --- a/tests/algorithms/test_nsga2.py +++ b/tests/algorithms/test_nsga2.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pymoo.algorithms.moo.nsga2 import calc_crowding_distance +from pymoo.operators.survival.rank_and_crowding.metrics import calc_crowding_distance from pymoo.util.nds.non_dominated_sorting import NonDominatedSorting from tests.test_util import load_to_test_resource diff --git a/tests/misc/test_crowding_distance.py b/tests/misc/test_crowding_distance.py index 4245f6927..750c67c98 100644 --- a/tests/misc/test_crowding_distance.py +++ b/tests/misc/test_crowding_distance.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from pymoo.algorithms.moo.nsga2 import calc_crowding_distance +from pymoo.operators.survival.rank_and_crowding.metrics import calc_crowding_distance from pymoo.config import get_pymoo From 98c3551f4826767971d6f363750dcd532a7f5a04 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 17 Oct 2022 12:46:57 -0300 Subject: [PATCH 02/32] Create new DE based algorithms --- pymoo/algorithms/moo/gde3.py | 141 +++++++++ pymoo/algorithms/moo/nsde.py | 108 +++++++ pymoo/algorithms/moo/nsder.py | 124 ++++++++ pymoo/algorithms/soo/nonconvex/de.py | 388 +++++++++--------------- pymoo/algorithms/soo/nonconvex/de_ep.py | 279 +++++++++++++++++ pymoo/algorithms/soo/nonconvex/pso.py | 3 +- pymoo/operators/crossover/dex.py | 388 ++++++++++++++++++------ pymoo/operators/selection/des.py | 213 +++++++++++++ 8 files changed, 1304 insertions(+), 340 deletions(-) create mode 100644 pymoo/algorithms/moo/gde3.py create mode 100644 pymoo/algorithms/moo/nsde.py create mode 100644 pymoo/algorithms/moo/nsder.py create mode 100644 pymoo/algorithms/soo/nonconvex/de_ep.py create mode 100644 pymoo/operators/selection/des.py diff --git a/pymoo/algorithms/moo/gde3.py b/pymoo/algorithms/moo/gde3.py new file mode 100644 index 000000000..588822cd2 --- /dev/null +++ b/pymoo/algorithms/moo/gde3.py @@ -0,0 +1,141 @@ +from pymoo.algorithms.moo.nsde import NSDE +from pymoo.core.population import Population +from pymoo.util.dominator import get_relation +from pymoo.operators.survival.rank_and_crowding import RankAndCrowding + + +# ========================================================================================================= +# Implementation +# ========================================================================================================= + + +class GDE3(NSDE): + + def __init__(self, + pop_size=100, + variant="DE/rand/1/bin", + CR=0.5, + F=None, + gamma=1e-4, + **kwargs): + """ + GDE3 is an extension of DE to multi-objective problems using a mixed type survival strategy. + It is implemented in this version with the same constraint handling strategy of NSGA-II by default. + + Derived algorithms GDE3-MNN and GDE3-2NN use by default survival RankAndCrowding with metrics 'mnn' and '2nn'. + + For many-objective problems, try using NSDE-R, GDE3-MNN, or GDE3-2NN. + + For Bi-objective problems, survival = RankAndCrowding(crowding_func='pcd') is very effective. + + Kukkonen, S. & Lampinen, J., 2005. GDE3: The third evolution step of generalized differential evolution. 2005 IEEE congress on evolutionary computation, Volume 1, pp. 443-450. + + Parameters + ---------- + pop_size : int, optional + Population size. Defaults to 100. + + sampling : Sampling, optional + Sampling strategy of pymoo. Defaults to LHS(). + + variant : str, optional + Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: + + - 'ranked' + - 'rand' + - 'best' + - 'current-to-best' + - 'current-to-best' + - 'current-to-rand' + - 'rand-to-best' + + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. + + CR : float, optional + Crossover parameter. Defined in the range [0, 1] + To reinforce mutation, use higher values. To control convergence speed, use lower values. + + F : iterable of float or float, optional + Scale factor or mutation parameter. Defined in the range (0, 2] + To reinforce exploration, use higher values; for exploitation, use lower values. + + gamma : float, optional + Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. + + de_repair : str, optional + Repair of DE mutant vectors. Is either callable or one of: + + - 'bounce-back' + - 'midway' + - 'rand-init' + - 'to-bounds' + + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. + Defaults to 'bounce-back'. + + mutation : Mutation, optional + Pymoo's mutation operator after crossover. Defaults to NoMutation(). + + repair : Repair, optional + Pymoo's repair operator after mutation. Defaults to NoRepair(). + + survival : Survival, optional + Pymoo's survival strategy. + Defaults to RankAndCrowding() with crowding distances ('cd'). + In GDE3, the survival strategy is applied after a one-to-one comparison between child vector and corresponding parent when both are non-dominated by the other. + """ + + super().__init__(pop_size=pop_size, + variant=variant, + CR=CR, + F=F, + gamma=gamma, + **kwargs) + + def _advance(self, infills=None, **kwargs): + + assert infills is not None, "This algorithms uses the AskAndTell interface thus 'infills' must to be provided." + + #The individuals that are considered for the survival later and final survive + survivors = [] + + # now for each of the infill solutions + for k in range(len(self.pop)): + + #Get the offspring an the parent it is coming from + off, parent = infills[k], self.pop[k] + + #Check whether the new solution dominates the parent or not + rel = get_relation(parent, off) + + #If indifferent we add both + if rel == 0: + survivors.extend([parent, off]) + + #If offspring dominates parent + elif rel == -1: + survivors.append(off) + + #If parent dominates offspring + else: + survivors.append(parent) + + #Create the population + survivors = Population.create(*survivors) + + #Perform a survival to reduce to pop size + self.pop = self.survival.do(self.problem, survivors, n_survive=self.n_offsprings) + + +class GDE3MNN(GDE3): + + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): + survival = RankAndCrowding(crowding_func="mnn") + super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) + + +class GDE32NN(GDE3): + + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): + survival = RankAndCrowding(crowding_func="2nn") + super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) \ No newline at end of file diff --git a/pymoo/algorithms/moo/nsde.py b/pymoo/algorithms/moo/nsde.py new file mode 100644 index 000000000..e730350e6 --- /dev/null +++ b/pymoo/algorithms/moo/nsde.py @@ -0,0 +1,108 @@ +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.operators.sampling.lhs import LHS +from pymoo.algorithms.soo.nonconvex.de import InfillDE +from pymoo.operators.survival.rank_and_crowding import RankAndCrowding + + +# ========================================================================================================= +# Implementation +# ========================================================================================================= + + +class NSDE(NSGA2): + + def __init__(self, + pop_size=100, + sampling=LHS(), + variant="DE/rand/1/bin", + CR=0.7, + F=None, + gamma=1e-4, + de_repair="bounce-back", + mutation=None, + repair=None, + survival=RankAndCrowding(), + **kwargs): + """ + NSDE is an algorithm that combines that combines NSGA-II sorting and survival strategies + to DE mutation and crossover. + + For many-objective problems, try using NSDE-R, GDE3-MNN, or GDE3-2NN. + + For Bi-objective problems, survival = RankAndCrowding(crowding_func='pcd') is very effective. + + Parameters + ---------- + pop_size : int, optional + Population size. Defaults to 100. + + sampling : Sampling, optional + Sampling strategy of pymoo. Defaults to LHS(). + + variant : str, optional + Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: + + - "ranked' + - 'rand' + - 'best' + - 'current-to-best' + - 'current-to-best' + - 'current-to-rand' + - 'rand-to-best' + + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. + + CR : float, optional + Crossover parameter. Defined in the range [0, 1] + To reinforce mutation, use higher values. To control convergence speed, use lower values. + + F : iterable of float or float, optional + Scale factor or mutation parameter. Defined in the range (0, 2] + To reinforce exploration, use higher values; for exploitation, use lower values. + + gamma : float, optional + Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. + + de_repair : str, optional + Repair of DE mutant vectors. Is either callable or one of: + + - 'bounce-back' + - 'midway' + - 'rand-init' + - 'to-bounds' + + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. + Defaults to 'bounce-back'. + + mutation : Mutation, optional + Pymoo's mutation operator after crossover. Defaults to NoMutation(). + + repair : Repair, optional + Pymoo's repair operator after mutation. Defaults to NoRepair(). + + survival : Survival, optional + Pymoo's survival strategy. + Defaults to RankAndCrowding() with crowding distances ('cd'). + In GDE3, the survival strategy is applied after a one-to-one comparison between child vector and corresponding parent when both are non-dominated by the other. + """ + + #Number of offsprings at each generation + n_offsprings = pop_size + + #Mating + mating = InfillDE(variant=variant, + CR=CR, + F=F, + gamma=gamma, + de_repair=de_repair, + mutation=mutation, + repair=repair) + + #Init from pymoo's NSGA2 + super().__init__(pop_size=pop_size, + sampling=sampling, + mating=mating, + survival=survival, + eliminate_duplicates=False, + n_offsprings=n_offsprings, + **kwargs) diff --git a/pymoo/algorithms/moo/nsder.py b/pymoo/algorithms/moo/nsder.py new file mode 100644 index 000000000..e5e0349e1 --- /dev/null +++ b/pymoo/algorithms/moo/nsder.py @@ -0,0 +1,124 @@ +import numpy as np +from pymoo.algorithms.moo.nsga3 import ReferenceDirectionSurvival +from pymoo.operators.sampling.lhs import LHS +from pymoo.util.misc import has_feasible +from pymoo.algorithms.moo.nsde import NSDE + +# ========================================================================================================= +# Implementation +# ========================================================================================================= + +class NSDER(NSDE): + + def __init__(self, + ref_dirs, + pop_size=100, + sampling=LHS(), + variant="DE/rand/1/bin", + CR=0.7, + F=None, + gamma=1e-4, + **kwargs): + """ + NSDE-R is an extension of NSDE to many-objective problems (Reddy & Dulikravich, 2019) using NSGA-III survival. + + S. R. Reddy and G. S. Dulikravich, "Many-objective differential evolution optimization based on reference points: NSDE-R," Struct. Multidisc. Optim., vol. 60, pp. 1455-1473, 2019. + + Parameters + ---------- + ref_dirs : array like + The reference directions that should be used during the optimization. + + pop_size : int, optional + Population size. Defaults to 100. + + sampling : Sampling, optional + Sampling strategy of pymoo. Defaults to LHS(). + + variant : str, optional + Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: + + - "ranked' + - 'rand' + - 'best' + - 'current-to-best' + - 'current-to-best' + - 'current-to-rand' + - 'rand-to-best' + + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. + + CR : float, optional + Crossover parameter. Defined in the range [0, 1] + To reinforce mutation, use higher values. To control convergence speed, use lower values. + + F : iterable of float or float, optional + Scale factor or mutation parameter. Defined in the range (0, 2] + To reinforce exploration, use higher values; for exploitation, use lower values. + + gamma : float, optional + Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. + + de_repair : str, optional + Repair of DE mutant vectors. Is either callable or one of: + + - 'bounce-back' + - 'midway' + - 'rand-init' + - 'to-bounds' + + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. + Defaults to 'bounce-back'. + + mutation : Mutation, optional + Pymoo's mutation operator after crossover. Defaults to NoMutation(). + + repair : Repair, optional + Pymoo's repair operator after mutation. Defaults to NoRepair(). + + survival : Survival, optional + Pymoo's survival strategy. + Defaults to ReferenceDirectionSurvival(). + """ + + self.ref_dirs = ref_dirs + + if self.ref_dirs is not None: + + if pop_size is None: + pop_size = len(self.ref_dirs) + + if pop_size < len(self.ref_dirs): + print( + f"WARNING: pop_size={pop_size} is less than the number of reference directions ref_dirs={len(self.ref_dirs)}.\n" + "This might cause unwanted behavior of the algorithm. \n" + "Please make sure pop_size is equal or larger than the number of reference directions. ") + + if 'survival' in kwargs: + survival = kwargs['survival'] + del kwargs['survival'] + else: + survival = ReferenceDirectionSurvival(ref_dirs) + + super().__init__(pop_size=pop_size, + sampling=sampling, + variant=variant, + CR=CR, + F=F, + gamma=gamma, + survival=survival, + **kwargs) + + def _setup(self, problem, **kwargs): + + if self.ref_dirs is not None: + if self.ref_dirs.shape[1] != problem.n_obj: + raise Exception( + "Dimensionality of reference points must be equal to the number of objectives: %s != %s" % + (self.ref_dirs.shape[1], problem.n_obj)) + + def _set_optimum(self, **kwargs): + if not has_feasible(self.pop): + self.opt = self.pop[[np.argmin(self.pop.get("CV"))]] + else: + self.opt = self.survival.opt diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 3a5562f51..a01c39fc3 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -1,279 +1,183 @@ -""" - -Differential Evolution (DE) - --------------------------------- Description ------------------------------- - - - --------------------------------- References -------------------------------- - -[1] J. Blank and K. Deb, pymoo: Multi-Objective Optimization in Python, in IEEE Access, -vol. 8, pp. 89497-89509, 2020, DOI: 10.1109/ACCESS.2020.2990567 - --------------------------------- License ----------------------------------- - - ----------------------------------------------------------------------------- -""" - import numpy as np - from pymoo.algorithms.base.genetic import GeneticAlgorithm from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival -from pymoo.core.infill import InfillCriterion -from pymoo.core.population import Population from pymoo.core.replacement import ImprovementReplacement -from pymoo.core.variable import Choice, get -from pymoo.core.variable import Real -from pymoo.docs import parse_doc_string -from pymoo.operators.control import EvolutionaryParameterControl, NoParameterControl -from pymoo.operators.crossover.binx import mut_binomial -from pymoo.operators.crossover.expx import mut_exp -from pymoo.operators.mutation.pm import PM -from pymoo.operators.repair.bounds_repair import repair_random_init -from pymoo.operators.sampling.rnd import FloatRandomSampling -from pymoo.operators.selection.rnd import fast_fill_random +from pymoo.operators.mutation.nom import NoMutation +from pymoo.core.repair import NoRepair +from pymoo.operators.sampling.lhs import LHS from pymoo.termination.default import DefaultSingleObjectiveTermination from pymoo.util.display.single import SingleObjectiveOutput -from pymoo.util.misc import where_is_what +from pymoo.operators.selection.des import DES +from pymoo.operators.crossover.dex import DEX # ========================================================================================================= -# Crossover -# ========================================================================================================= - -def de_differential(X, F, jitter, alpha=0.001): - n_parents, n_matings, n_var = X.shape - assert n_parents % 2 == 1, "For the differential an odd number of values need to be provided" - - # the differentials from each pair - delta = np.zeros((n_matings, n_var)) - - # for each difference of the differences - for i in range(1, n_parents, 2): - # create the weight vectors with jitter to give some variation - _F = F[:, None].repeat(n_var, axis=1) - _F[jitter] *= (1 + alpha * (np.random.random((jitter.sum(), n_var)) - 0.5)) - - # add the difference to the vector - delta += _F * (X[i] - X[i + 1]) - - # now add the differentials to the first parent - Xp = X[0] + delta - - return Xp - - -# ========================================================================================================= -# Variant +# Implementation # ========================================================================================================= -class Variant(InfillCriterion): - +class InfillDE: + def __init__(self, - selection="best", - n_diffs=1, - F=0.5, - crossover="bin", - CR=0.2, - jitter=False, - prob_mut=0.1, - control=EvolutionaryParameterControl, - **kwargs): - - super().__init__(**kwargs) - self.selection = Choice(selection, options=["rand", "best"], all=["rand", "best", "target-to-best"]) - self.n_diffs = Choice(n_diffs, options=[1], all=[1, 2]) - self.F = Real(F, bounds=(0.4, 0.7), strict=(0.0, None)) - self.crossover = Choice(crossover, ["bin"], all=["bin", "exp", "hypercube", "line"]) - self.CR = Real(CR, bounds=(0.2, 0.8), strict=(0.0, 1.0)) - self.jitter = Choice(jitter, options=[False], all=[True, False]) - - self.mutation = PM(at_least_once=True) - self.mutation.eta = 20 - self.mutation.prob = prob_mut - - self.control = control(self) - - def do(self, problem, pop, n_offsprings, algorithm=None, **kwargs): - control = self.control - - # let the parameter control now some information - control.tell(pop=pop) - - # set the controlled parameter for the desired number of offsprings - control.do(n_offsprings) - - # find the different groups of selection schemes and order them by category - sel, n_diffs = get(self.selection, self.n_diffs, size=n_offsprings) - H = where_is_what(zip(sel, n_diffs)) - - # get the parameters used for reproduction during the crossover - F, CR, jitter = get(self.F, self.CR, self.jitter, size=n_offsprings) - - # the `target` vectors which will be recombined - X = pop.get("X") - - # the `donor` vector which will be obtained through the differential equation - donor = np.full((n_offsprings, problem.n_var), np.nan) - - # for each type defined by the type and number of differentials - for (sel_type, n_diffs), targets in H.items(): - - # the number of offsprings created in this run - n_matings, n_parents = len(targets), 1 + 2 * n_diffs - - # create the parents array - P = np.full([n_matings, n_parents], -1) - - itself = np.array(targets)[:, None] - - best = lambda: np.random.choice(np.where(pop.get("rank") == 0)[0], replace=True, size=n_matings) - - if sel_type == "rand": - fast_fill_random(P, len(pop), columns=range(n_parents), Xp=itself) - elif sel_type == "best": - P[:, 0] = best() - fast_fill_random(P, len(pop), columns=range(1, n_parents), Xp=itself) - elif sel_type == "target-to-best": - P[:, 0] = targets - P[:, 1] = best() - fast_fill_random(P, len(pop), columns=range(2, n_parents), Xp=itself) - else: - raise Exception("Unknown selection method.") - - # get the values of the parents in the design space - XX = np.swapaxes(X[P], 0, 1) - - # do the differential crossover to create the donor vector - Xp = de_differential(XX, F[targets], jitter[targets]) - - # make sure everything stays in bounds - if problem.has_bounds(): - Xp = repair_random_init(Xp, XX[0], *problem.bounds()) - - # set the donors (the one we have created in this step) - donor[targets] = Xp - - # the `trial` created by by recombining target and donor - trial = np.full((n_offsprings, problem.n_var), np.nan) - - crossover = get(self.crossover, size=n_offsprings) - for name, K in where_is_what(crossover).items(): - - _target = X[K] - _donor = donor[K] - _CR = CR[K] - - if name == "bin": - M = mut_binomial(len(K), problem.n_var, _CR, at_least_once=True) - _trial = np.copy(_target) - _trial[M] = _donor[M] - elif name == "exp": - M = mut_exp(n_offsprings, problem.n_var, _CR, at_least_once=True) - _trial = np.copy(_target) - _trial[M] = _donor[M] - elif name == "line": - w = np.random.random((len(K), 1)) * _CR[:, None] - _trial = _target + w * (_donor - _target) - elif name == "hypercube": - w = np.random.random((len(K), _target.shape[1])) * _CR[:, None] - _trial = _target + w * (_donor - _target) - else: - raise Exception(f"Unknown crossover variant: {name}") - - trial[K] = _trial - - # create the population - off = Population.new(X=trial) - - # do the mutation which helps to add some more diversity - off = self.mutation(problem, off) - - # repair the individuals if necessary - disabled if repair is NoRepair - off = self.repair(problem, off, **kwargs) - - # advance the parameter control by attaching them to the offsprings - control.advance(off) - + variant="DE/rand/1/bin", + CR=0.7, + F=(0.5, 1.0), + gamma=1e-4, + de_repair="bounce-back", + mutation=None, + repair=None): + + # Parse the information from the string + _, selection_variant, n_diff, crossover_variant, = variant.split("/") + n_diffs = int(n_diff) + + # When "to" in variant there are more than 1 difference vectors + if "-to-" in variant: + n_diffs += 1 + + # Define parent selection operator + self.selection = DES(selection_variant) + + #Default value for F + if F is None: + F = (0.0, 1.0) + + # Define crossover strategy + self.crossover = DEX(variant=crossover_variant, + CR=CR, + F=F, + gamma=gamma, + n_diffs=n_diffs, + at_least_once=True, + repair=repair) + + # Define posterior mutation strategy and repair + self.mutation = mutation if mutation is not None else NoMutation() + self.repair = repair if repair is not None else NoRepair() + + def do(self, problem, pop, n_offsprings, **kwargs): + + # Select parents including donor vector + parents = self.selection.do(problem, pop, n_offsprings, self.crossover.n_parents, + to_pop=False, **kwargs) + + # Perform mutation included in DEX and crossover + off = self.crossover.do(problem, pop, parents, **kwargs) + + # Perform posterior mutation and repair if passed + off = self.mutation.do(problem, off) + off = self.repair.do(problem, off) + return off - - -# ========================================================================================================= -# Implementation -# ========================================================================================================= - + class DE(GeneticAlgorithm): def __init__(self, pop_size=100, - n_offsprings=None, - sampling=FloatRandomSampling(), - variant="DE/best/1/bin", + sampling=LHS(), + variant="DE/rand/1/bin", + CR=0.7, + F=(0.5, 1.0), + gamma=1e-4, + de_repair="bounce-back", + mutation=None, + repair=None, output=SingleObjectiveOutput(), - **kwargs - ): - - if variant is None: - if "control" not in kwargs: - kwargs["control"] = NoParameterControl - variant = Variant(**kwargs) - - elif isinstance(variant, str): - try: - _, selection, n_diffs, crossover = variant.split("/") - if "control" not in kwargs: - kwargs["control"] = NoParameterControl - variant = Variant(selection=selection, n_diffs=int(n_diffs), crossover=crossover, **kwargs) - except: - raise Exception("Please provide a valid variant: DE///") + **kwargs): + """ + Single-objective Differential Evolution proposed by Storn and Price (1997). + + Storn, R. & Price, K., 1997. Differential evolution–a simple and efficient heuristic for global optimization over continuous spaces. J. Glob. Optim., 11(4), pp. 341-359. + + Parameters + ---------- + pop_size : int, optional + Population size. Defaults to 100. + + sampling : Sampling, optional + Sampling strategy of pymoo. Defaults to LHS(). + + variant : str, optional + Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: + + - 'ranked' + - 'rand' + - 'best' + - 'current-to-best' + - 'current-to-best' + - 'current-to-rand' + - 'rand-to-best' + + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. + + CR : float, optional + Crossover parameter. Defined in the range [0, 1] + To reinforce mutation, use higher values. To control convergence speed, use lower values. + + F : iterable of float or float, optional + Scale factor or mutation parameter. Defined in the range (0, 2] + To reinforce exploration, use higher values; for exploitation, use lower values. + + gamma : float, optional + Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. + + de_repair : str, optional + Repair of DE mutant vectors. Is either callable or one of: + + - 'bounce-back' + - 'midway' + - 'rand-init' + - 'to-bounds' + + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. + Defaults to 'bounce-back'. + + mutation : Mutation, optional + Pymoo's mutation operator after crossover. Defaults to NoMutation(). + + repair : Repair, optional + Pymoo's repair operator after mutation. Defaults to NoRepair(). + """ + + mating = InfillDE(variant=variant, + CR=CR, + F=F, + gamma=gamma, + de_repair=de_repair, + mutation=mutation, + repair=repair) + + # Number of offsprings at each generation + n_offsprings = pop_size super().__init__(pop_size=pop_size, - n_offsprings=n_offsprings, sampling=sampling, - mating=variant, - survival=None, - output=output, + mating=mating, + n_offsprings=n_offsprings, eliminate_duplicates=False, + output=output, **kwargs) self.termination = DefaultSingleObjectiveTermination() def _initialize_advance(self, infills=None, **kwargs): - FitnessSurvival().do(self.problem, self.pop, return_indices=True) + self.pop = FitnessSurvival().do(self.problem, infills, n_survive=self.pop_size) def _infill(self): + infills = self.mating.do(self.problem, self.pop, self.n_offsprings, algorithm=self) - # tag each individual with an index - if a steady state version is executed - index = np.arange(len(infills)) - - # if number of offsprings is set lower than pop_size - randomly select - if self.n_offsprings < self.pop_size: - index = np.random.permutation(len(infills))[:self.n_offsprings] - infills = infills[index] - - infills.set("index", index) - return infills def _advance(self, infills=None, **kwargs): - assert infills is not None, "This algorithms uses the AskAndTell interface thus infills must to be provided." - - # get the indices where each offspring is originating from - I = infills.get("index") - - # replace the individuals with the corresponding parents from the mating - self.pop[I] = ImprovementReplacement().do(self.problem, self.pop[I], infills) - - # update the information regarding the current population - FitnessSurvival().do(self.problem, self.pop, return_indices=True) - - def _set_optimum(self, **kwargs): - k = self.pop.get("rank") == 0 - self.opt = self.pop[k] + + assert infills is not None, "This algorithms uses the AskAndTell interface thus infills must be provided." + #One-to-one replacement survival + self.pop = ImprovementReplacement().do(self.problem, self.pop, infills) -parse_doc_string(DE.__init__) + #Sort the population by fitness to make the selection simpler for mating (not an actual survival, just sorting) + self.pop = FitnessSurvival().do(self.problem, self.pop) + + #Set ranks + self.pop.set("rank", np.arange(self.pop_size)) diff --git a/pymoo/algorithms/soo/nonconvex/de_ep.py b/pymoo/algorithms/soo/nonconvex/de_ep.py new file mode 100644 index 000000000..690a36d86 --- /dev/null +++ b/pymoo/algorithms/soo/nonconvex/de_ep.py @@ -0,0 +1,279 @@ +""" + +Differential Evolution (DE) + +-------------------------------- Description ------------------------------- + + + +-------------------------------- References -------------------------------- + +[1] J. Blank and K. Deb, pymoo: Multi-Objective Optimization in Python, in IEEE Access, +vol. 8, pp. 89497-89509, 2020, DOI: 10.1109/ACCESS.2020.2990567 + +-------------------------------- License ----------------------------------- + + +---------------------------------------------------------------------------- +""" + +import numpy as np + +from pymoo.algorithms.base.genetic import GeneticAlgorithm +from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival +from pymoo.core.infill import InfillCriterion +from pymoo.core.population import Population +from pymoo.core.replacement import ImprovementReplacement +from pymoo.core.variable import Choice, get +from pymoo.core.variable import Real +from pymoo.docs import parse_doc_string +from pymoo.operators.control import EvolutionaryParameterControl, NoParameterControl +from pymoo.operators.crossover.binx import mut_binomial +from pymoo.operators.crossover.expx import mut_exp +from pymoo.operators.mutation.pm import PM +from pymoo.operators.repair.bounds_repair import repair_random_init +from pymoo.operators.sampling.rnd import FloatRandomSampling +from pymoo.operators.selection.rnd import fast_fill_random +from pymoo.termination.default import DefaultSingleObjectiveTermination +from pymoo.util.display.single import SingleObjectiveOutput +from pymoo.util.misc import where_is_what + + +# ========================================================================================================= +# Crossover +# ========================================================================================================= + +def de_differential(X, F, jitter, alpha=0.001): + n_parents, n_matings, n_var = X.shape + assert n_parents % 2 == 1, "For the differential an odd number of values need to be provided" + + # the differentials from each pair + delta = np.zeros((n_matings, n_var)) + + # for each difference of the differences + for i in range(1, n_parents, 2): + # create the weight vectors with jitter to give some variation + _F = F[:, None].repeat(n_var, axis=1) + _F[jitter] *= (1 + alpha * (np.random.random((jitter.sum(), n_var)) - 0.5)) + + # add the difference to the vector + delta += _F * (X[i] - X[i + 1]) + + # now add the differentials to the first parent + Xp = X[0] + delta + + return Xp + + +# ========================================================================================================= +# Variant +# ========================================================================================================= + +class Variant(InfillCriterion): + + def __init__(self, + selection="best", + n_diffs=1, + F=0.5, + crossover="bin", + CR=0.2, + jitter=False, + prob_mut=0.1, + control=EvolutionaryParameterControl, + **kwargs): + + super().__init__(**kwargs) + self.selection = Choice(selection, options=["rand", "best"], all=["rand", "best", "target-to-best"]) + self.n_diffs = Choice(n_diffs, options=[1], all=[1, 2]) + self.F = Real(F, bounds=(0.4, 0.7), strict=(0.0, None)) + self.crossover = Choice(crossover, ["bin"], all=["bin", "exp", "hypercube", "line"]) + self.CR = Real(CR, bounds=(0.2, 0.8), strict=(0.0, 1.0)) + self.jitter = Choice(jitter, options=[False], all=[True, False]) + + self.mutation = PM(at_least_once=True) + self.mutation.eta = 20 + self.mutation.prob = prob_mut + + self.control = control(self) + + def do(self, problem, pop, n_offsprings, algorithm=None, **kwargs): + control = self.control + + # let the parameter control now some information + control.tell(pop=pop) + + # set the controlled parameter for the desired number of offsprings + control.do(n_offsprings) + + # find the different groups of selection schemes and order them by category + sel, n_diffs = get(self.selection, self.n_diffs, size=n_offsprings) + H = where_is_what(zip(sel, n_diffs)) + + # get the parameters used for reproduction during the crossover + F, CR, jitter = get(self.F, self.CR, self.jitter, size=n_offsprings) + + # the `target` vectors which will be recombined + X = pop.get("X") + + # the `donor` vector which will be obtained through the differential equation + donor = np.full((n_offsprings, problem.n_var), np.nan) + + # for each type defined by the type and number of differentials + for (sel_type, n_diffs), targets in H.items(): + + # the number of offsprings created in this run + n_matings, n_parents = len(targets), 1 + 2 * n_diffs + + # create the parents array + P = np.full([n_matings, n_parents], -1) + + itself = np.array(targets)[:, None] + + best = lambda: np.random.choice(np.where(pop.get("rank") == 0)[0], replace=True, size=n_matings) + + if sel_type == "rand": + fast_fill_random(P, len(pop), columns=range(n_parents), Xp=itself) + elif sel_type == "best": + P[:, 0] = best() + fast_fill_random(P, len(pop), columns=range(1, n_parents), Xp=itself) + elif sel_type == "target-to-best": + P[:, 0] = targets + P[:, 1] = best() + fast_fill_random(P, len(pop), columns=range(2, n_parents), Xp=itself) + else: + raise Exception("Unknown selection method.") + + # get the values of the parents in the design space + XX = np.swapaxes(X[P], 0, 1) + + # do the differential crossover to create the donor vector + Xp = de_differential(XX, F[targets], jitter[targets]) + + # make sure everything stays in bounds + if problem.has_bounds(): + Xp = repair_random_init(Xp, XX[0], *problem.bounds()) + + # set the donors (the one we have created in this step) + donor[targets] = Xp + + # the `trial` created by by recombining target and donor + trial = np.full((n_offsprings, problem.n_var), np.nan) + + crossover = get(self.crossover, size=n_offsprings) + for name, K in where_is_what(crossover).items(): + + _target = X[K] + _donor = donor[K] + _CR = CR[K] + + if name == "bin": + M = mut_binomial(len(K), problem.n_var, _CR, at_least_once=True) + _trial = np.copy(_target) + _trial[M] = _donor[M] + elif name == "exp": + M = mut_exp(n_offsprings, problem.n_var, _CR, at_least_once=True) + _trial = np.copy(_target) + _trial[M] = _donor[M] + elif name == "line": + w = np.random.random((len(K), 1)) * _CR[:, None] + _trial = _target + w * (_donor - _target) + elif name == "hypercube": + w = np.random.random((len(K), _target.shape[1])) * _CR[:, None] + _trial = _target + w * (_donor - _target) + else: + raise Exception(f"Unknown crossover variant: {name}") + + trial[K] = _trial + + # create the population + off = Population.new(X=trial) + + # do the mutation which helps to add some more diversity + off = self.mutation(problem, off) + + # repair the individuals if necessary - disabled if repair is NoRepair + off = self.repair(problem, off, **kwargs) + + # advance the parameter control by attaching them to the offsprings + control.advance(off) + + return off + + +# ========================================================================================================= +# Implementation +# ========================================================================================================= + + +class EPDE(GeneticAlgorithm): + + def __init__(self, + pop_size=100, + n_offsprings=None, + sampling=FloatRandomSampling(), + variant="DE/best/1/bin", + output=SingleObjectiveOutput(), + **kwargs + ): + + if variant is None: + if "control" not in kwargs: + kwargs["control"] = NoParameterControl + variant = Variant(**kwargs) + + elif isinstance(variant, str): + try: + _, selection, n_diffs, crossover = variant.split("/") + if "control" not in kwargs: + kwargs["control"] = NoParameterControl + variant = Variant(selection=selection, n_diffs=int(n_diffs), crossover=crossover, **kwargs) + except: + raise Exception("Please provide a valid variant: DE///") + + super().__init__(pop_size=pop_size, + n_offsprings=n_offsprings, + sampling=sampling, + mating=variant, + survival=None, + output=output, + eliminate_duplicates=False, + **kwargs) + + self.termination = DefaultSingleObjectiveTermination() + + def _initialize_advance(self, infills=None, **kwargs): + FitnessSurvival().do(self.problem, self.pop, return_indices=True) + + def _infill(self): + infills = self.mating.do(self.problem, self.pop, self.n_offsprings, algorithm=self) + + # tag each individual with an index - if a steady state version is executed + index = np.arange(len(infills)) + + # if number of offsprings is set lower than pop_size - randomly select + if self.n_offsprings < self.pop_size: + index = np.random.permutation(len(infills))[:self.n_offsprings] + infills = infills[index] + + infills.set("index", index) + + return infills + + def _advance(self, infills=None, **kwargs): + assert infills is not None, "This algorithms uses the AskAndTell interface thus infills must to be provided." + + # get the indices where each offspring is originating from + I = infills.get("index") + + # replace the individuals with the corresponding parents from the mating + self.pop[I] = ImprovementReplacement().do(self.problem, self.pop[I], infills) + + # update the information regarding the current population + FitnessSurvival().do(self.problem, self.pop, return_indices=True) + + def _set_optimum(self, **kwargs): + k = self.pop.get("rank") == 0 + self.opt = self.pop[k] + + +parse_doc_string(EPDE.__init__) diff --git a/pymoo/algorithms/soo/nonconvex/pso.py b/pymoo/algorithms/soo/nonconvex/pso.py index 5943eae1b..6973f67e6 100644 --- a/pymoo/algorithms/soo/nonconvex/pso.py +++ b/pymoo/algorithms/soo/nonconvex/pso.py @@ -8,9 +8,8 @@ from pymoo.core.repair import NoRepair from pymoo.core.replacement import ImprovementReplacement from pymoo.docs import parse_doc_string -from pymoo.operators.crossover.dex import repair_random_init from pymoo.operators.mutation.pm import PM -from pymoo.operators.repair.bounds_repair import is_out_of_bounds_by_problem +from pymoo.operators.repair.bounds_repair import is_out_of_bounds_by_problem, repair_random_init from pymoo.operators.repair.to_bound import set_to_bounds_if_outside from pymoo.operators.sampling.lhs import LHS from pymoo.util.display.column import Column diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index 157856498..3e2d5135f 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -1,122 +1,318 @@ import numpy as np - from pymoo.core.crossover import Crossover from pymoo.core.population import Population from pymoo.operators.crossover.binx import mut_binomial from pymoo.operators.crossover.expx import mut_exp -from pymoo.operators.repair.bounds_repair import is_out_of_bounds_by_problem, repair_random_init - - -def de_differential(X, F, dither=None, jitter=True, gamma=0.0001, return_differentials=False): - n_parents, n_matings, n_var = X.shape - assert n_parents % 2 == 1, "For the differential an odd number of values need to be provided" - - # make sure F is a one-dimensional vector - F = np.ones(n_matings) * F - - # build the pairs for the differentials - pairs = (np.arange(n_parents - 1) + 1).reshape(-1, 2) - - # the differentials from each pair subtraction - diffs = np.zeros((n_matings, n_var)) - # for each difference - for i, j in pairs: - if dither == "vector": - F = (F + np.random.random(n_matings) * (1 - F)) - elif dither == "scalar": - F = F + np.random.random() * (1 - F) - - # http://www.cs.ndsu.nodak.edu/~siludwig/Publish/papers/SSCI20141.pdf - if jitter: - F = (F * (1 + gamma * (np.random.random(n_matings) - 0.5))) - - # an add the difference to the first vector - diffs += F[:, None] * (X[i] - X[j]) - - # now add the differentials to the first parent - Xp = X[0] + diffs - - if return_differentials: - return Xp, diffs - else: - return Xp +# ========================================================================================================= +# Implementation +# ========================================================================================================= +class DEM: + + def __init__(self, + F=None, + gamma=1e-4, + de_repair="bounce-back", + **kwargs): + # Default value for F + if F is None: + F = (0.0, 1.0) + + # Define which method will be used to generate F values + if hasattr(F, "__iter__"): + self.scale_factor = self._randomize_scale_factor + else: + self.scale_factor = self._scalar_scale_factor + + # Define which method will be used to generate F values + if not hasattr(de_repair, "__call__"): + try: + de_repair = REPAIRS[de_repair] + except: + raise KeyError("Repair must be either callable or in " + str(list(REPAIRS.keys()))) + + # Define which strategy of rotation will be used + if gamma is None: + self.get_diff = self._diff_simple + else: + self.get_diff = self._diff_jitter + + self.F = F + self.gamma = gamma + self.de_repair = de_repair + + def do(self, problem, pop, parents, **kwargs): + + # Get all X values for mutation parents + Xr = pop.get("X")[parents.T].copy() + assert len(Xr.shape) == 3, "Please provide a three-dimensional matrix n_parents x pop_size x n_vars." + + # Create mutation vectors + V, diffs = self.de_mutation(Xr, return_differentials=True) + + # If the problem has boundaries to be considered + if problem.has_bounds(): + + # Do repair + V = self.de_repair(V, Xr[0], *problem.bounds()) + + return Population.new("X", V) + + def de_mutation(self, Xr, return_differentials=True): + + n_parents, n_matings, n_var = Xr.shape + assert n_parents % 2 == 1, "For the differential an odd number of values need to be provided" + + # Build the pairs for the differentials + pairs = (np.arange(n_parents - 1) + 1).reshape(-1, 2) + + # The differentials from each pair subtraction + diffs = self.get_diffs(Xr, pairs, n_matings, n_var) + + # Add the difference vectors to the base vector + V = Xr[0] + diffs + + if return_differentials: + return V, diffs + else: + return V + + def _randomize_scale_factor(self, n_matings): + return (self.F[0] + np.random.random(n_matings) * (self.F[1] - self.F[0])) + + def _scalar_scale_factor(self, n_matings): + return np.full(n_matings, self.F) + + def _diff_jitter(self, F, Xi, Xj, n_matings, n_var): + F = F[:, None] * (1 + self.gamma * (np.random.random((n_matings, n_var)) - 0.5)) + return F * (Xi - Xj) + + def _diff_simple(self, F, Xi, Xj, n_matings, n_var): + return F[:, None] * (Xi - Xj) + + def get_diffs(self, Xr, pairs, n_matings, n_var): + + # The differentials from each pair subtraction + diffs = np.zeros((n_matings, n_var)) + + # For each difference + for i, j in pairs: + + # Obtain F randomized in range + F = self.scale_factor(n_matings) + + # New difference vector + diff = self.get_diff(F, Xr[i], Xr[j], n_matings, n_var) + + # Add the difference to the first vector + diffs = diffs + diff + + return diffs + + class DEX(Crossover): - + def __init__(self, - F=None, - CR=0.7, variant="bin", - dither=None, - jitter=False, + CR=0.7, + F=None, + gamma=1e-4, n_diffs=1, - n_iter=1, at_least_once=True, + de_repair="bounce-back", **kwargs): - - super().__init__(1 + 2 * n_diffs, 1, **kwargs) - self.n_diffs = n_diffs - self.F = F + + # Default value for F + if F is None: + F = (0.0, 1.0) + + # Create instace for mutation + self.dem = DEM(F=F, + gamma=gamma, + de_repair=de_repair) + self.CR = CR self.variant = variant self.at_least_once = at_least_once - self.dither = dither - self.jitter = jitter - self.n_iter = n_iter - - def do(self, problem, pop, parents=None, **kwargs): - - # if a parents with array with mating indices is provided -> transform the input first - if parents is not None: - pop = [pop[mating] for mating in parents] - - # get the actual values from each of the parents - X = np.swapaxes(np.array([[parent.get("X") for parent in mating] for mating in pop]), 0, 1).copy() - - n_parents, n_matings, n_var = X.shape - - # a mask over matings that need to be repeated - m = np.arange(n_matings) - - # if the user provides directly an F value to use - F = self.F if self.F is not None else rnd_F(m) - - # prepare the out to be set - Xp = de_differential(X[:, m], F) - - # if the problem has boundaries to be considered - if problem.has_bounds(): - - for k in range(self.n_iter): - # find the individuals which are still infeasible - m = is_out_of_bounds_by_problem(problem, Xp) - - F = rnd_F(m) - - # actually execute the differential equation - Xp[m] = de_differential(X[:, m], F) - - # if still infeasible do a random initialization - Xp = repair_random_init(Xp, X[0], *problem.bounds()) - + + super().__init__(2 + 2 * n_diffs, 1, prob=1.0, **kwargs) + + def do(self, problem, pop, parents, **kwargs): + + # Get target vectors + X = pop.get("X")[parents[:, 0]] + + # About Xi + n_matings, n_var = X.shape + + # Obtain mutants + mutants = self.dem.do(problem, pop, parents[:, 1:], **kwargs) + + # Obtain V + V = mutants.get("X") + + # Binomial crossover if self.variant == "bin": M = mut_binomial(n_matings, n_var, self.CR, at_least_once=self.at_least_once) + # Exponential crossover elif self.variant == "exp": M = mut_exp(n_matings, n_var, self.CR, at_least_once=self.at_least_once) else: raise Exception(f"Unknown variant: {self.variant}") - # take the first parents (this is already a copy) - X = X[0] - - # set the corresponding values from the donor vector - X[M] = Xp[M] - - return Population.new("X", X) - + # Add mutated elements in corresponding main parent + X[M] = V[M] + + off = Population.new("X", X) + + return off + + +def bounce_back(X, Xb, xl, xu): + """Repair strategy + + Args: + X (2d array like): Mutated vectors including violations. + Xb (2d array like): Reference vectors for repair in feasible space. + xl (1d array like): Lower-bounds. + xu (1d array like): Upper-bounds. -def rnd_F(m): - return 0.5 * (1 + np.random.uniform(size=len(m))) + Returns: + 2d array like: Repaired vectors. + """ + + XL = xl[None, :].repeat(len(X), axis=0) + XU = xu[None, :].repeat(len(X), axis=0) + + i, j = np.where(X < XL) + if len(i) > 0: + X[i, j] = XL[i, j] + np.random.random(len(i)) * (Xb[i, j] - XL[i, j]) + + i, j = np.where(X > XU) + if len(i) > 0: + X[i, j] = XU[i, j] - np.random.random(len(i)) * (XU[i, j] - Xb[i, j]) + + return X + +def midway(X, Xb, xl, xu): + """Repair strategy + + Args: + X (2d array like): Mutated vectors including violations. + Xb (2d array like): Reference vectors for repair in feasible space. + xl (1d array like): Lower-bounds. + xu (1d array like): Upper-bounds. + + Returns: + 2d array like: Repaired vectors. + """ + + XL = xl[None, :].repeat(len(X), axis=0) + XU = xu[None, :].repeat(len(X), axis=0) + + i, j = np.where(X < XL) + if len(i) > 0: + X[i, j] = XL[i, j] + (Xb[i, j] - XL[i, j]) / 2 + + i, j = np.where(X > XU) + if len(i) > 0: + X[i, j] = XU[i, j] - (XU[i, j] - Xb[i, j]) / 2 + + return X + +def to_bounds(X, Xb, xl, xu): + """Repair strategy + + Args: + X (2d array like): Mutated vectors including violations. + Xb (2d array like): Reference vectors for repair in feasible space. + xl (1d array like): Lower-bounds. + xu (1d array like): Upper-bounds. + + Returns: + 2d array like: Repaired vectors. + """ + + XL = xl[None, :].repeat(len(X), axis=0) + XU = xu[None, :].repeat(len(X), axis=0) + + i, j = np.where(X < XL) + if len(i) > 0: + X[i, j] = XL[i, j] + + i, j = np.where(X > XU) + if len(i) > 0: + X[i, j] = XU[i, j] + + return X + +def rand_init(X, Xb, xl, xu): + """Repair strategy + + Args: + X (2d array like): Mutated vectors including violations. + Xb (2d array like): Reference vectors for repair in feasible space. + xl (1d array like): Lower-bounds. + xu (1d array like): Upper-bounds. + + Returns: + 2d array like: Repaired vectors. + """ + + XL = xl[None, :].repeat(len(X), axis=0) + XU = xu[None, :].repeat(len(X), axis=0) + + i, j = np.where(X < XL) + if len(i) > 0: + X[i, j] = XL[i, j] + np.random.random(len(i)) * (XU[i, j] - XL[i, j]) + + i, j = np.where(X > XU) + if len(i) > 0: + X[i, j] = XU[i, j] - np.random.random(len(i)) * (XU[i, j] - XL[i, j]) + + return X + + +def squared_bounce_back(X, Xb, xl, xu): + """Repair strategy + + Args: + X (2d array like): Mutated vectors including violations. + Xb (2d array like): Reference vectors for repair in feasible space. + xl (1d array like): Lower-bounds. + xu (1d array like): Upper-bounds. + + Returns: + 2d array like: Repaired vectors. + """ + + XL = xl[None, :].repeat(len(X), axis=0) + XU = xu[None, :].repeat(len(X), axis=0) + + i, j = np.where(X < XL) + if len(i) > 0: + X[i, j] = XL[i, j] + np.square(np.random.random(len(i))) * (Xb[i, j] - XL[i, j]) + + i, j = np.where(X > XU) + if len(i) > 0: + X[i, j] = XU[i, j] - np.square(np.random.random(len(i))) * (XU[i, j] - Xb[i, j]) + + return X + +def normalize_fun(fun): + + fmin = fun.min(axis=0) + fmax = fun.max(axis=0) + den = fmax - fmin + + den[den <= 1e-16] = 1.0 + + return (fun - fmin)/den + +REPAIRS = {"bounce-back":bounce_back, + "midway":midway, + "rand-init":rand_init, + "to-bounds":to_bounds} diff --git a/pymoo/operators/selection/des.py b/pymoo/operators/selection/des.py new file mode 100644 index 000000000..897eac412 --- /dev/null +++ b/pymoo/operators/selection/des.py @@ -0,0 +1,213 @@ +import numpy as np +from pymoo.core.selection import Selection + + +# ========================================================================================================= +# Implementation +# ========================================================================================================= + + +# This is the core differential evolution selection class +class DES(Selection): + + def __init__(self, + variant, + **kwargs): + + super().__init__() + self.variant = variant + + def _do(self, problem, pop, n_select, n_parents, **kwargs): + + # Obtain number of elements in population + n_pop = len(pop) + + # For most variants n_select must be equal to len(pop) + variant = self.variant + + if variant == "ranked": + """Proposed by Zhang et al. (2021). doi.org/10.1016/j.asoc.2021.107317""" + P = self._ranked(pop, n_select, n_parents) + + elif variant == "best": + P = self._best(pop, n_select, n_parents) + + elif variant == "current-to-best": + P = self._current_to_best(pop, n_select, n_parents) + + elif variant == "current-to-rand": + P = self._current_to_rand(pop, n_select, n_parents) + + else: + P = self._rand(pop, n_select, n_parents) + + return P + + def _rand(self, pop, n_select, n_parents, **kwargs): + + # len of pop + n_pop = len(pop) + + # Base form + P = np.empty([n_select, n_parents], dtype=int) + + # Fill first column with corresponding parent + P[:, 0] = np.arange(n_pop) + + # Fill next columns in loop + for j in range(1, n_parents): + + P[:, j] = np.random.choice(n_pop, n_select) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + while np.any(reselect): + P[reselect, j] = np.random.choice(n_pop, reselect.sum()) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + return P + + def _best(self, pop, n_select, n_parents, **kwargs): + + # len of pop + n_pop = len(pop) + + # Base form + P = np.empty([n_select, n_parents], dtype=int) + + # Fill first column with corresponding parent + P[:, 0] = np.arange(n_pop) + + # Fill first column with best candidate + P[:, 1] = 0 + + # Fill next columns in loop + for j in range(2, n_parents): + + P[:, j] = np.random.choice(n_pop, n_select) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + while np.any(reselect): + P[reselect, j] = np.random.choice(n_pop, reselect.sum()) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + return P + + def _current_to_best(self, pop, n_select, n_parents, **kwargs): + + # len of pop + n_pop = len(pop) + + # Base form + P = np.empty([n_select, n_parents], dtype=int) + + # Fill first column with corresponding parent + P[:, 0] = np.arange(n_pop) + + # Fill first column with current candidate + P[:, 1] = np.arange(n_pop) + + # Fill first direction from current + P[:, 3] = np.arange(n_pop) + + # Towards best + P[:, 2] = 0 + + # Fill next columns in loop + for j in range(4, n_parents): + + P[:, j] = np.random.choice(n_pop, n_select) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + while np.any(reselect): + P[reselect, j] = np.random.choice(n_pop, reselect.sum()) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + return P + + def _current_to_rand(self, pop, n_select, n_parents, **kwargs): + + # len of pop + n_pop = len(pop) + + # Base form + P = np.empty([n_select, n_parents], dtype=int) + + # Fill first column with corresponding parent + P[:, 0] = np.arange(n_pop) + + # Fill first column with current candidate + P[:, 1] = np.arange(n_pop) + + # Fill first direction from current + P[:, 3] = np.arange(n_pop) + + # Towards random + P[:, 2] = np.random.choice(n_pop, n_select) + reselect = (P[:, 2].reshape([-1, 1]) == P[:, [0, 1, 3]]).any(axis=1) + + while np.any(reselect): + P[reselect, 2] = np.random.choice(n_pop, reselect.sum()) + reselect = (P[:, 2].reshape([-1, 1]) == P[:, [0, 1, 3]]).any(axis=1) + + # Fill next columns in loop + for j in range(4, n_parents): + + P[:, j] = np.random.choice(n_pop, n_select) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + while np.any(reselect): + P[reselect, j] = np.random.choice(n_pop, reselect.sum()) + reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) + + return P + + def _ranked(self, pop, n_select, n_parents, **kwargs): + + P = self._rand(pop, n_select, n_parents, **kwargs) + P[:, 1:] = rank_sort(P[:, 1:], pop) + + return P + + +def ranks_from_cv(pop): + + ranks = pop.get("rank") + cv_elements = ranks == None + + if np.any(cv_elements): + ranks[cv_elements] = np.arange(len(pop))[cv_elements] + + return ranks + +def rank_sort(P, pop): + + ranks = ranks_from_cv(pop) + + sorted = np.argsort(ranks[P], axis=1, kind="stable") + S = np.take_along_axis(P, sorted, axis=1) + + P[:, 0] = S[:, 0] + + n_diffs = int((P.shape[1] - 1) / 2) + + for j in range(1, n_diffs + 1): + P[:, 2*j - 1] = S[:, j] + P[:, 2*j] = S[:, -j] + + return P + +def reiforce_directions(P, pop): + + ranks = ranks_from_cv(pop) + + ranks = ranks[P] + S = P.copy() + + n_diffs = int(P.shape[1] / 2) + + for j in range(0, n_diffs): + bad_directions = ranks[:, 2*j] > ranks[:, 2*j + 1] + P[bad_directions, 2*j] = S[bad_directions, 2*j + 1] + P[bad_directions, 2*j + 1] = S[bad_directions, 2*j] + + return P From 8da25a82c1b991e859a37d84364892623ad2b83c Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 17 Oct 2022 13:43:50 -0300 Subject: [PATCH 03/32] Fix RankAndCrowding docs pcd --- pymoo/operators/survival/rank_and_crowding/classes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymoo/operators/survival/rank_and_crowding/classes.py b/pymoo/operators/survival/rank_and_crowding/classes.py index 516a0dce3..dcf3bc1fd 100644 --- a/pymoo/operators/survival/rank_and_crowding/classes.py +++ b/pymoo/operators/survival/rank_and_crowding/classes.py @@ -27,7 +27,7 @@ def __init__(self, nds=None, crowding_func="cd"): Crowding metric. Options are: - 'cd': crowding distances - - 'pcd' or 'pruned-cd': pruned crowding distances + - 'pcd' or 'pruning-cd': improved pruning based on crowding distances - 'ce': crowding entropy - 'mnn': M-Neaest Neighbors - '2nn': 2-Neaest Neighbors @@ -120,7 +120,7 @@ def __init__(self, nds=None, crowding_func="cd"): Crowding metric. Options are: - 'cd': crowding distances - - 'pcd' or 'pruned-cd': pruned crowding distances + - 'pcd' or 'pruning-cd': improved pruning based on crowding distances - 'ce': crowding entropy - 'mnn': M-Neaest Neighbors - '2nn': 2-Neaest Neighbors From 0ab60bd43c931e35331666b3a6bbe987a89ebb55 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 17 Oct 2022 13:47:51 -0300 Subject: [PATCH 04/32] Included GDE3PCD as a class --- pymoo/algorithms/moo/gde3.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pymoo/algorithms/moo/gde3.py b/pymoo/algorithms/moo/gde3.py index 588822cd2..00067f1f4 100644 --- a/pymoo/algorithms/moo/gde3.py +++ b/pymoo/algorithms/moo/gde3.py @@ -138,4 +138,10 @@ class GDE32NN(GDE3): def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): survival = RankAndCrowding(crowding_func="2nn") + super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) + +class GDE3PCD(GDE3): + + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): + survival = RankAndCrowding(crowding_func="pcd") super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) \ No newline at end of file From 91c7aad6b61cb9f7d1f78238a5eaaa6da1a1fb30 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 21:27:31 -0300 Subject: [PATCH 05/32] Include tests for new crowding metrics --- MANIFEST.in | 2 +- pymoo/cython/mnn.pyx | 131 +-------------------- pymoo/cython/pruning_cd.pyx | 119 +------------------ pymoo/cython/utils.pxd | 129 ++++++++++++++++++++ tests/algorithms/test_rank_and_crowding.py | 117 ++++++++++++++++++ 5 files changed, 251 insertions(+), 247 deletions(-) create mode 100644 pymoo/cython/utils.pxd create mode 100644 tests/algorithms/test_rank_and_crowding.py diff --git a/MANIFEST.in b/MANIFEST.in index 182882a01..f84b358f4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ prune . -recursive-include pymoo *.py *.pyx +recursive-include pymoo *.py *.pyx *.pxd recursive-include pymoo/cython/vendor *.cpp *.h include LICENSE Makefile diff --git a/pymoo/cython/mnn.pyx b/pymoo/cython/mnn.pyx index 7ae772e84..a3a8552fd 100644 --- a/pymoo/cython/mnn.pyx +++ b/pymoo/cython/mnn.pyx @@ -10,6 +10,8 @@ import numpy as np +from pymoo.cython.utils cimport c_get_drop, c_get_argmin, c_get_argmax, c_normalize_array + from libcpp cimport bool from libcpp.vector cimport vector from libcpp.set cimport set as cpp_set @@ -266,132 +268,3 @@ cdef cpp_set[int] c_get_calc_items( calc_items.insert(i) return calc_items - - -#Returns elements to remove based on crowding metric d and heap of remaining elements H -cdef int c_get_drop(double[:] d, cpp_set[int] H): - - cdef: - int i, min_i - double min_d - - min_d = HUGE_VAL - min_i = 0 - - for i in H: - - if d[i] <= min_d: - min_d = d[i] - min_i = i - - return min_i - - -#Elements in condensed matrix -cdef int c_square_to_condensed(int i, int j, int N): - - cdef int _i = i - - if i < j: - i = j - j = _i - - return N * j - j * (j + 1) // 2 + i - 1 - j - - -#Returns vector of positions of minimum values along axis 0 of a 2d memoryview -cdef vector[int] c_get_argmin(double[:, :] X): - - cdef: - int N, M, min_i, n, m - double min_val - vector[int] indexes - - N = X.shape[0] - M = X.shape[1] - - indexes = vector[int]() - - for m in range(M): - - min_i = 0 - min_val = X[0, m] - - for n in range(N): - - if X[n, m] < min_val: - - min_i = n - min_val = X[n, m] - - indexes.push_back(min_i) - - return indexes - - -#Returns vector of positions of maximum values along axis 0 of a 2d memoryview -cdef vector[int] c_get_argmax(double[:, :] X): - - cdef: - int N, M, max_i, n, m - double max_val - vector[int] indexes - - N = X.shape[0] - M = X.shape[1] - - indexes = vector[int]() - - for m in range(M): - - max_i = 0 - max_val = X[0, m] - - for n in range(N): - - if X[n, m] > max_val: - - max_i = n - max_val = X[n, m] - - indexes.push_back(max_i) - - return indexes - - -#Performs normalization of a 2d memoryview -cdef double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): - - cdef: - int N = X.shape[0] - int M = X.shape[1] - int n, m, l, u - double l_val, u_val, diff_val - vector[double] min_vals, max_vals - - min_vals = vector[double]() - max_vals = vector[double]() - - m = 0 - for u in extremes_max: - u_val = X[u, m] - max_vals.push_back(u_val) - m = m + 1 - - m = 0 - for l in extremes_min: - l_val = X[l, m] - min_vals.push_back(l_val) - m = m + 1 - - for m in range(M): - - diff_val = max_vals[m] - min_vals[m] - if diff_val == 0.0: - diff_val = 1.0 - - for n in range(N): - - X[n, m] = (X[n, m] - min_vals[m]) / diff_val - - return X \ No newline at end of file diff --git a/pymoo/cython/pruning_cd.pyx b/pymoo/cython/pruning_cd.pyx index 505c29d56..a08c07f5a 100644 --- a/pymoo/cython/pruning_cd.pyx +++ b/pymoo/cython/pruning_cd.pyx @@ -3,6 +3,8 @@ import numpy as np +from pymoo.cython.utils cimport c_get_drop, c_get_argmin, c_get_argmax, c_normalize_array + from libcpp cimport bool from libcpp.vector cimport vector from libcpp.set cimport set as cpp_set @@ -193,120 +195,3 @@ cdef cpp_set[int] c_get_calc_items( I[n:-1, m] = I[n + 1:, m] return calc_items - - -#Returns elements to remove based on crowding metric d and heap of remaining elements H -cdef int c_get_drop(double[:] d, cpp_set[int] H): - - cdef: - int i, min_i - double min_d - - min_d = HUGE_VAL - min_i = 0 - - for i in H: - - if d[i] <= min_d: - min_d = d[i] - min_i = i - - return min_i - - -#Returns vector of positions of minimum values along axis 0 of a 2d memoryview -cdef vector[int] c_get_argmin(double[:, :] X): - - cdef: - int N, M, min_i, n, m - double min_val - vector[int] indexes - - N = X.shape[0] - M = X.shape[1] - - indexes = vector[int]() - - for m in range(M): - - min_i = 0 - min_val = X[0, m] - - for n in range(N): - - if X[n, m] < min_val: - - min_i = n - min_val = X[n, m] - - indexes.push_back(min_i) - - return indexes - - -#Returns vector of positions of maximum values along axis 0 of a 2d memoryview -cdef vector[int] c_get_argmax(double[:, :] X): - - cdef: - int N, M, max_i, n, m - double max_val - vector[int] indexes - - N = X.shape[0] - M = X.shape[1] - - indexes = vector[int]() - - for m in range(M): - - max_i = 0 - max_val = X[0, m] - - for n in range(N): - - if X[n, m] > max_val: - - max_i = n - max_val = X[n, m] - - indexes.push_back(max_i) - - return indexes - - -#Performs normalization of a 2d memoryview -cdef double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): - - cdef: - int N = X.shape[0] - int M = X.shape[1] - int n, m, l, u - double l_val, u_val, diff_val - vector[double] min_vals, max_vals - - min_vals = vector[double]() - max_vals = vector[double]() - - m = 0 - for u in extremes_max: - u_val = X[u, m] - max_vals.push_back(u_val) - m = m + 1 - - m = 0 - for l in extremes_min: - l_val = X[l, m] - min_vals.push_back(l_val) - m = m + 1 - - for m in range(M): - - diff_val = max_vals[m] - min_vals[m] - if diff_val == 0.0: - diff_val = 1.0 - - for n in range(N): - - X[n, m] = (X[n, m] - min_vals[m]) / diff_val - - return X \ No newline at end of file diff --git a/pymoo/cython/utils.pxd b/pymoo/cython/utils.pxd new file mode 100644 index 000000000..1d15b672c --- /dev/null +++ b/pymoo/cython/utils.pxd @@ -0,0 +1,129 @@ +# distutils: language = c++ +# cython: language_level=2, boundscheck=False, wraparound=False, cdivision=True + +import numpy as np + +from libcpp cimport bool +from libcpp.vector cimport vector +from libcpp.set cimport set as cpp_set + + +cdef extern from "math.h": + double HUGE_VAL + + +#Returns elements to remove based on crowding metric d and heap of remaining elements H +cdef inline int c_get_drop(double[:] d, cpp_set[int] H): + + cdef: + int i, min_i + double min_d + + min_d = HUGE_VAL + min_i = 0 + + for i in H: + + if d[i] <= min_d: + min_d = d[i] + min_i = i + + return min_i + + +#Returns vector of positions of minimum values along axis 0 of a 2d memoryview +cdef inline vector[int] c_get_argmin(double[:, :] X): + + cdef: + int N, M, min_i, n, m + double min_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + min_i = 0 + min_val = X[0, m] + + for n in range(N): + + if X[n, m] < min_val: + + min_i = n + min_val = X[n, m] + + indexes.push_back(min_i) + + return indexes + + +#Returns vector of positions of maximum values along axis 0 of a 2d memoryview +cdef inline vector[int] c_get_argmax(double[:, :] X): + + cdef: + int N, M, max_i, n, m + double max_val + vector[int] indexes + + N = X.shape[0] + M = X.shape[1] + + indexes = vector[int]() + + for m in range(M): + + max_i = 0 + max_val = X[0, m] + + for n in range(N): + + if X[n, m] > max_val: + + max_i = n + max_val = X[n, m] + + indexes.push_back(max_i) + + return indexes + + +#Performs normalization of a 2d memoryview +cdef inline double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): + + cdef: + int N = X.shape[0] + int M = X.shape[1] + int n, m, l, u + double l_val, u_val, diff_val + vector[double] min_vals, max_vals + + min_vals = vector[double]() + max_vals = vector[double]() + + m = 0 + for u in extremes_max: + u_val = X[u, m] + max_vals.push_back(u_val) + m = m + 1 + + m = 0 + for l in extremes_min: + l_val = X[l, m] + min_vals.push_back(l_val) + m = m + 1 + + for m in range(M): + + diff_val = max_vals[m] - min_vals[m] + if diff_val == 0.0: + diff_val = 1.0 + + for n in range(N): + + X[n, m] = (X[n, m] - min_vals[m]) / diff_val + + return X \ No newline at end of file diff --git a/tests/algorithms/test_rank_and_crowding.py b/tests/algorithms/test_rank_and_crowding.py new file mode 100644 index 000000000..b08adde36 --- /dev/null +++ b/tests/algorithms/test_rank_and_crowding.py @@ -0,0 +1,117 @@ +import pytest + +import numpy as np +from pymoo.optimize import minimize +from pymoo.problems import get_problem +from pymoo.indicators.igd import IGD +from pymoo.algorithms.moo.nsga2 import NSGA2 +from pymoo.operators.survival.rank_and_crowding import RankAndCrowding, ConstrRankAndCrowding +from pymoo.operators.survival.rank_and_crowding.metrics import calc_crowding_distance +from pymoo.util.function_loader import load_function +from pymoo.util.mnn import calc_mnn as calc_mnn_python +from pymoo.util.mnn import calc_2nn as calc_2nn_python + + +calc_mnn = load_function("calc_mnn") +calc_2nn = load_function("calc_2nn") +calc_pcd = load_function("calc_pcd") + + +@pytest.mark.parametrize('crowding_func', ["mnn", "2nn", "cd", "pcd", "ce"]) +@pytest.mark.parametrize('survival', [RankAndCrowding, ConstrRankAndCrowding]) +def test_multi_run(crowding_func, survival): + + problem = get_problem("truss2d") + + NGEN = 250 + POPSIZE = 100 + SEED = 5 + + nsga2 = NSGA2(pop_size=POPSIZE, survival=survival(crowding_func=crowding_func)) + + res = minimize(problem, + nsga2, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + assert len(res.opt) > 0 + + +def test_cd_and_pcd(): + + problem = get_problem("truss2d") + + NGEN = 200 + POPSIZE = 100 + SEED = 5 + + nsga2 = NSGA2(pop_size=POPSIZE, survival=RankAndCrowding(crowding_func="pcd")) + + res = minimize(problem, + nsga2, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + cd = calc_crowding_distance(res.F) + pcd = calc_pcd(res.F) + + assert np.sum(np.abs(cd[~np.isinf(cd)] - pcd[~np.isinf(pcd)])) <= 1e-8 + + new_F = res.F.copy() + + for j in range(10): + + cd = calc_crowding_distance(new_F) + k = np.argmin(cd) + new_F = new_F[np.arange(len(new_F)) != k] + + pcd = calc_pcd(res.F, n_remove=10) + ind = np.argpartition(pcd, 10)[:10] + + new_F_alt = res.F.copy()[np.setdiff1d(np.arange(len(res.F)), ind)] + + assert np.sum(np.abs(new_F - new_F_alt)) <= 1e-8 + + +def test_mnn(): + + problem = get_problem("dtlz2") + + NGEN = 200 + POPSIZE = 100 + SEED = 5 + + nsga2 = NSGA2(pop_size=POPSIZE, survival=RankAndCrowding(crowding_func="mnn")) + + res = minimize(problem, + nsga2, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + surv_mnn = RankAndCrowding(crowding_func="mnn") + surv_2nn = RankAndCrowding(crowding_func="2nn") + + surv_mnn_py = RankAndCrowding(crowding_func=calc_mnn_python) + surv_2nn_py = RankAndCrowding(crowding_func=calc_2nn_python) + + np.random.seed(12) + pop_mnn = surv_mnn.do(problem, res.pop, n_survive=80) + + np.random.seed(12) + pop_mnn_py = surv_mnn_py.do(problem, res.pop, n_survive=80) + + assert np.sum(np.abs(pop_mnn.get("F") - pop_mnn_py.get("F"))) <= 1e-8 + + np.random.seed(12) + pop_2nn = surv_2nn.do(problem, res.pop, n_survive=70) + + np.random.seed(12) + pop_2nn_py = surv_2nn_py.do(problem, res.pop, n_survive=70) + + assert np.sum(np.abs(pop_2nn.get("F") - pop_2nn_py.get("F"))) <= 1e-8 \ No newline at end of file From d850b3d15e87b723bc617cbc9907187e73de9ca0 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 21:40:17 -0300 Subject: [PATCH 06/32] Fix constrained rank and crowding --- pymoo/operators/survival/rank_and_crowding/classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymoo/operators/survival/rank_and_crowding/classes.py b/pymoo/operators/survival/rank_and_crowding/classes.py index dcf3bc1fd..56493bc3f 100644 --- a/pymoo/operators/survival/rank_and_crowding/classes.py +++ b/pymoo/operators/survival/rank_and_crowding/classes.py @@ -155,7 +155,7 @@ def _do(self, if problem.n_constr > 0: #Split by feasibility - feas, infeas = split_by_feasibility(pop, eps=0.0, sort_infeasbible_by_cv=True) + feas, infeas = feas, infeas = split_by_feasibility(pop, sort_infeas_by_cv=True, sort_feas_by_obj=False, return_pop=False) #Obtain len of feasible n_feas = len(feas) From 97bf0ca21aa9045a972ebcb600cd0801cf9220e7 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 21:46:30 -0300 Subject: [PATCH 07/32] Fix crowding dist w duplicates test --- tests/misc/test_crowding_distance.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/misc/test_crowding_distance.py b/tests/misc/test_crowding_distance.py index 750c67c98..e6434fbe3 100644 --- a/tests/misc/test_crowding_distance.py +++ b/tests/misc/test_crowding_distance.py @@ -3,31 +3,32 @@ import numpy as np import pytest -from pymoo.operators.survival.rank_and_crowding.metrics import calc_crowding_distance +from pymoo.operators.survival.rank_and_crowding.metrics import calc_crowding_distance, FunctionalDiversity from pymoo.config import get_pymoo +crowding_func = FunctionalDiversity(calc_crowding_distance, filter_out_duplicates=True) @pytest.mark.skip(reason="check if this is supposed to work or not at all") def test_crowding_distance(): D = np.loadtxt(os.path.join(get_pymoo(), "tests", "resources", "test_crowding.dat")) F, cd = D[:, :-1], D[:, -1] - assert np.all(np.abs(cd - calc_crowding_distance(F)) < 0.001) + assert np.all(np.abs(cd - crowding_func.do(F)) < 0.001) def test_crowding_distance_one_duplicate(): F = np.array([[1.0, 1.0], [1.0, 1.0], [0.5, 1.5], [0.0, 2.0]]) - cd = calc_crowding_distance(F) + cd = crowding_func.do(F) np.testing.assert_almost_equal(cd, np.array([np.inf, 0.0, 1.0, np.inf])) def test_crowding_distance_two_duplicates(): F = np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [0.5, 1.5], [0.0, 2.0]]) - cd = calc_crowding_distance(F) + cd = crowding_func.do(F) np.testing.assert_almost_equal(cd, np.array([np.inf, 0.0, 0.0, 1.0, np.inf])) def test_crowding_distance_norm_equals_zero(): F = np.array([[1.0, 1.5, 0.5, 1.0], [1.0, 0.5, 1.5, 1.0], [1.0, 0.0, 2.0, 1.5]]) - cd = calc_crowding_distance(F) + cd = crowding_func.do(F) np.testing.assert_almost_equal(cd, np.array([np.inf, 0.75, np.inf])) From 2cd022cef0151d6481bf818204b19070df41b0c4 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 23:10:11 -0300 Subject: [PATCH 08/32] Fix functional diversity n_points n_obj --- pymoo/operators/survival/rank_and_crowding/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymoo/operators/survival/rank_and_crowding/metrics.py b/pymoo/operators/survival/rank_and_crowding/metrics.py index 751f4fe73..f40924ce7 100644 --- a/pymoo/operators/survival/rank_and_crowding/metrics.py +++ b/pymoo/operators/survival/rank_and_crowding/metrics.py @@ -46,7 +46,7 @@ def _do(self, F, **kwargs): n_points, n_obj = F.shape - if n_points <= F.shape[1]: + if n_points <= 2: return np.full(n_points, np.inf) else: From 5ceda2a0815560f90a32283492983e7150c5c42c Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 23:44:02 -0300 Subject: [PATCH 09/32] Fix mnn N versus M --- pymoo/cython/mnn.pyx | 3 +++ pymoo/operators/survival/rank_and_crowding/metrics.py | 2 +- pymoo/util/mnn.py | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pymoo/cython/mnn.pyx b/pymoo/cython/mnn.pyx index a3a8552fd..854400a95 100644 --- a/pymoo/cython/mnn.pyx +++ b/pymoo/cython/mnn.pyx @@ -31,6 +31,9 @@ def calc_mnn(double[:, :] X, int n_remove=0): N = X.shape[0] M = X.shape[1] + if N <= M: + return np.full(N, HUGE_VAL) + if n_remove <= (N - M): if n_remove < 0: n_remove = 0 diff --git a/pymoo/operators/survival/rank_and_crowding/metrics.py b/pymoo/operators/survival/rank_and_crowding/metrics.py index f40924ce7..f1ebeb8bc 100644 --- a/pymoo/operators/survival/rank_and_crowding/metrics.py +++ b/pymoo/operators/survival/rank_and_crowding/metrics.py @@ -19,7 +19,7 @@ def get_crowding_function(label): elif hasattr(label, "__call__"): fun = FunctionalDiversity(label, filter_out_duplicates=True) else: - raise KeyError("Crwoding function not defined") + raise KeyError("Crowding function not defined") return fun diff --git a/pymoo/util/mnn.py b/pymoo/util/mnn.py index 53a641ace..39aa63271 100644 --- a/pymoo/util/mnn.py +++ b/pymoo/util/mnn.py @@ -11,6 +11,9 @@ def calc_mnn_base(X, n_remove=0, twonn=False): N = X.shape[0] M = X.shape[1] + + if N <= M: + return np.full(N, np.inf) if n_remove <= (N - M): if n_remove < 0: From f64dbd5133d625def17b970174307679013cbccd Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Mon, 31 Oct 2022 23:58:33 -0300 Subject: [PATCH 10/32] Include CowdingDiversity as a valid kwarg --- .../survival/rank_and_crowding/metrics.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/pymoo/operators/survival/rank_and_crowding/metrics.py b/pymoo/operators/survival/rank_and_crowding/metrics.py index f1ebeb8bc..f6fb4b846 100644 --- a/pymoo/operators/survival/rank_and_crowding/metrics.py +++ b/pymoo/operators/survival/rank_and_crowding/metrics.py @@ -13,11 +13,13 @@ def get_crowding_function(label): elif label == "ce": fun = FunctionalDiversity(calc_crowding_entropy, filter_out_duplicates=True) elif label == "mnn": - fun = FunctionalDiversity(load_function("calc_mnn"), filter_out_duplicates=True) + fun = FuncionalDiversityMNN(load_function("calc_mnn"), filter_out_duplicates=True) elif label == "2nn": - fun = FunctionalDiversity(load_function("calc_2nn"), filter_out_duplicates=True) + fun = FuncionalDiversityMNN(load_function("calc_2nn"), filter_out_duplicates=True) elif hasattr(label, "__call__"): fun = FunctionalDiversity(label, filter_out_duplicates=True) + elif isinstance(label, CrowdingDiversity): + fun = label else: raise KeyError("Crowding function not defined") return fun @@ -69,6 +71,19 @@ def _do(self, F, **kwargs): return d +class FuncionalDiversityMNN(FunctionalDiversity): + + def _do(self, F, **kwargs): + + n_points, n_obj = F.shape + + if n_points <= n_obj: + return np.full(n_points, np.inf) + + else: + return super()._do(F, **kwargs) + + def calc_crowding_distance(F, **kwargs): n_points, n_obj = F.shape From 69b92a903a15afdfeeca25cdce373c590910ef15 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 00:32:21 -0300 Subject: [PATCH 11/32] Fix exp crossover for new DE --- pymoo/operators/crossover/expx.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pymoo/operators/crossover/expx.py b/pymoo/operators/crossover/expx.py index 18c47e11d..dd0b96c73 100644 --- a/pymoo/operators/crossover/expx.py +++ b/pymoo/operators/crossover/expx.py @@ -6,7 +6,11 @@ def mut_exp(n_matings, n_var, prob, at_least_once=True): - assert len(prob) == n_matings + + if isinstance(np.float64(prob), float) or isinstance(np.float64(prob), int): + prob = np.ones(n_matings) * prob + else: + assert len(prob) == n_matings # the mask do to the crossover M = np.full((n_matings, n_var), False) From a0616bcd4ee004ff3414c8918be75f0dcf8ae737 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 00:32:39 -0300 Subject: [PATCH 12/32] Include tests for soo DE --- tests/algorithms/test_de.py | 74 ++++++++++++++++++++++++++-------- tests/algorithms/test_de_ep.py | 25 ++++++++++++ 2 files changed, 83 insertions(+), 16 deletions(-) create mode 100644 tests/algorithms/test_de_ep.py diff --git a/tests/algorithms/test_de.py b/tests/algorithms/test_de.py index 426120e18..1f5ddd9ff 100644 --- a/tests/algorithms/test_de.py +++ b/tests/algorithms/test_de.py @@ -1,25 +1,67 @@ import pytest -from pymoo.algorithms.soo.nonconvex.de import DE -from pymoo.problems import get_problem -from pymoo.operators.sampling.lhs import LHS from pymoo.optimize import minimize +from pymoo.problems import get_problem +from pymoo.algorithms.soo.nonconvex.de import DE +from pymoo.algorithms.moo.gde3 import GDE3 -@pytest.mark.parametrize('selection', ["rand", "best", "target-to-best"]) +@pytest.mark.parametrize('selection', ["rand", "best", "current-to-best", "current-to-rand", "ranked"]) @pytest.mark.parametrize('crossover', ["bin", "exp"]) -def test_de(selection, crossover): - problem = get_problem("ackley", n_var=10) +@pytest.mark.parametrize('repair', ["bounce-back", "midway", "rand-init", "to-bounds"]) +def test_de_run(selection, crossover, repair): + problem = get_problem("rastrigin") + + NGEN = 30 + POPSIZE = 20 + SEED = 3 + + #DE Parameters + CR = 0.5 + F = (0.3, 1.0) + + de = DE(pop_size=POPSIZE, variant=f"DE/{selection}/1/{crossover}", CR=CR, F=F, repair=repair) + + res_de = minimize(problem, + de, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + assert len(res_de.opt) > 0 + + +def test_de_perf(): + problem = get_problem("rastrigin") + + NGEN = 100 + POPSIZE = 20 + SEED = 3 + + #DE Parameters + CR = 0.5 + F = (0.3, 1.0) - algorithm = DE( - pop_size=100, - sampling=LHS(), - variant=f"DE/{selection}/1/{crossover}") + de = DE(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=CR, F=F) - ret = minimize(problem, - algorithm, - ('n_gen', 20), - seed=1, - verbose=True) + res_de = minimize(problem, + de, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) - assert len(ret.opt) > 0 + assert len(res_de.opt) > 0 + assert res_de.F <= 1e-6 + + gde3 = GDE3(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=CR, F=F) + + res_gde3 = minimize(problem, + gde3, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + assert res_gde3.F <= 1e-6 \ No newline at end of file diff --git a/tests/algorithms/test_de_ep.py b/tests/algorithms/test_de_ep.py new file mode 100644 index 000000000..3ebbf8c38 --- /dev/null +++ b/tests/algorithms/test_de_ep.py @@ -0,0 +1,25 @@ +import pytest + +from pymoo.algorithms.soo.nonconvex.de_ep import EPDE +from pymoo.problems import get_problem +from pymoo.operators.sampling.lhs import LHS +from pymoo.optimize import minimize + + +@pytest.mark.parametrize('selection', ["rand", "best", "target-to-best"]) +@pytest.mark.parametrize('crossover', ["bin", "exp"]) +def test_de(selection, crossover): + problem = get_problem("ackley", n_var=10) + + algorithm = EPDE( + pop_size=100, + sampling=LHS(), + variant=f"DE/{selection}/1/{crossover}") + + ret = minimize(problem, + algorithm, + ('n_gen', 20), + seed=1, + verbose=True) + + assert len(ret.opt) > 0 From cae48e209a835244daec7cbb3ef717faaabc9013 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 00:44:18 -0300 Subject: [PATCH 13/32] Fix test DE repair --- tests/algorithms/test_de.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/algorithms/test_de.py b/tests/algorithms/test_de.py index 1f5ddd9ff..fc533729a 100644 --- a/tests/algorithms/test_de.py +++ b/tests/algorithms/test_de.py @@ -8,8 +8,8 @@ @pytest.mark.parametrize('selection', ["rand", "best", "current-to-best", "current-to-rand", "ranked"]) @pytest.mark.parametrize('crossover', ["bin", "exp"]) -@pytest.mark.parametrize('repair', ["bounce-back", "midway", "rand-init", "to-bounds"]) -def test_de_run(selection, crossover, repair): +@pytest.mark.parametrize('de_repair', ["bounce-back", "midway", "rand-init", "to-bounds"]) +def test_de_run(selection, crossover, de_repair): problem = get_problem("rastrigin") NGEN = 30 @@ -20,7 +20,7 @@ def test_de_run(selection, crossover, repair): CR = 0.5 F = (0.3, 1.0) - de = DE(pop_size=POPSIZE, variant=f"DE/{selection}/1/{crossover}", CR=CR, F=F, repair=repair) + de = DE(pop_size=POPSIZE, variant=f"DE/{selection}/1/{crossover}", CR=CR, F=F, de_repair=de_repair) res_de = minimize(problem, de, From cebc0b6eafc18677144d8d838269f9b0a3fea01f Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 00:44:29 -0300 Subject: [PATCH 14/32] Include tests MODE --- tests/algorithms/test_mode.py | 143 ++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 tests/algorithms/test_mode.py diff --git a/tests/algorithms/test_mode.py b/tests/algorithms/test_mode.py new file mode 100644 index 000000000..4573fe65d --- /dev/null +++ b/tests/algorithms/test_mode.py @@ -0,0 +1,143 @@ +import pytest + +import numpy as np + +from pymoo.optimize import minimize +from pymoo.problems import get_problem +from pymoo.indicators.igd import IGD +from pymoo.algorithms.moo.nsder import NSDER +from pymoo.algorithms.moo.nsde import NSDE +from pymoo.algorithms.moo.gde3 import GDE3 +from pymoo.operators.survival.rank_and_crowding import RankAndCrowding, ConstrRankAndCrowding +from pymoo.util.ref_dirs import get_reference_directions + + +@pytest.mark.parametrize('survival', [RankAndCrowding, ConstrRankAndCrowding]) +@pytest.mark.parametrize('crowding_func', ["mnn", "2nn", "cd", "pcd", "ce"]) +def test_multi_run(survival, crowding_func): + + problem = get_problem("truss2d") + + NGEN = 250 + POPSIZE = 100 + SEED = 5 + + gde3 = GDE3(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.5, F=(0.0, 0.9), de_repair="bounce-back", + survival=survival(crowding_func=crowding_func)) + + res_gde3 = minimize(problem, + gde3, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + assert len(res_gde3.opt) > 0 + + +def test_multi_perf(): + + problem = get_problem("truss2d") + igd = IGD(pf=problem.pareto_front(), zero_to_one=True) + + NGEN = 250 + POPSIZE = 100 + SEED = 5 + + gde3 = GDE3(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.5, F=(0.0, 0.9), de_repair="bounce-back", + survival=RankAndCrowding(crowding_func="cd")) + + res_gde3 = minimize(problem, + gde3, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + igd_gde3 = igd.do(res_gde3.F) + assert igd_gde3 <= 0.08 + + gde3p = GDE3(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.5, F=(0.0, 0.9), de_repair="bounce-back", + survival=RankAndCrowding(crowding_func="pcd")) + + res_gde3p = minimize(problem, + gde3p, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + igd_gde3p = igd.do(res_gde3p.F) + assert igd_gde3p <= 0.01 + + nsde = NSDE(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.5, F=(0.0, 0.9), de_repair="bounce-back", + survival=RankAndCrowding(crowding_func="pcd")) + + res_nsde = minimize(problem, + nsde, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + igd_nsde = igd.do(res_nsde.F) + assert igd_nsde <= 0.01 + +@pytest.mark.parametrize('selection', ["rand", "current-to-rand", "ranked"]) +@pytest.mark.parametrize('crossover', ["bin", "exp"]) +@pytest.mark.parametrize('crowding_func', ["mnn", "2nn"]) +def test_many_run(selection, crossover, crowding_func): + + problem = get_problem("dtlz2") + + NGEN = 50 + POPSIZE = 136 + SEED = 5 + + gde3 = GDE3(pop_size=POPSIZE, variant=f"DE/{selection}/1/{crossover}", CR=0.2, F=(0.0, 1.0), gamma=1e-4, + survival=RankAndCrowding(crowding_func=crowding_func)) + + res_gde3 = minimize(problem, + gde3, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + assert len(res_gde3.opt) > 0 + + +def test_many_perf(): + + problem = get_problem("dtlz2") + ref_dirs = get_reference_directions("das-dennis", 3, n_partitions=15) + igd = IGD(pf=problem.pareto_front(), zero_to_one=True) + + NGEN = 150 + POPSIZE = 136 + SEED = 5 + + gde3 = GDE3(pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.2, F=(0.0, 1.0), gamma=1e-4, + survival=RankAndCrowding(crowding_func="mnn")) + + res_gde3 = minimize(problem, + gde3, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + igd_gde3 = igd.do(res_gde3.F) + assert igd_gde3 <= 0.07 + + nsder = NSDER(ref_dirs=ref_dirs, pop_size=POPSIZE, variant="DE/rand/1/bin", CR=0.5, F=(0.0, 1.0), gamma=1e-4) + + res_nsder = minimize(problem, + nsder, + ('n_gen', NGEN), + seed=SEED, + save_history=False, + verbose=False) + + igd_nsder = igd.do(res_nsder.F) + assert igd_nsder <= 0.01 \ No newline at end of file From 12737a0a4ede23c36adf3c524992aba8d1df97ce Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 01:14:57 -0300 Subject: [PATCH 15/32] Refactor DEX --- pymoo/operators/crossover/dex.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index 3e2d5135f..aba6cbaf7 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -43,13 +43,12 @@ def __init__(self, self.F = F self.gamma = gamma self.de_repair = de_repair + + def __call__(self, problem, pop, parents, **kwargs): + return self.do(problem, pop, parents, **kwargs) - def do(self, problem, pop, parents, **kwargs): + def do(self, problem, Xr, **kwargs): - # Get all X values for mutation parents - Xr = pop.get("X")[parents.T].copy() - assert len(Xr.shape) == 3, "Please provide a three-dimensional matrix n_parents x pop_size x n_vars." - # Create mutation vectors V, diffs = self.de_mutation(Xr, return_differentials=True) @@ -139,17 +138,19 @@ def __init__(self, self.at_least_once = at_least_once super().__init__(2 + 2 * n_diffs, 1, prob=1.0, **kwargs) + - def do(self, problem, pop, parents, **kwargs): + def _do(self, problem, X, **kwargs): # Get target vectors - X = pop.get("X")[parents[:, 0]] + Xp = X[0] + Xr = X[1:] # About Xi n_matings, n_var = X.shape # Obtain mutants - mutants = self.dem.do(problem, pop, parents[:, 1:], **kwargs) + mutants = self.dem.do(problem, Xr, **kwargs) # Obtain V V = mutants.get("X") From cf99048b8bce10cfbd9b93ecde6a8014ee91aefb Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 01:19:34 -0300 Subject: [PATCH 16/32] Fix new DEX --- pymoo/operators/crossover/dex.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index aba6cbaf7..2615d936c 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -147,7 +147,7 @@ def _do(self, problem, X, **kwargs): Xr = X[1:] # About Xi - n_matings, n_var = X.shape + n_matings, n_var = Xp.shape # Obtain mutants mutants = self.dem.do(problem, Xr, **kwargs) @@ -165,9 +165,9 @@ def _do(self, problem, X, **kwargs): raise Exception(f"Unknown variant: {self.variant}") # Add mutated elements in corresponding main parent - X[M] = V[M] + Xp[M] = V[M] - off = Population.new("X", X) + off = Population.new("X", Xp) return off From e5d27338a5ef77fd02f7ce84acd8194410c8f774 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 01:35:47 -0300 Subject: [PATCH 17/32] Revert DEX arguments --- pymoo/operators/crossover/dex.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index 2615d936c..8fc802473 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -47,9 +47,13 @@ def __init__(self, def __call__(self, problem, pop, parents, **kwargs): return self.do(problem, pop, parents, **kwargs) - def do(self, problem, Xr, **kwargs): + def do(self, problem, pop, parents, **kwargs): - # Create mutation vectors + #Get all X values for mutation parents + Xr = pop.get("X")[parents.T].copy() + assert len(Xr.shape) == 3, "Please provide a three-dimensional matrix n_parents x pop_size x n_vars." + + #Create mutation vectors V, diffs = self.de_mutation(Xr, return_differentials=True) # If the problem has boundaries to be considered @@ -140,17 +144,16 @@ def __init__(self, super().__init__(2 + 2 * n_diffs, 1, prob=1.0, **kwargs) - def _do(self, problem, X, **kwargs): + def do(self, problem, pop, parents, **kwargs): - # Get target vectors - Xp = X[0] - Xr = X[1:] + #Get target vectors + X = pop.get("X")[parents[:, 0]] - # About Xi - n_matings, n_var = Xp.shape + #About Xi + n_matings, n_var = X.shape - # Obtain mutants - mutants = self.dem.do(problem, Xr, **kwargs) + #Obtain mutants + mutants = self.dem.do(problem, pop, parents[:, 1:], **kwargs) # Obtain V V = mutants.get("X") From c57f47849764565bed233ba2900f7f4e947886dd Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 01:36:01 -0300 Subject: [PATCH 18/32] Remove DEX from test_crossover --- tests/operators/test_crossover.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/operators/test_crossover.py b/tests/operators/test_crossover.py index ff022c079..f123fb65d 100644 --- a/tests/operators/test_crossover.py +++ b/tests/operators/test_crossover.py @@ -5,8 +5,6 @@ from pymoo.operators.crossover.ox import OrderCrossover from pymoo.operators.crossover.spx import SPX -from pymoo.operators.crossover.dex import DEX - from pymoo.algorithms.moo.nsga2 import NSGA2 from pymoo.algorithms.soo.nonconvex.ga import GA from pymoo.operators.crossover.pntx import TwoPointCrossover @@ -19,7 +17,7 @@ from pymoo.problems.single.traveling_salesman import create_random_tsp_problem -@pytest.mark.parametrize('crossover', [DEX(), SBX()]) +@pytest.mark.parametrize('crossover', [SBX()]) def test_crossover_real(crossover): method = GA(pop_size=20, crossover=crossover) minimize(get_problem("sphere"), method, ("n_gen", 20)) From 25eea6dbaabe6ebd09931740f9610c8398fd9f93 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 01:39:50 -0300 Subject: [PATCH 19/32] Fix DEX do --- pymoo/operators/crossover/dex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index 8fc802473..ef13d8202 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -168,9 +168,9 @@ def do(self, problem, pop, parents, **kwargs): raise Exception(f"Unknown variant: {self.variant}") # Add mutated elements in corresponding main parent - Xp[M] = V[M] + X[M] = V[M] - off = Population.new("X", Xp) + off = Population.new("X", X) return off From c8b1431cc8943a2af95c04daf38df2bd93dea072 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 18:23:43 -0300 Subject: [PATCH 20/32] Refactor DEX and DEM to match SBX format --- pymoo/algorithms/soo/nonconvex/de.py | 11 ++++----- pymoo/operators/crossover/dex.py | 37 +++++++++++++++++----------- tests/operators/test_crossover.py | 3 ++- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index a01c39fc3..d61801183 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -41,7 +41,7 @@ def __init__(self, if F is None: F = (0.0, 1.0) - # Define crossover strategy + # Define crossover strategy (DE mutation is included) self.crossover = DEX(variant=crossover_variant, CR=CR, F=F, @@ -57,15 +57,14 @@ def __init__(self, def do(self, problem, pop, n_offsprings, **kwargs): # Select parents including donor vector - parents = self.selection.do(problem, pop, n_offsprings, self.crossover.n_parents, - to_pop=False, **kwargs) + parents = self.selection(problem, pop, n_offsprings, self.crossover.n_parents, to_pop=True, **kwargs) # Perform mutation included in DEX and crossover - off = self.crossover.do(problem, pop, parents, **kwargs) + off = self.crossover(problem, parents, **kwargs) # Perform posterior mutation and repair if passed - off = self.mutation.do(problem, off) - off = self.repair.do(problem, off) + off = self.mutation(problem, off) + off = self.repair(problem, off) return off diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index ef13d8202..67a487965 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -9,12 +9,13 @@ # Implementation # ========================================================================================================= -class DEM: +class DEM(Crossover): def __init__(self, F=None, gamma=1e-4, de_repair="bounce-back", + n_diffs=1, **kwargs): # Default value for F @@ -43,17 +44,20 @@ def __init__(self, self.F = F self.gamma = gamma self.de_repair = de_repair + + super().__init__(1 + 2 * n_diffs, 1, prob=1.0, **kwargs) - def __call__(self, problem, pop, parents, **kwargs): - return self.do(problem, pop, parents, **kwargs) - def do(self, problem, pop, parents, **kwargs): - - #Get all X values for mutation parents - Xr = pop.get("X")[parents.T].copy() - assert len(Xr.shape) == 3, "Please provide a three-dimensional matrix n_parents x pop_size x n_vars." + def do(self, problem, pop, parents=None, **kwargs): - #Create mutation vectors + # Convert pop if parents is not None + if not parents is None: + pop = pop[parents] + + # Get all X values for mutation parents + Xr = np.swapaxes(pop, 0, 1).get("X") + + # Create mutation vectors V, diffs = self.de_mutation(Xr, return_differentials=True) # If the problem has boundaries to be considered @@ -135,7 +139,8 @@ def __init__(self, # Create instace for mutation self.dem = DEM(F=F, gamma=gamma, - de_repair=de_repair) + de_repair=de_repair, + n_diffs=n_diffs) self.CR = CR self.variant = variant @@ -144,16 +149,20 @@ def __init__(self, super().__init__(2 + 2 * n_diffs, 1, prob=1.0, **kwargs) - def do(self, problem, pop, parents, **kwargs): + def do(self, problem, pop, parents=None, **kwargs): + + # Convert pop if parents is not None + if not parents is None: + pop = pop[parents] - #Get target vectors - X = pop.get("X")[parents[:, 0]] + # Get all X values for mutation parents + X = pop[:, 0].get("X") #About Xi n_matings, n_var = X.shape #Obtain mutants - mutants = self.dem.do(problem, pop, parents[:, 1:], **kwargs) + mutants = self.dem.do(problem, pop[:, 1:], **kwargs) # Obtain V V = mutants.get("X") diff --git a/tests/operators/test_crossover.py b/tests/operators/test_crossover.py index f123fb65d..1ec45e536 100644 --- a/tests/operators/test_crossover.py +++ b/tests/operators/test_crossover.py @@ -9,6 +9,7 @@ from pymoo.algorithms.soo.nonconvex.ga import GA from pymoo.operators.crossover.pntx import TwoPointCrossover from pymoo.operators.crossover.sbx import SBX +from pymoo.operators.crossover.dex import DEM, DEX from pymoo.operators.crossover.ux import UX from pymoo.operators.mutation.inversion import InversionMutation from pymoo.operators.sampling.rnd import PermutationRandomSampling @@ -17,7 +18,7 @@ from pymoo.problems.single.traveling_salesman import create_random_tsp_problem -@pytest.mark.parametrize('crossover', [SBX()]) +@pytest.mark.parametrize('crossover', [DEX(), DEM(), SBX()]) def test_crossover_real(crossover): method = GA(pop_size=20, crossover=crossover) minimize(get_problem("sphere"), method, ("n_gen", 20)) From fbe12daf1a451dbab8c25f4b1330410e5572b30e Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Tue, 1 Nov 2022 19:08:32 -0300 Subject: [PATCH 21/32] Fix DE de_repair on mating --- pymoo/algorithms/soo/nonconvex/de.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index d61801183..5101f24d0 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -48,7 +48,7 @@ def __init__(self, gamma=gamma, n_diffs=n_diffs, at_least_once=True, - repair=repair) + de_repair=de_repair) # Define posterior mutation strategy and repair self.mutation = mutation if mutation is not None else NoMutation() From edbaec9c0f8b2797708c27ea47d3a771277824ed Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:06:54 -0300 Subject: [PATCH 22/32] Refactor infill of DE new class with inheritance --- pymoo/algorithms/soo/nonconvex/de.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 5101f24d0..57d8b13c7 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -3,19 +3,19 @@ from pymoo.algorithms.soo.nonconvex.ga import FitnessSurvival from pymoo.core.replacement import ImprovementReplacement from pymoo.operators.mutation.nom import NoMutation -from pymoo.core.repair import NoRepair from pymoo.operators.sampling.lhs import LHS from pymoo.termination.default import DefaultSingleObjectiveTermination from pymoo.util.display.single import SingleObjectiveOutput from pymoo.operators.selection.des import DES from pymoo.operators.crossover.dex import DEX +from pymoo.core.infill import InfillCriterion # ========================================================================================================= # Implementation # ========================================================================================================= -class InfillDE: +class VariantDE(InfillCriterion): def __init__(self, variant="DE/rand/1/bin", @@ -23,8 +23,7 @@ def __init__(self, F=(0.5, 1.0), gamma=1e-4, de_repair="bounce-back", - mutation=None, - repair=None): + mutation=None): # Parse the information from the string _, selection_variant, n_diff, crossover_variant, = variant.split("/") @@ -37,7 +36,7 @@ def __init__(self, # Define parent selection operator self.selection = DES(selection_variant) - #Default value for F + # Default value for F if F is None: F = (0.0, 1.0) @@ -52,9 +51,9 @@ def __init__(self, # Define posterior mutation strategy and repair self.mutation = mutation if mutation is not None else NoMutation() - self.repair = repair if repair is not None else NoRepair() + - def do(self, problem, pop, n_offsprings, **kwargs): + def _do(self, problem, pop, n_offsprings, **kwargs): # Select parents including donor vector parents = self.selection(problem, pop, n_offsprings, self.crossover.n_parents, to_pop=True, **kwargs) @@ -64,7 +63,6 @@ def do(self, problem, pop, n_offsprings, **kwargs): # Perform posterior mutation and repair if passed off = self.mutation(problem, off) - off = self.repair(problem, off) return off From 8fdd853792da0ea5e55e68b7cc75def28aae3448 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:09:34 -0300 Subject: [PATCH 23/32] Fix DE algorithms with new mating --- pymoo/algorithms/moo/nsde.py | 16 ++++++++-------- pymoo/algorithms/soo/nonconvex/de.py | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pymoo/algorithms/moo/nsde.py b/pymoo/algorithms/moo/nsde.py index e730350e6..69d581d39 100644 --- a/pymoo/algorithms/moo/nsde.py +++ b/pymoo/algorithms/moo/nsde.py @@ -1,6 +1,6 @@ from pymoo.algorithms.moo.nsga2 import NSGA2 from pymoo.operators.sampling.lhs import LHS -from pymoo.algorithms.soo.nonconvex.de import InfillDE +from pymoo.algorithms.soo.nonconvex.de import VariantDE from pymoo.operators.survival.rank_and_crowding import RankAndCrowding @@ -90,13 +90,13 @@ def __init__(self, n_offsprings = pop_size #Mating - mating = InfillDE(variant=variant, - CR=CR, - F=F, - gamma=gamma, - de_repair=de_repair, - mutation=mutation, - repair=repair) + mating = VariantDE(variant=variant, + CR=CR, + F=F, + gamma=gamma, + de_repair=de_repair, + mutation=mutation, + repair=repair) #Init from pymoo's NSGA2 super().__init__(pop_size=pop_size, diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 57d8b13c7..52766c920 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -136,13 +136,13 @@ def __init__(self, Pymoo's repair operator after mutation. Defaults to NoRepair(). """ - mating = InfillDE(variant=variant, - CR=CR, - F=F, - gamma=gamma, - de_repair=de_repair, - mutation=mutation, - repair=repair) + mating = VariantDE(variant=variant, + CR=CR, + F=F, + gamma=gamma, + de_repair=de_repair, + mutation=mutation, + repair=repair) # Number of offsprings at each generation n_offsprings = pop_size From 4646879f3af0ea79f24b355a5e5a33e5f3f79b57 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:32:31 -0300 Subject: [PATCH 24/32] Fix kwargs passed to DE variant --- pymoo/algorithms/soo/nonconvex/de.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 52766c920..73be395dd 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -23,7 +23,11 @@ def __init__(self, F=(0.5, 1.0), gamma=1e-4, de_repair="bounce-back", - mutation=None): + mutation=None, + **kwargs): + + # Default initialization of InfillCriterion + super().__init__(eliminate_duplicates=False, **kwargs) # Parse the information from the string _, selection_variant, n_diff, crossover_variant, = variant.split("/") @@ -77,8 +81,6 @@ def __init__(self, F=(0.5, 1.0), gamma=1e-4, de_repair="bounce-back", - mutation=None, - repair=None, output=SingleObjectiveOutput(), **kwargs): """ @@ -141,8 +143,7 @@ def __init__(self, F=F, gamma=gamma, de_repair=de_repair, - mutation=mutation, - repair=repair) + **kwargs) # Number of offsprings at each generation n_offsprings = pop_size @@ -162,7 +163,7 @@ def _initialize_advance(self, infills=None, **kwargs): def _infill(self): - infills = self.mating.do(self.problem, self.pop, self.n_offsprings, algorithm=self) + infills = self.mating(self.problem, self.pop, self.n_offsprings, algorithm=self) return infills From 884cb02c1b15caa43051fe69b8582bfe61ec73b3 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:35:54 -0300 Subject: [PATCH 25/32] Style comments --- pymoo/algorithms/soo/nonconvex/de.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 73be395dd..375002653 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -171,11 +171,11 @@ def _advance(self, infills=None, **kwargs): assert infills is not None, "This algorithms uses the AskAndTell interface thus infills must be provided." - #One-to-one replacement survival + # One-to-one replacement survival self.pop = ImprovementReplacement().do(self.problem, self.pop, infills) - #Sort the population by fitness to make the selection simpler for mating (not an actual survival, just sorting) + # Sort the population by fitness to make the selection simpler for mating (not an actual survival, just sorting) self.pop = FitnessSurvival().do(self.problem, self.pop) - #Set ranks + # Set ranks self.pop.set("rank", np.arange(self.pop_size)) From f2c3d502f479c3ba71f0f237d189d7536d40cdac Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:48:02 -0300 Subject: [PATCH 26/32] Fix eliminate duplicates None --- pymoo/algorithms/moo/nsde.py | 15 +++++---------- pymoo/algorithms/soo/nonconvex/de.py | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pymoo/algorithms/moo/nsde.py b/pymoo/algorithms/moo/nsde.py index 69d581d39..e0224340c 100644 --- a/pymoo/algorithms/moo/nsde.py +++ b/pymoo/algorithms/moo/nsde.py @@ -13,14 +13,11 @@ class NSDE(NSGA2): def __init__(self, pop_size=100, - sampling=LHS(), variant="DE/rand/1/bin", CR=0.7, F=None, gamma=1e-4, de_repair="bounce-back", - mutation=None, - repair=None, survival=RankAndCrowding(), **kwargs): """ @@ -86,23 +83,21 @@ def __init__(self, In GDE3, the survival strategy is applied after a one-to-one comparison between child vector and corresponding parent when both are non-dominated by the other. """ - #Number of offsprings at each generation + # Number of offsprings at each generation n_offsprings = pop_size - #Mating + # Mating mating = VariantDE(variant=variant, CR=CR, F=F, gamma=gamma, de_repair=de_repair, - mutation=mutation, - repair=repair) + **kwargs) - #Init from pymoo's NSGA2 + # Init from pymoo's NSGA2 super().__init__(pop_size=pop_size, - sampling=sampling, mating=mating, survival=survival, - eliminate_duplicates=False, + eliminate_duplicates=None, n_offsprings=n_offsprings, **kwargs) diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index 375002653..c667060b6 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -27,7 +27,7 @@ def __init__(self, **kwargs): # Default initialization of InfillCriterion - super().__init__(eliminate_duplicates=False, **kwargs) + super().__init__(eliminate_duplicates=None, **kwargs) # Parse the information from the string _, selection_variant, n_diff, crossover_variant, = variant.split("/") @@ -152,7 +152,7 @@ def __init__(self, sampling=sampling, mating=mating, n_offsprings=n_offsprings, - eliminate_duplicates=False, + eliminate_duplicates=None, output=output, **kwargs) From c8d215ceef09cf4973f454f6173af5e1cc4f0258 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:50:08 -0300 Subject: [PATCH 27/32] Restore sampling LHS on NSDE --- pymoo/algorithms/moo/nsde.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymoo/algorithms/moo/nsde.py b/pymoo/algorithms/moo/nsde.py index e0224340c..79b012d1d 100644 --- a/pymoo/algorithms/moo/nsde.py +++ b/pymoo/algorithms/moo/nsde.py @@ -13,6 +13,7 @@ class NSDE(NSGA2): def __init__(self, pop_size=100, + sampling=LHS(), variant="DE/rand/1/bin", CR=0.7, F=None, @@ -96,6 +97,7 @@ def __init__(self, # Init from pymoo's NSGA2 super().__init__(pop_size=pop_size, + sampling=sampling, mating=mating, survival=survival, eliminate_duplicates=None, From e5baab7823b64eef08e3665bc72b96410c9cb9dd Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:55:45 -0300 Subject: [PATCH 28/32] Style sampling of NSDER --- pymoo/algorithms/moo/nsder.py | 1 - pymoo/algorithms/soo/nonconvex/de.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/pymoo/algorithms/moo/nsder.py b/pymoo/algorithms/moo/nsder.py index e5e0349e1..000342450 100644 --- a/pymoo/algorithms/moo/nsder.py +++ b/pymoo/algorithms/moo/nsder.py @@ -13,7 +13,6 @@ class NSDER(NSDE): def __init__(self, ref_dirs, pop_size=100, - sampling=LHS(), variant="DE/rand/1/bin", CR=0.7, F=None, diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index c667060b6..c48499b61 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -138,6 +138,7 @@ def __init__(self, Pymoo's repair operator after mutation. Defaults to NoRepair(). """ + # Mating mating = VariantDE(variant=variant, CR=CR, F=F, From 977f7583e81e5afebd0e984686c38f10046d5110 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 17:58:58 -0300 Subject: [PATCH 29/32] Fix test DE GDE3 assertion --- tests/algorithms/test_de.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/algorithms/test_de.py b/tests/algorithms/test_de.py index fc533729a..96335dbc9 100644 --- a/tests/algorithms/test_de.py +++ b/tests/algorithms/test_de.py @@ -1,5 +1,6 @@ import pytest +import numpy as np from pymoo.optimize import minimize from pymoo.problems import get_problem from pymoo.algorithms.soo.nonconvex.de import DE @@ -64,4 +65,4 @@ def test_de_perf(): save_history=False, verbose=False) - assert res_gde3.F <= 1e-6 \ No newline at end of file + assert np.all(res_gde3.F <= 1e-6) \ No newline at end of file From 3536c233eb421e592dbaa1b37732c6b0e39d4755 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Wed, 2 Nov 2022 18:03:54 -0300 Subject: [PATCH 30/32] Fix sampling on NSDER init --- pymoo/algorithms/moo/nsder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymoo/algorithms/moo/nsder.py b/pymoo/algorithms/moo/nsder.py index 000342450..d0abd821c 100644 --- a/pymoo/algorithms/moo/nsder.py +++ b/pymoo/algorithms/moo/nsder.py @@ -100,7 +100,6 @@ def __init__(self, survival = ReferenceDirectionSurvival(ref_dirs) super().__init__(pop_size=pop_size, - sampling=sampling, variant=variant, CR=CR, F=F, From e73644db7915241a183a66ce8cba090fc9da9755 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Fri, 18 Nov 2022 00:14:40 -0300 Subject: [PATCH 31/32] Style code formatting with autopep8 --- pymoo/cython/mnn.pyx | 84 +++++++++---------- pymoo/cython/pruning_cd.pyx | 58 ++++++------- pymoo/cython/utils.pxd | 38 ++++----- .../survival/rank_and_crowding/classes.py | 70 ++++++++-------- .../survival/rank_and_crowding/metrics.py | 58 ++++++------- 5 files changed, 154 insertions(+), 154 deletions(-) diff --git a/pymoo/cython/mnn.pyx b/pymoo/cython/mnn.pyx index 854400a95..11496135e 100644 --- a/pymoo/cython/mnn.pyx +++ b/pymoo/cython/mnn.pyx @@ -52,7 +52,7 @@ def calc_mnn(double[:, :] X, int n_remove=0): for n in extremes_max: extremes.insert(n) - + X = c_normalize_array(X, extremes_max, extremes_min) return c_calc_mnn(X, n_remove, N, M, extremes) @@ -88,7 +88,7 @@ def calc_2nn(double[:, :] X, int n_remove=0): extremes.insert(n) X = c_normalize_array(X, extremes_max, extremes_min) - + M = 2 return c_calc_mnn(X, n_remove, N, M, extremes) @@ -104,27 +104,27 @@ cdef c_calc_mnn(double[:, :] X, int n_remove, int N, int M, cpp_set[int] extreme double[:, :] D double[:] d int[:, :] Mnn - - #Define items to calculate distances + + # Define items to calculate distances calc_items = cpp_set[int]() for n in range(N): calc_items.insert(n) for n in extremes: calc_items.erase(n) - - #Define remaining items to evaluate + + # Define remaining items to evaluate H = cpp_set[int]() for n in range(N): H.insert(n) - - #Instantiate distances array + + # Instantiate distances array _D = np.empty((N, N), dtype=np.double) D = _D[:, :] - #Shape of X + # Shape of X MM = X.shape[1] - - #Fill values on D + + # Fill values on D for i in range(N - 1): D[i, i] = 0.0 @@ -139,10 +139,10 @@ cdef c_calc_mnn(double[:, :] X, int n_remove, int N, int M, cpp_set[int] extreme D[N-1, N-1] = 0.0 - #Initialize + # Initialize n_removed = 0 - #Initialize neighbors and distances + # Initialize neighbors and distances # _Mnn = np.full((N, M), -1, dtype=np.intc) _Mnn = np.argpartition(D, range(1, M+1), axis=1)[:, 1:M+1].astype(np.intc) dd = np.full((N,), HUGE_VAL, dtype=np.double) @@ -150,25 +150,25 @@ cdef c_calc_mnn(double[:, :] X, int n_remove, int N, int M, cpp_set[int] extreme Mnn = _Mnn[:, :] d = dd[:] - #Obtain distance metrics + # Obtain distance metrics c_calc_d(d, Mnn, D, calc_items, M) - #While n_remove not acheived (no need to recalculate if only one item should be removed) + # While n_remove not acheived (no need to recalculate if only one item should be removed) while n_removed < (n_remove - 1): - #Obtain element to drop + # Obtain element to drop k = c_get_drop(d, H) H.erase(k) - #Update index + # Update index n_removed = n_removed + 1 - #Get items to be recalculated + # Get items to be recalculated calc_items = c_get_calc_items(Mnn, H, k, M) for n in extremes: calc_items.erase(n) - - #Fill in neighbors and distance matrix + + # Fill in neighbors and distance matrix c_calc_mnn_iter( X, Mnn, @@ -178,7 +178,7 @@ cdef c_calc_mnn(double[:, :] X, int n_remove, int N, int M, cpp_set[int] extreme H ) - #Obtain distance metrics + # Obtain distance metrics c_calc_d(d, Mnn, D, calc_items, M) return dd @@ -195,51 +195,51 @@ cdef c_calc_mnn_iter( cdef: int i, j, m - - #Iterate over items to calculate + + # Iterate over items to calculate for i in calc_items: - #Iterate over elements in X + # Iterate over elements in X for j in H: - #Go to next if same element + # Go to next if same element if (j == i): continue - - #Replace at least the last neighbor + + # Replace at least the last neighbor elif (D[i, j] <= D[i, Mnn[i, M-1]]) or (Mnn[i, M-1] == -1): - - #Iterate over current values + + # Iterate over current values for m in range(M): - #Set to current if unassigned + # Set to current if unassigned if (Mnn[i, m] == -1): - #Set last neighbor to index + # Set last neighbor to index Mnn[i, m] = j break - #Break if checking already corresponding index + # Break if checking already corresponding index elif (j == Mnn[i, m]): break - #Distance satisfies condition + # Distance satisfies condition elif (D[i, j] <= D[i, Mnn[i, m]]): - - #Replace higher values + + # Replace higher values Mnn[i, m + 1:] = Mnn[i, m:-1] - - #Replace current value + + # Replace current value Mnn[i, m] = j break -#Calculate crowding metric +# Calculate crowding metric cdef c_calc_d(double[:] d, int[:, :] Mnn, double[:, :] D, cpp_set[int] calc_items, int M): cdef: int i, m - + for i in calc_items: d[i] = 1 @@ -247,7 +247,7 @@ cdef c_calc_d(double[:] d, int[:, :] Mnn, double[:, :] D, cpp_set[int] calc_item d[i] = d[i] * D[i, Mnn[i, m]] -#Returns indexes of items to be recalculated after removal +# Returns indexes of items to be recalculated after removal cdef cpp_set[int] c_get_calc_items( int[:, :] Mnn, cpp_set[int] H, @@ -256,7 +256,7 @@ cdef cpp_set[int] c_get_calc_items( cdef: int i, m cpp_set[int] calc_items - + calc_items = cpp_set[int]() for i in H: @@ -269,5 +269,5 @@ cdef cpp_set[int] c_get_calc_items( Mnn[i, M-1] = -1 calc_items.insert(i) - + return calc_items diff --git a/pymoo/cython/pruning_cd.pyx b/pymoo/cython/pruning_cd.pyx index a08c07f5a..6a602d60e 100644 --- a/pymoo/cython/pruning_cd.pyx +++ b/pymoo/cython/pruning_cd.pyx @@ -14,7 +14,7 @@ cdef extern from "math.h": double HUGE_VAL -#Python definition +# Python definition def calc_pcd(double[:, :] X, int n_remove=0): cdef: @@ -53,7 +53,7 @@ def calc_pcd(double[:, :] X, int n_remove=0): return c_calc_pcd(X, I, n_remove, N, M, extremes) -#Returns crowding metrics with recursive elimination +# Returns crowding metrics with recursive elimination cdef c_calc_pcd(double[:, :] X, int[:, :] I, int n_remove, int N, int M, cpp_set[int] extremes): cdef: @@ -62,30 +62,30 @@ cdef c_calc_pcd(double[:, :] X, int[:, :] I, int n_remove, int N, int M, cpp_set cpp_set[int] H double[:, :] D double[:] d - - #Define items to calculate distances + + # Define items to calculate distances calc_items = cpp_set[int]() for n in range(N): calc_items.insert(n) for n in extremes: calc_items.erase(n) - - #Define remaining items to evaluate + + # Define remaining items to evaluate H = cpp_set[int]() for n in range(N): H.insert(n) - #Initialize + # Initialize n_removed = 0 - #Initialize neighbors and distances + # Initialize neighbors and distances _D = np.full((N, M), HUGE_VAL, dtype=np.double) dd = np.full((N,), HUGE_VAL, dtype=np.double) D = _D[:, :] d = dd[:] - #Fill in neighbors and distance matrix + # Fill in neighbors and distance matrix c_calc_pcd_iter( X, I, @@ -94,25 +94,25 @@ cdef c_calc_pcd(double[:, :] X, int[:, :] I, int n_remove, int N, int M, cpp_set calc_items, ) - #Obtain distance metrics + # Obtain distance metrics c_calc_d(d, D, calc_items, M) - #While n_remove not acheived + # While n_remove not acheived while n_removed < (n_remove - 1): - #Obtain element to drop + # Obtain element to drop k = c_get_drop(d, H) H.erase(k) - #Update index + # Update index n_removed = n_removed + 1 - #Get items to be recalculated + # Get items to be recalculated calc_items = c_get_calc_items(I, k, M, N) for n in extremes: calc_items.erase(n) - - #Fill in neighbors and distance matrix + + # Fill in neighbors and distance matrix c_calc_pcd_iter( X, I, @@ -121,13 +121,13 @@ cdef c_calc_pcd(double[:, :] X, int[:, :] I, int n_remove, int N, int M, cpp_set calc_items, ) - #Obtain distance metrics + # Obtain distance metrics c_calc_d(d, D, calc_items, M) return dd -#Iterate +# Iterate cdef c_calc_pcd_iter( double[:, :] X, int[:, :] I, @@ -138,11 +138,11 @@ cdef c_calc_pcd_iter( cdef: int i, m, n, l, u - - #Iterate over items to calculate + + # Iterate over items to calculate for i in calc_items: - #Iterate over elements in X + # Iterate over elements in X for m in range(M): for n in range(N): @@ -155,12 +155,12 @@ cdef c_calc_pcd_iter( D[i, m] = (X[u, m] - X[l, m]) / M -#Calculate crowding metric +# Calculate crowding metric cdef c_calc_d(double[:] d, double[:, :] D, cpp_set[int] calc_items, int M): cdef: int i, m - + for i in calc_items: d[i] = 0 @@ -168,7 +168,7 @@ cdef c_calc_d(double[:] d, double[:, :] D, cpp_set[int] calc_items, int M): d[i] = d[i] + D[i, m] -#Returns indexes of items to be recalculated after removal +# Returns indexes of items to be recalculated after removal cdef cpp_set[int] c_get_calc_items( int[:, :] I, int k, int M, int N @@ -177,21 +177,21 @@ cdef cpp_set[int] c_get_calc_items( cdef: int n, m cpp_set[int] calc_items - + calc_items = cpp_set[int]() - #Iterate over all elements in I + # Iterate over all elements in I for m in range(M): for n in range(N): if I[n, m] == k: - #Add to set of items to be recalculated + # Add to set of items to be recalculated calc_items.insert(I[n - 1, m]) calc_items.insert(I[n + 1, m]) - #Remove element from sorted array + # Remove element from sorted array I[n:-1, m] = I[n + 1:, m] - + return calc_items diff --git a/pymoo/cython/utils.pxd b/pymoo/cython/utils.pxd index 1d15b672c..303807cf2 100644 --- a/pymoo/cython/utils.pxd +++ b/pymoo/cython/utils.pxd @@ -10,9 +10,9 @@ from libcpp.set cimport set as cpp_set cdef extern from "math.h": double HUGE_VAL - -#Returns elements to remove based on crowding metric d and heap of remaining elements H + +# Returns elements to remove based on crowding metric d and heap of remaining elements H cdef inline int c_get_drop(double[:] d, cpp_set[int] H): cdef: @@ -27,23 +27,23 @@ cdef inline int c_get_drop(double[:] d, cpp_set[int] H): if d[i] <= min_d: min_d = d[i] min_i = i - + return min_i -#Returns vector of positions of minimum values along axis 0 of a 2d memoryview +# Returns vector of positions of minimum values along axis 0 of a 2d memoryview cdef inline vector[int] c_get_argmin(double[:, :] X): cdef: int N, M, min_i, n, m double min_val vector[int] indexes - + N = X.shape[0] M = X.shape[1] indexes = vector[int]() - + for m in range(M): min_i = 0 @@ -55,25 +55,25 @@ cdef inline vector[int] c_get_argmin(double[:, :] X): min_i = n min_val = X[n, m] - + indexes.push_back(min_i) - + return indexes -#Returns vector of positions of maximum values along axis 0 of a 2d memoryview +# Returns vector of positions of maximum values along axis 0 of a 2d memoryview cdef inline vector[int] c_get_argmax(double[:, :] X): cdef: int N, M, max_i, n, m double max_val vector[int] indexes - + N = X.shape[0] M = X.shape[1] indexes = vector[int]() - + for m in range(M): max_i = 0 @@ -85,13 +85,13 @@ cdef inline vector[int] c_get_argmax(double[:, :] X): max_i = n max_val = X[n, m] - + indexes.push_back(max_i) - + return indexes -#Performs normalization of a 2d memoryview +# Performs normalization of a 2d memoryview cdef inline double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_max, vector[int] extremes_min): cdef: @@ -100,7 +100,7 @@ cdef inline double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_ int n, m, l, u double l_val, u_val, diff_val vector[double] min_vals, max_vals - + min_vals = vector[double]() max_vals = vector[double]() @@ -109,13 +109,13 @@ cdef inline double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_ u_val = X[u, m] max_vals.push_back(u_val) m = m + 1 - + m = 0 for l in extremes_min: l_val = X[l, m] min_vals.push_back(l_val) m = m + 1 - + for m in range(M): diff_val = max_vals[m] - min_vals[m] @@ -125,5 +125,5 @@ cdef inline double[:, :] c_normalize_array(double[:, :] X, vector[int] extremes_ for n in range(N): X[n, m] = (X[n, m] - min_vals[m]) / diff_val - - return X \ No newline at end of file + + return X diff --git a/pymoo/operators/survival/rank_and_crowding/classes.py b/pymoo/operators/survival/rank_and_crowding/classes.py index 56493bc3f..bc326221c 100644 --- a/pymoo/operators/survival/rank_and_crowding/classes.py +++ b/pymoo/operators/survival/rank_and_crowding/classes.py @@ -13,9 +13,9 @@ def __init__(self, nds=None, crowding_func="cd"): A generalization of the NSGA-II survival operator that ranks individuals by dominance criteria and sorts the last front by some user-specified crowding metric. The default is NSGA-II's crowding distances although others might be more effective. - + For many-objective problems, try using 'mnn' or '2nn'. - + For Bi-objective problems, 'pcd' is very effective. Parameters @@ -25,29 +25,29 @@ def __init__(self, nds=None, crowding_func="cd"): crowding_func : str or callable, optional Crowding metric. Options are: - + - 'cd': crowding distances - 'pcd' or 'pruning-cd': improved pruning based on crowding distances - 'ce': crowding entropy - 'mnn': M-Neaest Neighbors - '2nn': 2-Neaest Neighbors - + If callable, it has the form ``fun(F, filter_out_duplicates=None, n_remove=None, **kwargs)`` in which F (n, m) and must return metrics in a (n,) array. - + The options 'pcd', 'cd', and 'ce' are recommended for two-objective problems, whereas 'mnn' and '2nn' for many objective. When using 'pcd', 'mnn', or '2nn', individuals are already eliminated in a 'single' manner. Due to Cython implementation, they are as fast as the corresponding 'cd', 'mnn-fast', or '2nn-fast', although they can singnificantly improve diversity of solutions. Defaults to 'cd'. """ - + crowding_func_ = get_crowding_function(crowding_func) super().__init__(filter_infeasible=True) self.nds = nds if nds is not None else NonDominatedSorting() self.crowding_func = crowding_func_ - + def _do(self, problem, pop, @@ -68,19 +68,19 @@ def _do(self, # current front sorted by crowding distance if splitting while len(survivors) + len(front) > n_survive: - - #Define how many will be removed + + # Define how many will be removed n_remove = len(survivors) + len(front) - n_survive - + # re-calculate the crowding distance of the front crowding_of_front = \ self.crowding_func.do( F[front, :], n_remove=n_remove ) - + I = randomized_argsort(crowding_of_front, order='descending', method='numpy') - + I = I[:-n_remove] front = front[I] @@ -105,7 +105,7 @@ def _do(self, class ConstrRankAndCrowding(Survival): - + def __init__(self, nds=None, crowding_func="cd"): """ The Rank and Crowding survival approach for handling constraints proposed on @@ -118,27 +118,27 @@ def __init__(self, nds=None, crowding_func="cd"): crowding_func : str or callable, optional Crowding metric. Options are: - + - 'cd': crowding distances - 'pcd' or 'pruning-cd': improved pruning based on crowding distances - 'ce': crowding entropy - 'mnn': M-Neaest Neighbors - '2nn': 2-Neaest Neighbors - + If callable, it has the form ``fun(F, filter_out_duplicates=None, n_remove=None, **kwargs)`` in which F (n, m) and must return metrics in a (n,) array. - + The options 'pcd', 'cd', and 'ce' are recommended for two-objective problems, whereas 'mnn' and '2nn' for many objective. When using 'pcd', 'mnn', or '2nn', individuals are already eliminated in a 'single' manner. Due to Cython implementation, they are as fast as the corresponding 'cd', 'mnn-fast', or '2nn-fast', although they can singnificantly improve diversity of solutions. Defaults to 'cd'. """ - + super().__init__(filter_infeasible=False) self.nds = nds if nds is not None else NonDominatedSorting() self.ranking = RankAndCrowding(nds=nds, crowding_func=crowding_func) - + def _do(self, problem, pop, @@ -151,49 +151,49 @@ def _do(self, n_survive = min(n_survive, len(pop)) - #If the split should be done beforehand + # If the split should be done beforehand if problem.n_constr > 0: - #Split by feasibility + # Split by feasibility feas, infeas = feas, infeas = split_by_feasibility(pop, sort_infeas_by_cv=True, sort_feas_by_obj=False, return_pop=False) - #Obtain len of feasible + # Obtain len of feasible n_feas = len(feas) - #Assure there is at least_one survivor + # Assure there is at least_one survivor if n_feas == 0: survivors = Population() else: survivors = self.ranking.do(problem, pop[feas], *args, n_survive=min(len(feas), n_survive), **kwargs) - #Calculate how many individuals are still remaining to be filled up with infeasible ones + # Calculate how many individuals are still remaining to be filled up with infeasible ones n_remaining = n_survive - len(survivors) - #If infeasible solutions need to be added + # If infeasible solutions need to be added if n_remaining > 0: - - #Constraints to new ranking + + # Constraints to new ranking G = pop[infeas].get("G") G = np.maximum(G, 0) - - #Fronts in infeasible population + + # Fronts in infeasible population infeas_fronts = self.nds.do(G, n_stop_if_ranked=n_remaining) - - #Iterate over fronts + + # Iterate over fronts for k, front in enumerate(infeas_fronts): - #Save ranks + # Save ranks pop[infeas][front].set("cv_rank", k) - #Current front sorted by CV + # Current front sorted by CV if len(survivors) + len(front) > n_survive: - - #Obtain CV of front + + # Obtain CV of front CV = pop[infeas][front].get("CV").flatten() I = randomized_argsort(CV, order='ascending', method='numpy') I = I[:(n_survive - len(survivors))] - #Otherwise take the whole front unsorted + # Otherwise take the whole front unsorted else: I = np.arange(len(front)) diff --git a/pymoo/operators/survival/rank_and_crowding/metrics.py b/pymoo/operators/survival/rank_and_crowding/metrics.py index f6fb4b846..ba45dee32 100644 --- a/pymoo/operators/survival/rank_and_crowding/metrics.py +++ b/pymoo/operators/survival/rank_and_crowding/metrics.py @@ -23,29 +23,29 @@ def get_crowding_function(label): else: raise KeyError("Crowding function not defined") return fun - + class CrowdingDiversity: - + def do(self, F, n_remove=0): - #Converting types Python int to Cython int would fail in some cases converting to long instead + # Converting types Python int to Cython int would fail in some cases converting to long instead n_remove = np.intc(n_remove) F = np.array(F, dtype=np.double) return self._do(F, n_remove=n_remove) - + def _do(self, F, n_remove=None): pass class FunctionalDiversity(CrowdingDiversity): - + def __init__(self, function=None, filter_out_duplicates=True): self.function = function self.filter_out_duplicates = filter_out_duplicates super().__init__() - + def _do(self, F, **kwargs): - + n_points, n_obj = F.shape if n_points <= 2: @@ -62,24 +62,24 @@ def _do(self, F, **kwargs): # index the unique points of the array _F = F[is_unique] - + _d = self.function(_F, **kwargs) - + d = np.zeros(n_points) d[is_unique] = _d - + return d class FuncionalDiversityMNN(FunctionalDiversity): - + def _do(self, F, **kwargs): - + n_points, n_obj = F.shape if n_points <= n_obj: return np.full(n_points, np.inf) - + else: return super()._do(F, **kwargs) @@ -113,7 +113,7 @@ def calc_crowding_distance(F, **kwargs): cd = np.sum(dist_to_last[J, np.arange(n_obj)] + dist_to_next[J, np.arange(n_obj)], axis=1) / n_obj return cd - + def calc_crowding_entropy(F, **kwargs): """Wang, Y.-N., Wu, L.-H. & Yuan, X.-F., 2010. Multi-objective self-adaptive differential @@ -148,24 +148,24 @@ def calc_crowding_entropy(F, **kwargs): # prepare the distance to last and next vectors dl = dist.copy()[:-1] du = dist.copy()[1:] - - #Fix nan + + # Fix nan dl[np.isnan(dl)] = 0.0 du[np.isnan(du)] = 0.0 - - #Total distance + + # Total distance cd = dl + du - #Get relative positions + # Get relative positions pl = (dl[1:-1] / cd[1:-1]) pu = (du[1:-1] / cd[1:-1]) - #Entropy + # Entropy entropy = np.row_stack([np.full(n_obj, np.inf), -(pl * np.log2(pl) + pu * np.log2(pu)), np.full(n_obj, np.inf)]) - - #Crowding entropy + + # Crowding entropy J = np.argsort(I, axis=0) _cej = cd[J, np.arange(n_obj)] * entropy[J, np.arange(n_obj)] / norm _cej[np.isnan(_cej)] = 0.0 @@ -176,8 +176,8 @@ def calc_crowding_entropy(F, **kwargs): def calc_mnn_fast(F, **kwargs): return _calc_mnn_fast(F, F.shape[1], **kwargs) - - + + def calc_2nn_fast(F, **kwargs): return _calc_mnn_fast(F, 2, **kwargs) @@ -187,20 +187,20 @@ def _calc_mnn_fast(F, n_neighbors, **kwargs): # calculate the norm for each objective - set to NaN if all values are equal norm = np.max(F, axis=0) - np.min(F, axis=0) norm[norm == 0] = 1.0 - + # F normalized F = (F - F.min(axis=0)) / norm - + # Distances pairwise (Inefficient) D = squareform(pdist(F, metric="sqeuclidean")) - + # M neighbors M = F.shape[1] _D = np.partition(D, range(1, M+1), axis=1)[:, 1:M+1] - + # Metric d d = np.prod(_D, axis=1) - + # Set top performers as np.inf _extremes = np.concatenate((np.argmin(F, axis=0), np.argmax(F, axis=0))) d[_extremes] = np.inf From 3cb4f2b94a00d487dedeb308b0bfd5080f073868 Mon Sep 17 00:00:00 2001 From: mooscalia project <93492480+mooscaliaproject@users.noreply.github.com> Date: Fri, 18 Nov 2022 00:19:42 -0300 Subject: [PATCH 32/32] Style code formatting with autopep8 --- pymoo/algorithms/moo/gde3.py | 63 +++++++------- pymoo/algorithms/moo/nsde.py | 38 ++++---- pymoo/algorithms/moo/nsder.py | 39 +++++---- pymoo/algorithms/soo/nonconvex/de.py | 61 +++++++------ pymoo/operators/crossover/dex.py | 119 +++++++++++++------------ pymoo/operators/selection/des.py | 126 ++++++++++++++------------- 6 files changed, 226 insertions(+), 220 deletions(-) diff --git a/pymoo/algorithms/moo/gde3.py b/pymoo/algorithms/moo/gde3.py index 00067f1f4..636086c5a 100644 --- a/pymoo/algorithms/moo/gde3.py +++ b/pymoo/algorithms/moo/gde3.py @@ -10,7 +10,7 @@ class GDE3(NSDE): - + def __init__(self, pop_size=100, variant="DE/rand/1/bin", @@ -21,11 +21,11 @@ def __init__(self, """ GDE3 is an extension of DE to multi-objective problems using a mixed type survival strategy. It is implemented in this version with the same constraint handling strategy of NSGA-II by default. - + Derived algorithms GDE3-MNN and GDE3-2NN use by default survival RankAndCrowding with metrics 'mnn' and '2nn'. - + For many-objective problems, try using NSDE-R, GDE3-MNN, or GDE3-2NN. - + For Bi-objective problems, survival = RankAndCrowding(crowding_func='pcd') is very effective. Kukkonen, S. & Lampinen, J., 2005. GDE3: The third evolution step of generalized differential evolution. 2005 IEEE congress on evolutionary computation, Volume 1, pp. 443-450. @@ -34,13 +34,13 @@ def __init__(self, ---------- pop_size : int, optional Population size. Defaults to 100. - + sampling : Sampling, optional Sampling strategy of pymoo. Defaults to LHS(). - + variant : str, optional Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: - + - 'ranked' - 'rand' - 'best' @@ -48,43 +48,43 @@ def __init__(self, - 'current-to-best' - 'current-to-rand' - 'rand-to-best' - + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. - + CR : float, optional Crossover parameter. Defined in the range [0, 1] To reinforce mutation, use higher values. To control convergence speed, use lower values. - + F : iterable of float or float, optional Scale factor or mutation parameter. Defined in the range (0, 2] To reinforce exploration, use higher values; for exploitation, use lower values. - + gamma : float, optional Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. - + de_repair : str, optional Repair of DE mutant vectors. Is either callable or one of: - + - 'bounce-back' - 'midway' - 'rand-init' - 'to-bounds' - + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. Defaults to 'bounce-back'. - + mutation : Mutation, optional Pymoo's mutation operator after crossover. Defaults to NoMutation(). - + repair : Repair, optional Pymoo's repair operator after mutation. Defaults to NoRepair(). - + survival : Survival, optional Pymoo's survival strategy. Defaults to RankAndCrowding() with crowding distances ('cd'). In GDE3, the survival strategy is applied after a one-to-one comparison between child vector and corresponding parent when both are non-dominated by the other. """ - + super().__init__(pop_size=pop_size, variant=variant, CR=CR, @@ -93,55 +93,56 @@ def __init__(self, **kwargs) def _advance(self, infills=None, **kwargs): - + assert infills is not None, "This algorithms uses the AskAndTell interface thus 'infills' must to be provided." - #The individuals that are considered for the survival later and final survive + # The individuals that are considered for the survival later and final survive survivors = [] # now for each of the infill solutions for k in range(len(self.pop)): - #Get the offspring an the parent it is coming from + # Get the offspring an the parent it is coming from off, parent = infills[k], self.pop[k] - #Check whether the new solution dominates the parent or not + # Check whether the new solution dominates the parent or not rel = get_relation(parent, off) - #If indifferent we add both + # If indifferent we add both if rel == 0: survivors.extend([parent, off]) - #If offspring dominates parent + # If offspring dominates parent elif rel == -1: survivors.append(off) - #If parent dominates offspring + # If parent dominates offspring else: survivors.append(parent) - #Create the population + # Create the population survivors = Population.create(*survivors) - #Perform a survival to reduce to pop size + # Perform a survival to reduce to pop size self.pop = self.survival.do(self.problem, survivors, n_survive=self.n_offsprings) class GDE3MNN(GDE3): - + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): survival = RankAndCrowding(crowding_func="mnn") super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) class GDE32NN(GDE3): - + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): survival = RankAndCrowding(crowding_func="2nn") super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) + class GDE3PCD(GDE3): - + def __init__(self, pop_size=100, variant="DE/rand/1/bin", CR=0.5, F=None, gamma=0.0001, **kwargs): survival = RankAndCrowding(crowding_func="pcd") - super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) \ No newline at end of file + super().__init__(pop_size, variant, CR, F, gamma, survival=survival, **kwargs) diff --git a/pymoo/algorithms/moo/nsde.py b/pymoo/algorithms/moo/nsde.py index 79b012d1d..8bf11596c 100644 --- a/pymoo/algorithms/moo/nsde.py +++ b/pymoo/algorithms/moo/nsde.py @@ -10,7 +10,7 @@ class NSDE(NSGA2): - + def __init__(self, pop_size=100, sampling=LHS(), @@ -24,22 +24,22 @@ def __init__(self, """ NSDE is an algorithm that combines that combines NSGA-II sorting and survival strategies to DE mutation and crossover. - + For many-objective problems, try using NSDE-R, GDE3-MNN, or GDE3-2NN. - + For Bi-objective problems, survival = RankAndCrowding(crowding_func='pcd') is very effective. Parameters ---------- pop_size : int, optional Population size. Defaults to 100. - + sampling : Sampling, optional Sampling strategy of pymoo. Defaults to LHS(). - + variant : str, optional Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: - + - "ranked' - 'rand' - 'best' @@ -47,46 +47,46 @@ def __init__(self, - 'current-to-best' - 'current-to-rand' - 'rand-to-best' - + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. - + CR : float, optional Crossover parameter. Defined in the range [0, 1] To reinforce mutation, use higher values. To control convergence speed, use lower values. - + F : iterable of float or float, optional Scale factor or mutation parameter. Defined in the range (0, 2] To reinforce exploration, use higher values; for exploitation, use lower values. - + gamma : float, optional Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. - + de_repair : str, optional Repair of DE mutant vectors. Is either callable or one of: - + - 'bounce-back' - 'midway' - 'rand-init' - 'to-bounds' - + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. Defaults to 'bounce-back'. - + mutation : Mutation, optional Pymoo's mutation operator after crossover. Defaults to NoMutation(). - + repair : Repair, optional Pymoo's repair operator after mutation. Defaults to NoRepair(). - + survival : Survival, optional Pymoo's survival strategy. Defaults to RankAndCrowding() with crowding distances ('cd'). In GDE3, the survival strategy is applied after a one-to-one comparison between child vector and corresponding parent when both are non-dominated by the other. """ - + # Number of offsprings at each generation n_offsprings = pop_size - + # Mating mating = VariantDE(variant=variant, CR=CR, @@ -94,7 +94,7 @@ def __init__(self, gamma=gamma, de_repair=de_repair, **kwargs) - + # Init from pymoo's NSGA2 super().__init__(pop_size=pop_size, sampling=sampling, diff --git a/pymoo/algorithms/moo/nsder.py b/pymoo/algorithms/moo/nsder.py index d0abd821c..a0a8f5de5 100644 --- a/pymoo/algorithms/moo/nsder.py +++ b/pymoo/algorithms/moo/nsder.py @@ -8,8 +8,9 @@ # Implementation # ========================================================================================================= + class NSDER(NSDE): - + def __init__(self, ref_dirs, pop_size=100, @@ -20,23 +21,23 @@ def __init__(self, **kwargs): """ NSDE-R is an extension of NSDE to many-objective problems (Reddy & Dulikravich, 2019) using NSGA-III survival. - + S. R. Reddy and G. S. Dulikravich, "Many-objective differential evolution optimization based on reference points: NSDE-R," Struct. Multidisc. Optim., vol. 60, pp. 1455-1473, 2019. Parameters ---------- ref_dirs : array like The reference directions that should be used during the optimization. - + pop_size : int, optional Population size. Defaults to 100. - + sampling : Sampling, optional Sampling strategy of pymoo. Defaults to LHS(). - + variant : str, optional Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: - + - "ranked' - 'rand' - 'best' @@ -44,42 +45,42 @@ def __init__(self, - 'current-to-best' - 'current-to-rand' - 'rand-to-best' - + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. - + CR : float, optional Crossover parameter. Defined in the range [0, 1] To reinforce mutation, use higher values. To control convergence speed, use lower values. - + F : iterable of float or float, optional Scale factor or mutation parameter. Defined in the range (0, 2] To reinforce exploration, use higher values; for exploitation, use lower values. - + gamma : float, optional Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. - + de_repair : str, optional Repair of DE mutant vectors. Is either callable or one of: - + - 'bounce-back' - 'midway' - 'rand-init' - 'to-bounds' - + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. Defaults to 'bounce-back'. - + mutation : Mutation, optional Pymoo's mutation operator after crossover. Defaults to NoMutation(). - + repair : Repair, optional Pymoo's repair operator after mutation. Defaults to NoRepair(). - + survival : Survival, optional Pymoo's survival strategy. Defaults to ReferenceDirectionSurvival(). """ - + self.ref_dirs = ref_dirs if self.ref_dirs is not None: @@ -98,7 +99,7 @@ def __init__(self, del kwargs['survival'] else: survival = ReferenceDirectionSurvival(ref_dirs) - + super().__init__(pop_size=pop_size, variant=variant, CR=CR, @@ -114,7 +115,7 @@ def _setup(self, problem, **kwargs): raise Exception( "Dimensionality of reference points must be equal to the number of objectives: %s != %s" % (self.ref_dirs.shape[1], problem.n_obj)) - + def _set_optimum(self, **kwargs): if not has_feasible(self.pop): self.opt = self.pop[[np.argmin(self.pop.get("CV"))]] diff --git a/pymoo/algorithms/soo/nonconvex/de.py b/pymoo/algorithms/soo/nonconvex/de.py index c48499b61..d4ae4a586 100755 --- a/pymoo/algorithms/soo/nonconvex/de.py +++ b/pymoo/algorithms/soo/nonconvex/de.py @@ -16,7 +16,7 @@ # ========================================================================================================= class VariantDE(InfillCriterion): - + def __init__(self, variant="DE/rand/1/bin", CR=0.7, @@ -25,25 +25,25 @@ def __init__(self, de_repair="bounce-back", mutation=None, **kwargs): - + # Default initialization of InfillCriterion super().__init__(eliminate_duplicates=None, **kwargs) - + # Parse the information from the string _, selection_variant, n_diff, crossover_variant, = variant.split("/") n_diffs = int(n_diff) - + # When "to" in variant there are more than 1 difference vectors if "-to-" in variant: n_diffs += 1 - + # Define parent selection operator self.selection = DES(selection_variant) - + # Default value for F if F is None: F = (0.0, 1.0) - + # Define crossover strategy (DE mutation is included) self.crossover = DEX(variant=crossover_variant, CR=CR, @@ -52,24 +52,23 @@ def __init__(self, n_diffs=n_diffs, at_least_once=True, de_repair=de_repair) - + # Define posterior mutation strategy and repair self.mutation = mutation if mutation is not None else NoMutation() - def _do(self, problem, pop, n_offsprings, **kwargs): - + # Select parents including donor vector parents = self.selection(problem, pop, n_offsprings, self.crossover.n_parents, to_pop=True, **kwargs) - + # Perform mutation included in DEX and crossover off = self.crossover(problem, parents, **kwargs) - + # Perform posterior mutation and repair if passed off = self.mutation(problem, off) - + return off - + class DE(GeneticAlgorithm): @@ -92,13 +91,13 @@ def __init__(self, ---------- pop_size : int, optional Population size. Defaults to 100. - + sampling : Sampling, optional Sampling strategy of pymoo. Defaults to LHS(). - + variant : str, optional Differential evolution strategy. Must be a string in the format: "DE/selection/n/crossover", in which, n in an integer of number of difference vectors, and crossover is either 'bin' or 'exp'. Selection variants are: - + - 'ranked' - 'rand' - 'best' @@ -106,38 +105,38 @@ def __init__(self, - 'current-to-best' - 'current-to-rand' - 'rand-to-best' - + The selection strategy 'ranked' might be helpful to improve convergence speed without much harm to diversity. Defaults to 'DE/rand/1/bin'. - + CR : float, optional Crossover parameter. Defined in the range [0, 1] To reinforce mutation, use higher values. To control convergence speed, use lower values. - + F : iterable of float or float, optional Scale factor or mutation parameter. Defined in the range (0, 2] To reinforce exploration, use higher values; for exploitation, use lower values. - + gamma : float, optional Jitter deviation parameter. Should be in the range (0, 2). Defaults to 1e-4. - + de_repair : str, optional Repair of DE mutant vectors. Is either callable or one of: - + - 'bounce-back' - 'midway' - 'rand-init' - 'to-bounds' - + If callable, has the form fun(X, Xb, xl, xu) in which X contains mutated vectors including violations, Xb contains reference vectors for repair in feasible space, xl is a 1d vector of lower bounds, and xu a 1d vector of upper bounds. Defaults to 'bounce-back'. - + mutation : Mutation, optional Pymoo's mutation operator after crossover. Defaults to NoMutation(). - + repair : Repair, optional Pymoo's repair operator after mutation. Defaults to NoRepair(). """ - + # Mating mating = VariantDE(variant=variant, CR=CR, @@ -145,7 +144,7 @@ def __init__(self, gamma=gamma, de_repair=de_repair, **kwargs) - + # Number of offsprings at each generation n_offsprings = pop_size @@ -163,13 +162,13 @@ def _initialize_advance(self, infills=None, **kwargs): self.pop = FitnessSurvival().do(self.problem, infills, n_survive=self.pop_size) def _infill(self): - + infills = self.mating(self.problem, self.pop, self.n_offsprings, algorithm=self) return infills def _advance(self, infills=None, **kwargs): - + assert infills is not None, "This algorithms uses the AskAndTell interface thus infills must be provided." # One-to-one replacement survival @@ -177,6 +176,6 @@ def _advance(self, infills=None, **kwargs): # Sort the population by fitness to make the selection simpler for mating (not an actual survival, just sorting) self.pop = FitnessSurvival().do(self.problem, self.pop) - + # Set ranks self.pop.set("rank", np.arange(self.pop_size)) diff --git a/pymoo/operators/crossover/dex.py b/pymoo/operators/crossover/dex.py index 67a487965..acecac7fb 100644 --- a/pymoo/operators/crossover/dex.py +++ b/pymoo/operators/crossover/dex.py @@ -10,7 +10,7 @@ # ========================================================================================================= class DEM(Crossover): - + def __init__(self, F=None, gamma=1e-4, @@ -21,55 +21,54 @@ def __init__(self, # Default value for F if F is None: F = (0.0, 1.0) - + # Define which method will be used to generate F values if hasattr(F, "__iter__"): self.scale_factor = self._randomize_scale_factor else: self.scale_factor = self._scalar_scale_factor - + # Define which method will be used to generate F values if not hasattr(de_repair, "__call__"): try: de_repair = REPAIRS[de_repair] except: raise KeyError("Repair must be either callable or in " + str(list(REPAIRS.keys()))) - + # Define which strategy of rotation will be used if gamma is None: self.get_diff = self._diff_simple else: self.get_diff = self._diff_jitter - + self.F = F self.gamma = gamma self.de_repair = de_repair - + super().__init__(1 + 2 * n_diffs, 1, prob=1.0, **kwargs) - - + def do(self, problem, pop, parents=None, **kwargs): - + # Convert pop if parents is not None if not parents is None: pop = pop[parents] - + # Get all X values for mutation parents Xr = np.swapaxes(pop, 0, 1).get("X") - + # Create mutation vectors V, diffs = self.de_mutation(Xr, return_differentials=True) # If the problem has boundaries to be considered if problem.has_bounds(): - + # Do repair V = self.de_repair(V, Xr[0], *problem.bounds()) - + return Population.new("X", V) - + def de_mutation(self, Xr, return_differentials=True): - + n_parents, n_matings, n_var = Xr.shape assert n_parents % 2 == 1, "For the differential an odd number of values need to be provided" @@ -86,42 +85,42 @@ def de_mutation(self, Xr, return_differentials=True): return V, diffs else: return V - + def _randomize_scale_factor(self, n_matings): - return (self.F[0] + np.random.random(n_matings) * (self.F[1] - self.F[0])) - + return (self.F[0] + np.random.random(n_matings) * (self.F[1] - self.F[0])) + def _scalar_scale_factor(self, n_matings): - return np.full(n_matings, self.F) - + return np.full(n_matings, self.F) + def _diff_jitter(self, F, Xi, Xj, n_matings, n_var): F = F[:, None] * (1 + self.gamma * (np.random.random((n_matings, n_var)) - 0.5)) return F * (Xi - Xj) - + def _diff_simple(self, F, Xi, Xj, n_matings, n_var): return F[:, None] * (Xi - Xj) - + def get_diffs(self, Xr, pairs, n_matings, n_var): - + # The differentials from each pair subtraction diffs = np.zeros((n_matings, n_var)) - + # For each difference for i, j in pairs: - + # Obtain F randomized in range F = self.scale_factor(n_matings) - + # New difference vector diff = self.get_diff(F, Xr[i], Xr[j], n_matings, n_var) - + # Add the difference to the first vector diffs = diffs + diff - + return diffs - - + + class DEX(Crossover): - + def __init__(self, variant="bin", CR=0.7, @@ -131,42 +130,41 @@ def __init__(self, at_least_once=True, de_repair="bounce-back", **kwargs): - + # Default value for F if F is None: F = (0.0, 1.0) - + # Create instace for mutation self.dem = DEM(F=F, gamma=gamma, de_repair=de_repair, n_diffs=n_diffs) - + self.CR = CR self.variant = variant self.at_least_once = at_least_once - + super().__init__(2 + 2 * n_diffs, 1, prob=1.0, **kwargs) - def do(self, problem, pop, parents=None, **kwargs): - + # Convert pop if parents is not None if not parents is None: pop = pop[parents] - + # Get all X values for mutation parents X = pop[:, 0].get("X") - - #About Xi + + # About Xi n_matings, n_var = X.shape - - #Obtain mutants + + # Obtain mutants mutants = self.dem.do(problem, pop[:, 1:], **kwargs) - + # Obtain V V = mutants.get("X") - + # Binomial crossover if self.variant == "bin": M = mut_binomial(n_matings, n_var, self.CR, at_least_once=self.at_least_once) @@ -180,9 +178,9 @@ def do(self, problem, pop, parents=None, **kwargs): X[M] = V[M] off = Population.new("X", X) - + return off - + def bounce_back(X, Xb, xl, xu): """Repair strategy @@ -196,7 +194,7 @@ def bounce_back(X, Xb, xl, xu): Returns: 2d array like: Repaired vectors. """ - + XL = xl[None, :].repeat(len(X), axis=0) XU = xu[None, :].repeat(len(X), axis=0) @@ -210,6 +208,7 @@ def bounce_back(X, Xb, xl, xu): return X + def midway(X, Xb, xl, xu): """Repair strategy @@ -222,7 +221,7 @@ def midway(X, Xb, xl, xu): Returns: 2d array like: Repaired vectors. """ - + XL = xl[None, :].repeat(len(X), axis=0) XU = xu[None, :].repeat(len(X), axis=0) @@ -236,6 +235,7 @@ def midway(X, Xb, xl, xu): return X + def to_bounds(X, Xb, xl, xu): """Repair strategy @@ -248,7 +248,7 @@ def to_bounds(X, Xb, xl, xu): Returns: 2d array like: Repaired vectors. """ - + XL = xl[None, :].repeat(len(X), axis=0) XU = xu[None, :].repeat(len(X), axis=0) @@ -262,6 +262,7 @@ def to_bounds(X, Xb, xl, xu): return X + def rand_init(X, Xb, xl, xu): """Repair strategy @@ -274,7 +275,7 @@ def rand_init(X, Xb, xl, xu): Returns: 2d array like: Repaired vectors. """ - + XL = xl[None, :].repeat(len(X), axis=0) XU = xu[None, :].repeat(len(X), axis=0) @@ -301,7 +302,7 @@ def squared_bounce_back(X, Xb, xl, xu): Returns: 2d array like: Repaired vectors. """ - + XL = xl[None, :].repeat(len(X), axis=0) XU = xu[None, :].repeat(len(X), axis=0) @@ -315,17 +316,19 @@ def squared_bounce_back(X, Xb, xl, xu): return X + def normalize_fun(fun): - + fmin = fun.min(axis=0) fmax = fun.max(axis=0) den = fmax - fmin - + den[den <= 1e-16] = 1.0 - + return (fun - fmin)/den -REPAIRS = {"bounce-back":bounce_back, - "midway":midway, - "rand-init":rand_init, - "to-bounds":to_bounds} + +REPAIRS = {"bounce-back": bounce_back, + "midway": midway, + "rand-init": rand_init, + "to-bounds": to_bounds} diff --git a/pymoo/operators/selection/des.py b/pymoo/operators/selection/des.py index 897eac412..d889f95fe 100644 --- a/pymoo/operators/selection/des.py +++ b/pymoo/operators/selection/des.py @@ -13,201 +13,203 @@ class DES(Selection): def __init__(self, variant, **kwargs): - + super().__init__() self.variant = variant def _do(self, problem, pop, n_select, n_parents, **kwargs): - + # Obtain number of elements in population n_pop = len(pop) - + # For most variants n_select must be equal to len(pop) variant = self.variant - + if variant == "ranked": """Proposed by Zhang et al. (2021). doi.org/10.1016/j.asoc.2021.107317""" P = self._ranked(pop, n_select, n_parents) - + elif variant == "best": P = self._best(pop, n_select, n_parents) - + elif variant == "current-to-best": P = self._current_to_best(pop, n_select, n_parents) - + elif variant == "current-to-rand": P = self._current_to_rand(pop, n_select, n_parents) - + else: P = self._rand(pop, n_select, n_parents) return P - + def _rand(self, pop, n_select, n_parents, **kwargs): - + # len of pop n_pop = len(pop) # Base form P = np.empty([n_select, n_parents], dtype=int) - + # Fill first column with corresponding parent P[:, 0] = np.arange(n_pop) # Fill next columns in loop for j in range(1, n_parents): - - P[:, j] = np.random.choice(n_pop, n_select) + + P[:, j] = np.random.choice(n_pop, n_select) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + while np.any(reselect): P[reselect, j] = np.random.choice(n_pop, reselect.sum()) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + return P - + def _best(self, pop, n_select, n_parents, **kwargs): - + # len of pop n_pop = len(pop) # Base form P = np.empty([n_select, n_parents], dtype=int) - + # Fill first column with corresponding parent P[:, 0] = np.arange(n_pop) - + # Fill first column with best candidate P[:, 1] = 0 # Fill next columns in loop for j in range(2, n_parents): - - P[:, j] = np.random.choice(n_pop, n_select) + + P[:, j] = np.random.choice(n_pop, n_select) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + while np.any(reselect): P[reselect, j] = np.random.choice(n_pop, reselect.sum()) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + return P - + def _current_to_best(self, pop, n_select, n_parents, **kwargs): - + # len of pop n_pop = len(pop) # Base form P = np.empty([n_select, n_parents], dtype=int) - + # Fill first column with corresponding parent P[:, 0] = np.arange(n_pop) - + # Fill first column with current candidate P[:, 1] = np.arange(n_pop) - + # Fill first direction from current P[:, 3] = np.arange(n_pop) - + # Towards best P[:, 2] = 0 # Fill next columns in loop for j in range(4, n_parents): - - P[:, j] = np.random.choice(n_pop, n_select) + + P[:, j] = np.random.choice(n_pop, n_select) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + while np.any(reselect): P[reselect, j] = np.random.choice(n_pop, reselect.sum()) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + return P - + def _current_to_rand(self, pop, n_select, n_parents, **kwargs): - + # len of pop n_pop = len(pop) # Base form P = np.empty([n_select, n_parents], dtype=int) - + # Fill first column with corresponding parent P[:, 0] = np.arange(n_pop) - + # Fill first column with current candidate P[:, 1] = np.arange(n_pop) - + # Fill first direction from current P[:, 3] = np.arange(n_pop) - + # Towards random - P[:, 2] = np.random.choice(n_pop, n_select) + P[:, 2] = np.random.choice(n_pop, n_select) reselect = (P[:, 2].reshape([-1, 1]) == P[:, [0, 1, 3]]).any(axis=1) - + while np.any(reselect): P[reselect, 2] = np.random.choice(n_pop, reselect.sum()) reselect = (P[:, 2].reshape([-1, 1]) == P[:, [0, 1, 3]]).any(axis=1) # Fill next columns in loop for j in range(4, n_parents): - - P[:, j] = np.random.choice(n_pop, n_select) + + P[:, j] = np.random.choice(n_pop, n_select) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + while np.any(reselect): P[reselect, j] = np.random.choice(n_pop, reselect.sum()) reselect = (P[:, j].reshape([-1, 1]) == P[:, :j]).any(axis=1) - + return P - + def _ranked(self, pop, n_select, n_parents, **kwargs): - + P = self._rand(pop, n_select, n_parents, **kwargs) P[:, 1:] = rank_sort(P[:, 1:], pop) - + return P - + def ranks_from_cv(pop): - + ranks = pop.get("rank") cv_elements = ranks == None - + if np.any(cv_elements): ranks[cv_elements] = np.arange(len(pop))[cv_elements] - + return ranks + def rank_sort(P, pop): - + ranks = ranks_from_cv(pop) - - sorted = np.argsort(ranks[P], axis=1, kind="stable") + + sorted = np.argsort(ranks[P], axis=1, kind="stable") S = np.take_along_axis(P, sorted, axis=1) P[:, 0] = S[:, 0] - + n_diffs = int((P.shape[1] - 1) / 2) for j in range(1, n_diffs + 1): P[:, 2*j - 1] = S[:, j] P[:, 2*j] = S[:, -j] - + return P + def reiforce_directions(P, pop): - + ranks = ranks_from_cv(pop) - - ranks = ranks[P] + + ranks = ranks[P] S = P.copy() - + n_diffs = int(P.shape[1] / 2) for j in range(0, n_diffs): bad_directions = ranks[:, 2*j] > ranks[:, 2*j + 1] P[bad_directions, 2*j] = S[bad_directions, 2*j + 1] P[bad_directions, 2*j + 1] = S[bad_directions, 2*j] - + return P