Skip to content

Commit

Permalink
think I fixed it :)
Browse files Browse the repository at this point in the history
  • Loading branch information
EveCharbie committed Feb 23, 2024
1 parent ce5a784 commit 7494252
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
7 changes: 0 additions & 7 deletions bioptim/limits/penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ def minimize_controls(penalty: PenaltyOption, controller: PenaltyController, key
"""

penalty.quadratic = True if penalty.quadratic is None else penalty.quadratic
if key in controller.get_nlp.variable_mappings:
target_mapping = controller.get_nlp.variable_mappings[key]
else:
target_mapping = BiMapping(
to_first=list(range(controller.get_nlp.controls[key].cx_start.shape[0])),
to_second=list(range(controller.get_nlp.controls[key].cx_start.shape[0])),
) # TODO: why if condition, target_mapping not used (Pariterre?)

if penalty.integration_rule == QuadratureRule.RECTANGLE_LEFT:
# TODO: for trapezoidal integration (This should not be done here but in _set_penalty_function)
Expand Down
9 changes: 5 additions & 4 deletions bioptim/limits/penalty_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .penalty_controller import PenaltyController
from ..misc.enums import Node, PlotType, ControlType, PenaltyType, QuadratureRule, PhaseDynamics
from ..misc.options import OptionGeneric
from ..misc.mapping import BiMapping
from ..models.protocols.stochastic_biomodel import StochasticBioModel
from ..limits.penalty_helpers import PenaltyHelpers

Expand Down Expand Up @@ -730,8 +731,8 @@ def vertcat_cx_end():
return u

@staticmethod
def define_target_mapping(controller: PenaltyController, key: str):
target_mapping = controller.get_nlp.variable_mappings[key]
def define_target_mapping(controller: PenaltyController, key: str, rows):
target_mapping = BiMapping(range(len(controller.get_nlp.variable_mappings[key].to_first.map_idx)), list(rows))
return target_mapping

def add_target_to_plot(self, controller: PenaltyController, combine_to: str):
Expand Down Expand Up @@ -777,15 +778,15 @@ def plot_function(t0, phases_dt, node_idx, x, u, p, a, penalty=None):
else:
plot_type = PlotType.POINT

target_mapping = self.define_target_mapping(controller, self.params["key"])
target_mapping = self.define_target_mapping(controller, self.params["key"], self.rows)

Check warning on line 781 in bioptim/limits/penalty_option.py

View check run for this annotation

Codecov / codecov/patch

bioptim/limits/penalty_option.py#L781

Added line #L781 was not covered by tests
controller.ocp.add_plot(
self.target_plot_name,
plot_function,
penalty=self if plot_type == PlotType.POINT else None,
color="tab:red",
plot_type=plot_type,
phase=controller.get_nlp.phase_idx,
axes_idx=target_mapping, # TODO verify if not all elements has target
axes_idx=target_mapping,
node_idx=controller.t,
)

Expand Down

0 comments on commit 7494252

Please sign in to comment.