From 477eb1745c1ad69d755330f508e7f8b46359e2d4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 24 Apr 2024 12:32:52 -0700 Subject: [PATCH 1/2] Better support for subgraphs in graph viz (#20840) --- libs/core/langchain_core/runnables/base.py | 11 ++---- libs/core/langchain_core/runnables/graph.py | 35 +++++++++++++++++-- .../langchain_core/runnables/graph_mermaid.py | 23 ++++++++---- 3 files changed, 51 insertions(+), 18 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index ee48234df3a80..3484e8bba3c69 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2409,8 +2409,7 @@ def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: step_graph.trim_first_node() if step is not self.last: step_graph.trim_last_node() - graph.extend(step_graph) - step_first_node = step_graph.first_node() + step_first_node, _ = graph.extend(step_graph) if not step_first_node: raise ValueError(f"Runnable {step} has no first node") if current_last_node: @@ -3082,11 +3081,9 @@ def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph: if not step_graph: graph.add_edge(input_node, output_node) else: - graph.extend(step_graph) - step_first_node = step_graph.first_node() + step_first_node, step_last_node = graph.extend(step_graph) if not step_first_node: raise ValueError(f"Runnable {step} has no first node") - step_last_node = step_graph.last_node() if not step_last_node: raise ValueError(f"Runnable {step} has no last node") graph.add_edge(input_node, step_first_node) @@ -3779,11 +3776,9 @@ def get_graph(self, config: RunnableConfig | None = None) -> Graph: if not dep_graph: graph.add_edge(input_node, output_node) else: - graph.extend(dep_graph) - dep_first_node = dep_graph.first_node() + dep_first_node, dep_last_node = graph.extend(dep_graph) if not dep_first_node: raise ValueError(f"Runnable {dep} has no first node") - dep_last_node = dep_graph.last_node() if not dep_last_node: raise ValueError(f"Runnable {dep} has no last node") graph.add_edge(input_node, dep_first_node) diff --git a/libs/core/langchain_core/runnables/graph.py b/libs/core/langchain_core/runnables/graph.py index f92b4b064a271..4486fca525245 100644 --- a/libs/core/langchain_core/runnables/graph.py +++ b/libs/core/langchain_core/runnables/graph.py @@ -11,6 +11,7 @@ List, NamedTuple, Optional, + Tuple, Type, TypedDict, Union, @@ -236,11 +237,39 @@ def add_edge( self.edges.append(edge) return edge - def extend(self, graph: Graph) -> None: + def extend( + self, graph: Graph, *, prefix: str = "" + ) -> Tuple[Optional[Node], Optional[Node]]: """Add all nodes and edges from another graph. Note this doesn't check for duplicates, nor does it connect the graphs.""" - self.nodes.update(graph.nodes) - self.edges.extend(graph.edges) + if all(is_uuid(node.id) for node in graph.nodes.values()): + prefix = "" + + def prefixed(id: str) -> str: + return f"{prefix}:{id}" if prefix else id + + # prefix each node + self.nodes.update( + {prefixed(k): Node(prefixed(k), v.data) for k, v in graph.nodes.items()} + ) + # prefix each edge's source and target + self.edges.extend( + [ + Edge( + prefixed(edge.source), + prefixed(edge.target), + edge.data, + edge.conditional, + ) + for edge in graph.edges + ] + ) + # return (prefixed) first and last nodes of the subgraph + first, last = graph.first_node(), graph.last_node() + return ( + Node(prefixed(first.id), first.data) if first else None, + Node(prefixed(last.id), last.data) if last else None, + ) def first_node(self) -> Optional[Node]: """Find the single node that is not a target of any edge. diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index 93e052d8919db..61b922ae68fbf 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -48,7 +48,7 @@ def draw_mermaid( if with_styles: # Node formatting templates default_class_label = "default" - format_dict = {default_class_label: "{0}([{0}]):::otherclass"} + format_dict = {default_class_label: "{0}([{1}]):::otherclass"} if first_node_label is not None: format_dict[first_node_label] = "{0}[{0}]:::startclass" if last_node_label is not None: @@ -57,17 +57,24 @@ def draw_mermaid( # Add nodes to the graph for node in nodes.values(): node_label = format_dict.get(node, format_dict[default_class_label]).format( - _escape_node_label(node) + _escape_node_label(node), _escape_node_label(node.split(":", 1)[-1]) ) mermaid_graph += f"\t{node_label};\n" + subgraph = "" # Add edges to the graph for edge in edges: + src_prefix = edge.source.split(":")[0] + tgt_prefix = edge.target.split(":")[0] + # exit subgraph if source or target is not in the same subgraph + if subgraph and (subgraph != src_prefix or subgraph != tgt_prefix): + mermaid_graph += "\tend\n" + subgraph = "" + # enter subgraph if source and target are in the same subgraph + if not subgraph and src_prefix and src_prefix == tgt_prefix: + mermaid_graph += f"\tsubgraph {src_prefix}\n" + subgraph = src_prefix adjusted_edge = _adjust_mermaid_edge(edge=edge, nodes=nodes) - if ( - adjusted_edge is None - ): # Ignore if it is connection between source and intermediate node - continue source, target = adjusted_edge @@ -96,6 +103,8 @@ def draw_mermaid( f"\t{_escape_node_label(source)}{edge_label}" f"{_escape_node_label(target)};\n" ) + if subgraph: + mermaid_graph += "end\n" # Add custom styles for nodes if with_styles: @@ -111,7 +120,7 @@ def _escape_node_label(node_label: str) -> str: def _adjust_mermaid_edge( edge: Edge, nodes: Dict[str, str], -) -> Optional[Tuple[str, str]]: +) -> Tuple[str, str]: """Adjusts Mermaid edge to map conditional nodes to pure nodes.""" source_node_label = nodes.get(edge.source, edge.source) target_node_label = nodes.get(edge.target, edge.target) From 8c95ac3145600904ae717e451e5391e8da2e1351 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 24 Apr 2024 12:34:57 -0700 Subject: [PATCH 2/2] docs, multiple: de-beta with_structured_output (#20850) --- .../model_io/chat/structured_output.ipynb | 115 +++++++++--------- .../langchain_core/language_models/base.py | 3 +- .../langchain_anthropic/chat_models.py | 1 - .../langchain_fireworks/chat_models.py | 2 - .../groq/langchain_groq/chat_models.py | 2 - .../langchain_mistralai/chat_models.py | 2 - .../langchain_openai/chat_models/base.py | 2 - 7 files changed, 60 insertions(+), 67 deletions(-) diff --git a/docs/docs/modules/model_io/chat/structured_output.ipynb b/docs/docs/modules/model_io/chat/structured_output.ipynb index aa0c956ee1040..e76f9926438bb 100644 --- a/docs/docs/modules/model_io/chat/structured_output.ipynb +++ b/docs/docs/modules/model_io/chat/structured_output.ipynb @@ -15,7 +15,7 @@ "id": "6e3f0f72", "metadata": {}, "source": [ - "# [beta] Structured Output\n", + "# Structured Output\n", "\n", "It is often crucial to have LLMs return structured output. This is because oftentimes the outputs of the LLMs are used in downstream applications, where specific arguments are required. Having the LLM return structured output reliably is necessary for that.\n", "\n", @@ -39,21 +39,14 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "08029f4e", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.pydantic_v1 import BaseModel, Field" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "id": "070bf702", "metadata": {}, "outputs": [], "source": [ + "from langchain_core.pydantic_v1 import BaseModel, Field\n", + "\n", + "\n", "class Joke(BaseModel):\n", " setup: str = Field(description=\"The setup of the joke\")\n", " punchline: str = Field(description=\"The punchline to the joke\")" @@ -93,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "id": "6700994a", "metadata": {}, "outputs": [], @@ -104,17 +97,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "id": "c55a61b8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Joke(setup='Why was the cat sitting on the computer?', punchline='It wanted to keep an eye on the mouse!')" + "Joke(setup='Why was the cat sitting on the computer?', punchline='To keep an eye on the mouse!')" ] }, - "execution_count": 10, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -135,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "id": "df0370e3", "metadata": {}, "outputs": [], @@ -145,17 +138,17 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "id": "23844a26", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')" + "Joke(setup='Why was the cat sitting on the computer?', punchline='Because it wanted to keep an eye on the mouse!')" ] }, - "execution_count": 14, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -180,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "id": "ad45fdd8", "metadata": {}, "outputs": [], @@ -252,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "649f9632", "metadata": {}, "outputs": [ @@ -262,7 +255,7 @@ "Joke(setup='Why did the dog sit in the shade?', punchline='To avoid getting burned.')" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -287,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 12, "id": "bffd3fad", "metadata": {}, "outputs": [], @@ -297,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "c8bd7549", "metadata": {}, "outputs": [], @@ -308,10 +301,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "17b15816", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "structured_llm.invoke(\"Tell me a joke about cats\")" ] @@ -328,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 15, "id": "9b9617e3", "metadata": {}, "outputs": [], @@ -340,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 16, "id": "90549664", "metadata": {}, "outputs": [], @@ -355,7 +359,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 17, "id": "01da39be", "metadata": {}, "outputs": [ @@ -365,7 +369,7 @@ "Joke(setup='Why did the cat sit on the computer?', punchline='To keep an eye on the mouse!')" ] }, - "execution_count": 25, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -388,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "id": "70511bc3", "metadata": {}, "outputs": [], @@ -408,19 +412,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 19, "id": "be9fdf04", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/reag/src/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `with_structured_output` is in beta. It is actively being worked on, so the API may change.\n", - " warn_beta(\n" - ] - } - ], + "outputs": [], "source": [ "model = ChatGroq()\n", "structured_llm = model.with_structured_output(Joke)" @@ -428,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "id": "e13f4676", "metadata": {}, "outputs": [ @@ -438,7 +433,7 @@ "Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')" ] }, - "execution_count": 7, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -459,7 +454,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 21, "id": "86574fb8", "metadata": {}, "outputs": [], @@ -469,7 +464,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 22, "id": "01dced9c", "metadata": {}, "outputs": [ @@ -479,7 +474,7 @@ "Joke(setup=\"Why don't cats play poker in the jungle?\", punchline='Too many cheetahs!')" ] }, - "execution_count": 9, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -504,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 23, "id": "12682237-6689-4408-88b1-3595feac447f", "metadata": {}, "outputs": [ @@ -514,7 +509,7 @@ "Joke(setup='What do you call a cat that loves to bowl?', punchline='An alley cat!')" ] }, - "execution_count": 5, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -541,17 +536,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "id": "24421189-02bf-4589-a91a-197584c4a696", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Joke(setup='A cat-ch', punchline='What do you call a cat that loves to play fetch?')" + "Joke(setup='Why did the scarecrow win an award?', punchline='Why did the scarecrow win an award? Because he was outstanding in his field.')" ] }, - "execution_count": 7, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -563,13 +558,21 @@ "structured_llm = llm.with_structured_output(Joke)\n", "structured_llm.invoke(\"Tell me a joke about cats\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2630a2cb", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "poetry-venv-2", + "display_name": ".venv", "language": "python", - "name": "poetry-venv-2" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -581,7 +584,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index a6addf8f6b0aa..a62809fbe9c43 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -19,7 +19,7 @@ from typing_extensions import TypeAlias -from langchain_core._api import beta, deprecated +from langchain_core._api import deprecated from langchain_core.messages import ( AnyMessage, BaseMessage, @@ -201,7 +201,6 @@ async def agenerate_prompt( prompt and additional model provider-specific output. """ - @beta() def with_structured_output( self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 5b2c715029afb..29a75ba9bdd42 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -580,7 +580,6 @@ class GetWeather(BaseModel): formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] return self.bind(tools=formatted_tools, **kwargs) - @beta() def with_structured_output( self, schema: Union[Dict, Type[BaseModel]], diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index fc5960eea98b5..52dea080894bf 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -24,7 +24,6 @@ ) from fireworks.client import AsyncFireworks, Fireworks # type: ignore -from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -671,7 +670,6 @@ def bind_tools( kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs) - @beta() def with_structured_output( self, schema: Optional[Union[Dict, Type[BaseModel]]] = None, diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 2650f7608c21d..86db80e8e22a7 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -23,7 +23,6 @@ cast, ) -from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -595,7 +594,6 @@ def bind_tools( kwargs["tool_choice"] = tool_choice return super().bind(tools=formatted_tools, **kwargs) - @beta() def with_structured_output( self, schema: Optional[Union[Dict, Type[BaseModel]]] = None, diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index ab3027c94d969..a00943d2c00a3 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -22,7 +22,6 @@ import httpx from httpx_sse import EventSource, aconnect_sse, connect_sse -from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -588,7 +587,6 @@ def bind_tools( formatted_tools = [convert_to_openai_tool(tool) for tool in tools] return super().bind(tools=formatted_tools, **kwargs) - @beta() def with_structured_output( self, schema: Union[Dict, Type[BaseModel]], diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 915557baa567b..e68377c2a85d2 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -29,7 +29,6 @@ import openai import tiktoken -from langchain_core._api import beta from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -885,7 +884,6 @@ def with_structured_output( ) -> Runnable[LanguageModelInput, _DictOrPydantic]: ... - @beta() def with_structured_output( self, schema: Optional[_DictOrPydanticClass] = None,