From ff8f9b877fa64f6fc27c4ac5a68b25077e8be6d8 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Wed, 24 Jul 2024 11:23:26 -0700 Subject: [PATCH] google-common[minor]: Add tool choice param --- .../src/chat_models.ts | 48 +---- .../langchain-google-common/src/connection.ts | 21 ++ libs/langchain-google-common/src/types.ts | 24 +++ .../src/utils/common.ts | 197 +++++++++++++----- libs/langchain-google-vertexai/package.json | 3 +- .../src/tests/chat_models.int.test.ts | 37 ++++ yarn.lock | 1 + 7 files changed, 227 insertions(+), 104 deletions(-) diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index ca63052bacbe..bd06fbfe0dc4 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -13,7 +13,6 @@ import { BaseLanguageModelInput, StructuredOutputMethodOptions, ToolDefinition, - isOpenAITool, } from "@langchain/core/language_models/base"; import type { z } from "zod"; import { @@ -24,7 +23,6 @@ import { } from "@langchain/core/runnables"; import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools"; import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; -import { isStructuredTool } from "@langchain/core/utils/function_calling"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { StructuredToolInterface } from "@langchain/core/tools"; import { concat } from "@langchain/core/utils/stream"; @@ -39,6 +37,7 @@ import { GoogleAIBaseLanguageModelCallOptions, } from "./types.js"; import { + convertToGeminiTools, copyAIModelParams, copyAndValidateModelParamsInto, } from "./utils/common.js"; @@ -60,7 +59,6 @@ import type { GeminiFunctionSchema, } from "./types.js"; import { - jsonSchemaToGeminiParameters, zodToGeminiParameters, } from "./utils/zod_to_gemini_parameters.js"; @@ -160,44 +158,6 @@ export interface ChatGoogleBaseInput GoogleAISafetyParams, Pick {} -function convertToGeminiTools( - structuredTools: ( - | StructuredToolInterface - | Record - | ToolDefinition - | RunnableToolLike - )[] -): GeminiTool[] { - return [ - { - functionDeclarations: structuredTools.map( - (structuredTool): GeminiFunctionDeclaration => { - if (isStructuredTool(structuredTool)) { - const jsonSchema = zodToGeminiParameters(structuredTool.schema); - return { - name: structuredTool.name, - description: structuredTool.description, - parameters: jsonSchema as GeminiFunctionSchema, - }; - } - if (isOpenAITool(structuredTool)) { - return { - name: structuredTool.function.name, - description: - structuredTool.function.description ?? - `A function available to call.`, - parameters: jsonSchemaToGeminiParameters( - structuredTool.function.parameters - ), - }; - } - return structuredTool as unknown as GeminiFunctionDeclaration; - } - ), - }, - ]; -} - /** * Integration with a chat model. */ @@ -342,12 +302,6 @@ export abstract class ChatGoogleBase * Get the parameters used to invoke the model */ override invocationParams(options?: this["ParsedCallOptions"]) { - if (options?.tool_choice) { - throw new Error( - `'tool_choice' call option is not supported by ${this.getName()}.` - ); - } - return copyAIModelParams(this, options); } diff --git a/libs/langchain-google-common/src/connection.ts b/libs/langchain-google-common/src/connection.ts index 212bfa886b8f..740223b26d6e 100644 --- a/libs/langchain-google-common/src/connection.ts +++ b/libs/langchain-google-common/src/connection.ts @@ -101,6 +101,7 @@ export abstract class GoogleConnection< if (data && method === "POST") { opts.data = data; } + console.log("data", data) if (this.streaming) { opts.responseType = "stream"; } else { @@ -350,6 +351,21 @@ export abstract class AbstractGoogleLLMConnection< } } + formatToolConfig( + parameters: GoogleAIModelRequestParams + ): GeminiRequest["toolConfig"] | undefined { + if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") { + return undefined; + } + + return { + functionCallingConfig: { + mode: parameters.tool_choice as "auto" | "any" | "none", + allowedFunctionNames: parameters.allowed_function_names, + } + } + } + formatData( input: MessageType, parameters: GoogleAIModelRequestParams @@ -357,6 +373,7 @@ export abstract class AbstractGoogleLLMConnection< const contents = this.formatContents(input, parameters); const generationConfig = this.formatGenerationConfig(input, parameters); const tools = this.formatTools(input, parameters); + const toolConfig = this.formatToolConfig(parameters); const safetySettings = this.formatSafetySettings(input, parameters); const systemInstruction = this.formatSystemInstruction(input, parameters); @@ -365,8 +382,12 @@ export abstract class AbstractGoogleLLMConnection< generationConfig, }; if (tools && tools.length) { + console.log("HAVE TOOLS!!!!!!!") ret.tools = tools; } + if (toolConfig) { + ret.toolConfig = toolConfig; + } if (safetySettings && safetySettings.length) { ret.safetySettings = safetySettings; } diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index 3d316f52ddbe..04ad60e7f01f 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -117,6 +117,24 @@ export interface GoogleAIModelParams { */ export interface GoogleAIModelRequestParams extends GoogleAIModelParams { tools?: StructuredToolInterface[] | GeminiTool[]; + /** + * Force the model to use tools in a specific way. + * + * | Mode | Description | + * |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------| + * | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. | + * | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. | + * | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. | + * | string | The string value must be one of the function names. This will force the model to predict the specified function call. | + * + * The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only. + */ + tool_choice?: string | "auto" | "any" | "none" | Record; + /** + * Allowed functions to call when the mode is "any". + * If empty, any one of the provided functions are called. + */ + allowed_function_names?: string[]; } export interface GoogleAIBaseLLMInput @@ -251,6 +269,12 @@ export interface GeminiRequest { contents?: GeminiContent[]; systemInstruction?: GeminiContent; tools?: GeminiTool[]; + toolConfig?: { + functionCallingConfig: { + mode: "auto" | "any" | "none"; + allowedFunctionNames?: string[]; + } + } safetySettings?: GeminiSafetySetting[]; generationConfig?: GeminiGenerationConfig; } diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index 6ea6533d8225..f16c96188597 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -1,5 +1,7 @@ import { StructuredToolInterface } from "@langchain/core/tools"; import type { + GeminiFunctionDeclaration, + GeminiFunctionSchema, GeminiTool, GoogleAIBaseLanguageModelCallOptions, GoogleAIModelParams, @@ -7,6 +9,10 @@ import type { GoogleLLMModelFamily, } from "../types.js"; import { isModelGemini, validateGeminiParams } from "./gemini.js"; +import { isOpenAITool, ToolDefinition } from "@langchain/core/language_models/base"; +import { RunnableToolLike } from "@langchain/core/runnables"; +import { isStructuredTool } from "@langchain/core/utils/function_calling"; +import { jsonSchemaToGeminiParameters, zodToGeminiParameters } from "./zod_to_gemini_parameters.js"; export function copyAIModelParams( params: GoogleAIModelParams | undefined, @@ -15,6 +21,77 @@ export function copyAIModelParams( return copyAIModelParamsInto(params, options, {}); } +function processToolChoice(toolChoice: GoogleAIBaseLanguageModelCallOptions["tool_choice"], allowedFunctionNames: GoogleAIBaseLanguageModelCallOptions["allowed_function_names"]): { + tool_choice: "any" | "auto" | "none"; + allowed_function_names?: string[]; +} | undefined { + + if (!toolChoice) { + if (allowedFunctionNames) { + // Allowed func names is passed, return 'any' so it forces the model to use a tool. + return { + tool_choice: "any", + allowed_function_names: allowedFunctionNames, + }; + } + return undefined; + } + + if (toolChoice === "any" || toolChoice === "auto" || toolChoice === "none") { + return { + tool_choice: toolChoice, + allowed_function_names: allowedFunctionNames, + }; + } + if (typeof toolChoice === "string") { + // String representing the function name. + // Return any to force the model to predict the specified function call. + return { + tool_choice: "any", + allowed_function_names: [...(allowedFunctionNames ?? []), toolChoice], + }; + } + throw new Error("Object inputs for tool_choice not supported.") +} + +export function convertToGeminiTools( + structuredTools: ( + | StructuredToolInterface + | Record + | ToolDefinition + | RunnableToolLike + )[] +): GeminiTool[] { + return [ + { + functionDeclarations: structuredTools.map( + (structuredTool): GeminiFunctionDeclaration => { + if (isStructuredTool(structuredTool)) { + const jsonSchema = zodToGeminiParameters(structuredTool.schema); + return { + name: structuredTool.name, + description: structuredTool.description, + parameters: jsonSchema as GeminiFunctionSchema, + }; + } + if (isOpenAITool(structuredTool)) { + return { + name: structuredTool.function.name, + description: + structuredTool.function.description ?? + `A function available to call.`, + parameters: jsonSchemaToGeminiParameters( + structuredTool.function.parameters + ), + }; + } + return structuredTool as unknown as GeminiFunctionDeclaration; + } + ), + }, + ]; +} + export function copyAIModelParamsInto( params: GoogleAIModelParams | undefined, options: GoogleAIBaseLanguageModelCallOptions | undefined, @@ -46,66 +123,74 @@ export function copyAIModelParamsInto( params?.responseMimeType ?? target?.responseMimeType; ret.streaming = options?.streaming ?? params?.streaming ?? target?.streaming; + const toolChoice = processToolChoice(options?.tool_choice, options?.allowed_function_names); + if (toolChoice) { + ret.tool_choice = toolChoice.tool_choice; + ret.allowed_function_names = toolChoice.allowed_function_names; + } - ret.tools = options?.tools; - // Ensure tools are formatted properly for Gemini - const geminiTools = options?.tools - ?.map((tool) => { - if ( - "function" in tool && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - "parameters" in (tool.function as Record) - ) { - // Tool is in OpenAI format. Convert to Gemini then return. + const tools = options?.tools; + if (tools) { + ret.tools = convertToGeminiTools(tools as Record[]); + } + // // Ensure tools are formatted properly for Gemini + // const geminiTools = options?.tools + // ?.map((tool) => { + // if ( + // "function" in tool && + // // eslint-disable-next-line @typescript-eslint/no-explicit-any + // "parameters" in (tool.function as Record) + // ) { + // // Tool is in OpenAI format. Convert to Gemini then return. - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const castTool = tool.function as Record; - const cleanedParameters = castTool.parameters; - if ("$schema" in cleanedParameters) { - delete cleanedParameters.$schema; - } - if ("additionalProperties" in cleanedParameters) { - delete cleanedParameters.additionalProperties; - } - const toolInGeminiFormat: GeminiTool = { - functionDeclarations: [ - { - name: castTool.name, - description: castTool.description, - parameters: cleanedParameters, - }, - ], - }; - return toolInGeminiFormat; - } else if ("functionDeclarations" in tool) { - return tool; - } else { - return null; - } - }) - .filter((tool): tool is GeminiTool => tool !== null); + // // eslint-disable-next-line @typescript-eslint/no-explicit-any + // const castTool = tool.function as Record; + // const cleanedParameters = castTool.parameters; + // if ("$schema" in cleanedParameters) { + // delete cleanedParameters.$schema; + // } + // if ("additionalProperties" in cleanedParameters) { + // delete cleanedParameters.additionalProperties; + // } + // const toolInGeminiFormat: GeminiTool = { + // functionDeclarations: [ + // { + // name: castTool.name, + // description: castTool.description, + // parameters: cleanedParameters, + // }, + // ], + // }; + // return toolInGeminiFormat; + // } else if ("functionDeclarations" in tool) { + // return tool; + // } else { + // return convertToGeminiTools([tool]); + // } + // }) + // .filter((tool): tool is GeminiTool => tool !== null); - const structuredOutputTools = options?.tools - ?.map((tool) => { - if ("lc_namespace" in tool) { - return tool; - } else { - return null; - } - }) - .filter((tool): tool is StructuredToolInterface => tool !== null); + // const structuredOutputTools = options?.tools + // ?.map((tool) => { + // if ("lc_namespace" in tool) { + // return tool; + // } else { + // return null; + // } + // }) + // .filter((tool): tool is StructuredToolInterface => tool !== null); - if ( - structuredOutputTools && - structuredOutputTools.length > 0 && - geminiTools && - geminiTools.length > 0 - ) { - throw new Error( - `Cannot mix structured tools with Gemini tools.\nReceived ${structuredOutputTools.length} structured tools and ${geminiTools.length} Gemini tools.` - ); - } - ret.tools = geminiTools ?? structuredOutputTools; + // if ( + // structuredOutputTools && + // structuredOutputTools.length > 0 && + // geminiTools && + // geminiTools.length > 0 + // ) { + // throw new Error( + // `Cannot mix structured tools with Gemini tools.\nReceived ${structuredOutputTools.length} structured tools and ${geminiTools.length} Gemini tools.` + // ); + // } + // ret.tools = geminiTools ?? structuredOutputTools; return ret; } diff --git a/libs/langchain-google-vertexai/package.json b/libs/langchain-google-vertexai/package.json index 8d37a978879c..e1e62e633ea1 100644 --- a/libs/langchain-google-vertexai/package.json +++ b/libs/langchain-google-vertexai/package.json @@ -70,7 +70,8 @@ "release-it": "^15.10.1", "rollup": "^4.5.2", "ts-jest": "^29.1.0", - "typescript": "<5.2.0" + "typescript": "<5.2.0", + "zod": "^3.22.3" }, "publishConfig": { "access": "public" diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index 054462d7d1c0..618d0ea2d13a 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -13,6 +13,8 @@ import { } from "@langchain/core/messages"; import { ChatVertexAI } from "../chat_models.js"; import { GeminiTool } from "../types.js"; +import { tool } from "@langchain/core/tools"; +import { z } from "zod"; describe("GAuth Chat", () => { test("invoke", async () => { @@ -322,3 +324,38 @@ test("Streaming true constructor param will stream", async () => { expect(totalTokenCount).toBeGreaterThan(1); }); + +test.only("tool_choice works", async () => { + const model = new ChatVertexAI({ + model: "gemini-1.5-pro", + }); + const weatherTool = tool((_) => { + return "no-op" + }, { + name: "get_weather", + description: "Get the weather of a specific location and return the temperature in Celsius.", + schema: z.object({ + location: z.string().describe("The name of city to get the weather for."), + }) + }); + const calculatorTool = tool((_) => { + return "no-op" + }, { + name: "calculator", + description: "Calculate the result of a math expression.", + schema: z.object({ + expression: z.string().describe("The math expression to calculate."), + }) + }); + const modelWithTools = model.bind({ + tools: [weatherTool, calculatorTool], + tool_choice: "get_weather" + }); + // const modelWithTools = model.bindTools([weatherTool]); + + // const result = await modelWithTools.invoke("Whats the weather like in paris today?"); + const result = await modelWithTools.invoke("Whats the weather like in paris today? also, what's 18628362 plus 18361?"); + console.log(result); + expect(result.tool_calls).toBeDefined(); + console.log(result.tool_calls); +}) \ No newline at end of file diff --git a/yarn.lock b/yarn.lock index 3c34ea815cab..e8eaadd92600 100644 --- a/yarn.lock +++ b/yarn.lock @@ -11257,6 +11257,7 @@ __metadata: rollup: ^4.5.2 ts-jest: ^29.1.0 typescript: <5.2.0 + zod: ^3.22.3 languageName: unknown linkType: soft