diff --git a/source/package/adaptation_pathways/desktop/application.py b/source/package/adaptation_pathways/desktop/application.py
index 28be71d..fce9800 100644
--- a/source/package/adaptation_pathways/desktop/application.py
+++ b/source/package/adaptation_pathways/desktop/application.py
@@ -10,6 +10,7 @@
from .. import alias
from ..action import Action
+from ..action_combination import ActionCombination
from ..graph import (
PathwayGraph,
PathwayMap,
@@ -47,6 +48,9 @@
# pylint: disable=too-many-instance-attributes, too-many-locals
+ap_repo_url = "https://github.com/Deltares-research/PathwaysGenerator"
+
+
loader = QUiLoader()
@@ -77,6 +81,25 @@ def show_error_message(parent, message):
# sys.stderr.write(f"{traceback}")
+def handle_exceptions(function):
+ def wrap(self, *args, **kwargs):
+ try:
+ function(self, *args, **kwargs)
+ except KeyError as exception:
+ # TODO Add traceback
+ show_error_message(
+ self.ui,
+ f"Key error: {exception}\n"
+ "This is likely a bug.\n"
+ f"Please report it at {ap_repo_url}/issues",
+ )
+ raise
+ except ValueError as exception:
+ show_error_message(self.ui, f"{exception}")
+
+ return wrap
+
+
class MainUI(QObject): # Not a widget
def __init__(self):
# pylint: disable=too-many-statements
@@ -133,14 +156,14 @@ def __init__(self):
QtWidgets.QAbstractItemView.InternalMove
)
- self.colour_by_action: dict[Action, Colour] = {} # type: ignore
+ self.colour_by_action_name: dict[str, Colour] = {} # type: ignore
self.actions: list[list] = [] # type: ignore
- self.action_model = ActionModel(self.actions, self.colour_by_action)
+ self.action_model = ActionModel(self.actions, self.colour_by_action_name)
self.ui.table_actions.setModel(self.action_model)
self.sequences: list[list[Action]] = [] # type: ignore
- self.sequence_model = SequenceModel(self.sequences, self.colour_by_action)
+ self.sequence_model = SequenceModel(self.sequences, self.colour_by_action_name)
self.ui.table_sequences.setModel(self.sequence_model)
self.action_model.rowsAboutToBeRemoved.connect(
@@ -198,11 +221,8 @@ def eventFilter(self, object_, event):
return False
def _plot_sequence_graph(self, sequence_graph: SequenceGraph) -> None:
- colour_by_action_name = {
- action.name: colour for action, colour in self.colour_by_action.items()
- }
plot_colours = PlotColours(
- sequence_graph_node_colours(sequence_graph, colour_by_action_name),
+ sequence_graph_node_colours(sequence_graph, self.colour_by_action_name),
default_edge_colours(sequence_graph),
default_node_edge_colours(sequence_graph),
default_label_colour(),
@@ -215,11 +235,8 @@ def _plot_sequence_graph(self, sequence_graph: SequenceGraph) -> None:
self.sequence_graph_widget.draw()
def _plot_pathway_graph(self, pathway_graph: PathwayGraph) -> None:
- colour_by_action_name = {
- action.name: colour for action, colour in self.colour_by_action.items()
- }
plot_colours = PlotColours(
- pathway_graph_node_colours(pathway_graph, colour_by_action_name),
+ pathway_graph_node_colours(pathway_graph, self.colour_by_action_name),
default_edge_colours(pathway_graph),
default_node_edge_colours(pathway_graph),
default_label_colour(),
@@ -232,12 +249,9 @@ def _plot_pathway_graph(self, pathway_graph: PathwayGraph) -> None:
self.pathway_graph_widget.draw()
def _plot_pathway_map(self, pathway_map: PathwayMap) -> None:
- colour_by_action_name = {
- action.name: colour for action, colour in self.colour_by_action.items()
- }
plot_colours = PlotColours(
- pathway_map_node_colours(pathway_map, colour_by_action_name),
- pathway_map_edge_colours(pathway_map, colour_by_action_name),
+ pathway_map_node_colours(pathway_map, self.colour_by_action_name),
+ pathway_map_edge_colours(pathway_map, self.colour_by_action_name),
default_node_edge_colours(pathway_map),
default_label_colour(),
)
@@ -290,8 +304,12 @@ def _open_dataset(self):
dbms.read_dataset(dataset_pathname)
)
- self.colour_by_action.clear()
- self.colour_by_action.update(dict(colour_by_action.items()))
+ colour_by_action_name = {
+ action.name: colour for action, colour in colour_by_action.items()
+ }
+
+ self.colour_by_action_name.clear()
+ self.colour_by_action_name.update(dict(colour_by_action_name.items()))
self._set_dataset_pathname(dataset_pathname)
@@ -309,6 +327,7 @@ def _open_dataset(self):
self._update_plots()
+ @handle_exceptions
def _save_dataset(self, dataset_pathname: str):
"""
Save all data to a dataset
@@ -322,7 +341,16 @@ def _save_dataset(self, dataset_pathname: str):
actions = [record[0] for record in self.actions]
sequences = [(sequence[0], sequence[1]) for sequence in self.sequences]
tipping_point_by_action: alias.TippingPointByAction = {} # TODO
- colour_by_action = dict(self.colour_by_action.items())
+ colour_by_action = {}
+
+ for action_name, colour in self.colour_by_action_name.items():
+ colour_by_action[
+ next(
+ record[0]
+ for record in self.actions
+ if record[0].name == action_name
+ )
+ ] = colour
try:
dbms.write_dataset(
@@ -389,7 +417,10 @@ def _on_actions_table_context_menu(self, pos):
lambda: self._remove_actions(action_idx, 1)
)
context.addAction(remove_action_action)
- remove_action_action.setEnabled(action_idx != -1)
+ # The root action can only be removed if it is the only action
+ remove_action_action.setEnabled(
+ action_idx != -1 and (action_idx > 0 or len(self.actions) == 1)
+ )
add_action_action = QtGui.QAction("Add action...", self.ui.table_actions)
add_action_action.triggered.connect(self._add_action)
@@ -404,6 +435,7 @@ def _on_actions_table_context_menu(self, pos):
context.exec(self.ui.table_actions.viewport().mapToGlobal(pos))
+ @handle_exceptions
def _add_action(self):
current_nr_actions = len(self.actions)
name = f"Name{current_nr_actions + 1}"
@@ -411,14 +443,18 @@ def _add_action(self):
colour = self._current_palette[colour_idx]
action = Action(name)
- self.colour_by_action[action] = colour
+ assert not name in self.colour_by_action_name, name
+ self.colour_by_action_name[action.name] = colour
self.actions.append([action])
self._set_data_changed(True)
self.ui.table_actions.model().layoutChanged.emit()
self._edit_action(len(self.actions) - 1, default_values=True)
- def _edit_action(self, idx, default_values=False):
+ @handle_exceptions
+ def _edit_action(
+ self, idx, default_values=False
+ ): # pylint: disable=too-many-statements
action_record = self.actions[idx]
action = action_record[0]
@@ -431,7 +467,7 @@ def _edit_action(self, idx, default_values=False):
palette = dialog.select_colour_button.palette()
role = dialog.select_colour_button.backgroundRole()
- colour = QtGui.QColor.fromRgbF(*self.colour_by_action[action])
+ colour = QtGui.QColor.fromRgbF(*self.colour_by_action_name[action.name])
palette.setColor(role, colour)
dialog.select_colour_button.setPalette(palette)
dialog.select_colour_button.setAutoFillBackground(True)
@@ -453,32 +489,86 @@ def select_colour():
dialog.select_colour_button.setPalette(palette)
dialog.select_colour_button.clicked.connect(select_colour)
+
+ actions_menu = QtWidgets.QMenu(dialog)
+
+ combined_action_names = (
+ []
+ if not isinstance(action, ActionCombination)
+ else [combined_action.name for combined_action in action.actions]
+ )
+
+ for record in self.actions:
+ action_name = record[0].name
+ if action_name != action.name:
+ qaction = actions_menu.addAction(action_name)
+ qaction.setCheckable(True)
+ qaction.setChecked(action_name in combined_action_names)
+
+ dialog.select_actions_button.setMenu(actions_menu)
+ dialog.select_actions_button.setPopupMode(QtWidgets.QToolButton.InstantPopup)
+ # Enable if there are two or more other actions that can be combined
+ dialog.select_actions_button.setEnabled(len(self.actions) > 2)
+
dialog.adjustSize()
if dialog.exec():
new_name = dialog.name_edit.text()
- something_changed = new_name != action.name or new_colour != colour
+ new_combined_action_names = [
+ menu_action.text()
+ for menu_action in actions_menu.actions()
+ if menu_action.isChecked()
+ ]
+
+ something_changed = (
+ new_name != action.name
+ or new_colour != colour
+ or new_combined_action_names != combined_action_names
+ )
if something_changed:
+ old_action = action
+
+ if len(new_combined_action_names) == 0:
+ new_action = Action(new_name)
+ else:
+ combined_actions = []
+ for name in new_combined_action_names:
+ combined_actions.append(
+ next(
+ record[0]
+ for record in self.actions
+ if record[0].name == name
+ )
+ )
+ new_action = ActionCombination(new_name, combined_actions)
+
# self.actions[idx][0].name = new_name
# self.actions[idx][1] = new_colour.getRgbF()
self._set_data_changed(True)
- self.colour_by_action[action] = new_colour.getRgbF()
- self.actions[idx] = [action]
+ self.colour_by_action_name.pop(old_action.name)
+ self.colour_by_action_name[new_action.name] = new_colour.getRgbF()
+ self.actions[idx] = [new_action] # [action]
- old_name = action.name
+ old_name = old_action.name
for sequence in self.sequences:
from_action, to_action = sequence
if from_action.name == old_name:
- from_action.name = new_name
+ # from_action.name = new_name
+ sequence[0] = copy.copy(new_action)
if to_action.name == old_name:
- to_action.name = new_name
+ # to_action.name = new_name
+ sequence[1] = copy.copy(new_action)
+
+ # self.actions[idx][0].name = new_name
- self.actions[idx][0].name = new_name
+ assert len(self.colour_by_action_name) == len(
+ self.actions
+ ), f"{self.colour_by_action_name} ↔ {self.actions}"
# new_action_tuple = (
# Action(new_name),
@@ -512,9 +602,44 @@ def _actions_about_to_be_removed(
): # pylint: disable=unused-argument
# Handle the situation that actions are about to be removed from the table:
# - Sequences that involve the actions must be removed as well
+ # - Action combinations that involve the actions must be removed as well
# - Colours associated with the actions must be removed
+
actions = [record[0] for record in self.actions[first_idx : last_idx + 1]]
action_names = [action.name for action in actions]
+
+ action_combinations = [
+ record[0]
+ for record in self.actions
+ if isinstance(record[0], ActionCombination)
+ and any(
+ combined_action_name in action_names
+ for combined_action_name in [
+ combined_action.name for combined_action in record[0].actions
+ ]
+ )
+ ]
+
+ if all(
+ action_combination.name in action_names
+ for action_combination in action_combinations
+ ):
+ action_combinations.clear()
+
+ if len(action_combinations) > 0:
+ # TODO Prevent this from happening of deal with it. Here it is too late to cancel.
+ action_combination_names = ", ".join(
+ action_combination.name for action_combination in action_combinations
+ )
+ QtWidgets.QMessageBox.warning(
+ self.ui,
+ "Warning",
+ "You should not remove actions that are still being used in action combinations.\n"
+ f"Remove those first: {action_combination_names}.\n"
+ "We will continue, but the application will likely crash due to inconsistencies.",
+ QtWidgets.QMessageBox.Close,
+ )
+
sequences = [
record
for record in self.sequences
@@ -524,8 +649,8 @@ def _actions_about_to_be_removed(
for sequence in sequences:
self.ui.table_sequences.model().removeRow(self.sequences.index(sequence))
- for action in actions:
- del self.colour_by_action[action]
+ for action_name in action_names:
+ del self.colour_by_action_name[action_name]
def _actions_removed(
self, parent, first_idx, last_idx
@@ -593,6 +718,7 @@ def _add_sequence(self):
self._edit_sequence(len(self.sequences) - 1)
self._update_plots()
+ @handle_exceptions
def _edit_sequence(self, idx): # pylint: disable=too-many-statements
sequence_record = self.sequences[idx]
from_action, to_action = sequence_record
@@ -622,7 +748,7 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements
# To end a sequence, one of the existing actions must be selected
for action in actions:
image = QtGui.QPixmap(16, 16)
- image.fill(QtGui.QColor.fromRgbF(*self.colour_by_action[action]))
+ image.fill(QtGui.QColor.fromRgbF(*self.colour_by_action_name[action.name]))
dialog.from_action_start_combo_box.addItem(image, action.name)
dialog.to_action_combo_box.addItem(image, action.name)
@@ -632,10 +758,18 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements
image = QtGui.QImage(16, 16, QtGui.QImage.Format_RGB32)
painter = QtGui.QPainter(image)
painter.fillRect(
- 0, 0, 8, 16, QtGui.QColor.fromRgbF(*self.colour_by_action[from_action_])
+ 0,
+ 0,
+ 8,
+ 16,
+ QtGui.QColor.fromRgbF(*self.colour_by_action_name[from_action_.name]),
)
painter.fillRect(
- 8, 0, 8, 16, QtGui.QColor.fromRgbF(*self.colour_by_action[to_action_])
+ 8,
+ 0,
+ 8,
+ 16,
+ QtGui.QColor.fromRgbF(*self.colour_by_action_name[to_action_.name]),
)
painter.end()
dialog.from_action_continue_combo_box.addItem(
@@ -683,9 +817,9 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements
else:
# Deep copy the existing action
new_from_action = copy.deepcopy(actions[action_idx])
- self.colour_by_action[new_from_action] = self.colour_by_action[
- actions[action_idx]
- ]
+ self.colour_by_action_name[new_from_action.name] = (
+ self.colour_by_action_name[actions[action_idx].name]
+ )
else:
sequence_idx = dialog.from_action_continue_combo_box.currentIndex()
@@ -700,9 +834,9 @@ def _edit_sequence(self, idx): # pylint: disable=too-many-statements
else:
# Deep copy the existing action
new_to_action = copy.deepcopy(actions[action_idx])
- self.colour_by_action[new_to_action] = self.colour_by_action[
- actions[action_idx]
- ]
+ self.colour_by_action_name[new_to_action.name] = (
+ self.colour_by_action_name[actions[action_idx].name]
+ )
something_changed = (
new_from_action != from_action or new_to_action != to_action
diff --git a/source/package/adaptation_pathways/desktop/model/action.py b/source/package/adaptation_pathways/desktop/model/action.py
index 5a39fcf..ae5126e 100644
--- a/source/package/adaptation_pathways/desktop/model/action.py
+++ b/source/package/adaptation_pathways/desktop/model/action.py
@@ -1,7 +1,6 @@
from PySide6 import QtCore, QtGui
from PySide6.QtCore import Qt
-from ...action import Action
from ...plot.colour import Colour
@@ -9,14 +8,14 @@ class ActionModel(QtCore.QAbstractTableModel):
_actions: list[list]
_headers: tuple[str]
- _colour_by_action: dict[Action, Colour]
+ _colour_by_action_name: dict[str, Colour]
- def __init__(self, actions: list[list], colour_by_action: dict[Action, Colour]):
+ def __init__(self, actions: list[list], colour_by_action_name: dict[str, Colour]):
super().__init__()
self._actions = actions
self._headers = ("Name",)
- self._colour_by_action = colour_by_action
+ self._colour_by_action_name = colour_by_action_name
# pylint: disable=inconsistent-return-statements, no-else-return
def data(self, index, role):
@@ -28,7 +27,7 @@ def data(self, index, role):
elif role == Qt.DecorationRole:
if index.column() == 0:
action = self._actions[index.row()][0]
- colour = self._colour_by_action[action]
+ colour = self._colour_by_action_name[action.name]
return QtGui.QColor.fromRgbF(*colour)
def rowCount(self, index): # pylint: disable=unused-argument
diff --git a/source/package/adaptation_pathways/desktop/model/sequence.py b/source/package/adaptation_pathways/desktop/model/sequence.py
index 3ec1c20..f6647bc 100644
--- a/source/package/adaptation_pathways/desktop/model/sequence.py
+++ b/source/package/adaptation_pathways/desktop/model/sequence.py
@@ -8,15 +8,15 @@
class SequenceModel(QtCore.QAbstractTableModel):
_sequences: list[list[Action]]
_horizonal_headers: tuple[str, str]
- _colour_by_action: dict[Action, Colour]
+ _colour_by_action_name: dict[str, Colour]
def __init__(
- self, sequences: list[list[Action]], colour_by_action: dict[Action, Colour]
+ self, sequences: list[list[Action]], colour_by_action_name: dict[str, Colour]
):
super().__init__()
self._sequences = sequences
self._horizonal_headers = ("From action", "To action")
- self._colour_by_action = colour_by_action
+ self._colour_by_action_name = colour_by_action_name
# pylint: disable=inconsistent-return-statements
def data(self, index, role):
@@ -26,11 +26,11 @@ def data(self, index, role):
if role == Qt.DecorationRole:
action = self._sequences[index.row()][index.column()]
- colour = self._colour_by_action[
+ colour = self._colour_by_action_name[
next(
- action_
- for action_ in self._colour_by_action
- if action_.name == action.name
+ action_name
+ for action_name in self._colour_by_action_name
+ if action_name == action.name
)
]
return QtGui.QColor.fromRgbF(*colour)
diff --git a/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui b/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui
index 40b84e0..9acc721 100644
--- a/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui
+++ b/source/package/adaptation_pathways/desktop/ui/edit_action_dialog.ui
@@ -6,8 +6,8 @@
0
0
- 1048
- 605
+ 768
+ 367
@@ -46,6 +46,26 @@
+ -
+
+
+ Combination
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Actions
+
+
+
-
diff --git a/source/package/adaptation_pathways/io/sqlite.py b/source/package/adaptation_pathways/io/sqlite.py
index fa62f91..07d5876 100644
--- a/source/package/adaptation_pathways/io/sqlite.py
+++ b/source/package/adaptation_pathways/io/sqlite.py
@@ -108,14 +108,14 @@ def write_dataset( # pylint: disable=too-many-locals, too-many-arguments
f"""
CREATE TABLE {_action_combination_table_name}
(
- edition_id INTEGER NOT NULL,
- combined_edition_id INTEGER NOT NULL,
+ action_id INTEGER NOT NULL,
+ combined_action_id INTEGER NOT NULL,
- UNIQUE (edition_id, combined_edition_id),
- FOREIGN KEY (edition_id)
- REFERENCES {_edition_table_name} (edition_id),
- FOREIGN KEY (combined_edition_id)
- REFERENCES {_edition_table_name} (edition_id)
+ UNIQUE (action_id, combined_action_id),
+ FOREIGN KEY (action_id)
+ REFERENCES {_action_table_name} (action_id),
+ FOREIGN KEY (combined_action_id)
+ REFERENCES {_action_table_name} (action_id)
)
"""
)
@@ -270,13 +270,13 @@ def add_action_instance(action):
# any number (≥ 2) of actions can be combined into a single action combination.
action_combination_records = []
- for action, edition_id in edition_id_by_instance.items():
+ for action in actions:
if isinstance(action, ActionCombination):
for combined_action in action.actions:
action_combination_records.append(
{
- "edition_id": edition_id,
- "combined_edition_id": edition_id_by_instance[combined_action],
+ "action_id": action_id_by_name[action.name],
+ "combined_action_id": action_id_by_name[combined_action.name],
}
)
@@ -285,13 +285,13 @@ def add_action_instance(action):
f"""
INSERT INTO {_action_combination_table_name}
(
- edition_id,
- combined_edition_id
+ action_id,
+ combined_action_id
)
VALUES
(
- :edition_id,
- :combined_edition_id
+ :action_id,
+ :combined_action_id
)
""",
action_combination_records,
@@ -342,77 +342,66 @@ def read_dataset( # pylint: disable=too-many-locals
# TODO Use SQL for this?! We've got all relations set up. Use'm!
- action_combination_data = list(
+ action_data = list(
connection.execute(
f"""
- SELECT edition_id, combined_edition_id
- FROM {_action_combination_table_name}
+ SELECT action_id, name
+ FROM {_action_table_name}
"""
)
)
- combined_edition_ids_by_edition_id: dict[int, list[int]] = {}
- for edition_id, combined_edition_id in action_combination_data:
- combined_edition_ids_by_edition_id.setdefault(edition_id, []).append(
- combined_edition_id
- )
+ action_name_by_id = {}
- edition_data = list(
+ for action_id, name in action_data:
+ action_name_by_id[action_id] = name
+
+ action_combination_data = list(
connection.execute(
f"""
- SELECT action_id, edition_id
- FROM {_edition_table_name}
+ SELECT action_id, combined_action_id
+ FROM {_action_combination_table_name}
"""
)
)
+ combined_action_ids_by_action_id: dict[int, list[int]] = {}
- action_id_by_edition_id: dict[int, int] = {
- edition_id: action_id for action_id, edition_id in edition_data
- }
-
- combined_action_ids_by_action_id: dict[int, list[int]] = {
- action_id_by_edition_id[edition_id]: [
- action_id_by_edition_id[edition_id] for edition_id in combined_edition_ids
- ]
- for edition_id, combined_edition_ids in combined_edition_ids_by_edition_id.items()
- }
-
- action_data = list(
- connection.execute(
- f"""
- SELECT action_id, name
- FROM {_action_table_name}
- """
+ for action_id, combined_action_id in action_combination_data:
+ combined_action_ids_by_action_id.setdefault(action_id, []).append(
+ combined_action_id
)
- )
- action_instance_by_id: dict[int, Action | ActionCombination] = {}
- for action_id, name in action_data:
- if action_id not in combined_action_ids_by_action_id:
- action_instance_by_id[action_id] = Action(name)
- else:
- # Placeholder. First add the regular actions. The ones to combine may come after the
- # current combination.
- action_instance_by_id[action_id] = Action(name)
+ action_by_id: dict[int, alias.Action] = {}
- for action_id, name in action_data:
+ # First add a regular action instance for all actions. This will keep the order as is.
+ for action_id, action_name in action_name_by_id.items():
+ action_by_id[action_id] = Action(action_name)
+
+ # Now replace some of the actions by action combinations that combine regular actions
+ for action_id, action_name in action_name_by_id.items():
if action_id in combined_action_ids_by_action_id:
- combined_action_ids = combined_action_ids_by_action_id[action_id]
combined_actions = [
- action_instance_by_id[combined_action_id]
- for combined_action_id in combined_action_ids
+ action_by_id[combined_action_id]
+ for combined_action_id in combined_action_ids_by_action_id[action_id]
]
- action_instance_by_id[action_id] = ActionCombination(name, combined_actions)
+ action_by_id[action_id] = ActionCombination(action_name, combined_actions)
- actions: list[Action | ActionCombination] = [
- action_instance_by_id[record[0]] for record in action_data
- ]
+ edition_data = list(
+ connection.execute(
+ f"""
+ SELECT action_id, edition_id
+ FROM {_edition_table_name}
+ """
+ )
+ )
action_instance_by_edition: dict[tuple[str, int], Action] = {
- edition_id: copy.copy(action_instance_by_id[action_id])
+ edition_id: copy.copy(action_by_id[action_id])
for action_id, edition_id in edition_data
}
+ actions: alias.Actions = list(action_by_id.values())
+
sequence_data = list(
connection.execute(
f"""
@@ -430,15 +419,16 @@ def read_dataset( # pylint: disable=too-many-locals
for sequence_record in sequence_data
]
- # One of the sequences relates the root action with itself. This is the one sequence which
- # we must remove from the collection.
- root_sequences = [
- (sequence[0], sequence[1])
- for sequence in sequences
- if sequence[0] == sequence[1]
- ]
- assert len(root_sequences) == 1, f"{root_sequences}"
- sequences.remove(root_sequences[0])
+ if len(sequences) > 0:
+ # One of the sequences relates the root action with itself. This is the one sequence which
+ # we must remove from the collection.
+ root_sequences = [
+ (sequence[0], sequence[1])
+ for sequence in sequences
+ if sequence[0] == sequence[1]
+ ]
+ assert len(root_sequences) == 1, f"{root_sequences}"
+ sequences.remove(root_sequences[0])
tipping_point_by_action = {
action_instance_by_edition[sequence_record[2]]: sequence_record[3]
@@ -453,11 +443,11 @@ def read_dataset( # pylint: disable=too-many-locals
)
colour_by_action = {
- action_instance_by_id[plot_record[0]]: hex_to_rgba(plot_record[1])
+ action_by_id[plot_record[0]]: hex_to_rgba(plot_record[1])
for plot_record in plot_data
}
- for action in actions:
+ for action in action_by_id.values():
if not action in colour_by_action:
colour_by_action[action] = default_node_colour()
diff --git a/source/test/ap_test/io/sqlite_test.py b/source/test/ap_test/io/sqlite_test.py
index 404df35..e37e775 100644
--- a/source/test/ap_test/io/sqlite_test.py
+++ b/source/test/ap_test/io/sqlite_test.py
@@ -76,20 +76,23 @@ def _test_round_trip(
self, database_path: str, actions: alias.Actions, sequences: alias.Sequences
):
- # Add tipping point for the root action, which is not part of the sequences collection
- to_action_names = [sequence[1].name for sequence in sequences]
- root_actions = {
- sequence[0]
- for sequence in sequences
- if sequence[0].name not in to_action_names
- }
- assert len(root_actions) == 1, f"{root_actions}"
- root_action = root_actions.pop()
- tipping_point_by_action = {root_action: random.randint(2020, 2100)}
-
- tipping_point_by_action |= {
- sequence[1]: random.randint(2020, 2100) for sequence in sequences
- }
+ tipping_point_by_action = {}
+
+ if len(sequences) > 0:
+ # Add tipping point for the root action, which is not part of the sequences collection
+ to_action_names = [sequence[1].name for sequence in sequences]
+ root_actions = {
+ sequence[0]
+ for sequence in sequences
+ if sequence[0].name not in to_action_names
+ }
+ assert len(root_actions) == 1, f"{root_actions}"
+ root_action = root_actions.pop()
+ tipping_point_by_action = {root_action: random.randint(2020, 2100)}
+
+ tipping_point_by_action |= {
+ sequence[1]: random.randint(2020, 2100) for sequence in sequences
+ }
colours = list(default_action_colours(len(actions)))
colour_by_action = {action: colours[idx] for idx, action in enumerate(actions)}
@@ -161,6 +164,11 @@ def test_converging_pathway(self):
actions, sequences = test_data.converging_pathway()
self._test_round_trip(database_path, actions, sequences)
+ def test_action_combination_01_actions(self):
+ database_path = "test_action_combination_01_actions.db"
+ actions, sequences = test_data.action_combination_01_actions()
+ self._test_round_trip(database_path, actions, sequences)
+
def test_action_combination_01_pathway(self):
database_path = "test_action_combination_01_pathway.db"
actions, sequences = test_data.action_combination_01_pathway()
diff --git a/source/test/ap_test/test_data.py b/source/test/ap_test/test_data.py
index e2fd275..7394536 100644
--- a/source/test/ap_test/test_data.py
+++ b/source/test/ap_test/test_data.py
@@ -71,6 +71,18 @@ def converging_pathway():
return actions, sequences
+def action_combination_01_actions():
+ current = Action("current")
+ a = Action("a")
+ b = Action("b")
+ c = ActionCombination("c", [a, b])
+
+ actions = [current, a, b, c]
+ sequences = []
+
+ return actions, sequences
+
+
def action_combination_01_pathway():
current = Action("current")
a = Action("a")