\n",
+ " \n",
+ "
\n",
+ " \n",
+ " import functools\n",
+ " from typing import Annotated, Any, Callable, Dict, List, Optional, Union\n",
+ "\n",
+ " from langchain_community.adapters.openai import convert_message_to_dict\n",
+ " from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage\n",
+ " from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
+ " from langchain_core.runnables import Runnable, RunnableLambda\n",
+ " from langchain_core.runnables import chain as as_runnable\n",
+ " from langchain_openai import ChatOpenAI\n",
+ " from typing_extensions import TypedDict\n",
+ "\n",
+ " from langgraph.graph import END, StateGraph, START\n",
+ "\n",
+ "\n",
+ " def langchain_to_openai_messages(messages: List[BaseMessage]):\n",
+ " \"\"\"\n",
+ " Convert a list of langchain base messages to a list of openai messages.\n",
+ "\n",
+ " Parameters:\n",
+ " messages (List[BaseMessage]): A list of langchain base messages.\n",
+ "\n",
+ " Returns:\n",
+ " List[dict]: A list of openai messages.\n",
+ " \"\"\"\n",
+ "\n",
+ " return [\n",
+ " convert_message_to_dict(m) if isinstance(m, BaseMessage) else m\n",
+ " for m in messages\n",
+ " ]\n",
+ "\n",
+ "\n",
+ " def create_simulated_user(\n",
+ " system_prompt: str, llm: Runnable | None = None\n",
+ " ) -> Runnable[Dict, AIMessage]:\n",
+ " \"\"\"\n",
+ " Creates a simulated user for chatbot simulation.\n",
+ "\n",
+ " Args:\n",
+ " system_prompt (str): The system prompt to be used by the simulated user.\n",
+ " llm (Runnable | None, optional): The language model to be used for the simulation.\n",
+ " Defaults to gpt-3.5-turbo.\n",
+ "\n",
+ " Returns:\n",
+ " Runnable[Dict, AIMessage]: The simulated user for chatbot simulation.\n",
+ " \"\"\"\n",
+ " return ChatPromptTemplate.from_messages(\n",
+ " [\n",
+ " (\"system\", system_prompt),\n",
+ " MessagesPlaceholder(variable_name=\"messages\"),\n",
+ " ]\n",
+ " ) | (llm or ChatOpenAI(model=\"gpt-3.5-turbo\")).with_config(\n",
+ " run_name=\"simulated_user\"\n",
+ " )\n",
+ "\n",
+ "\n",
+ " Messages = Union[list[AnyMessage], AnyMessage]\n",
+ "\n",
+ "\n",
+ " def add_messages(left: Messages, right: Messages) -> Messages:\n",
+ " if not isinstance(left, list):\n",
+ " left = [left]\n",
+ " if not isinstance(right, list):\n",
+ " right = [right]\n",
+ " return left + right\n",
+ "\n",
+ "\n",
+ " class SimulationState(TypedDict):\n",
+ " \"\"\"\n",
+ " Represents the state of a simulation.\n",
+ "\n",
+ " Attributes:\n",
+ " messages (List[AnyMessage]): A list of messages in the simulation.\n",
+ " inputs (Optional[dict[str, Any]]): Optional inputs for the simulation.\n",
+ " \"\"\"\n",
+ "\n",
+ " messages: Annotated[List[AnyMessage], add_messages]\n",
+ " inputs: Optional[dict[str, Any]]\n",
+ "\n",
+ "\n",
+ " def create_chat_simulator(\n",
+ " assistant: (\n",
+ " Callable[[List[AnyMessage]], str | AIMessage]\n",
+ " | Runnable[List[AnyMessage], str | AIMessage]\n",
+ " ),\n",
+ " simulated_user: Runnable[Dict, AIMessage],\n",
+ " *,\n",
+ " input_key: str,\n",
+ " max_turns: int = 6,\n",
+ " should_continue: Optional[Callable[[SimulationState], str]] = None,\n",
+ " ):\n",
+ " \"\"\"Creates a chat simulator for evaluating a chatbot.\n",
+ "\n",
+ " Args:\n",
+ " assistant: The chatbot assistant function or runnable object.\n",
+ " simulated_user: The simulated user object.\n",
+ " input_key: The key for the input to the chat simulation.\n",
+ " max_turns: The maximum number of turns in the chat simulation. Default is 6.\n",
+ " should_continue: Optional function to determine if the simulation should continue.\n",
+ " If not provided, a default function will be used.\n",
+ "\n",
+ " Returns:\n",
+ " The compiled chat simulation graph.\n",
+ "\n",
+ " \"\"\"\n",
+ " graph_builder = StateGraph(SimulationState)\n",
+ " graph_builder.add_node(\n",
+ " \"user\",\n",
+ " _create_simulated_user_node(simulated_user),\n",
+ " )\n",
+ " graph_builder.add_node(\n",
+ " \"assistant\", _fetch_messages | assistant | _coerce_to_message\n",
+ " )\n",
+ " graph_builder.add_edge(\"assistant\", \"user\")\n",
+ " graph_builder.add_conditional_edges(\n",
+ " \"user\",\n",
+ " should_continue or functools.partial(_should_continue, max_turns=max_turns),\n",
+ " )\n",
+ " # If your dataset has a 'leading question/input', then we route first to the assistant, otherwise, we let the user take the lead.\n",
+ " graph_builder.add_edge(START, \"assistant\" if input_key is not None else \"user\")\n",
+ "\n",
+ " return (\n",
+ " RunnableLambda(_prepare_example).bind(input_key=input_key)\n",
+ " | graph_builder.compile()\n",
+ " )\n",
+ "\n",
+ "\n",
+ " ## Private methods\n",
+ "\n",
+ "\n",
+ " def _prepare_example(inputs: dict[str, Any], input_key: Optional[str] = None):\n",
+ " if input_key is not None:\n",
+ " if input_key not in inputs:\n",
+ " raise ValueError(\n",
+ " f\"Dataset's example input must contain the provided input key: '{input_key}'.\\nFound: {list(inputs.keys())}\"\n",
+ " )\n",
+ " messages = [HumanMessage(content=inputs[input_key])]\n",
+ " return {\n",
+ " \"inputs\": {k: v for k, v in inputs.items() if k != input_key},\n",
+ " \"messages\": messages,\n",
+ " }\n",
+ " return {\"inputs\": inputs, \"messages\": []}\n",
+ "\n",
+ "\n",
+ " def _invoke_simulated_user(state: SimulationState, simulated_user: Runnable):\n",
+ " \"\"\"Invoke the simulated user node.\"\"\"\n",
+ " runnable = (\n",
+ " simulated_user\n",
+ " if isinstance(simulated_user, Runnable)\n",
+ " else RunnableLambda(simulated_user)\n",
+ " )\n",
+ " inputs = state.get(\"inputs\", {})\n",
+ " inputs[\"messages\"] = state[\"messages\"]\n",
+ " return runnable.invoke(inputs)\n",
+ "\n",
+ "\n",
+ " def _swap_roles(state: SimulationState):\n",
+ " new_messages = []\n",
+ " for m in state[\"messages\"]:\n",
+ " if isinstance(m, AIMessage):\n",
+ " new_messages.append(HumanMessage(content=m.content))\n",
+ " else:\n",
+ " new_messages.append(AIMessage(content=m.content))\n",
+ " return {\n",
+ " \"inputs\": state.get(\"inputs\", {}),\n",
+ " \"messages\": new_messages,\n",
+ " }\n",
+ "\n",
+ "\n",
+ " @as_runnable\n",
+ " def _fetch_messages(state: SimulationState):\n",
+ " \"\"\"Invoke the simulated user node.\"\"\"\n",
+ " return state[\"messages\"]\n",
+ "\n",
+ "\n",
+ " def _convert_to_human_message(message: BaseMessage):\n",
+ " return {\"messages\": [HumanMessage(content=message.content)]}\n",
+ "\n",
+ "\n",
+ " def _create_simulated_user_node(simulated_user: Runnable):\n",
+ " \"\"\"Simulated user accepts a {\"messages\": [...]} argument and returns a single message.\"\"\"\n",
+ " return (\n",
+ " _swap_roles\n",
+ " | RunnableLambda(_invoke_simulated_user).bind(simulated_user=simulated_user)\n",
+ " | _convert_to_human_message\n",
+ " )\n",
+ "\n",
+ "\n",
+ " def _coerce_to_message(assistant_output: str | BaseMessage):\n",
+ " if isinstance(assistant_output, str):\n",
+ " return {\"messages\": [AIMessage(content=assistant_output)]}\n",
+ " else:\n",
+ " return {\"messages\": [assistant_output]}\n",
+ "\n",
+ "\n",
+ " def _should_continue(state: SimulationState, max_turns: int = 6):\n",
+ " messages = state[\"messages\"]\n",
+ " # TODO support other stop criteria\n",
+ " if len(messages) > max_turns:\n",
+ " return END\n",
+ " elif messages[-1].content.strip() == \"FINISHED\":\n",
+ " return END\n",
+ " else:\n",
+ " return \"assistant\"\n",
+ "\n",
+ "\n",
+ "
\n",
+ "