diff --git a/source/package/adaptation_pathways/cli/plot_pathway_map.py b/source/package/adaptation_pathways/cli/plot_pathway_map.py index 99a4604..3a7bdff 100644 --- a/source/package/adaptation_pathways/cli/plot_pathway_map.py +++ b/source/package/adaptation_pathways/cli/plot_pathway_map.py @@ -54,6 +54,20 @@ def plot_map( return 0 +def parse_spread(spread: str) -> tuple[float, float]: + spreads = spread.split(",") + + if len(spreads) == 1: + result = float(spreads[0]), float(spreads[0]) + else: + assert ( + len(spreads) == 2 + ), "Pass in a single floating point value, or two separated by a comma" + result = float(spreads[0]), float(spreads[1]) + + return result + + def main() -> int: command = os.path.basename(sys.argv[0]) usage = f"""\ @@ -78,9 +92,13 @@ def main() -> int: bit beyond the actual point --show_legend Show legend --spread= Separate overlapping lines by a percentage [0, 1] of - the range passed in. A value of 0.01 means 1% of the - range of x-coordinates. Passing in a value > 0.02 is - likely not useful. [default: 0] + the data range. A value of 0.01 means 1% of the + range. Passing in a value > 0.02 is likely not useful. + Pass in a tuple of hspread,vspread to separate between + horizontal and vertical spread. Horizontal spread is + about the separation of vertical lines (transitions). + Vertical spread is about horizontal lines (actions). + [default: 0] --title= Title --x_label=<label> Label of x-axis @@ -100,7 +118,7 @@ def main() -> int: x_label = arguments["--x_label"] if arguments["--x_label"] is not None else "" show_legend = arguments["--show_legend"] overshoot = arguments["--overshoot"] - overlapping_lines_spread = float(arguments["--spread"]) + overlapping_lines_spread: tuple[float, float] = parse_spread(arguments["--spread"]) plot_arguments: dict[str, typing.Any] = { "title": title, diff --git a/source/package/adaptation_pathways/plot/pathway_map/classic.py b/source/package/adaptation_pathways/plot/pathway_map/classic.py index 2743113..883722e 100644 --- a/source/package/adaptation_pathways/plot/pathway_map/classic.py +++ b/source/package/adaptation_pathways/plot/pathway_map/classic.py @@ -567,7 +567,7 @@ def _distribute_vertically( def _layout( pathway_map: PathwayMap, *, - overlapping_lines_spread: float, + overlapping_lines_spread: float | tuple[float, float], ) -> tuple[dict[ActionBegin | ActionEnd, np.ndarray], dict[str, float]]: """ Layout that replicates the pathway map layout of the original (pre-2024) pathway generator @@ -610,11 +610,18 @@ def _layout( pathway_map, root_action_begin, position_by_node ) - if overlapping_lines_spread > 0: - _spread_horizontally( - pathway_map, position_by_node, overlapping_lines_spread + if not isinstance(overlapping_lines_spread, tuple): + overlapping_lines_spread = ( + overlapping_lines_spread, + overlapping_lines_spread, ) - _spread_vertically(pathway_map, position_by_node, overlapping_lines_spread) + + horizontal_spread, vertical_spread = overlapping_lines_spread + + if horizontal_spread > 0: + _spread_horizontally(pathway_map, position_by_node, horizontal_spread) + if vertical_spread > 0: + _spread_vertically(pathway_map, position_by_node, vertical_spread) return position_by_node, y_coordinate_by_action_name @@ -633,7 +640,9 @@ def plot( if legend_arguments is None: legend_arguments = {} - overlapping_lines_spread: float = arguments.get("overlapping_lines_spread", 0) + overlapping_lines_spread: tuple[float, float] = arguments.get( + "overlapping_lines_spread", (0, 0) + ) layout, y_coordinate_by_action_name = _layout( pathway_map, overlapping_lines_spread=overlapping_lines_spread