Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Feb 7, 2024
1 parent 2d4135d commit 11a1744
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from vertexai.language_models import ( # type: ignore
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
)
from vertexai.preview.generative_models import ( # type: ignore
Candidate,
Content,
Expand All @@ -38,12 +46,10 @@
Part,
)
from vertexai.preview.language_models import ( # type: ignore
ChatMessage,
ChatModel,
ChatSession,
CodeChatModel,
CodeChatSession,
InputOutputTextPair,
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)

from langchain_google_vertexai._utils import (
Expand Down Expand Up @@ -322,10 +328,14 @@ def validate_environment(cls, values: Dict) -> Dict:
else:
if is_codey_model(values["model_name"]):
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls.from_pretrained(values["model_name"])
values["client_preview"] = model_cls_preview.from_pretrained(
values["model_name"]
)
return values

def _generate(
Expand Down
15 changes: 12 additions & 3 deletions libs/partners/google-vertexai/langchain_google_vertexai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
Image,
)
from vertexai.preview.language_models import ( # type: ignore[import-untyped]
ChatModel as PreviewChatModel,
)
from vertexai.preview.language_models import (
CodeChatModel as PreviewCodeChatModel,
)
from vertexai.preview.language_models import (
CodeGenerationModel as PreviewCodeGenerationModel,
)
from vertexai.preview.language_models import (
Expand Down Expand Up @@ -250,10 +256,13 @@ def get_num_tokens(self, text: str) -> int:
Returns:
The integer number of tokens in the text.
"""
try:
is_palm_chat_model = isinstance(
self.client_preview, PreviewChatModel
) or isinstance(self.client_preview, PreviewCodeChatModel)
if is_palm_chat_model:
result = self.client_preview.start_chat().count_tokens(text)
else:
result = self.client_preview.count_tokens([text])
except AttributeError:
raise NotImplementedError(f"Not yet implemented for {self.model_name}")

return result.total_tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,13 @@ def test_chat_vertexai_system_message(model_name: Optional[str]) -> None:
assert isinstance(response.content, str)


def test_get_num_tokens_from_messages() -> None:
model = ChatVertexAI(model_name="gemini-pro", temperature=0.0)
@pytest.mark.parametrize("model_name", model_names_to_test)
def test_get_num_tokens_from_messages(model_name: str) -> None:
if model_name:
model = ChatVertexAI(model_name=model_name, temperature=0.0)
else:
model = ChatVertexAI(temperature=0.0)
message = HumanMessage(content="Hello")
token = model.get_num_tokens_from_messages(messages=[message])
assert isinstance(token, int)
assert token == 3

# test for exception with chat-bison
model = ChatVertexAI(model_name="chat-bison", temperature=0.0)
with pytest.raises(Exception):
token = model.get_num_tokens_from_messages(messages=[message])
assert isinstance(token, int)
assert token == 3
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:


def test_tools() -> None:
from langchain.agents import AgentExecutor # type: ignore[import-not-found]
from langchain.agents.format_scratchpad import ( # type: ignore[import-not-found]
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import (
format_to_openai_function_messages,
)
from langchain.chains import LLMMathChain # type: ignore[import-not-found]
from langchain.chains import LLMMathChain

llm = ChatVertexAI(model_name="gemini-pro")
math_chain = LLMMathChain.from_llm(llm=llm)
Expand Down Expand Up @@ -78,10 +78,9 @@ def test_tools() -> None:
| llm_with_tools
| _TestOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore

response = agent_executor.invoke({"input": "What is 6 raised to the 0.43 power?"})
print(response)
assert isinstance(response, dict)
assert response["input"] == "What is 6 raised to the 0.43 power?"

Expand All @@ -106,7 +105,6 @@ def test_stream() -> None:
]
response = list(llm.stream("What is 6 raised to the 0.43 power?", functions=tools))
assert len(response) == 1
# for chunk in response:
assert isinstance(response[0], AIMessageChunk)
assert "function_call" in response[0].additional_kwargs

Expand All @@ -115,7 +113,7 @@ def test_multiple_tools() -> None:
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.chains import LLMMathChain
from langchain.utilities import ( # type: ignore[import-not-found]
from langchain.utilities import (
GoogleSearchAPIWrapper,
)

Expand Down Expand Up @@ -160,7 +158,7 @@ def test_multiple_tools() -> None:
| llm_with_tools
| _TestOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore

question = (
"Who is Leo DiCaprio's girlfriend? What is her "
Expand Down

0 comments on commit 11a1744

Please sign in to comment.