Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(google-common): Grounding with Google Search and Vertex AI Search #7280

Merged
merged 13 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,136 @@
"console.log(aiMsg.content)"
]
},
{
"cell_type": "markdown",
"id": "de2480fa",
"metadata": {},
"source": [
"## Tool Calling with Google Search Retrieval\n",
"\n",
"It is possible to call the model with a Google search tool which you can use to [ground](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding) content generation with real-world information and reduce hallucinations. \n",
"\n",
"You can choose to either ground using Google Search or by using a custom data store. Here are examples of both: "
]
},
{
"cell_type": "markdown",
"id": "fd2091ba",
"metadata": {},
"source": [
"### Google Search Retrieval\n",
"\n",
"Grounding example that uses Google Search:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65d019ee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Boston Celtics won the 2024 NBA Finals, defeating the Dallas Mavericks 4-1 in the series to claim their 18th NBA championship. This victory marked their first title since 2008 and established them as the team with the most NBA championships, surpassing the Los Angeles Lakers' 17 titles.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\"\n",
"\n",
"const searchRetrievalTool = {\n",
" googleSearchRetrieval: {\n",
" dynamicRetrievalConfig: {\n",
" mode: \"MODE_DYNAMIC\", // Use Dynamic Retrieval\n",
" dynamicThreshold: 0.7, // Default for Dynamic Retrieval threshold\n",
" },\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModel = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalTool]);\n",
"\n",
"const searchRetrievalResult = await searchRetrievalModel.invoke(\"Who won the 2024 NBA Finals?\");\n",
"\n",
"console.log(searchRetrievalResult.content);"
]
},
{
"cell_type": "markdown",
"id": "ac3a4a98",
"metadata": {},
"source": [
"### Google Search Retrieval with Data Store\n",
"\n",
"First, set up your data store (this is a schema of an example data store):\n",
"\n",
"| ID | Date | Team 1 | Score | Team 2 |\n",
"|:-------:|:------------:|:-----------:|:--------:|:----------:|\n",
"| 3001 | 2023-09-07 | Argentina | 1 - 0 | Ecuador |\n",
"| 3002 | 2023-09-12 | Venezuela | 1 - 0 | Paraguay |\n",
"| 3003 | 2023-09-12 | Chile | 0 - 0 | Colombia |\n",
"| 3004 | 2023-09-12 | Peru | 0 - 1 | Brazil |\n",
"| 3005 | 2024-10-15 | Argentina | 6 - 0 | Bolivia |\n",
"\n",
"Then, use this data store in the example provided below:\n",
"\n",
"(Note that you have to use your own variables for `projectId` and `datastoreId`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6a539d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Argentina won against Bolivia with a score of 6-0 on October 15, 2024.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\";\n",
"\n",
"const searchRetrievalTool = {\n",
" retrieval: {\n",
" vertexAiSearch: {\n",
" datastore: `projects/${projectId}/locations/global/collections/default_collection/dataStores/${datastoreId}`,\n",
" },\n",
" disableAttribution: false,\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModel = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalTool]);\n",
"\n",
"const searchRetrievalModel = await searchRetrievalModel.invoke(\n",
"\"What is the score of Argentina vs Bolivia football game?\"\n",
");\n",
"\n",
"console.log(searchRetrievalModel.content);"
]
},
{
"cell_type": "markdown",
"id": "8d11f2be",
"metadata": {},
"source": [
"You should now get results that are grounded in the data from your provided data store."
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
Expand Down
16 changes: 16 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ export interface GeminiContent {

export interface GeminiTool {
functionDeclarations?: GeminiFunctionDeclaration[];
googleSearchRetrieval?: GoogleSearchRetrieval;
retrieval?: VertexAIRetrieval;
}

export interface GoogleSearchRetrieval {
dynamicRetrievalConfig?: {
mode?: string;
dynamicThreshold?: number;
};
}

export interface VertexAIRetrieval {
vertexAiSearch: {
datastore: string;
};
disableAttribution?: boolean;
}

export interface GeminiFunctionDeclaration {
Expand Down
36 changes: 21 additions & 15 deletions libs/langchain-google-common/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,38 @@ function processToolChoice(
}

export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] {
const geminiTools: GeminiTool[] = [
{
functionDeclarations: [],
},
];
const geminiTools: GeminiTool[] = [];
tools.forEach((tool) => {
if (
"functionDeclarations" in tool &&
Array.isArray(tool.functionDeclarations)
) {
const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations;
geminiTools[0].functionDeclarations?.push(...funcs);
geminiTools.push({ functionDeclarations: [...funcs] });
Copy link
Collaborator

@jacoblee93 jacoblee93 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember this not working due to Google rejecting it - can you add a test using multiple traditional tools? And maybe one using search grounding alongside several tools as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @afirstenberg and @jacoblee93,

Thank you for the review! We’re glad to hear that you appreciate our code contribution. We’ll make sure to add the additional test cases as requested. Please note that libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts already includes test cases for grounding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are some known bugs on Google's side about having multiple tools at once - but Google considers this a bug on their end. It should be possible to submit multiple tools at once and mixing functions and other tools.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this breaks multiple functions, will revert partially

} else if (isLangChainTool(tool)) {
const jsonSchema = zodToGeminiParameters(tool.schema);
geminiTools[0].functionDeclarations?.push({
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
geminiTools.push({
functionDeclarations: [
{
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
},
],
});
} else if (isOpenAITool(tool)) {
geminiTools[0].functionDeclarations?.push({
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
geminiTools.push({
functionDeclarations: [
{
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
},
],
});
} else if ("googleSearchRetrieval" in tool || "retrieval" in tool) {
geminiTools.push(tool);
}
});
return geminiTools;
Expand Down
37 changes: 15 additions & 22 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1015,34 +1015,27 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
};
}

function structuredToolsToGeminiTools(
tools: StructuredToolParams[]
): GeminiTool[] {
return [
{
functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
},
];
}

function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] {
const tools: GoogleAIToolType[] | undefined = parameters?.tools;
if (!tools || tools.length === 0) {
return [];
}

if (tools.every(isLangChainTool)) {
return structuredToolsToGeminiTools(tools);
} else {
if (
tools.length === 1 &&
(!("functionDeclarations" in tools[0]) ||
!tools[0].functionDeclarations?.length)
) {
return [];
return tools.reduce((acc: GeminiTool[], tool) => {
if (isLangChainTool(tool)) {
if (!acc[0]) {
acc[0] = { functionDeclarations: [] };
}
if (!acc[0].functionDeclarations) {
acc[0].functionDeclarations = [];
}
acc[0].functionDeclarations.push(
structuredToolToFunctionDeclaration(tool)
);
} else {
acc.push(tool as GeminiTool);
}
return tools as GeminiTool[];
}
return acc;
}, []) as GeminiTool[];
}

function formatToolConfig(
Expand Down
47 changes: 47 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,50 @@ describe("GAuth Anthropic Chat", () => {
expect(toolCalls?.[0].args).toHaveProperty("location");
});
});

describe("GoogleSearchRetrievalTool", () => {
test("Supports GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const result = await model.invoke("Who won the 2024 MLB World Series?");
expect(result.content as string).toContain("Dodgers");
});

test("Can stream GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const stream = await model.stream("Who won the 2024 MLB World Series?");
let finalMsg: AIMessageChunk | undefined;
for await (const msg of stream) {
finalMsg = finalMsg ? concat(finalMsg, msg) : msg;
}
if (!finalMsg) {
throw new Error("finalMsg is undefined");
}
expect(finalMsg.content as string).toContain("Dodgers");
});
});