Skip to content

Commit

Permalink
feat(google-common): Grounding with Google Search and Vertex AI Search (
Browse files Browse the repository at this point in the history
#7280)

Co-authored-by: William <[email protected]>
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2024
1 parent b2afdf1 commit 84da383
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 49 deletions.
141 changes: 138 additions & 3 deletions docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"source": [
"# ChatVertexAI\n",
"\n",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.",
"It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-2.0-flash-exp`, etc.\n",
"It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).\n",
"\n",
"\n",
"This will help you getting started with `ChatVertexAI` [chat models](/docs/concepts/chat_models). For detailed documentation of all `ChatVertexAI` features and configurations head to the [API reference](https://api.js.langchain.com/classes/langchain_google_vertexai.ChatVertexAI.html).\n",
Expand Down Expand Up @@ -116,7 +116,7 @@
"// import { ChatVertexAI } from \"@langchain/google-vertexai-web\"\n",
"\n",
"const llm = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" model: \"gemini-2.0-flash-exp\",\n",
" temperature: 0,\n",
" maxRetries: 2,\n",
" // For web, authOptions.credentials\n",
Expand Down Expand Up @@ -191,6 +191,141 @@
"console.log(aiMsg.content)"
]
},
{
"cell_type": "markdown",
"id": "de2480fa",
"metadata": {},
"source": [
"## Tool Calling with Google Search Retrieval\n",
"\n",
"It is possible to call the model with a Google search tool which you can use to [ground](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding) content generation with real-world information and reduce hallucinations.\n",
"\n",
"Grounding is currently not supported by `gemini-2.0-flash-exp`.\n",
"\n",
"You can choose to either ground using Google Search or by using a custom data store. Here are examples of both: "
]
},
{
"cell_type": "markdown",
"id": "fd2091ba",
"metadata": {},
"source": [
"### Google Search Retrieval\n",
"\n",
"Grounding example that uses Google Search:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65d019ee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Boston Celtics won the 2024 NBA Finals, defeating the Dallas Mavericks 4-1 in the series to claim their 18th NBA championship. This victory marked their first title since 2008 and established them as the team with the most NBA championships, surpassing the Los Angeles Lakers' 17 titles.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\"\n",
"\n",
"const searchRetrievalTool = {\n",
" googleSearchRetrieval: {\n",
" dynamicRetrievalConfig: {\n",
" mode: \"MODE_DYNAMIC\", // Use Dynamic Retrieval\n",
" dynamicThreshold: 0.7, // Default for Dynamic Retrieval threshold\n",
" },\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModel = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalTool]);\n",
"\n",
"const searchRetrievalResult = await searchRetrievalModel.invoke(\"Who won the 2024 NBA Finals?\");\n",
"\n",
"console.log(searchRetrievalResult.content);"
]
},
{
"cell_type": "markdown",
"id": "ac3a4a98",
"metadata": {},
"source": [
"### Google Search Retrieval with Data Store\n",
"\n",
"First, set up your data store (this is a schema of an example data store):\n",
"\n",
"| ID | Date | Team 1 | Score | Team 2 |\n",
"|:-------:|:------------:|:-----------:|:--------:|:----------:|\n",
"| 3001 | 2023-09-07 | Argentina | 1 - 0 | Ecuador |\n",
"| 3002 | 2023-09-12 | Venezuela | 1 - 0 | Paraguay |\n",
"| 3003 | 2023-09-12 | Chile | 0 - 0 | Colombia |\n",
"| 3004 | 2023-09-12 | Peru | 0 - 1 | Brazil |\n",
"| 3005 | 2024-10-15 | Argentina | 6 - 0 | Bolivia |\n",
"\n",
"Then, use this data store in the example provided below:\n",
"\n",
"(Note that you have to use your own variables for `projectId` and `datastoreId`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6a539d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Argentina won against Bolivia with a score of 6-0 on October 15, 2024.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\";\n",
"\n",
"const projectId = \"YOUR_PROJECT_ID\";\n",
"const datastoreId = \"YOUR_DATASTORE_ID\";\n",
"\n",
"const searchRetrievalToolWithDataset = {\n",
" retrieval: {\n",
" vertexAiSearch: {\n",
" datastore: `projects/${projectId}/locations/global/collections/default_collection/dataStores/${datastoreId}`,\n",
" },\n",
" disableAttribution: false,\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModelWithDataset = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalToolWithDataset]);\n",
"\n",
"const searchRetrievalModelResult = await searchRetrievalModelWithDataset.invoke(\n",
" \"What is the score of Argentina vs Bolivia football game?\"\n",
");\n",
"\n",
"console.log(searchRetrievalModelResult.content);"
]
},
{
"cell_type": "markdown",
"id": "8d11f2be",
"metadata": {},
"source": [
"You should now get results that are grounded in the data from your provided data store."
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
Expand Down
16 changes: 16 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ export interface GeminiContent {

export interface GeminiTool {
functionDeclarations?: GeminiFunctionDeclaration[];
googleSearchRetrieval?: GoogleSearchRetrieval;
retrieval?: VertexAIRetrieval;
}

export interface GoogleSearchRetrieval {
dynamicRetrievalConfig?: {
mode?: string;
dynamicThreshold?: number;
};
}

export interface VertexAIRetrieval {
vertexAiSearch: {
datastore: string;
};
disableAttribution?: boolean;
}

export interface GeminiFunctionDeclaration {
Expand Down
61 changes: 36 additions & 25 deletions libs/langchain-google-common/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,43 @@ function processToolChoice(
}

export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] {
const geminiTools: GeminiTool[] = [
{
functionDeclarations: [],
},
];
const geminiTools: GeminiTool[] = [];
let functionDeclarationsIndex = -1;
tools.forEach((tool) => {
if (
"functionDeclarations" in tool &&
Array.isArray(tool.functionDeclarations)
) {
const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations;
geminiTools[0].functionDeclarations?.push(...funcs);
} else if (isLangChainTool(tool)) {
const jsonSchema = zodToGeminiParameters(tool.schema);
geminiTools[0].functionDeclarations?.push({
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
});
} else if (isOpenAITool(tool)) {
geminiTools[0].functionDeclarations?.push({
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
});
if ("googleSearchRetrieval" in tool || "retrieval" in tool) {
geminiTools.push(tool);
} else {
if (functionDeclarationsIndex === -1) {
geminiTools.push({
functionDeclarations: [],
});
functionDeclarationsIndex = geminiTools.length - 1;
}
if (
"functionDeclarations" in tool &&
Array.isArray(tool.functionDeclarations)
) {
const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations;
geminiTools[functionDeclarationsIndex].functionDeclarations!.push(
...funcs
);
} else if (isLangChainTool(tool)) {
const jsonSchema = zodToGeminiParameters(tool.schema);
geminiTools[functionDeclarationsIndex].functionDeclarations!.push({
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
});
} else if (isOpenAITool(tool)) {
geminiTools[functionDeclarationsIndex].functionDeclarations!.push({
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
});
} else {
throw new Error(`Received invalid tool: ${JSON.stringify(tool)}`);
}
}
});
return geminiTools;
Expand Down
37 changes: 16 additions & 21 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1015,34 +1015,29 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
};
}

function structuredToolsToGeminiTools(
tools: StructuredToolParams[]
): GeminiTool[] {
return [
{
functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
},
];
}

function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] {
const tools: GoogleAIToolType[] | undefined = parameters?.tools;
if (!tools || tools.length === 0) {
return [];
}

if (tools.every(isLangChainTool)) {
return structuredToolsToGeminiTools(tools);
} else {
if (
tools.length === 1 &&
(!("functionDeclarations" in tools[0]) ||
!tools[0].functionDeclarations?.length)
) {
return [];
}
return tools as GeminiTool[];
// Group all LangChain tools into a single functionDeclarations array
const langChainTools = tools.filter(isLangChainTool);
const otherTools = tools.filter(
(tool) => !isLangChainTool(tool)
) as GeminiTool[];

const result: GeminiTool[] = [...otherTools];

if (langChainTools.length > 0) {
result.push({
functionDeclarations: langChainTools.map(
structuredToolToFunctionDeclaration
),
});
}

return result;
}

function formatToolConfig(
Expand Down
47 changes: 47 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,50 @@ describe("GAuth Anthropic Chat", () => {
expect(toolCalls?.[0].args).toHaveProperty("location");
});
});

describe("GoogleSearchRetrievalTool", () => {
test("Supports GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const result = await model.invoke("Who won the 2024 MLB World Series?");
expect(result.content as string).toContain("Dodgers");
});

test("Can stream GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const stream = await model.stream("Who won the 2024 MLB World Series?");
let finalMsg: AIMessageChunk | undefined;
for await (const msg of stream) {
finalMsg = finalMsg ? concat(finalMsg, msg) : msg;
}
if (!finalMsg) {
throw new Error("finalMsg is undefined");
}
expect(finalMsg.content as string).toContain("Dodgers");
});
});

0 comments on commit 84da383

Please sign in to comment.