From 0aa37c672b47b8fe3923328839e57a377ca51b31 Mon Sep 17 00:00:00 2001 From: Kor de Jong Date: Sat, 6 Apr 2024 15:16:44 +0200 Subject: [PATCH] Add support for adding action combinations --- .../desktop/application.py | 216 ++++++++++++++---- .../desktop/model/action.py | 9 +- .../desktop/model/sequence.py | 14 +- .../desktop/ui/edit_action_dialog.ui | 24 +- .../package/adaptation_pathways/io/sqlite.py | 132 +++++------ source/test/ap_test/io/sqlite_test.py | 36 +-- source/test/ap_test/test_data.py | 12 + 7 files changed, 303 insertions(+), 140 deletions(-) diff --git a/source/package/adaptation_pathways/desktop/application.py b/source/package/adaptation_pathways/desktop/application.py index 28be71d..fce9800 100644 --- a/source/package/adaptation_pathways/desktop/application.py +++ b/source/package/adaptation_pathways/desktop/application.py @@ -10,6 +10,7 @@ from .. import alias from ..action import Action +from ..action_combination import ActionCombination from ..graph import ( PathwayGraph, PathwayMap, @@ -47,6 +48,9 @@ # pylint: disable=too-many-instance-attributes, too-many-locals +ap_repo_url = "https://github.com/Deltares-research/PathwaysGenerator" + + loader = QUiLoader() @@ -77,6 +81,25 @@ def show_error_message(parent, message): # sys.stderr.write(f"{traceback}") +def handle_exceptions(function): + def wrap(self, *args, **kwargs): + try: + function(self, *args, **kwargs) + except KeyError as exception: + # TODO Add traceback + show_error_message( + self.ui, + f"Key error: {exception}\n" + "This is likely a bug.\n" + f"Please report it at {ap_repo_url}/issues", + ) + raise + except ValueError as exception: + show_error_message(self.ui, f"{exception}") + + return wrap + + class MainUI(QObject): # Not a widget def __init__(self): # pylint: disable=too-many-statements @@ -133,14 +156,14 @@ def __init__(self): QtWidgets.QAbstractItemView.InternalMove ) - self.colour_by_action: dict[Action, Colour] = {} # type: ignore + self.colour_by_action_name: dict[str, Colour] = {} # type: ignore self.actions: list[list] = [] # type: ignore - self.action_model = ActionModel(self.actions, self.colour_by_action) + self.action_model = ActionModel(self.actions, self.colour_by_action_name) self.ui.table_actions.setModel(self.action_model) self.sequences: list[list[Action]] = [] # type: ignore - self.sequence_model = SequenceModel(self.sequences, self.colour_by_action) + self.sequence_model = SequenceModel(self.sequences, self.colour_by_action_name) self.ui.table_sequences.setModel(self.sequence_model) self.action_model.rowsAboutToBeRemoved.connect( @@ -198,11 +221,8 @@ def eventFilter(self, object_, event): return False def _plot_sequence_graph(self, sequence_graph: SequenceGraph) -> None: - colour_by_action_name = { - action.name: colour for action, colour in self.colour_by_action.items() - } plot_colours = PlotColours( - sequence_graph_node_colours(sequence_graph, colour_by_action_name), + sequence_graph_node_colours(sequence_graph, self.colour_by_action_name), default_edge_colours(sequence_graph), default_node_edge_colours(sequence_graph), default_label_colour(), @@ -215,11 +235,8 @@ def _plot_sequence_graph(self, sequence_graph: SequenceGraph) -> None: self.sequence_graph_widget.draw() def _plot_pathway_graph(self, pathway_graph: PathwayGraph) -> None: - colour_by_action_name = { - action.name: colour for action, colour in self.colour_by_action.items() - } plot_colours = PlotColours( - pathway_graph_node_colours(pathway_graph, colour_by_action_name), + pathway_graph_node_colours(pathway_graph, self.colour_by_action_name), default_edge_colours(pathway_graph), default_node_edge_colours(pathway_graph), default_label_colour(), @@ -232,12 +249,9 @@ def _plot_pathway_graph(self, pathway_graph: PathwayGraph) -> None: self.pathway_graph_widget.draw() def _plot_pathway_map(self, pathway_map: PathwayMap) -> None: - colour_by_action_name = { - action.name: colour for action, colour in self.colour_by_action.items() - } plot_colours = PlotColours( - pathway_map_node_colours(pathway_map, colour_by_action_name), - pathway_map_edge_colours(pathway_map, colour_by_action_name), + pathway_map_node_colours(pathway_map, self.colour_by_action_name), + pathway_map_edge_colours(pathway_map, self.colour_by_action_name), default_node_edge_colours(pathway_map), default_label_colour(), ) @@ -290,8 +304,12 @@ def _open_dataset(self): dbms.read_dataset(dataset_pathname) ) - self.colour_by_action.clear() - self.colour_by_action.update(dict(colour_by_action.items())) + colour_by_action_name = { + action.name: colour for action, colour in colour_by_action.items() + } + + self.colour_by_action_name.clear() + self.colour_by_action_name.update(dict(colour_by_action_name.items())) self._set_dataset_pathname(dataset_pathname) @@ -309,6 +327,7 @@ def _open_dataset(self): self._update_plots() + @handle_exceptions def _save_dataset(self, dataset_pathname: str): """ Save all data to a dataset @@ -322,7 +341,16 @@ def _save_dataset(self, dataset_pathname: str): actions = [record[0] for record in self.actions] sequences = [(sequence[0], sequence[1]) for sequence in self.sequences] tipping_point_by_action: alias.TippingPointByAction = {} # TODO - colour_by_action = dict(self.colour_by_action.items()) + colour_by_action = {} + + for action_name, colour in self.colour_by_action_name.items(): + colour_by_action[ + next( + record[0] + for record in self.actions + if record[0].name == action_name + ) + ] = colour try: dbms.write_dataset( @@ -389,7 +417,10 @@ def _on_actions_table_context_menu(self, pos): lambda: self._remove_actions(action_idx, 1) ) context.addAction(remove_action_action) - remove_action_action.setEnabled(action_idx != -1) + # The root action can only be removed if it is the only action + remove_action_action.setEnabled( + action_idx != -1 and (action_idx > 0 or len(self.actions) == 1) + ) add_action_action = QtGui.QAction("Add action...", self.ui.table_actions) add_action_action.triggered.connect(self._add_action) @@ -404,6 +435,7 @@ def _on_actions_table_context_menu(self, pos): context.exec(self.ui.table_actions.viewport().mapToGlobal(pos)) + @handle_exceptions def _add_action(self): current_nr_actions = len(self.actions) name = f"Name{current_nr_actions + 1}" @@ -411,14 +443,18 @@ def _add_action(self): colour = self._current_palette[colour_idx] action = Action(name) - self.colour_by_action[action] = colour + assert not name in self.colour_by_action_name, name + self.colour_by_action_name[action.name] = colour self.actions.append([action]) self._set_data_changed(True) self.ui.table_actions.model().layoutChanged.emit() self._edit_action(len(self.actions) - 1, default_values=True) - def _edit_action(self, idx, default_values=False): + @handle_exceptions + def _edit_action( + self, idx, default_values=False + ): # pylint: disable=too-many-statements action_record = self.actions[idx] action = action_record[0] @@ -431,7 +467,7 @@ def _edit_action(self, idx, default_values=False): palette = dialog.select_colour_button.palette() role = dialog.select_colour_button.backgroundRole() - colour = QtGui.QColor.fromRgbF(*self.colour_by_action[action]) + colour = QtGui.QColor.fromRgbF(*self.colour_by_action_name[action.name]) palette.setColor(role, colour) dialog.select_colour_button.setPalette(palette) dialog.select_colour_button.setAutoFillBackground(True) @@ -453,32 +489,86 @@ def select_colour(): dialog.select_colour_button.setPalette(palette) dialog.select_colour_button.clicked.connect(select_colour) + + actions_menu = QtWidgets.QMenu(dialog) + + combined_action_names = ( + [] + if not isinstance(action, ActionCombination) + else [combined_action.name for combined_action in action.actions] + ) + + for record in self.actions: + action_name = record[0].name + if action_name != action.name: + qaction = actions_menu.addAction(action_name) + qaction.setCheckable(True) + qaction.setChecked(action_name in combined_action_names) + + dialog.select_actions_button.setMenu(actions_menu) + dialog.select_actions_button.setPopupMode(QtWidgets.QToolButton.InstantPopup) + # Enable if there are two or more other actions that can be combined + dialog.select_actions_button.setEnabled(len(self.actions) > 2) + dialog.adjustSize() if dialog.exec(): new_name = dialog.name_edit.text() - something_changed = new_name != action.name or new_colour != colour + new_combined_action_names = [ + menu_action.text() + for menu_action in actions_menu.actions() + if menu_action.isChecked() + ] + + something_changed = ( + new_name != action.name + or new_colour != colour + or new_combined_action_names != combined_action_names + ) if something_changed: + old_action = action + + if len(new_combined_action_names) == 0: + new_action = Action(new_name) + else: + combined_actions = [] + for name in new_combined_action_names: + combined_actions.append( + next( + record[0] + for record in self.actions + if record[0].name == name + ) + ) + new_action = ActionCombination(new_name, combined_actions) + # self.actions[idx][0].name = new_name # self.actions[idx][1] = new_colour.getRgbF() self._set_data_changed(True) - self.colour_by_action[action] = new_colour.getRgbF() - self.actions[idx] = [action] + self.colour_by_action_name.pop(old_action.name) + self.colour_by_action_name[new_action.name] = new_colour.getRgbF() + self.actions[idx] = [new_action] # [action] - old_name = action.name + old_name = old_action.name for sequence in self.sequences: from_action, to_action = sequence if from_action.name == old_name: - from_action.name = new_name + # from_action.name = new_name + sequence[0] = copy.copy(new_action) if to_action.name == old_name: - to_action.name = new_name + # to_action.name = new_name + sequence[1] = copy.copy(new_action) + + # self.actions[idx][0].name = new_name - self.actions[idx][0].name = new_name + assert len(self.colour_by_action_name) == len( + self.actions + ), f"{self.colour_by_action_name} ↔ {self.actions}" # new_action_tuple = ( # Action(new_name), @@ -512,9 +602,44 @@ def _actions_about_to_be_removed( ): # pylint: disable=unused-argument # Handle the situation that actions are about to be removed from the table: # - Sequences that involve the actions must be removed as well + # - Action combinations that involve the actions must be removed as well # - Colours associated with the actions must be removed + actions = [record[0] for record in self.actions[first_idx : last_idx + 1]] action_names = [action.name for action in actions] + + action_combinations = [ + record[0] + for record in self.actions + if isinstance(record[0], ActionCombination) + and any( + combined_action_name in action_names + for combined_action_name in [ + combined_action.name for combined_action in record[0].actions + ] + ) + ] + + if all( + action_combination.name in action_names + for action_combination in action_combinations + ): + action_combinations.clear() + + if len(action_combinations) > 0: + # TODO Prevent this from happening of deal with it. Here it is too late to cancel. + action_combination_names = ", ".join( + action_combination.name for action_combination in action_combinations + ) + QtWidgets.QMessageBox.warning( + self.ui, + "Warning", + "You should not remove actions that are still being used in action combinations.\n" + f"Remove those first: {action_combination_names}.\n" + "We will continue, but the application will likely crash due to inconsistencies.", + QtWidgets.QMessageBox.Close, + ) + sequences = [ record for record in self.sequences @@ -524,8 +649,8 @@ def _actions_about_to_be_removed( for sequence in sequences: self.ui.table_sequences.model().removeRow(self.sequences.index(sequence)) - for action in actions: - del self.colour_by_action[action] + for action_name in action_names: + del self.colour_by_action_name[action_name] def _actions_removed( self, parent, first_idx, last_idx @@ -593,6 +718,7 @@ def _add_sequence(self): self._edit_sequence(len(self.sequences) - 1) self._update_plots() + @handle_exceptions def _edit_sequence(self, idx): # pylint: disable=too-many-statements sequence_record = self.sequences[idx] from_action, to_action = sequence_record @@ -622,7 +748,7 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements # To end a sequence, one of the existing actions must be selected for action in actions: image = QtGui.QPixmap(16, 16) - image.fill(QtGui.QColor.fromRgbF(*self.colour_by_action[action])) + image.fill(QtGui.QColor.fromRgbF(*self.colour_by_action_name[action.name])) dialog.from_action_start_combo_box.addItem(image, action.name) dialog.to_action_combo_box.addItem(image, action.name) @@ -632,10 +758,18 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements image = QtGui.QImage(16, 16, QtGui.QImage.Format_RGB32) painter = QtGui.QPainter(image) painter.fillRect( - 0, 0, 8, 16, QtGui.QColor.fromRgbF(*self.colour_by_action[from_action_]) + 0, + 0, + 8, + 16, + QtGui.QColor.fromRgbF(*self.colour_by_action_name[from_action_.name]), ) painter.fillRect( - 8, 0, 8, 16, QtGui.QColor.fromRgbF(*self.colour_by_action[to_action_]) + 8, + 0, + 8, + 16, + QtGui.QColor.fromRgbF(*self.colour_by_action_name[to_action_.name]), ) painter.end() dialog.from_action_continue_combo_box.addItem( @@ -683,9 +817,9 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements else: # Deep copy the existing action new_from_action = copy.deepcopy(actions[action_idx]) - self.colour_by_action[new_from_action] = self.colour_by_action[ - actions[action_idx] - ] + self.colour_by_action_name[new_from_action.name] = ( + self.colour_by_action_name[actions[action_idx].name] + ) else: sequence_idx = dialog.from_action_continue_combo_box.currentIndex() @@ -700,9 +834,9 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements else: # Deep copy the existing action new_to_action = copy.deepcopy(actions[action_idx]) - self.colour_by_action[new_to_action] = self.colour_by_action[ - actions[action_idx] - ] + self.colour_by_action_name[new_to_action.name] = ( + self.colour_by_action_name[actions[action_idx].name] + ) something_changed = ( new_from_action != from_action or new_to_action != to_action diff --git a/source/package/adaptation_pathways/desktop/model/action.py b/source/package/adaptation_pathways/desktop/model/action.py index 5a39fcf..ae5126e 100644 --- a/source/package/adaptation_pathways/desktop/model/action.py +++ b/source/package/adaptation_pathways/desktop/model/action.py @@ -1,7 +1,6 @@ from PySide6 import QtCore, QtGui from PySide6.QtCore import Qt -from ...action import Action from ...plot.colour import Colour @@ -9,14 +8,14 @@ class ActionModel(QtCore.QAbstractTableModel): _actions: list[list] _headers: tuple[str] - _colour_by_action: dict[Action, Colour] + _colour_by_action_name: dict[str, Colour] - def __init__(self, actions: list[list], colour_by_action: dict[Action, Colour]): + def __init__(self, actions: list[list], colour_by_action_name: dict[str, Colour]): super().__init__() self._actions = actions self._headers = ("Name",) - self._colour_by_action = colour_by_action + self._colour_by_action_name = colour_by_action_name # pylint: disable=inconsistent-return-statements, no-else-return def data(self, index, role): @@ -28,7 +27,7 @@ def data(self, index, role): elif role == Qt.DecorationRole: if index.column() == 0: action = self._actions[index.row()][0] - colour = self._colour_by_action[action] + colour = self._colour_by_action_name[action.name] return QtGui.QColor.fromRgbF(*colour) def rowCount(self, index): # pylint: disable=unused-argument diff --git a/source/package/adaptation_pathways/desktop/model/sequence.py b/source/package/adaptation_pathways/desktop/model/sequence.py index 3ec1c20..f6647bc 100644 --- a/source/package/adaptation_pathways/desktop/model/sequence.py +++ b/source/package/adaptation_pathways/desktop/model/sequence.py @@ -8,15 +8,15 @@ class SequenceModel(QtCore.QAbstractTableModel): _sequences: list[list[Action]] _horizonal_headers: tuple[str, str] - _colour_by_action: dict[Action, Colour] + _colour_by_action_name: dict[str, Colour] def __init__( - self, sequences: list[list[Action]], colour_by_action: dict[Action, Colour] + self, sequences: list[list[Action]], colour_by_action_name: dict[str, Colour] ): super().__init__() self._sequences = sequences self._horizonal_headers = ("From action", "To action") - self._colour_by_action = colour_by_action + self._colour_by_action_name = colour_by_action_name # pylint: disable=inconsistent-return-statements def data(self, index, role): @@ -26,11 +26,11 @@ def data(self, index, role): if role == Qt.DecorationRole: action = self._sequences[index.row()][index.column()] - colour = self._colour_by_action[ + colour = self._colour_by_action_name[ next( - action_ - for action_ in self._colour_by_action - if action_.name == action.name + action_name + for action_name in self._colour_by_action_name + if action_name == action.name ) ] return QtGui.QColor.fromRgbF(*colour) diff --git a/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui b/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui index 40b84e0..9acc721 100644 --- a/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui +++ b/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui @@ -6,8 +6,8 @@ 0 0 - 1048 - 605 + 768 + 367 @@ -46,6 +46,26 @@ + + + + Combination + + + + + + + + 0 + 0 + + + + Actions + + + diff --git a/source/package/adaptation_pathways/io/sqlite.py b/source/package/adaptation_pathways/io/sqlite.py index fa62f91..07d5876 100644 --- a/source/package/adaptation_pathways/io/sqlite.py +++ b/source/package/adaptation_pathways/io/sqlite.py @@ -108,14 +108,14 @@ def write_dataset( # pylint: disable=too-many-locals, too-many-arguments f""" CREATE TABLE {_action_combination_table_name} ( - edition_id INTEGER NOT NULL, - combined_edition_id INTEGER NOT NULL, + action_id INTEGER NOT NULL, + combined_action_id INTEGER NOT NULL, - UNIQUE (edition_id, combined_edition_id), - FOREIGN KEY (edition_id) - REFERENCES {_edition_table_name} (edition_id), - FOREIGN KEY (combined_edition_id) - REFERENCES {_edition_table_name} (edition_id) + UNIQUE (action_id, combined_action_id), + FOREIGN KEY (action_id) + REFERENCES {_action_table_name} (action_id), + FOREIGN KEY (combined_action_id) + REFERENCES {_action_table_name} (action_id) ) """ ) @@ -270,13 +270,13 @@ def add_action_instance(action): # any number (≥ 2) of actions can be combined into a single action combination. action_combination_records = [] - for action, edition_id in edition_id_by_instance.items(): + for action in actions: if isinstance(action, ActionCombination): for combined_action in action.actions: action_combination_records.append( { - "edition_id": edition_id, - "combined_edition_id": edition_id_by_instance[combined_action], + "action_id": action_id_by_name[action.name], + "combined_action_id": action_id_by_name[combined_action.name], } ) @@ -285,13 +285,13 @@ def add_action_instance(action): f""" INSERT INTO {_action_combination_table_name} ( - edition_id, - combined_edition_id + action_id, + combined_action_id ) VALUES ( - :edition_id, - :combined_edition_id + :action_id, + :combined_action_id ) """, action_combination_records, @@ -342,77 +342,66 @@ def read_dataset( # pylint: disable=too-many-locals # TODO Use SQL for this?! We've got all relations set up. Use'm! - action_combination_data = list( + action_data = list( connection.execute( f""" - SELECT edition_id, combined_edition_id - FROM {_action_combination_table_name} + SELECT action_id, name + FROM {_action_table_name} """ ) ) - combined_edition_ids_by_edition_id: dict[int, list[int]] = {} - for edition_id, combined_edition_id in action_combination_data: - combined_edition_ids_by_edition_id.setdefault(edition_id, []).append( - combined_edition_id - ) + action_name_by_id = {} - edition_data = list( + for action_id, name in action_data: + action_name_by_id[action_id] = name + + action_combination_data = list( connection.execute( f""" - SELECT action_id, edition_id - FROM {_edition_table_name} + SELECT action_id, combined_action_id + FROM {_action_combination_table_name} """ ) ) + combined_action_ids_by_action_id: dict[int, list[int]] = {} - action_id_by_edition_id: dict[int, int] = { - edition_id: action_id for action_id, edition_id in edition_data - } - - combined_action_ids_by_action_id: dict[int, list[int]] = { - action_id_by_edition_id[edition_id]: [ - action_id_by_edition_id[edition_id] for edition_id in combined_edition_ids - ] - for edition_id, combined_edition_ids in combined_edition_ids_by_edition_id.items() - } - - action_data = list( - connection.execute( - f""" - SELECT action_id, name - FROM {_action_table_name} - """ + for action_id, combined_action_id in action_combination_data: + combined_action_ids_by_action_id.setdefault(action_id, []).append( + combined_action_id ) - ) - action_instance_by_id: dict[int, Action | ActionCombination] = {} - for action_id, name in action_data: - if action_id not in combined_action_ids_by_action_id: - action_instance_by_id[action_id] = Action(name) - else: - # Placeholder. First add the regular actions. The ones to combine may come after the - # current combination. - action_instance_by_id[action_id] = Action(name) + action_by_id: dict[int, alias.Action] = {} - for action_id, name in action_data: + # First add a regular action instance for all actions. This will keep the order as is. + for action_id, action_name in action_name_by_id.items(): + action_by_id[action_id] = Action(action_name) + + # Now replace some of the actions by action combinations that combine regular actions + for action_id, action_name in action_name_by_id.items(): if action_id in combined_action_ids_by_action_id: - combined_action_ids = combined_action_ids_by_action_id[action_id] combined_actions = [ - action_instance_by_id[combined_action_id] - for combined_action_id in combined_action_ids + action_by_id[combined_action_id] + for combined_action_id in combined_action_ids_by_action_id[action_id] ] - action_instance_by_id[action_id] = ActionCombination(name, combined_actions) + action_by_id[action_id] = ActionCombination(action_name, combined_actions) - actions: list[Action | ActionCombination] = [ - action_instance_by_id[record[0]] for record in action_data - ] + edition_data = list( + connection.execute( + f""" + SELECT action_id, edition_id + FROM {_edition_table_name} + """ + ) + ) action_instance_by_edition: dict[tuple[str, int], Action] = { - edition_id: copy.copy(action_instance_by_id[action_id]) + edition_id: copy.copy(action_by_id[action_id]) for action_id, edition_id in edition_data } + actions: alias.Actions = list(action_by_id.values()) + sequence_data = list( connection.execute( f""" @@ -430,15 +419,16 @@ def read_dataset( # pylint: disable=too-many-locals for sequence_record in sequence_data ] - # One of the sequences relates the root action with itself. This is the one sequence which - # we must remove from the collection. - root_sequences = [ - (sequence[0], sequence[1]) - for sequence in sequences - if sequence[0] == sequence[1] - ] - assert len(root_sequences) == 1, f"{root_sequences}" - sequences.remove(root_sequences[0]) + if len(sequences) > 0: + # One of the sequences relates the root action with itself. This is the one sequence which + # we must remove from the collection. + root_sequences = [ + (sequence[0], sequence[1]) + for sequence in sequences + if sequence[0] == sequence[1] + ] + assert len(root_sequences) == 1, f"{root_sequences}" + sequences.remove(root_sequences[0]) tipping_point_by_action = { action_instance_by_edition[sequence_record[2]]: sequence_record[3] @@ -453,11 +443,11 @@ def read_dataset( # pylint: disable=too-many-locals ) colour_by_action = { - action_instance_by_id[plot_record[0]]: hex_to_rgba(plot_record[1]) + action_by_id[plot_record[0]]: hex_to_rgba(plot_record[1]) for plot_record in plot_data } - for action in actions: + for action in action_by_id.values(): if not action in colour_by_action: colour_by_action[action] = default_node_colour() diff --git a/source/test/ap_test/io/sqlite_test.py b/source/test/ap_test/io/sqlite_test.py index 404df35..e37e775 100644 --- a/source/test/ap_test/io/sqlite_test.py +++ b/source/test/ap_test/io/sqlite_test.py @@ -76,20 +76,23 @@ def _test_round_trip( self, database_path: str, actions: alias.Actions, sequences: alias.Sequences ): - # Add tipping point for the root action, which is not part of the sequences collection - to_action_names = [sequence[1].name for sequence in sequences] - root_actions = { - sequence[0] - for sequence in sequences - if sequence[0].name not in to_action_names - } - assert len(root_actions) == 1, f"{root_actions}" - root_action = root_actions.pop() - tipping_point_by_action = {root_action: random.randint(2020, 2100)} - - tipping_point_by_action |= { - sequence[1]: random.randint(2020, 2100) for sequence in sequences - } + tipping_point_by_action = {} + + if len(sequences) > 0: + # Add tipping point for the root action, which is not part of the sequences collection + to_action_names = [sequence[1].name for sequence in sequences] + root_actions = { + sequence[0] + for sequence in sequences + if sequence[0].name not in to_action_names + } + assert len(root_actions) == 1, f"{root_actions}" + root_action = root_actions.pop() + tipping_point_by_action = {root_action: random.randint(2020, 2100)} + + tipping_point_by_action |= { + sequence[1]: random.randint(2020, 2100) for sequence in sequences + } colours = list(default_action_colours(len(actions))) colour_by_action = {action: colours[idx] for idx, action in enumerate(actions)} @@ -161,6 +164,11 @@ def test_converging_pathway(self): actions, sequences = test_data.converging_pathway() self._test_round_trip(database_path, actions, sequences) + def test_action_combination_01_actions(self): + database_path = "test_action_combination_01_actions.db" + actions, sequences = test_data.action_combination_01_actions() + self._test_round_trip(database_path, actions, sequences) + def test_action_combination_01_pathway(self): database_path = "test_action_combination_01_pathway.db" actions, sequences = test_data.action_combination_01_pathway() diff --git a/source/test/ap_test/test_data.py b/source/test/ap_test/test_data.py index e2fd275..7394536 100644 --- a/source/test/ap_test/test_data.py +++ b/source/test/ap_test/test_data.py @@ -71,6 +71,18 @@ def converging_pathway(): return actions, sequences +def action_combination_01_actions(): + current = Action("current") + a = Action("a") + b = Action("b") + c = ActionCombination("c", [a, b]) + + actions = [current, a, b, c] + sequences = [] + + return actions, sequences + + def action_combination_01_pathway(): current = Action("current") a = Action("a")