diff --git a/bioptim/dynamics/configure_problem.py b/bioptim/dynamics/configure_problem.py index 63e64d711..3e6d8195b 100644 --- a/bioptim/dynamics/configure_problem.py +++ b/bioptim/dynamics/configure_problem.py @@ -558,7 +558,6 @@ def stochastic_torque_driven_free_floating_base( ConfigureProblem.torque_driven_free_floating_base( ocp=ocp, nlp=nlp, - with_contact=with_contact, # TODO : this should be removed with_friction=with_friction, ) @@ -566,7 +565,6 @@ def stochastic_torque_driven_free_floating_base( ocp, nlp, DynamicsFunctions.stochastic_torque_driven_free_floating_base, - with_contact=with_contact, with_friction=with_friction, ) diff --git a/bioptim/dynamics/dynamics_functions.py b/bioptim/dynamics/dynamics_functions.py index 09537a747..d8d83a3b2 100644 --- a/bioptim/dynamics/dynamics_functions.py +++ b/bioptim/dynamics/dynamics_functions.py @@ -330,7 +330,6 @@ def stochastic_torque_driven_free_floating_base( parameters: MX.sym, algebraic_states: MX.sym, nlp, - with_contact: bool, with_friction: bool, ) -> DynamicsEvaluation: """ @@ -350,8 +349,6 @@ def stochastic_torque_driven_free_floating_base( The algebraic states of the system nlp: NonLinearProgram The definition of the system - with_contact: bool - If the dynamic with contact should be used with_friction: bool If the dynamic with friction should be used @@ -386,7 +383,7 @@ def stochastic_torque_driven_free_floating_base( tau_full = vertcat(MX.zeros(nlp.model.nb_root), tau_joints) dq = DynamicsFunctions.compute_qdot(nlp, q_full, qdot_full) - ddq = DynamicsFunctions.forward_dynamics(nlp, q_full, qdot_full, tau_full, with_contact) + ddq = DynamicsFunctions.forward_dynamics(nlp, q_full, qdot_full, tau_full, with_contact=False) dxdt = MX(nlp.states.shape, ddq.shape[1]) dxdt[:n_q, :] = horzcat(*[dq for _ in range(ddq.shape[1])]) dxdt[n_q:, :] = ddq diff --git a/bioptim/limits/penalty.py b/bioptim/limits/penalty.py index bcf36896e..c5411c83a 100644 --- a/bioptim/limits/penalty.py +++ b/bioptim/limits/penalty.py @@ -88,13 +88,6 @@ def minimize_controls(penalty: PenaltyOption, controller: PenaltyController, key """ penalty.quadratic = True if penalty.quadratic is None else penalty.quadratic - if key in controller.get_nlp.variable_mappings: - target_mapping = controller.get_nlp.variable_mappings[key] - else: - target_mapping = BiMapping( - to_first=list(range(controller.get_nlp.controls[key].cx_start.shape[0])), - to_second=list(range(controller.get_nlp.controls[key].cx_start.shape[0])), - ) # TODO: why if condition, target_mapping not used (Pariterre?) if penalty.integration_rule == QuadratureRule.RECTANGLE_LEFT: # TODO: for trapezoidal integration (This should not be done here but in _set_penalty_function) diff --git a/bioptim/limits/penalty_option.py b/bioptim/limits/penalty_option.py index 58964f558..8001bb775 100644 --- a/bioptim/limits/penalty_option.py +++ b/bioptim/limits/penalty_option.py @@ -6,6 +6,7 @@ from .penalty_controller import PenaltyController from ..misc.enums import Node, PlotType, ControlType, PenaltyType, QuadratureRule, PhaseDynamics from ..misc.options import OptionGeneric +from ..misc.mapping import BiMapping from ..models.protocols.stochastic_biomodel import StochasticBioModel from ..limits.penalty_helpers import PenaltyHelpers @@ -730,8 +731,8 @@ def vertcat_cx_end(): return u @staticmethod - def define_target_mapping(controller: PenaltyController, key: str): - target_mapping = controller.get_nlp.variable_mappings[key] + def define_target_mapping(controller: PenaltyController, key: str, rows): + target_mapping = BiMapping(range(len(controller.get_nlp.variable_mappings[key].to_first.map_idx)), list(rows)) return target_mapping def add_target_to_plot(self, controller: PenaltyController, combine_to: str): @@ -777,7 +778,7 @@ def plot_function(t0, phases_dt, node_idx, x, u, p, a, penalty=None): else: plot_type = PlotType.POINT - target_mapping = self.define_target_mapping(controller, self.params["key"]) + target_mapping = self.define_target_mapping(controller, self.params["key"], self.rows) controller.ocp.add_plot( self.target_plot_name, plot_function, @@ -785,7 +786,7 @@ def plot_function(t0, phases_dt, node_idx, x, u, p, a, penalty=None): color="tab:red", plot_type=plot_type, phase=controller.get_nlp.phase_idx, - axes_idx=target_mapping, # TODO verify if not all elements has target + axes_idx=target_mapping, node_idx=controller.t, )