diff --git a/bioptim/limits/penalty.py b/bioptim/limits/penalty.py index bcf36896e..c5411c83a 100644 --- a/bioptim/limits/penalty.py +++ b/bioptim/limits/penalty.py @@ -88,13 +88,6 @@ def minimize_controls(penalty: PenaltyOption, controller: PenaltyController, key """ penalty.quadratic = True if penalty.quadratic is None else penalty.quadratic - if key in controller.get_nlp.variable_mappings: - target_mapping = controller.get_nlp.variable_mappings[key] - else: - target_mapping = BiMapping( - to_first=list(range(controller.get_nlp.controls[key].cx_start.shape[0])), - to_second=list(range(controller.get_nlp.controls[key].cx_start.shape[0])), - ) # TODO: why if condition, target_mapping not used (Pariterre?) if penalty.integration_rule == QuadratureRule.RECTANGLE_LEFT: # TODO: for trapezoidal integration (This should not be done here but in _set_penalty_function) diff --git a/bioptim/limits/penalty_option.py b/bioptim/limits/penalty_option.py index 58964f558..8001bb775 100644 --- a/bioptim/limits/penalty_option.py +++ b/bioptim/limits/penalty_option.py @@ -6,6 +6,7 @@ from .penalty_controller import PenaltyController from ..misc.enums import Node, PlotType, ControlType, PenaltyType, QuadratureRule, PhaseDynamics from ..misc.options import OptionGeneric +from ..misc.mapping import BiMapping from ..models.protocols.stochastic_biomodel import StochasticBioModel from ..limits.penalty_helpers import PenaltyHelpers @@ -730,8 +731,8 @@ def vertcat_cx_end(): return u @staticmethod - def define_target_mapping(controller: PenaltyController, key: str): - target_mapping = controller.get_nlp.variable_mappings[key] + def define_target_mapping(controller: PenaltyController, key: str, rows): + target_mapping = BiMapping(range(len(controller.get_nlp.variable_mappings[key].to_first.map_idx)), list(rows)) return target_mapping def add_target_to_plot(self, controller: PenaltyController, combine_to: str): @@ -777,7 +778,7 @@ def plot_function(t0, phases_dt, node_idx, x, u, p, a, penalty=None): else: plot_type = PlotType.POINT - target_mapping = self.define_target_mapping(controller, self.params["key"]) + target_mapping = self.define_target_mapping(controller, self.params["key"], self.rows) controller.ocp.add_plot( self.target_plot_name, plot_function, @@ -785,7 +786,7 @@ def plot_function(t0, phases_dt, node_idx, x, u, p, a, penalty=None): color="tab:red", plot_type=plot_type, phase=controller.get_nlp.phase_idx, - axes_idx=target_mapping, # TODO verify if not all elements has target + axes_idx=target_mapping, node_idx=controller.t, )