diff --git a/libs/langchain/langchain/chains/router/multi_prompt.py b/libs/langchain/langchain/chains/router/multi_prompt.py index 214b9a2b37208..0531cdb834db3 100644 --- a/libs/langchain/langchain/chains/router/multi_prompt.py +++ b/libs/langchain/langchain/chains/router/multi_prompt.py @@ -20,9 +20,8 @@ since="0.2.12", removal="1.0", message=( - "Use RunnableLambda to select from multiple prompt templates. See example " - "in API reference: " - "https://api.python.langchain.com/en/latest/chains/langchain.chains.router.multi_prompt.MultiPromptChain.html" # noqa: E501 + "Please see migration guide here for recommended implementation: " + "https://python.langchain.com/docs/versions/migrating_chains/multi_prompt_chain/" # noqa: E501 ), ) class MultiPromptChain(MultiRouteChain): @@ -37,60 +36,109 @@ class MultiPromptChain(MultiRouteChain): from operator import itemgetter from typing import Literal - from typing_extensions import TypedDict from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate - from langchain_core.runnables import RunnableLambda, RunnablePassthrough + from langchain_core.runnables import RunnableConfig from langchain_openai import ChatOpenAI + from langgraph.graph import END, START, StateGraph + from typing_extensions import TypedDict llm = ChatOpenAI(model="gpt-4o-mini") + # Define the prompts we will route to prompt_1 = ChatPromptTemplate.from_messages( [ ("system", "You are an expert on animals."), - ("human", "{query}"), + ("human", "{input}"), ] ) prompt_2 = ChatPromptTemplate.from_messages( [ ("system", "You are an expert on vegetables."), - ("human", "{query}"), + ("human", "{input}"), ] ) + # Construct the chains we will route to. These format the input query + # into the respective prompt, run it through a chat model, and cast + # the result to a string. chain_1 = prompt_1 | llm | StrOutputParser() chain_2 = prompt_2 | llm | StrOutputParser() + + # Next: define the chain that selects which branch to route to. + # Here we will take advantage of tool-calling features to force + # the output to select one of two desired branches. route_system = "Route the user's query to either the animal or vegetable expert." route_prompt = ChatPromptTemplate.from_messages( [ ("system", route_system), - ("human", "{query}"), + ("human", "{input}"), ] ) + # Define schema for output: class RouteQuery(TypedDict): - \"\"\"Route query to destination.\"\"\" + \"\"\"Route query to destination expert.\"\"\" + destination: Literal["animal", "vegetable"] - route_chain = ( - route_prompt - | llm.with_structured_output(RouteQuery) - | itemgetter("destination") - ) + route_chain = route_prompt | llm.with_structured_output(RouteQuery) - chain = { - "destination": route_chain, # "animal" or "vegetable" - "query": lambda x: x["query"], # pass through input query - } | RunnableLambda( - # if animal, chain_1. otherwise, chain_2. - lambda x: chain_1 if x["destination"] == "animal" else chain_2, - ) - chain.invoke({"query": "what color are carrots"}) + # For LangGraph, we will define the state of the graph to hold the query, + # destination, and final answer. + class State(TypedDict): + query: str + destination: RouteQuery + answer: str + + + # We define functions for each node, including routing the query: + async def route_query(state: State, config: RunnableConfig): + destination = await route_chain.ainvoke(state["query"], config) + return {"destination": destination} + + + # And one node for each prompt + async def prompt_1(state: State, config: RunnableConfig): + return {"answer": await chain_1.ainvoke(state["query"], config)} + + + async def prompt_2(state: State, config: RunnableConfig): + return {"answer": await chain_2.ainvoke(state["query"], config)} + + + # We then define logic that selects the prompt based on the classification + def select_node(state: State) -> Literal["prompt_1", "prompt_2"]: + if state["destination"] == "animal": + return "prompt_1" + else: + return "prompt_2" + + + # Finally, assemble the multi-prompt chain. This is a sequence of two steps: + # 1) Select "animal" or "vegetable" via the route_chain, and collect the answer + # alongside the input query. + # 2) Route the input query to chain_1 or chain_2, based on the + # selection. + graph = StateGraph(State) + graph.add_node("route_query", route_query) + graph.add_node("prompt_1", prompt_1) + graph.add_node("prompt_2", prompt_2) + + graph.add_edge(START, "route_query") + graph.add_conditional_edges("route_query", select_node) + graph.add_edge("prompt_1", END) + graph.add_edge("prompt_2", END) + app = graph.compile() + + result = await app.ainvoke({"query": "what color are carrots"}) + print(result["destination"]) + print(result["answer"]) """ # noqa: E501 @property