Skip to content

Commit

Permalink
refactored tools (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 authored May 17, 2024
1 parent e872a4c commit 421d9ca
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 396 deletions.
39 changes: 25 additions & 14 deletions libs/vertexai/langchain_google_vertexai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from google.cloud.aiplatform_v1beta1.types import (
FunctionCallingConfig,
FunctionDeclaration,
Schema,
ToolConfig,
Type,
)

from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import (
PydanticFunctionsOutputParser,
ToolConfig,
)
from langchain_google_vertexai.gemma import (
GemmaChatLocalHF,
Expand Down Expand Up @@ -33,28 +40,32 @@

__all__ = [
"ChatVertexAI",
"GemmaVertexAIModelGarden",
"GemmaChatVertexAIModelGarden",
"GemmaLocalKaggle",
"create_structured_runnable",
"DataStoreDocumentStorage",
"FunctionCallingConfig",
"FunctionDeclaration",
"GCSDocumentStorage",
"GemmaChatLocalHF",
"GemmaChatLocalKaggle",
"GemmaChatVertexAIModelGarden",
"GemmaLocalHF",
"GemmaChatLocalHF",
"VertexAIEmbeddings",
"VertexAI",
"VertexAIModelGarden",
"GemmaLocalKaggle",
"GemmaVertexAIModelGarden",
"HarmBlockThreshold",
"HarmCategory",
"PydanticFunctionsOutputParser",
"Schema",
"ToolConfig",
"create_structured_runnable",
"Type",
"VectorSearchVectorStore",
"VectorSearchVectorStoreDatastore",
"VectorSearchVectorStoreGCS",
"VertexAI",
"VertexAIEmbeddings",
"VertexAIImageCaptioning",
"VertexAIImageCaptioningChat",
"VertexAIImageEditorChat",
"VertexAIImageGeneratorChat",
"VertexAIModelGarden",
"VertexAIVisualQnAChat",
"DataStoreDocumentStorage",
"GCSDocumentStorage",
"VectorSearchVectorStore",
"VectorSearchVectorStoreDatastore",
"VectorSearchVectorStoreGCS",
]
6 changes: 2 additions & 4 deletions libs/vertexai/langchain_google_vertexai/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Union,
)

import google.cloud.aiplatform_v1beta1.types as gapic
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseOutputParser,
Expand All @@ -14,9 +15,6 @@
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from vertexai.generative_models._generative_models import ( # type: ignore
ToolConfig,
)

from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser

Expand Down Expand Up @@ -59,7 +57,7 @@ def _create_structured_runnable_extra_step(
functions=functions,
tool_config={
"function_calling_config": {
"mode": ToolConfig.FunctionCallingConfig.Mode.ANY,
"mode": gapic.FunctionCallingConfig.Mode.ANY,
"allowed_function_names": names,
}
},
Expand Down
73 changes: 17 additions & 56 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@
_ToolConfigDict,
_tool_choice_to_tool_config,
_ToolChoiceType,
_FunctionDeclarationLike,
_VertexToolDict,
_format_to_vertex_tool,
_format_functions_to_vertex_tool_dict,
_ToolsType,
_format_to_gapic_tool,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -705,8 +703,8 @@ def _prepare_request_gemini(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
stream: bool = False,
tools: Optional[List[Union[_VertexToolDict, VertexTool]]] = None,
functions: Optional[List[_FunctionDeclarationLike]] = None,
tools: Optional[_ToolsType] = None,
functions: Optional[_ToolsType] = None,
tool_config: Optional[Union[_ToolConfigDict, ToolConfig]] = None,
safety_settings: Optional[SafetySettingsType] = None,
**kwargs,
Expand Down Expand Up @@ -778,13 +776,18 @@ def get_num_tokens(self, text: str) -> int:

def _tools_gemini(
self,
tools: Optional[List[Union[_VertexToolDict, VertexTool, GapicTool]]] = None,
functions: Optional[List[_FunctionDeclarationLike]] = None,
tools: Optional[_ToolsType] = None,
functions: Optional[_ToolsType] = None,
) -> Optional[Sequence[GapicTool]]:
if tools and functions:
logger.warning(
"Binding tools and functions together is not supported.",
"Only tools will be used",
)
if tools:
return [_format_to_vertex_tool(tool) for tool in tools]
return [_format_to_gapic_tool(tools)]
if functions:
return [_format_to_vertex_tool(functions)]
return [_format_to_gapic_tool(functions)]
return None

def _tool_config_gemini(
Expand Down Expand Up @@ -1109,7 +1112,7 @@ class AnswerWithJustification(BaseModel):

def bind_tools(
self,
tools: Sequence[Union[_FunctionDeclarationLike, VertexTool]],
tools: _ToolsType,
tool_config: Optional[_ToolConfigDict] = None,
*,
tool_choice: Optional[Union[_ToolChoiceType, bool]] = None,
Expand All @@ -1132,25 +1135,12 @@ def bind_tools(
"Must specify at most one of tool_choice and tool_config, received "
f"both:\n\n{tool_choice=}\n\n{tool_config=}"
)
vertexai_tools: List[_VertexToolDict] = []
vertexai_functions = []
for schema in tools:
if isinstance(schema, VertexTool):
vertexai_tools.append(
{"function_declarations": schema.to_dict()["function_declarations"]}
)
elif isinstance(schema, dict) and "function_declarations" in schema:
vertexai_tools.append(cast(_VertexToolDict, schema))
else:
vertexai_functions.append(schema)
vertexai_tools.append(_format_functions_to_vertex_tool_dict(vertexai_functions))
vertexai_tool = _format_to_gapic_tool(tools)
if tool_choice:
all_names = [
f["name"] for vt in vertexai_tools for f in vt["function_declarations"]
]
all_names = [f["name"] for f in vertexai_tool.function_declarations]
tool_config = _tool_choice_to_tool_config(tool_choice, all_names)
# Bind dicts for easier serialization/deserialization.
return self.bind(tools=vertexai_tools, tool_config=tool_config, **kwargs)
return self.bind(tools=[vertexai_tool], tool_config=tool_config, **kwargs)

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand All @@ -1162,35 +1152,6 @@ def _start_chat(
else:
return self.client.start_chat(message_history=history.history, **kwargs)

def _gemini_params(
self,
*,
stop: Optional[List[str]] = None,
stream: bool = False,
tools: Optional[List[Union[_VertexToolDict, VertexTool]]] = None,
functions: Optional[List[_FunctionDeclarationLike]] = None,
tool_config: Optional[Union[_ToolConfigDict, ToolConfig]] = None,
safety_settings: Optional[SafetySettingsType] = None,
**kwargs: Any,
) -> _GeminiGenerateContentKwargs:
generation_config = self._prepare_params(stop=stop, stream=stream, **kwargs)
if tools:
tools = [_format_to_vertex_tool(tool) for tool in tools]
elif functions:
tools = [_format_to_vertex_tool(functions)]
else:
pass

if tool_config and not isinstance(tool_config, ToolConfig):
tool_config = _format_tool_config(cast(_ToolConfigDict, tool_config))

return _GeminiGenerateContentKwargs(
generation_config=generation_config,
tools=tools,
tool_config=tool_config,
safety_settings=safety_settings,
)

def _gemini_response_to_chat_result(
self, response: GenerationResponse
) -> ChatResult:
Expand Down
Loading

0 comments on commit 421d9ca

Please sign in to comment.