Skip to content

Commit

Permalink
Mermaid: Draw subgraphes for groups of consecutive edges with common …
Browse files Browse the repository at this point in the history
…prefix
  • Loading branch information
nfcampos committed Apr 24, 2024
1 parent dd7b852 commit b90545a
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions libs/core/langchain_core/runnables/graph_mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def draw_mermaid(
if with_styles:
# Node formatting templates
default_class_label = "default"
format_dict = {default_class_label: "{0}([{0}]):::otherclass"}
format_dict = {default_class_label: "{0}([{1}]):::otherclass"}
if first_node_label is not None:
format_dict[first_node_label] = "{0}[{0}]:::startclass"
if last_node_label is not None:
Expand All @@ -57,17 +57,24 @@ def draw_mermaid(
# Add nodes to the graph
for node in nodes.values():
node_label = format_dict.get(node, format_dict[default_class_label]).format(
_escape_node_label(node)
_escape_node_label(node), _escape_node_label(node.split(":", 1)[-1])
)
mermaid_graph += f"\t{node_label};\n"

subgraph = ""
# Add edges to the graph
for edge in edges:
src_prefix = edge.source.split(":")[0]
tgt_prefix = edge.target.split(":")[0]
# exit subgraph if source or target is not in the same subgraph
if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix):
mermaid_graph += "\tend\n"
subgraph = ""
# enter subgraph if source and target are in the same subgraph
if not subgraph and src_prefix and src_prefix == tgt_prefix:
mermaid_graph += f"\tsubgraph {src_prefix}\n"
subgraph = src_prefix
adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes)
if (
adjusted_edge is None
): # Ignore if it is connection between source and intermediate node
continue

source, target = adjusted_edge

Expand Down Expand Up @@ -96,6 +103,8 @@ def draw_mermaid(
f"\t{_escape_node_label(source)}{edge_label}"
f"{_escape_node_label(target)};\n"
)
if subgraph:
mermaid_graph += "end\n"

# Add custom styles for nodes
if with_styles:
Expand All @@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str:
def _adjust_mermaid_edge(
edge: Edge,
nodes: Dict[str, str],
) -> Optional[Tuple[str, str]]:
) -> Tuple[str, str]:
"""Adjusts Mermaid edge to map conditional nodes to pure nodes."""
source_node_label = nodes.get(edge.source, edge.source)
target_node_label = nodes.get(edge.target, edge.target)
Expand Down

0 comments on commit b90545a

Please sign in to comment.