Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(google-common): Grounding with Google Search and Vertex AI Search #7280

Merged
merged 13 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 138 additions & 3 deletions docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"source": [
"# ChatVertexAI\n",
"\n",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.",
"It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).",
"[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-2.0-flash-exp`, etc.\n",
"It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).\n",
"\n",
"\n",
"This will help you getting started with `ChatVertexAI` [chat models](/docs/concepts/chat_models). For detailed documentation of all `ChatVertexAI` features and configurations head to the [API reference](https://api.js.langchain.com/classes/langchain_google_vertexai.ChatVertexAI.html).\n",
Expand Down Expand Up @@ -116,7 +116,7 @@
"// import { ChatVertexAI } from \"@langchain/google-vertexai-web\"\n",
"\n",
"const llm = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" model: \"gemini-2.0-flash-exp\",\n",
" temperature: 0,\n",
" maxRetries: 2,\n",
" // For web, authOptions.credentials\n",
Expand Down Expand Up @@ -191,6 +191,141 @@
"console.log(aiMsg.content)"
]
},
{
"cell_type": "markdown",
"id": "de2480fa",
"metadata": {},
"source": [
"## Tool Calling with Google Search Retrieval\n",
"\n",
"It is possible to call the model with a Google search tool which you can use to [ground](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding) content generation with real-world information and reduce hallucinations.\n",
"\n",
"Grounding is currently not supported by `gemini-2.0-flash-exp`.\n",
"\n",
"You can choose to either ground using Google Search or by using a custom data store. Here are examples of both: "
]
},
{
"cell_type": "markdown",
"id": "fd2091ba",
"metadata": {},
"source": [
"### Google Search Retrieval\n",
"\n",
"Grounding example that uses Google Search:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65d019ee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Boston Celtics won the 2024 NBA Finals, defeating the Dallas Mavericks 4-1 in the series to claim their 18th NBA championship. This victory marked their first title since 2008 and established them as the team with the most NBA championships, surpassing the Los Angeles Lakers' 17 titles.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\"\n",
"\n",
"const searchRetrievalTool = {\n",
" googleSearchRetrieval: {\n",
" dynamicRetrievalConfig: {\n",
" mode: \"MODE_DYNAMIC\", // Use Dynamic Retrieval\n",
" dynamicThreshold: 0.7, // Default for Dynamic Retrieval threshold\n",
" },\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModel = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalTool]);\n",
"\n",
"const searchRetrievalResult = await searchRetrievalModel.invoke(\"Who won the 2024 NBA Finals?\");\n",
"\n",
"console.log(searchRetrievalResult.content);"
]
},
{
"cell_type": "markdown",
"id": "ac3a4a98",
"metadata": {},
"source": [
"### Google Search Retrieval with Data Store\n",
"\n",
"First, set up your data store (this is a schema of an example data store):\n",
"\n",
"| ID | Date | Team 1 | Score | Team 2 |\n",
"|:-------:|:------------:|:-----------:|:--------:|:----------:|\n",
"| 3001 | 2023-09-07 | Argentina | 1 - 0 | Ecuador |\n",
"| 3002 | 2023-09-12 | Venezuela | 1 - 0 | Paraguay |\n",
"| 3003 | 2023-09-12 | Chile | 0 - 0 | Colombia |\n",
"| 3004 | 2023-09-12 | Peru | 0 - 1 | Brazil |\n",
"| 3005 | 2024-10-15 | Argentina | 6 - 0 | Bolivia |\n",
"\n",
"Then, use this data store in the example provided below:\n",
"\n",
"(Note that you have to use your own variables for `projectId` and `datastoreId`)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6a539d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Argentina won against Bolivia with a score of 6-0 on October 15, 2024.\n",
"\n"
]
}
],
"source": [
"import { ChatVertexAI } from \"@langchain/google-vertexai\";\n",
"\n",
"const projectId = \"YOUR_PROJECT_ID\";\n",
"const datastoreId = \"YOUR_DATASTORE_ID\";\n",
"\n",
"const searchRetrievalToolWithDataset = {\n",
" retrieval: {\n",
" vertexAiSearch: {\n",
" datastore: `projects/${projectId}/locations/global/collections/default_collection/dataStores/${datastoreId}`,\n",
" },\n",
" disableAttribution: false,\n",
" },\n",
"};\n",
"\n",
"const searchRetrievalModelWithDataset = new ChatVertexAI({\n",
" model: \"gemini-1.5-pro\",\n",
" temperature: 0,\n",
" maxRetries: 0,\n",
"}).bindTools([searchRetrievalToolWithDataset]);\n",
"\n",
"const searchRetrievalModelResult = await searchRetrievalModelWithDataset.invoke(\n",
" \"What is the score of Argentina vs Bolivia football game?\"\n",
");\n",
"\n",
"console.log(searchRetrievalModelResult.content);"
]
},
{
"cell_type": "markdown",
"id": "8d11f2be",
"metadata": {},
"source": [
"You should now get results that are grounded in the data from your provided data store."
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
Expand Down
16 changes: 16 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ export interface GeminiContent {

export interface GeminiTool {
functionDeclarations?: GeminiFunctionDeclaration[];
googleSearchRetrieval?: GoogleSearchRetrieval;
retrieval?: VertexAIRetrieval;
}

export interface GoogleSearchRetrieval {
dynamicRetrievalConfig?: {
mode?: string;
dynamicThreshold?: number;
};
}

export interface VertexAIRetrieval {
vertexAiSearch: {
datastore: string;
};
disableAttribution?: boolean;
}

export interface GeminiFunctionDeclaration {
Expand Down
61 changes: 36 additions & 25 deletions libs/langchain-google-common/src/utils/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,43 @@ function processToolChoice(
}

export function convertToGeminiTools(tools: GoogleAIToolType[]): GeminiTool[] {
const geminiTools: GeminiTool[] = [
{
functionDeclarations: [],
},
];
const geminiTools: GeminiTool[] = [];
let functionDeclarationsIndex = -1;
tools.forEach((tool) => {
if (
"functionDeclarations" in tool &&
Array.isArray(tool.functionDeclarations)
) {
const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations;
geminiTools[0].functionDeclarations?.push(...funcs);
} else if (isLangChainTool(tool)) {
const jsonSchema = zodToGeminiParameters(tool.schema);
geminiTools[0].functionDeclarations?.push({
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
});
} else if (isOpenAITool(tool)) {
geminiTools[0].functionDeclarations?.push({
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
});
if ("googleSearchRetrieval" in tool || "retrieval" in tool) {
geminiTools.push(tool);
} else {
if (functionDeclarationsIndex === -1) {
geminiTools.push({
functionDeclarations: [],
});
functionDeclarationsIndex = geminiTools.length - 1;
}
if (
"functionDeclarations" in tool &&
Array.isArray(tool.functionDeclarations)
) {
const funcs: GeminiFunctionDeclaration[] = tool.functionDeclarations;
geminiTools[functionDeclarationsIndex].functionDeclarations!.push(
...funcs
);
} else if (isLangChainTool(tool)) {
const jsonSchema = zodToGeminiParameters(tool.schema);
geminiTools[functionDeclarationsIndex].functionDeclarations!.push({
name: tool.name,
description: tool.description ?? `A function available to call.`,
parameters: jsonSchema as GeminiFunctionSchema,
});
} else if (isOpenAITool(tool)) {
geminiTools[functionDeclarationsIndex].functionDeclarations!.push({
name: tool.function.name,
description:
tool.function.description ?? `A function available to call.`,
parameters: jsonSchemaToGeminiParameters(tool.function.parameters),
});
} else {
throw new Error(`Received invalid tool: ${JSON.stringify(tool)}`);
}
}
});
return geminiTools;
Expand Down
37 changes: 16 additions & 21 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1015,34 +1015,29 @@ export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI {
};
}

function structuredToolsToGeminiTools(
tools: StructuredToolParams[]
): GeminiTool[] {
return [
{
functionDeclarations: tools.map(structuredToolToFunctionDeclaration),
},
];
}

function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] {
const tools: GoogleAIToolType[] | undefined = parameters?.tools;
if (!tools || tools.length === 0) {
return [];
}

if (tools.every(isLangChainTool)) {
return structuredToolsToGeminiTools(tools);
} else {
if (
tools.length === 1 &&
(!("functionDeclarations" in tools[0]) ||
!tools[0].functionDeclarations?.length)
) {
return [];
}
return tools as GeminiTool[];
// Group all LangChain tools into a single functionDeclarations array
const langChainTools = tools.filter(isLangChainTool);
const otherTools = tools.filter(
(tool) => !isLangChainTool(tool)
) as GeminiTool[];

const result: GeminiTool[] = [...otherTools];

if (langChainTools.length > 0) {
result.push({
functionDeclarations: langChainTools.map(
structuredToolToFunctionDeclaration
),
});
}

return result;
}

function formatToolConfig(
Expand Down
47 changes: 47 additions & 0 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,50 @@ describe("GAuth Anthropic Chat", () => {
expect(toolCalls?.[0].args).toHaveProperty("location");
});
});

describe("GoogleSearchRetrievalTool", () => {
test("Supports GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const result = await model.invoke("Who won the 2024 MLB World Series?");
expect(result.content as string).toContain("Dodgers");
});

test("Can stream GoogleSearchRetrievalTool", async () => {
const searchRetrievalTool = {
googleSearchRetrieval: {
dynamicRetrievalConfig: {
mode: "MODE_DYNAMIC",
dynamicThreshold: 0.7, // default is 0.7
},
},
};
const model = new ChatVertexAI({
model: "gemini-1.5-pro",
temperature: 0,
maxRetries: 0,
}).bindTools([searchRetrievalTool]);

const stream = await model.stream("Who won the 2024 MLB World Series?");
let finalMsg: AIMessageChunk | undefined;
for await (const msg of stream) {
finalMsg = finalMsg ? concat(finalMsg, msg) : msg;
}
if (!finalMsg) {
throw new Error("finalMsg is undefined");
}
expect(finalMsg.content as string).toContain("Dodgers");
});
});
Loading