Skip to content

Commit

Permalink
feat: Switch the llm to ChatVertexAI (#486)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuan <[email protected]>
  • Loading branch information
anubhav756 and Yuan325 authored Dec 3, 2024
1 parent aab0b9c commit 479c5e5
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import StructuredTool
from langchain_google_vertexai import VertexAI
from langchain_google_vertexai import ChatVertexAI
from pytz import timezone

from ..orchestrator import BaseOrchestrator, classproperty
Expand Down Expand Up @@ -69,7 +69,8 @@ def initialize_agent(
prompt: ChatPromptTemplate,
model: str,
) -> "UserAgent":
llm = VertexAI(max_output_tokens=512, model_name=model, temperature=0.0)
# TODO: Use .bind_tools(tools) to bind the tools with the LLM.
llm = ChatVertexAI(max_output_tokens=512, model_name=model, temperature=0.0)
memory = ConversationBufferMemory(
chat_memory=ChatMessageHistory(messages=history),
memory_key="chat_history",
Expand Down
6 changes: 4 additions & 2 deletions llm_demo/orchestrator/langgraph/langgraph_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def user_session_reset(self, session: dict[str, Any], uuid: str):
# Reset graph checkpointer
checkpoint = empty_checkpoint()
config = self.get_config(uuid)
self._checkpointer.put(config=config, checkpoint=checkpoint, metadata={})
self._checkpointer.put(
config=config, checkpoint=checkpoint, metadata={}, new_versions={}
)

# Update state with message history
self._langgraph_app.update_state(config, {"messages": history})
Expand Down Expand Up @@ -242,7 +244,7 @@ def get_base_history(self, session: dict[str, Any]):
return BASE_HISTORY

def get_config(self, uuid: str):
return {"configurable": {"thread_id": uuid}}
return {"configurable": {"thread_id": uuid, "checkpoint_ns": ""}}

async def user_session_signout(self, uuid: str):
checkpoint = empty_checkpoint()
Expand Down
54 changes: 32 additions & 22 deletions llm_demo/orchestrator/langgraph/react_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_google_vertexai import VertexAI
from langchain_google_vertexai import ChatVertexAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
Expand Down Expand Up @@ -85,7 +85,8 @@ async def create_graph(
tool_node = ToolNode(tools)

# model node
model = VertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0)
# TODO: Use .bind_tools(tools) to bind the tools with the LLM.
model = ChatVertexAI(max_output_tokens=512, model_name=model_name, temperature=0.0)

# Add the prompt to the model to create a model runnable
model_runnable = prompt | model
Expand All @@ -97,35 +98,44 @@ async def acall_model(state: UserState, config: RunnableConfig):
"""
messages = state["messages"]
res = await model_runnable.ainvoke({"messages": messages}, config)
response = res.replace("```json", "").replace("```", "")
try:
json_response = json.loads(response)
action = json_response.get("action")
action_input = json_response.get("action_input")
if action == "Final Answer":
new_message = AIMessage(content=action_input)
else:
new_message = AIMessage(
content="suggesting a tool call",
tool_calls=[
ToolCall(id=str(uuid.uuid4()), name=action, args=action_input)
],

# TODO: Remove the temporary fix of parsing LLM response and invoking
# tools until we use bind_tools API and have automatic response parsing
# and tool calling. (see
# https://langchain-ai.github.io/langgraph/#example)
if "```json" in res.content:
try:
response = str(res.content).replace("```json", "").replace("```", "")
json_response = json.loads(response)
action = json_response.get("action")
action_input = json_response.get("action_input")
if action == "Final Answer":
res = AIMessage(content=action_input)
else:
res = AIMessage(
content="suggesting a tool call",
tool_calls=[
ToolCall(
id=str(uuid.uuid4()), name=action, args=action_input
)
],
)
except Exception as e:
json_response = response
res = AIMessage(
content="Sorry, failed to generate the right format for response"
)
except Exception as e:
json_response = response
new_message = AIMessage(
content="Sorry, failed to generate the right format for response"
)

# if model exceed the number of steps and has not yet return a final answer
if state["is_last_step"] and hasattr(new_message, "tool_calls"):
if state["is_last_step"] and hasattr(res, "tool_calls"):
return {
"messages": [
AIMessage(
content="Sorry, need more steps to process this request.",
)
]
}
return {"messages": [new_message]}
return {"messages": [res]}

def agent_should_continue(
state: UserState,
Expand Down
3 changes: 2 additions & 1 deletion llm_demo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ itsdangerous==2.2.0
jinja2==3.1.4
langchain-community==0.3.2
langchain==0.3.7
langchain-google-vertexai==2.0.4
langchain-google-vertexai==2.0.7
markdown==3.7
types-Markdown==3.7.0.20240822
uvicorn[standard]==0.31.0
Expand All @@ -16,3 +16,4 @@ langgraph==0.2.48
httpx==0.27.2
pandas-stubs==2.2.2.240807
pandas==2.2.3
pydantic==2.9.0
2 changes: 1 addition & 1 deletion retrieval_service/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ google-cloud-aiplatform==1.72.0
google-cloud-spanner==3.49.1
langchain-core==0.3.18
pgvector==0.3.5
pydantic==2.9.2
pydantic==2.9.0
uvicorn[standard]==0.31.0
cloud-sql-python-connector==1.12.1
google-cloud-alloydb-connector[asyncpg]==1.4.0
Expand Down

0 comments on commit 479c5e5

Please sign in to comment.