From 84da383c747ca345d91a4b70a301510a7c6df79d Mon Sep 17 00:00:00 2001 From: renathossain <114009484+renathossain@users.noreply.github.com> Date: Sat, 14 Dec 2024 00:50:06 -0500 Subject: [PATCH] feat(google-common): Grounding with Google Search and Vertex AI Search (#7280) Co-authored-by: William <73299741+williamc99@users.noreply.github.com> Co-authored-by: jacoblee93 --- .../integrations/chat/google_vertex_ai.ipynb | 141 +++++++++++++++++- libs/langchain-google-common/src/types.ts | 16 ++ .../src/utils/common.ts | 61 ++++---- .../src/utils/gemini.ts | 37 ++--- .../src/tests/chat_models.int.test.ts | 47 ++++++ 5 files changed, 253 insertions(+), 49 deletions(-) diff --git a/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb b/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb index d4de68c3f5e2..8d046defbac1 100644 --- a/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb +++ b/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb @@ -21,8 +21,8 @@ "source": [ "# ChatVertexAI\n", "\n", - "[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.", - "It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).", + "[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-2.0-flash-exp`, etc.\n", + "It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).\n", "\n", "\n", "This will help you getting started with `ChatVertexAI` [chat models](/docs/concepts/chat_models). For detailed documentation of all `ChatVertexAI` features and configurations head to the [API reference](https://api.js.langchain.com/classes/langchain_google_vertexai.ChatVertexAI.html).\n", @@ -116,7 +116,7 @@ "// import { ChatVertexAI } from \"@langchain/google-vertexai-web\"\n", "\n", "const llm = new ChatVertexAI({\n", - " model: \"gemini-1.5-pro\",\n", + " model: \"gemini-2.0-flash-exp\",\n", " temperature: 0,\n", " maxRetries: 2,\n", " // For web, authOptions.credentials\n", @@ -191,6 +191,141 @@ "console.log(aiMsg.content)" ] }, + { + "cell_type": "markdown", + "id": "de2480fa", + "metadata": {}, + "source": [ + "## Tool Calling with Google Search Retrieval\n", + "\n", + "It is possible to call the model with a Google search tool which you can use to [ground](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding) content generation with real-world information and reduce hallucinations.\n", + "\n", + "Grounding is currently not supported by `gemini-2.0-flash-exp`.\n", + "\n", + "You can choose to either ground using Google Search or by using a custom data store. Here are examples of both: " + ] + }, + { + "cell_type": "markdown", + "id": "fd2091ba", + "metadata": {}, + "source": [ + "### Google Search Retrieval\n", + "\n", + "Grounding example that uses Google Search:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65d019ee", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The Boston Celtics won the 2024 NBA Finals, defeating the Dallas Mavericks 4-1 in the series to claim their 18th NBA championship. This victory marked their first title since 2008 and established them as the team with the most NBA championships, surpassing the Los Angeles Lakers' 17 titles.\n", + "\n" + ] + } + ], + "source": [ + "import { ChatVertexAI } from \"@langchain/google-vertexai\"\n", + "\n", + "const searchRetrievalTool = {\n", + " googleSearchRetrieval: {\n", + " dynamicRetrievalConfig: {\n", + " mode: \"MODE_DYNAMIC\", // Use Dynamic Retrieval\n", + " dynamicThreshold: 0.7, // Default for Dynamic Retrieval threshold\n", + " },\n", + " },\n", + "};\n", + "\n", + "const searchRetrievalModel = new ChatVertexAI({\n", + " model: \"gemini-1.5-pro\",\n", + " temperature: 0,\n", + " maxRetries: 0,\n", + "}).bindTools([searchRetrievalTool]);\n", + "\n", + "const searchRetrievalResult = await searchRetrievalModel.invoke(\"Who won the 2024 NBA Finals?\");\n", + "\n", + "console.log(searchRetrievalResult.content);" + ] + }, + { + "cell_type": "markdown", + "id": "ac3a4a98", + "metadata": {}, + "source": [ + "### Google Search Retrieval with Data Store\n", + "\n", + "First, set up your data store (this is a schema of an example data store):\n", + "\n", + "| ID | Date | Team 1 | Score | Team 2 |\n", + "|:-------:|:------------:|:-----------:|:--------:|:----------:|\n", + "| 3001 | 2023-09-07 | Argentina | 1 - 0 | Ecuador |\n", + "| 3002 | 2023-09-12 | Venezuela | 1 - 0 | Paraguay |\n", + "| 3003 | 2023-09-12 | Chile | 0 - 0 | Colombia |\n", + "| 3004 | 2023-09-12 | Peru | 0 - 1 | Brazil |\n", + "| 3005 | 2024-10-15 | Argentina | 6 - 0 | Bolivia |\n", + "\n", + "Then, use this data store in the example provided below:\n", + "\n", + "(Note that you have to use your own variables for `projectId` and `datastoreId`)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6a539d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Argentina won against Bolivia with a score of 6-0 on October 15, 2024.\n", + "\n" + ] + } + ], + "source": [ + "import { ChatVertexAI } from \"@langchain/google-vertexai\";\n", + "\n", + "const projectId = \"YOUR_PROJECT_ID\";\n", + "const datastoreId = \"YOUR_DATASTORE_ID\";\n", + "\n", + "const searchRetrievalToolWithDataset = {\n", + " retrieval: {\n", + " vertexAiSearch: {\n", + " datastore: `projects/${projectId}/locations/global/collections/default_collection/dataStores/${datastoreId}`,\n", + " },\n", + " disableAttribution: false,\n", + " },\n", + "};\n", + "\n", + "const searchRetrievalModelWithDataset = new ChatVertexAI({\n", + " model: \"gemini-1.5-pro\",\n", + " temperature: 0,\n", + " maxRetries: 0,\n", + "}).bindTools([searchRetrievalToolWithDataset]);\n", + "\n", + "const searchRetrievalModelResult = await searchRetrievalModelWithDataset.invoke(\n", + " \"What is the score of Argentina vs Bolivia football game?\"\n", + ");\n", + "\n", + "console.log(searchRetrievalModelResult.content);" + ] + }, + { + "cell_type": "markdown", + "id": "8d11f2be", + "metadata": {}, + "source": [ + "You should now get results that are grounded in the data from your provided data store." + ] + }, { "cell_type": "markdown", "id": "18e2bfc0-7e78-4528-a73f-499ac150dca8", diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index bb49cf2edd4f..b88b3e01d090 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -309,6 +309,22 @@ export interface GeminiContent { export interface GeminiTool { functionDeclarations?: GeminiFunctionDeclaration[]; + googleSearchRetrieval?: GoogleSearchRetrieval; + retrieval?: VertexAIRetrieval; +} + +export interface GoogleSearchRetrieval { + dynamicRetrievalConfig?: { + mode?: string; + dynamicThreshold?: number; + }; +} + +export interface VertexAIRetrieval { + vertexAiSearch: { + datastore: string; + }; + disableAttribution?: boolean; } export interface GeminiFunctionDeclaration { diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index bf8ddb228382..b40ce25fe3fc 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -62,32 +62,43 @@ function processToolChoice( } export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] { - const geminiTools: GeminiTool[] = [ - { - functionDeclarations: [], - }, - ]; + const geminiTools: GeminiTool[] = []; + let functionDeclarationsIndex = -1; tools.forEach((tool) => { - if ( - "functionDeclarations" in tool && - Array.isArray(tool.functionDeclarations) - ) { - const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations; - geminiTools[0].functionDeclarations?.push(...funcs); - } else if (isLangChainTool(tool)) { - const jsonSchema = zodToGeminiParameters(tool.schema); - geminiTools[0].functionDeclarations?.push({ - name: tool.name, - description: tool.description ?? `A function available to call.`, - parameters: jsonSchema as GeminiFunctionSchema, - }); - } else if (isOpenAITool(tool)) { - geminiTools[0].functionDeclarations?.push({ - name: tool.function.name, - description: - tool.function.description ?? `A function available to call.`, - parameters: jsonSchemaToGeminiParameters(tool.function.parameters), - }); + if ("googleSearchRetrieval" in tool || "retrieval" in tool) { + geminiTools.push(tool); + } else { + if (functionDeclarationsIndex === -1) { + geminiTools.push({ + functionDeclarations: [], + }); + functionDeclarationsIndex = geminiTools.length - 1; + } + if ( + "functionDeclarations" in tool && + Array.isArray(tool.functionDeclarations) + ) { + const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations; + geminiTools[functionDeclarationsIndex].functionDeclarations!.push( + ...funcs + ); + } else if (isLangChainTool(tool)) { + const jsonSchema = zodToGeminiParameters(tool.schema); + geminiTools[functionDeclarationsIndex].functionDeclarations!.push({ + name: tool.name, + description: tool.description ?? `A function available to call.`, + parameters: jsonSchema as GeminiFunctionSchema, + }); + } else if (isOpenAITool(tool)) { + geminiTools[functionDeclarationsIndex].functionDeclarations!.push({ + name: tool.function.name, + description: + tool.function.description ?? `A function available to call.`, + parameters: jsonSchemaToGeminiParameters(tool.function.parameters), + }); + } else { + throw new Error(`Received invalid tool: ${JSON.stringify(tool)}`); + } } }); return geminiTools; diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index e6d0f6e96001..48bd41fb5c2f 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -1015,34 +1015,29 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { }; } - function structuredToolsToGeminiTools( - tools: StructuredToolParams[] - ): GeminiTool[] { - return [ - { - functionDeclarations: tools.map(structuredToolToFunctionDeclaration), - }, - ]; - } - function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] { const tools: GoogleAIToolType[] | undefined = parameters?.tools; if (!tools || tools.length === 0) { return []; } - if (tools.every(isLangChainTool)) { - return structuredToolsToGeminiTools(tools); - } else { - if ( - tools.length === 1 && - (!("functionDeclarations" in tools[0]) || - !tools[0].functionDeclarations?.length) - ) { - return []; - } - return tools as GeminiTool[]; + // Group all LangChain tools into a single functionDeclarations array + const langChainTools = tools.filter(isLangChainTool); + const otherTools = tools.filter( + (tool) => !isLangChainTool(tool) + ) as GeminiTool[]; + + const result: GeminiTool[] = [...otherTools]; + + if (langChainTools.length > 0) { + result.push({ + functionDeclarations: langChainTools.map( + structuredToolToFunctionDeclaration + ), + }); } + + return result; } function formatToolConfig( diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index a3b8bbe4b2d8..ddcdf579a394 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -617,3 +617,50 @@ describe("GAuth Anthropic Chat", () => { expect(toolCalls?.[0].args).toHaveProperty("location"); }); }); + +describe("GoogleSearchRetrievalTool", () => { + test("Supports GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + model: "gemini-1.5-pro", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const result = await model.invoke("Who won the 2024 MLB World Series?"); + expect(result.content as string).toContain("Dodgers"); + }); + + test("Can stream GoogleSearchRetrievalTool", async () => { + const searchRetrievalTool = { + googleSearchRetrieval: { + dynamicRetrievalConfig: { + mode: "MODE_DYNAMIC", + dynamicThreshold: 0.7, // default is 0.7 + }, + }, + }; + const model = new ChatVertexAI({ + model: "gemini-1.5-pro", + temperature: 0, + maxRetries: 0, + }).bindTools([searchRetrievalTool]); + + const stream = await model.stream("Who won the 2024 MLB World Series?"); + let finalMsg: AIMessageChunk | undefined; + for await (const msg of stream) { + finalMsg = finalMsg ? concat(finalMsg, msg) : msg; + } + if (!finalMsg) { + throw new Error("finalMsg is undefined"); + } + expect(finalMsg.content as string).toContain("Dodgers"); + }); +});