Skip to content

Commit

Permalink
google-common[minor]: Add tool choice param
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 24, 2024
1 parent a8e74c1 commit ff8f9b8
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 104 deletions.
48 changes: 1 addition & 47 deletions libs/langchain-google-common/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import {
BaseLanguageModelInput,
StructuredOutputMethodOptions,
ToolDefinition,
isOpenAITool,
} from "@langchain/core/language_models/base";
import type { z } from "zod";
import {
Expand All @@ -24,7 +23,6 @@ import {
} from "@langchain/core/runnables";
import { JsonOutputKeyToolsParser } from "@langchain/core/output_parsers/openai_tools";
import { BaseLLMOutputParser } from "@langchain/core/output_parsers";
import { isStructuredTool } from "@langchain/core/utils/function_calling";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { StructuredToolInterface } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
Expand All @@ -39,6 +37,7 @@ import {
GoogleAIBaseLanguageModelCallOptions,
} from "./types.js";
import {
convertToGeminiTools,
copyAIModelParams,
copyAndValidateModelParamsInto,
} from "./utils/common.js";
Expand All @@ -60,7 +59,6 @@ import type {
GeminiFunctionSchema,
} from "./types.js";
import {
jsonSchemaToGeminiParameters,
zodToGeminiParameters,
} from "./utils/zod_to_gemini_parameters.js";

Expand Down Expand Up @@ -160,44 +158,6 @@ export interface ChatGoogleBaseInput<AuthOptions>
GoogleAISafetyParams,
Pick<GoogleAIBaseLanguageModelCallOptions, "streamUsage"> {}

function convertToGeminiTools(
structuredTools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
| RunnableToolLike
)[]
): GeminiTool[] {
return [
{
functionDeclarations: structuredTools.map(
(structuredTool): GeminiFunctionDeclaration => {
if (isStructuredTool(structuredTool)) {
const jsonSchema = zodToGeminiParameters(structuredTool.schema);
return {
name: structuredTool.name,
description: structuredTool.description,
parameters: jsonSchema as GeminiFunctionSchema,
};
}
if (isOpenAITool(structuredTool)) {
return {
name: structuredTool.function.name,
description:
structuredTool.function.description ??
`A function available to call.`,
parameters: jsonSchemaToGeminiParameters(
structuredTool.function.parameters
),
};
}
return structuredTool as unknown as GeminiFunctionDeclaration;
}
),
},
];
}

/**
* Integration with a chat model.
*/
Expand Down Expand Up @@ -342,12 +302,6 @@ export abstract class ChatGoogleBase<AuthOptions>
* Get the parameters used to invoke the model
*/
override invocationParams(options?: this["ParsedCallOptions"]) {
if (options?.tool_choice) {
throw new Error(
`'tool_choice' call option is not supported by ${this.getName()}.`
);
}

return copyAIModelParams(this, options);
}

Expand Down
21 changes: 21 additions & 0 deletions libs/langchain-google-common/src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export abstract class GoogleConnection<
if (data && method === "POST") {
opts.data = data;
}
console.log("data", data)
if (this.streaming) {
opts.responseType = "stream";
} else {
Expand Down Expand Up @@ -350,13 +351,29 @@ export abstract class AbstractGoogleLLMConnection<
}
}

formatToolConfig(
parameters: GoogleAIModelRequestParams
): GeminiRequest["toolConfig"] | undefined {
if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") {
return undefined;
}

return {
functionCallingConfig: {
mode: parameters.tool_choice as "auto" | "any" | "none",
allowedFunctionNames: parameters.allowed_function_names,
}
}
}

formatData(
input: MessageType,
parameters: GoogleAIModelRequestParams
): GeminiRequest {
const contents = this.formatContents(input, parameters);
const generationConfig = this.formatGenerationConfig(input, parameters);
const tools = this.formatTools(input, parameters);
const toolConfig = this.formatToolConfig(parameters);
const safetySettings = this.formatSafetySettings(input, parameters);
const systemInstruction = this.formatSystemInstruction(input, parameters);

Expand All @@ -365,8 +382,12 @@ export abstract class AbstractGoogleLLMConnection<
generationConfig,
};
if (tools && tools.length) {
console.log("HAVE TOOLS!!!!!!!")
ret.tools = tools;
}
if (toolConfig) {
ret.toolConfig = toolConfig;
}
if (safetySettings && safetySettings.length) {
ret.safetySettings = safetySettings;
}
Expand Down
24 changes: 24 additions & 0 deletions libs/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,24 @@ export interface GoogleAIModelParams {
*/
export interface GoogleAIModelRequestParams extends GoogleAIModelParams {
tools?: StructuredToolInterface[] | GeminiTool[];
/**
* Force the model to use tools in a specific way.
*
* | Mode | Description |
* |----------|---------------------------------------------------------------------------------------------------------------------------------------------------------|
* | "auto" | The default model behavior. The model decides whether to predict a function call or a natural language response. |
* | "any" | The model must predict only function calls. To limit the model to a subset of functions, define the allowed function names in `allowed_function_names`. |
* | "none" | The model must not predict function calls. This behavior is equivalent to a model request without any associated function declarations. |
* | string | The string value must be one of the function names. This will force the model to predict the specified function call. |
*
* The tool configuration's "any" mode ("forced function calling") is supported for Gemini 1.5 Pro models only.
*/
tool_choice?: string | "auto" | "any" | "none" | Record<string, any>;
/**
* Allowed functions to call when the mode is "any".
* If empty, any one of the provided functions are called.
*/
allowed_function_names?: string[];
}

export interface GoogleAIBaseLLMInput<AuthOptions>
Expand Down Expand Up @@ -251,6 +269,12 @@ export interface GeminiRequest {
contents?: GeminiContent[];
systemInstruction?: GeminiContent;
tools?: GeminiTool[];
toolConfig?: {
functionCallingConfig: {
mode: "auto" | "any" | "none";
allowedFunctionNames?: string[];
}
}
safetySettings?: GeminiSafetySetting[];
generationConfig?: GeminiGenerationConfig;
}
Expand Down
Loading

0 comments on commit ff8f9b8

Please sign in to comment.