diff --git a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py index f1b601ad..f1ee7679 100644 --- a/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py +++ b/llm_demo/orchestrator/langchain_tools/langchain_tools_orchestrator.py @@ -28,7 +28,7 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import StructuredTool -from langchain_google_vertexai import VertexAI +from langchain_google_vertexai import ChatVertexAI from pytz import timezone from ..orchestrator import BaseOrchestrator, classproperty @@ -69,7 +69,8 @@ def initialize_agent( prompt: ChatPromptTemplate, model: str, ) -> "UserAgent": - llm = VertexAI(max_output_tokens=512, model_name=model, temperature=0.0) + # TODO: Use .bind_tools(tools) to bind the tools with the LLM. + llm = ChatVertexAI(max_output_tokens=512, model_name=model, temperature=0.0) memory = ConversationBufferMemory( chat_memory=ChatMessageHistory(messages=history), memory_key="chat_history", diff --git a/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py b/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py index 6a2a71a7..4b2d7a19 100644 --- a/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py +++ b/llm_demo/orchestrator/langgraph/langgraph_orchestrator.py @@ -160,7 +160,9 @@ def user_session_reset(self, session: dict[str, Any], uuid: str): # Reset graph checkpointer checkpoint = empty_checkpoint() config = self.get_config(uuid) - self._checkpointer.put(config=config, checkpoint=checkpoint, metadata={}) + self._checkpointer.put( + config=config, checkpoint=checkpoint, metadata={}, new_versions={} + ) # Update state with message history self._langgraph_app.update_state(config, {"messages": history}) @@ -242,7 +244,7 @@ def get_base_history(self, session: dict[str, Any]): return BASE_HISTORY def get_config(self, uuid: str): - return {"configurable": {"thread_id": uuid}} + return {"configurable": {"thread_id": uuid, "checkpoint_ns": ""}} async def user_session_signout(self, uuid: str): checkpoint = empty_checkpoint() diff --git a/llm_demo/orchestrator/langgraph/react_graph.py b/llm_demo/orchestrator/langgraph/react_graph.py index cd3dc44f..b05134f1 100644 --- a/llm_demo/orchestrator/langgraph/react_graph.py +++ b/llm_demo/orchestrator/langgraph/react_graph.py @@ -26,7 +26,7 @@ ) from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.runnables import RunnableConfig, RunnableLambda -from langchain_google_vertexai import VertexAI +from langchain_google_vertexai import ChatVertexAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, StateGraph from langgraph.graph.message import add_messages @@ -85,7 +85,8 @@ async def create_graph( tool_node = ToolNode(tools) # model node - model = VertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0) + # TODO: Use .bind_tools(tools) to bind the tools with the LLM. + model = ChatVertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0) # Add the prompt to the model to create a model runnable model_runnable = prompt | model @@ -97,27 +98,36 @@ async def acall_model(state: UserState, config: RunnableConfig): """ messages = state["messages"] res = await model_runnable.ainvoke({"messages": messages}, config) - response = res.replace("```json", "").replace("```", "") - try: - json_response = json.loads(response) - action = json_response.get("action") - action_input = json_response.get("action_input") - if action == "Final Answer": - new_message = AIMessage(content=action_input) - else: - new_message = AIMessage( - content="suggesting a tool call", - tool_calls=[ - ToolCall(id=str(uuid.uuid4()), name=action, args=action_input) - ], + + # TODO: Remove the temporary fix of parsing LLM response and invoking + # tools until we use bind_tools API and have automatic response parsing + # and tool calling. (see + # https://langchain-ai.github.io/langgraph/#example) + if "```json" in res.content: + try: + response = str(res.content).replace("```json", "").replace("```", "") + json_response = json.loads(response) + action = json_response.get("action") + action_input = json_response.get("action_input") + if action == "Final Answer": + res = AIMessage(content=action_input) + else: + res = AIMessage( + content="suggesting a tool call", + tool_calls=[ + ToolCall( + id=str(uuid.uuid4()), name=action, args=action_input + ) + ], + ) + except Exception as e: + json_response = response + res = AIMessage( + content="Sorry, failed to generate the right format for response" ) - except Exception as e: - json_response = response - new_message = AIMessage( - content="Sorry, failed to generate the right format for response" - ) + # if model exceed the number of steps and has not yet return a final answer - if state["is_last_step"] and hasattr(new_message, "tool_calls"): + if state["is_last_step"] and hasattr(res, "tool_calls"): return { "messages": [ AIMessage( @@ -125,7 +135,7 @@ async def acall_model(state: UserState, config: RunnableConfig): ) ] } - return {"messages": [new_message]} + return {"messages": [res]} def agent_should_continue( state: UserState, diff --git a/llm_demo/requirements.txt b/llm_demo/requirements.txt index f0c63c2b..72550b2b 100644 --- a/llm_demo/requirements.txt +++ b/llm_demo/requirements.txt @@ -5,7 +5,7 @@ itsdangerous==2.2.0 jinja2==3.1.4 langchain-community==0.3.2 langchain==0.3.7 -langchain-google-vertexai==2.0.4 +langchain-google-vertexai==2.0.7 markdown==3.7 types-Markdown==3.7.0.20240822 uvicorn[standard]==0.31.0 @@ -16,3 +16,4 @@ langgraph==0.2.48 httpx==0.27.2 pandas-stubs==2.2.2.240807 pandas==2.2.3 +pydantic==2.9.0 \ No newline at end of file diff --git a/retrieval_service/requirements.txt b/retrieval_service/requirements.txt index 33a7e640..07b163d0 100644 --- a/retrieval_service/requirements.txt +++ b/retrieval_service/requirements.txt @@ -6,7 +6,7 @@ google-cloud-aiplatform==1.72.0 google-cloud-spanner==3.49.1 langchain-core==0.3.18 pgvector==0.3.5 -pydantic==2.9.2 +pydantic==2.9.0 uvicorn[standard]==0.31.0 cloud-sql-python-connector==1.12.1 google-cloud-alloydb-connector[asyncpg]==1.4.0