From 5c9e0f58e2ec1b41574121b5b5d77ae1a3e46253 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Wed, 3 Jul 2024 15:35:59 -0700 Subject: [PATCH] Adds mermaid graph format --- langchain-core/src/runnables/graph.ts | 66 +++-- langchain-core/src/runnables/graph_mermaid.ts | 243 ++++++++++++++++++ .../runnables/tests/runnable_graph.test.ts | 21 ++ langchain-core/src/runnables/types.ts | 13 + 4 files changed, 321 insertions(+), 22 deletions(-) create mode 100644 langchain-core/src/runnables/graph_mermaid.ts diff --git a/langchain-core/src/runnables/graph.ts b/langchain-core/src/runnables/graph.ts index 854688e95ecf..f5c028561c2e 100644 --- a/langchain-core/src/runnables/graph.ts +++ b/langchain-core/src/runnables/graph.ts @@ -1,19 +1,13 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { v4 as uuidv4, validate as isUuid } from "uuid"; -import type { RunnableInterface, RunnableIOSchema } from "./types.js"; +import type { + RunnableInterface, + RunnableIOSchema, + Node, + Edge, +} from "./types.js"; import { isRunnableInterface } from "./utils.js"; - -interface Edge { - source: string; - target: string; - data?: string; -} - -interface Node { - id: string; - - data: RunnableIOSchema | RunnableInterface; -} +import { drawMermaid } from "./graph_mermaid.js"; const MAX_DATA_DISPLAY_NAME_LENGTH = 42; @@ -22,17 +16,12 @@ export function nodeDataStr(node: Node): string { return node.id; } else if (isRunnableInterface(node.data)) { try { - let data = node.data.toString(); - if ( - data.startsWith("<") || - data[0] !== data[0].toUpperCase() || - data.split("\n").length > 1 - ) { - data = node.data.getName(); - } else if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) { + let data = node.data.getName(); + data = data.startsWith("Runnable") ? data.slice("Runnable".length) : data; + if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) { data = `${data.substring(0, MAX_DATA_DISPLAY_NAME_LENGTH)}...`; } - return data.startsWith("Runnable") ? data.slice("Runnable".length) : data; + return data; } catch (error) { return node.data.getName(); } @@ -179,4 +168,37 @@ export class Graph { } } } + + drawMermaid(params?: { + withStyles?: boolean; + curveStyle?: string; + nodeColors?: Record; + wrapLabelNWords?: number; + }): string { + const { + withStyles, + curveStyle, + nodeColors = { start: "#ffdfba", end: "#baffc9", other: "#fad7de" }, + wrapLabelNWords, + } = params ?? {}; + const nodes: Record = {}; + for (const node of Object.values(this.nodes)) { + nodes[node.id] = nodeDataStr(node); + } + + const firstNode = this.firstNode(); + const firstNodeLabel = firstNode ? nodeDataStr(firstNode) : undefined; + + const lastNode = this.lastNode(); + const lastNodeLabel = lastNode ? nodeDataStr(lastNode) : undefined; + + return drawMermaid(nodes, this.edges, { + firstNodeLabel, + lastNodeLabel, + withStyles, + curveStyle, + nodeColors, + wrapLabelNWords, + }); + } } diff --git a/langchain-core/src/runnables/graph_mermaid.ts b/langchain-core/src/runnables/graph_mermaid.ts new file mode 100644 index 000000000000..d201655d3f27 --- /dev/null +++ b/langchain-core/src/runnables/graph_mermaid.ts @@ -0,0 +1,243 @@ +import { Edge } from "./types.js"; + +function _escapeNodeLabel(nodeLabel: string): string { + // Escapes the node label for Mermaid syntax. + return nodeLabel.replace(/[^a-zA-Z-_0-9]/g, "_"); +} + +// Adjusts Mermaid edge to map conditional nodes to pure nodes. +function _adjustMermaidEdge(edge: Edge, nodes: Record) { + const sourceNodeLabel = nodes[edge.source] ?? edge.source; + const targetNodeLabel = nodes[edge.target] ?? edge.target; + return [sourceNodeLabel, targetNodeLabel]; +} + +function _generateMermaidGraphStyles( + nodeColors: Record +): string { + let styles = ""; + for (const [className, color] of Object.entries(nodeColors)) { + styles += `\tclassDef ${className}class fill:${color};\n`; + } + return styles; +} + +/** + * Draws a Mermaid graph using the provided graph data + */ +export function drawMermaid( + nodes: Record, + edges: Edge[], + config?: { + firstNodeLabel?: string; + lastNodeLabel?: string; + curveStyle?: string; + withStyles?: boolean; + nodeColors?: Record; + wrapLabelNWords?: number; + } +): string { + const { + firstNodeLabel, + lastNodeLabel, + nodeColors, + withStyles = true, + curveStyle = "linear", + wrapLabelNWords = 9, + } = config ?? {}; + // Initialize Mermaid graph configuration + let mermaidGraph = withStyles + ? `%%{init: {'flowchart': {'curve': '${curveStyle}'}}}%%\ngraph TD;\n` + : "graph TD;\n"; + if (withStyles) { + // Node formatting templates + const defaultClassLabel = "default"; + const formatDict: Record = { + [defaultClassLabel]: "{0}([{1}]):::otherclass", + }; + if (firstNodeLabel !== undefined) { + formatDict[firstNodeLabel] = "{0}[{0}]:::startclass"; + } + if (lastNodeLabel !== undefined) { + formatDict[lastNodeLabel] = "{0}[{0}]:::endclass"; + } + + // Add nodes to the graph + for (const node of Object.values(nodes)) { + const nodeLabel = formatDict[node] ?? formatDict[defaultClassLabel]; + const escapedNodeLabel = _escapeNodeLabel(node); + const nodeParts = node.split(":"); + const nodeSplit = nodeParts[nodeParts.length - 1]; + mermaidGraph += `\t${nodeLabel + .replace(/\{0\}/g, escapedNodeLabel) + .replace(/\{1\}/g, nodeSplit)};\n`; + } + } + let subgraph = ""; + // Add edges to the graph + for (const edge of edges) { + const sourcePrefix = edge.source.includes(":") + ? edge.source.split(":")[0] + : undefined; + const targetPrefix = edge.target.includes(":") + ? edge.target.split(":")[0] + : undefined; + // Exit subgraph if source or target is not in the same subgraph + if ( + (subgraph !== "" && subgraph !== sourcePrefix) || + subgraph !== targetPrefix + ) { + mermaidGraph += "\tend\n"; + subgraph = ""; + } + // Enter subgraph if source and target are in the same subgraph + if ( + subgraph === "" && + sourcePrefix !== undefined && + sourcePrefix === targetPrefix + ) { + mermaidGraph = `\tsubgraph ${sourcePrefix}\n`; + subgraph = sourcePrefix; + } + const [source, target] = _adjustMermaidEdge(edge, nodes); + let edgeLabel = ""; + // Add BR every wrapLabelNWords words + if (edge.data !== undefined) { + let edgeData = edge.data; + const words = edgeData.split(" "); + // Group words into chunks of wrapLabelNWords size + if (words.length > wrapLabelNWords) { + edgeData = words + .reduce((acc: string[], word: string, i: number) => { + if (i % wrapLabelNWords === 0) acc.push(""); + acc[acc.length - 1] += ` ${word}`; + return acc; + }, []) + .join("
"); + if (edge.conditional) { + edgeLabel = ` -. ${edgeData} .-> `; + } else { + edgeLabel = ` -- ${edgeData} --> `; + } + } + } else { + if (edge.conditional) { + edgeLabel = ` -.-> `; + } else { + edgeLabel = ` --> `; + } + } + mermaidGraph += `\t${_escapeNodeLabel( + source + )}${edgeLabel}${_escapeNodeLabel(target)};\n`; + } + if (subgraph !== undefined) { + mermaidGraph += "end\n"; + } + + // Add custom styles for nodes + if (withStyles && nodeColors !== undefined) { + mermaidGraph += _generateMermaidGraphStyles(nodeColors); + } + return mermaidGraph; +} + +// subgraph = "" +// # Add edges to the graph +// for edge in edges: +// src_prefix = edge.source.split(":")[0] if ":" in edge.source else None +// tgt_prefix = edge.target.split(":")[0] if ":" in edge.target else None +// # exit subgraph if source or target is not in the same subgraph +// if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix): +// mermaid_graph += "\tend\n" +// subgraph = "" +// # enter subgraph if source and target are in the same subgraph +// if not subgraph and src_prefix and src_prefix == tgt_prefix: +// mermaid_graph += f"\tsubgraph {src_prefix}\n" +// subgraph = src_prefix +// adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes) + +// source, target = adjusted_edge + +// # Add BR every wrap_label_n_words words +// if edge.data is not None: +// edge_data = edge.data +// words = str(edge_data).split() # Split the string into words +// # Group words into chunks of wrap_label_n_words size +// if len(words) > wrap_label_n_words: +// edge_data = "
".join( +// [ +// " ".join(words[i : i + wrap_label_n_words]) +// for i in range(0, len(words), wrap_label_n_words) +// ] +// ) +// if edge.conditional: +// edge_label = f" -. {edge_data} .-> " +// else: +// edge_label = f" -- {edge_data} --> " +// else: +// if edge.conditional: +// edge_label = " -.-> " +// else: +// edge_label = " --> " +// mermaid_graph += ( +// f"\t{_escape_node_label(source)}{edge_label}" +// f"{_escape_node_label(target)};\n" +// ) +// if subgraph: +// mermaid_graph += "end\n" + +// # Add custom styles for nodes +// if with_styles: +// mermaid_graph += _generate_mermaid_graph_styles(node_colors) +// return mermaid_graph + +/** + * Renders Mermaid graph using the Mermaid.INK API. + */ +// export async function drawMermaidPng( +// mermaidSyntax: string, +// config = { +// backgroundColor: "white", +// } +// ) { +// let encoder = new TextEncoder(); +// let data = encoder.encode(mermaidSyntax); +// let mermaidSyntaxEncoded = btoa(String.fromCharCode.apply(null, data)); +// } + +// try: +// import requests # type: ignore[import] +// except ImportError as e: +// raise ImportError( +// "Install the `requests` module to use the Mermaid.INK API: " +// "`pip install requests`." +// ) from e + +// # Use Mermaid API to render the image +// mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( +// "ascii" +// ) + +// # Check if the background color is a hexadecimal color code using regex +// if backgroundColor is not None: +// hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") +// if not hex_color_pattern.match(backgroundColor): +// backgroundColor = f"!{backgroundColor}" + +// image_url = ( +// f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={backgroundColor}" +// ) +// response = requests.get(image_url) +// if response.status_code == 200: +// img_bytes = response.content +// if output_file_path is not None: +// with open(output_file_path, "wb") as file: +// file.write(response.content) + +// return img_bytes +// else: +// raise ValueError( +// f"Failed to render the graph using the Mermaid.INK API. " +// f"Status code: {response.status_code}." +// ) diff --git a/langchain-core/src/runnables/tests/runnable_graph.test.ts b/langchain-core/src/runnables/tests/runnable_graph.test.ts index 352eacc618dd..fde659e6b076 100644 --- a/langchain-core/src/runnables/tests/runnable_graph.test.ts +++ b/langchain-core/src/runnables/tests/runnable_graph.test.ts @@ -87,4 +87,25 @@ test("Test graph sequence", async () => { { source: 2, target: 3 }, ], }); + expect(graph.drawMermaid()) + .toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% +graph TD; +\tPromptTemplateInput[PromptTemplateInput]:::startclass; +\tPromptTemplate([PromptTemplate]):::otherclass; +\tFakeLLM([FakeLLM]):::otherclass; +\tCommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; +\tCommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; +\tend +\tPromptTemplateInput --> PromptTemplate; +\tend +\tPromptTemplate --> FakeLLM; +\tend +\tCommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput; +\tend +\tFakeLLM --> CommaSeparatedListOutputParser; +end +\tclassDef startclass fill:#ffdfba; +\tclassDef endclass fill:#baffc9; +\tclassDef otherclass fill:#fad7de; +`); }); diff --git a/langchain-core/src/runnables/types.ts b/langchain-core/src/runnables/types.ts index 0e7e319ddd8e..eccc800864a7 100644 --- a/langchain-core/src/runnables/types.ts +++ b/langchain-core/src/runnables/types.ts @@ -61,3 +61,16 @@ export interface RunnableInterface< getName(suffix?: string): string; } + +export interface Edge { + source: string; + target: string; + data?: string; + conditional?: boolean; +} + +export interface Node { + id: string; + + data: RunnableIOSchema | RunnableInterface; +}