From a1530dafca0ac1734f9c02382f5825de011aa158 Mon Sep 17 00:00:00 2001 From: Allen Firstenberg Date: Mon, 11 Nov 2024 19:06:33 -0500 Subject: [PATCH] feat(google-vertexai): Support Non-Google and Model Garden models in Vertex AI - Anthropic integration (#6999) Co-authored-by: jacoblee93 Co-authored-by: bracesproul --- .../integrations/chat/google_vertex_ai.ipynb | 6 +- .../docs/integrations/platforms/google.mdx | 13 +- .../src/chat_models.ts | 121 ++- .../langchain-google-common/src/connection.ts | 414 ++++++---- libs/langchain-google-common/src/llms.ts | 12 +- .../src/tests/chat_models.test.ts | 213 +++++- .../src/tests/data/chat-2-mock.json | 8 - .../src/tests/data/claude-chat-1-mock.json | 18 + .../src/tests/data/claude-chat-1-mock.sse | 267 +++++++ .../src/tests/utils.test.ts | 150 ++-- .../src/types-anthropic.ts | 237 ++++++ libs/langchain-google-common/src/types.ts | 118 ++- .../src/utils/anthropic.ts | 719 ++++++++++++++++++ .../src/utils/common.ts | 26 +- .../src/utils/gemini.ts | 528 +++++++++---- .../src/utils/stream.ts | 201 ++++- libs/langchain-google-gauth/src/auth.ts | 67 +- .../src/tests/chat_models.int.test.ts | 10 +- .../src/tests/chat_models.int.test.ts | 497 +++++++----- .../src/tests/chat_models.int.test.ts | 6 +- 20 files changed, 2957 insertions(+), 674 deletions(-) create mode 100644 libs/langchain-google-common/src/tests/data/claude-chat-1-mock.json create mode 100644 libs/langchain-google-common/src/tests/data/claude-chat-1-mock.sse create mode 100644 libs/langchain-google-common/src/types-anthropic.ts create mode 100644 libs/langchain-google-common/src/utils/anthropic.ts diff --git a/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb b/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb index 158e71453fcb..d4de68c3f5e2 100644 --- a/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb +++ b/docs/core_docs/docs/integrations/chat/google_vertex_ai.ipynb @@ -21,7 +21,9 @@ "source": [ "# ChatVertexAI\n", "\n", - "[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.\n", + "[Google Vertex](https://cloud.google.com/vertex-ai) is a service that exposes all foundation models available in Google Cloud, like `gemini-1.5-pro`, `gemini-1.5-flash`, etc.", + "It also provides some non-Google models such as [Anthropic's Claude](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude).", + "\n", "\n", "This will help you getting started with `ChatVertexAI` [chat models](/docs/concepts/chat_models). For detailed documentation of all `ChatVertexAI` features and configurations head to the [API reference](https://api.js.langchain.com/classes/langchain_google_vertexai.ChatVertexAI.html).\n", "\n", @@ -279,4 +281,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/core_docs/docs/integrations/platforms/google.mdx b/docs/core_docs/docs/integrations/platforms/google.mdx index 00ff6503538d..8460c654bf98 100644 --- a/docs/core_docs/docs/integrations/platforms/google.mdx +++ b/docs/core_docs/docs/integrations/platforms/google.mdx @@ -10,7 +10,7 @@ Functionality related to [Google Cloud Platform](https://cloud.google.com/) ### Gemini Models -Access Gemini models such as `gemini-pro` and `gemini-pro-vision` through the [`ChatGoogleGenerativeAI`](/docs/integrations/chat/google_generativeai), +Access Gemini models such as `gemini-1.5-pro` and `gemini-1.5-flex` through the [`ChatGoogleGenerativeAI`](/docs/integrations/chat/google_generativeai), or if using VertexAI, via the [`ChatVertexAI`](/docs/integrations/chat/google_vertex_ai) class. import Tabs from "@theme/Tabs"; @@ -153,6 +153,17 @@ Click [here](/docs/integrations/chat/google_vertex_ai) for the `@langchain/googl The value of `image_url` must be a base64 encoded image (e.g., ``). +### Non-Gemini Models + +See above for setting up authentication through Vertex AI to use these models. + +[Anthropic](/docs/integrations/chat/anthropic) Claude models are also available through +the [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude) +platform. See [here](https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude) +for more information about enabling access to the models and the model names to use. + +PaLM models are no longer supported. + ## Vector Store ### Vertex AI Vector Search diff --git a/libs/langchain-google-common/src/chat_models.ts b/libs/langchain-google-common/src/chat_models.ts index d770c25b026f..1de5280fbe72 100644 --- a/libs/langchain-google-common/src/chat_models.ts +++ b/libs/langchain-google-common/src/chat_models.ts @@ -29,9 +29,10 @@ import { GoogleAISafetySetting, GoogleConnectionParams, GooglePlatformType, - GeminiContent, GeminiTool, GoogleAIBaseLanguageModelCallOptions, + GoogleAIAPI, + GoogleAIAPIParams, } from "./types.js"; import { convertToGeminiTools, @@ -39,7 +40,7 @@ import { copyAndValidateModelParamsInto, } from "./utils/common.js"; import { AbstractGoogleLLMConnection } from "./connection.js"; -import { DefaultGeminiSafetyHandler } from "./utils/gemini.js"; +import { DefaultGeminiSafetyHandler, getGeminiAPI } from "./utils/gemini.js"; import { ApiKeyGoogleAuth, GoogleAbstractedClient } from "./auth.js"; import { JsonStream } from "./utils/stream.js"; import { ensureParams } from "./utils/failed_handler.js"; @@ -96,71 +97,21 @@ export class ChatConnection extends AbstractGoogleLLMConnection< return true; } - async formatContents( - input: BaseMessage[], - _parameters: GoogleAIModelParams - ): Promise { - const inputPromises: Promise[] = input.map((msg, i) => - this.api.baseMessageToContent( - msg, - input[i - 1], - this.useSystemInstruction - ) - ); - const inputs = await Promise.all(inputPromises); - - return inputs.reduce((acc, cur) => { - // Filter out the system content - if (cur.every((content) => content.role === "system")) { - return acc; - } - - // Combine adjacent function messages - if ( - cur[0]?.role === "function" && - acc.length > 0 && - acc[acc.length - 1].role === "function" - ) { - acc[acc.length - 1].parts = [ - ...acc[acc.length - 1].parts, - ...cur[0].parts, - ]; - } else { - acc.push(...cur); - } - - return acc; - }, [] as GeminiContent[]); + buildGeminiAPI(): GoogleAIAPI { + const geminiConfig: GeminiAPIConfig = { + useSystemInstruction: this.useSystemInstruction, + ...(this.apiConfig as GeminiAPIConfig), + }; + return getGeminiAPI(geminiConfig); } - async formatSystemInstruction( - input: BaseMessage[], - _parameters: GoogleAIModelParams - ): Promise { - if (!this.useSystemInstruction) { - return {} as GeminiContent; + get api(): GoogleAIAPI { + switch (this.apiName) { + case "google": + return this.buildGeminiAPI(); + default: + return super.api; } - - let ret = {} as GeminiContent; - for (let index = 0; index < input.length; index += 1) { - const message = input[index]; - if (message._getType() === "system") { - // For system types, we only want it if it is the first message, - // if it appears anywhere else, it should be an error. - if (index === 0) { - // eslint-disable-next-line prefer-destructuring - ret = ( - await this.api.baseMessageToContent(message, undefined, true) - )[0]; - } else { - throw new Error( - "System messages are only permitted as the first passed message." - ); - } - } - } - - return ret; } } @@ -172,7 +123,7 @@ export interface ChatGoogleBaseInput GoogleConnectionParams, GoogleAIModelParams, GoogleAISafetyParams, - GeminiAPIConfig, + GoogleAIAPIParams, Pick {} /** @@ -341,13 +292,14 @@ export abstract class ChatGoogleBase const response = await this.connection.request( messages, parameters, - options + options, + runManager ); - const ret = this.connection.api.safeResponseToChatResult( - response, - this.safetyHandler - ); - await runManager?.handleLLMNewToken(ret.generations[0].text); + const ret = this.connection.api.responseToChatResult(response); + const chunk = ret?.generations?.[0]; + if (chunk) { + await runManager?.handleLLMNewToken(chunk.text || ""); + } return ret; } @@ -361,7 +313,8 @@ export abstract class ChatGoogleBase const response = await this.streamedConnection.request( _messages, parameters, - options + options, + runManager ); // Get the streaming parser of the response @@ -372,6 +325,12 @@ export abstract class ChatGoogleBase // that is either available or added to the queue while (!stream.streamDone) { const output = await stream.nextChunk(); + await runManager?.handleCustomEvent( + `google-chunk-${this.constructor.name}`, + { + output, + } + ); if ( output && output.usageMetadata && @@ -386,10 +345,7 @@ export abstract class ChatGoogleBase } const chunk = output !== null - ? this.connection.api.safeResponseToChatGeneration( - { data: output }, - this.safetyHandler - ) + ? this.connection.api.responseToChatGeneration({ data: output }) : new ChatGenerationChunk({ text: "", generationInfo: { finishReason: "stop" }, @@ -398,8 +354,17 @@ export abstract class ChatGoogleBase usage_metadata: usageMetadata, }), }); - yield chunk; - await runManager?.handleLLMNewToken(chunk.text); + if (chunk) { + yield chunk; + await runManager?.handleLLMNewToken( + chunk.text ?? "", + undefined, + undefined, + undefined, + undefined, + { chunk } + ); + } } } diff --git a/libs/langchain-google-common/src/connection.ts b/libs/langchain-google-common/src/connection.ts index 7e7da9daa304..5a1c1fa494ae 100644 --- a/libs/langchain-google-common/src/connection.ts +++ b/libs/langchain-google-common/src/connection.ts @@ -1,35 +1,37 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; import { AsyncCaller, AsyncCallerCallOptions, } from "@langchain/core/utils/async_caller"; import { getRuntimeEnvironment } from "@langchain/core/utils/env"; -import { StructuredToolParams } from "@langchain/core/tools"; -import { isLangChainTool } from "@langchain/core/utils/function_calling"; +import { BaseRunManager } from "@langchain/core/callbacks/manager"; +import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; import type { GoogleAIBaseLLMInput, GoogleConnectionParams, - GoogleLLMModelFamily, GooglePlatformType, GoogleResponse, GoogleLLMResponse, - GeminiContent, - GeminiGenerationConfig, - GeminiRequest, - GeminiSafetySetting, - GeminiTool, - GeminiFunctionDeclaration, GoogleAIModelRequestParams, GoogleRawResponse, - GoogleAIToolType, + GoogleAIAPI, + VertexModelFamily, + GoogleAIAPIConfig, + AnthropicAPIConfig, + GeminiAPIConfig, } from "./types.js"; import { GoogleAbstractedClient, GoogleAbstractedClientOps, GoogleAbstractedClientOpsMethod, } from "./auth.js"; -import { zodToGeminiParameters } from "./utils/zod_to_gemini_parameters.js"; -import { getGeminiAPI } from "./utils/index.js"; +import { + getGeminiAPI, + modelToFamily, + modelToPublisher, +} from "./utils/index.js"; +import { getAnthropicAPI } from "./utils/anthropic.js"; export abstract class GoogleConnection< CallOptions extends AsyncCallerCallOptions, @@ -148,9 +150,9 @@ export abstract class GoogleHostConnection< // Use the "platform" getter if you need this. platformType: GooglePlatformType | undefined; - endpoint = "us-central1-aiplatform.googleapis.com"; + _endpoint: string | undefined; - location = "us-central1"; + _location: string | undefined; apiVersion = "v1"; @@ -164,8 +166,8 @@ export abstract class GoogleHostConnection< this.caller = caller; this.platformType = fields?.platformType; - this.endpoint = fields?.endpoint ?? this.endpoint; - this.location = fields?.location ?? this.location; + this._endpoint = fields?.endpoint; + this._location = fields?.location; this.apiVersion = fields?.apiVersion ?? this.apiVersion; this.client = client; } @@ -178,6 +180,22 @@ export abstract class GoogleHostConnection< return "gcp"; } + get location(): string { + return this._location ?? this.computedLocation; + } + + get computedLocation(): string { + return "us-central1"; + } + + get endpoint(): string { + return this._endpoint ?? this.computedEndpoint; + } + + get computedEndpoint(): string { + return `${this.location}-aiplatform.googleapis.com`; + } + buildMethod(): GoogleAbstractedClientOpsMethod { return "POST"; } @@ -213,8 +231,9 @@ export abstract class GoogleAIConnection< client: GoogleAbstractedClient; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - api: any; // FIXME: Make this a real type + _apiName?: string; + + apiConfig?: GoogleAIAPIConfig; constructor( fields: GoogleAIBaseLLMInput | undefined, @@ -226,14 +245,39 @@ export abstract class GoogleAIConnection< this.client = client; this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; - this.api = getGeminiAPI(fields); + + this._apiName = fields?.apiName; + this.apiConfig = { + safetyHandler: fields?.safetyHandler, // For backwards compatibility + ...fields?.apiConfig, + }; } - get modelFamily(): GoogleLLMModelFamily { - if (this.model.startsWith("gemini")) { - return "gemini"; - } else { - return null; + get modelFamily(): VertexModelFamily { + return modelToFamily(this.model); + } + + get modelPublisher(): string { + return modelToPublisher(this.model); + } + + get computedAPIName(): string { + // At least at the moment, model publishers and APIs map the same + return this.modelPublisher; + } + + get apiName(): string { + return this._apiName ?? this.computedAPIName; + } + + get api(): GoogleAIAPI { + switch (this.apiName) { + case "google": + return getGeminiAPI(this.apiConfig as GeminiAPIConfig); + case "anthropic": + return getAnthropicAPI(this.apiConfig as AnthropicAPIConfig); + default: + throw new Error(`Unknown API: ${this.apiName}`); } } @@ -245,6 +289,19 @@ export abstract class GoogleAIConnection< } } + get computedLocation(): string { + switch (this.apiName) { + case "google": + return super.computedLocation; + case "anthropic": + return "us-east5"; + default: + throw new Error( + `Unknown apiName: ${this.apiName}. Can't get location.` + ); + } + } + abstract buildUrlMethod(): Promise; async buildUrlGenerativeLanguage(): Promise { @@ -256,7 +313,8 @@ export abstract class GoogleAIConnection< async buildUrlVertex(): Promise { const projectId = await this.client.getProjectId(); const method = await this.buildUrlMethod(); - const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/publishers/google/models/${this.model}:${method}`; + const publisher = this.modelPublisher; + const url = `https://${this.endpoint}/${this.apiVersion}/projects/${projectId}/locations/${this.location}/publishers/${publisher}/models/${this.model}:${method}`; return url; } @@ -277,10 +335,37 @@ export abstract class GoogleAIConnection< async request( input: InputType, parameters: GoogleAIModelRequestParams, - options: CallOptions - ): Promise { - const data = await this.formatData(input, parameters); + + options: CallOptions, + runManager?: BaseRunManager + ): Promise { + const moduleName = this.constructor.name; + const streamingParameters: GoogleAIModelRequestParams = { + ...parameters, + streaming: this.streaming, + }; + const data = await this.formatData(input, streamingParameters); + + await runManager?.handleCustomEvent(`google-request-${moduleName}`, { + data, + parameters: streamingParameters, + options, + connection: { + ...this, + url: await this.buildUrl(), + urlMethod: await this.buildUrlMethod(), + modelFamily: this.modelFamily, + modelPublisher: this.modelPublisher, + computedPlatformType: this.computedPlatformType, + }, + }); + const response = await this._request(data, options); + + await runManager?.handleCustomEvent(`google-response-${moduleName}`, { + response, + }); + return response; } } @@ -298,141 +383,202 @@ export abstract class AbstractGoogleLLMConnection< return this.streaming ? "streamGenerateContent" : "generateContent"; } + async buildUrlMethodClaude(): Promise { + return this.streaming ? "streamRawPredict" : "rawPredict"; + } + async buildUrlMethod(): Promise { switch (this.modelFamily) { case "gemini": return this.buildUrlMethodGemini(); + case "claude": + return this.buildUrlMethodClaude(); default: throw new Error(`Unknown model family: ${this.modelFamily}`); } } - abstract formatContents( + async formatData( input: MessageType, parameters: GoogleAIModelRequestParams - ): Promise; + ): Promise { + return this.api.formatData(input, parameters); + } +} - formatGenerationConfig( - _input: MessageType, - parameters: GoogleAIModelRequestParams - ): GeminiGenerationConfig { +export interface GoogleCustomEventInfo { + subEvent: string; + module: string; +} + +export abstract class GoogleRequestCallbackHandler extends BaseCallbackHandler { + customEventInfo(eventName: string): GoogleCustomEventInfo { + const names = eventName.split("-"); return { - temperature: parameters.temperature, - topK: parameters.topK, - topP: parameters.topP, - maxOutputTokens: parameters.maxOutputTokens, - stopSequences: parameters.stopSequences, - responseMimeType: parameters.responseMimeType, + subEvent: names[1], + module: names[2], }; } - formatSafetySettings( - _input: MessageType, - parameters: GoogleAIModelRequestParams - ): GeminiSafetySetting[] { - return parameters.safetySettings ?? []; + abstract handleCustomRequestEvent( + eventName: string, + eventInfo: GoogleCustomEventInfo, + data: any, + runId: string, + tags?: string[], + metadata?: Record + ): any; + + abstract handleCustomResponseEvent( + eventName: string, + eventInfo: GoogleCustomEventInfo, + data: any, + runId: string, + tags?: string[], + metadata?: Record + ): any; + + abstract handleCustomChunkEvent( + eventName: string, + eventInfo: GoogleCustomEventInfo, + data: any, + runId: string, + tags?: string[], + metadata?: Record + ): any; + + handleCustomEvent( + eventName: string, + data: any, + runId: string, + tags?: string[], + metadata?: Record + ): any { + if (!eventName) { + return undefined; + } + const eventInfo = this.customEventInfo(eventName); + switch (eventInfo.subEvent) { + case "request": + return this.handleCustomRequestEvent( + eventName, + eventInfo, + data, + runId, + tags, + metadata + ); + case "response": + return this.handleCustomResponseEvent( + eventName, + eventInfo, + data, + runId, + tags, + metadata + ); + case "chunk": + return this.handleCustomChunkEvent( + eventName, + eventInfo, + data, + runId, + tags, + metadata + ); + default: + console.error( + `Unexpected eventInfo for ${eventName} ${JSON.stringify( + eventInfo, + null, + 1 + )}` + ); + } } +} - async formatSystemInstruction( - _input: MessageType, - _parameters: GoogleAIModelRequestParams - ): Promise { - return {} as GeminiContent; - } +export class GoogleRequestLogger extends GoogleRequestCallbackHandler { + name: string = "GoogleRequestLogger"; - structuredToolToFunctionDeclaration( - tool: StructuredToolParams - ): GeminiFunctionDeclaration { - const jsonSchema = zodToGeminiParameters(tool.schema); - return { - name: tool.name, - description: tool.description ?? `A function available to call.`, - parameters: jsonSchema, - }; + log(eventName: string, data: any, tags?: string[]): undefined { + const tagStr = tags ? `[${tags}]` : "[]"; + console.log(`${eventName} ${tagStr} ${JSON.stringify(data, null, 1)}`); } - structuredToolsToGeminiTools(tools: StructuredToolParams[]): GeminiTool[] { - return [ - { - functionDeclarations: tools.map( - this.structuredToolToFunctionDeclaration - ), - }, - ]; + handleCustomRequestEvent( + eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + tags?: string[], + _metadata?: Record + ): any { + this.log(eventName, data, tags); } - formatTools( - _input: MessageType, - parameters: GoogleAIModelRequestParams - ): GeminiTool[] { - const tools: GoogleAIToolType[] | undefined = parameters?.tools; - if (!tools || tools.length === 0) { - return []; - } + handleCustomResponseEvent( + eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + tags?: string[], + _metadata?: Record + ): any { + this.log(eventName, data, tags); + } - if (tools.every(isLangChainTool)) { - return this.structuredToolsToGeminiTools(tools); - } else { - if ( - tools.length === 1 && - (!("functionDeclarations" in tools[0]) || - !tools[0].functionDeclarations?.length) - ) { - return []; - } - return tools as GeminiTool[]; - } + handleCustomChunkEvent( + eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + tags?: string[], + _metadata?: Record + ): any { + this.log(eventName, data, tags); } +} - formatToolConfig( - parameters: GoogleAIModelRequestParams - ): GeminiRequest["toolConfig"] | undefined { - if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") { - return undefined; - } +export class GoogleRequestRecorder extends GoogleRequestCallbackHandler { + name = "GoogleRequestRecorder"; - return { - functionCallingConfig: { - mode: parameters.tool_choice as "auto" | "any" | "none", - allowedFunctionNames: parameters.allowed_function_names, - }, - }; + request: any = {}; + + response: any = {}; + + chunk: any[] = []; + + handleCustomRequestEvent( + _eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + _tags?: string[], + _metadata?: Record + ): any { + this.request = data; } - async formatData( - input: MessageType, - parameters: GoogleAIModelRequestParams - ): Promise { - const contents = await this.formatContents(input, parameters); - const generationConfig = this.formatGenerationConfig(input, parameters); - const tools = this.formatTools(input, parameters); - const toolConfig = this.formatToolConfig(parameters); - const safetySettings = this.formatSafetySettings(input, parameters); - const systemInstruction = await this.formatSystemInstruction( - input, - parameters - ); + handleCustomResponseEvent( + _eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + _tags?: string[], + _metadata?: Record + ): any { + this.response = data; + } - const ret: GeminiRequest = { - contents, - generationConfig, - }; - if (tools && tools.length) { - ret.tools = tools; - } - if (toolConfig) { - ret.toolConfig = toolConfig; - } - if (safetySettings && safetySettings.length) { - ret.safetySettings = safetySettings; - } - if ( - systemInstruction?.role && - systemInstruction?.parts && - systemInstruction?.parts?.length - ) { - ret.systemInstruction = systemInstruction; - } - return ret; + handleCustomChunkEvent( + _eventName: string, + _eventInfo: GoogleCustomEventInfo, + data: any, + _runId: string, + _tags?: string[], + _metadata?: Record + ): any { + this.chunk.push(data); } } diff --git a/libs/langchain-google-common/src/llms.ts b/libs/langchain-google-common/src/llms.ts index b359a41e7d45..ad74c74e4ac3 100644 --- a/libs/langchain-google-common/src/llms.ts +++ b/libs/langchain-google-common/src/llms.ts @@ -37,7 +37,7 @@ class GoogleLLMConnection extends AbstractGoogleLLMConnection< input: MessageContent, _parameters: GoogleAIModelParams ): Promise { - const parts = await this.api.messageContentToParts(input); + const parts = await this.api.messageContentToParts!(input); const contents: GeminiContent[] = [ { role: "user", // Required by Vertex AI @@ -189,10 +189,7 @@ export abstract class GoogleBaseLLM ): Promise { const parameters = copyAIModelParams(this, options); const result = await this.connection.request(prompt, parameters, options); - const ret = this.connection.api.safeResponseToString( - result, - this.safetyHandler - ); + const ret = this.connection.api.responseToString(result); return ret; } @@ -270,10 +267,7 @@ export abstract class GoogleBaseLLM {}, options as BaseLanguageModelCallOptions ); - const ret = this.connection.api.safeResponseToBaseMessage( - result, - this.safetyHandler - ); + const ret = this.connection.api.responseToBaseMessage(result); return ret; } diff --git a/libs/langchain-google-common/src/tests/chat_models.test.ts b/libs/langchain-google-common/src/tests/chat_models.test.ts index 9da477df3e0e..aa15be74ed79 100644 --- a/libs/langchain-google-common/src/tests/chat_models.test.ts +++ b/libs/langchain-google-common/src/tests/chat_models.test.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { expect, test } from "@jest/globals"; import { AIMessage, @@ -10,12 +11,19 @@ import { ToolMessage, } from "@langchain/core/messages"; import { InMemoryStore } from "@langchain/core/stores"; - +import { CallbackHandlerMethods } from "@langchain/core/callbacks/base"; +import { Serialized } from "@langchain/core/load/serializable"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatGoogleBase, ChatGoogleBaseInput } from "../chat_models.js"; import { authOptions, MockClient, MockClientAuthInfo, mockId } from "./mock.js"; -import { GeminiTool, GoogleAIBaseLLMInput } from "../types.js"; +import { + GeminiTool, + GoogleAIBaseLLMInput, + GoogleAISafetyCategory, + GoogleAISafetyHandler, + GoogleAISafetyThreshold, +} from "../types.js"; import { GoogleAbstractedClient } from "../auth.js"; import { GoogleAISafetyError } from "../utils/safety.js"; import { @@ -25,6 +33,7 @@ import { ReadThroughBlobStore, } from "../experimental/utils/media_core.js"; import { removeAdditionalProperties } from "../utils/zod_to_gemini_parameters.js"; +import { MessageGeminiSafetyHandler } from "../utils/index.js"; class ChatGoogle extends ChatGoogleBase { constructor(fields?: ChatGoogleBaseInput) { @@ -39,7 +48,7 @@ class ChatGoogle extends ChatGoogleBase { } } -describe("Mock ChatGoogle", () => { +describe("Mock ChatGoogle - Gemini", () => { test("Setting invalid model parameters", async () => { expect(() => { const model = new ChatGoogle({ @@ -71,7 +80,6 @@ describe("Mock ChatGoogle", () => { }); test("user agent header", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -100,7 +108,6 @@ describe("Mock ChatGoogle", () => { }); test("platform default", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -115,7 +122,6 @@ describe("Mock ChatGoogle", () => { }); test("platform set", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -131,7 +137,6 @@ describe("Mock ChatGoogle", () => { }); test("1. Basic request format", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -165,7 +170,6 @@ describe("Mock ChatGoogle", () => { }); test("1. Invoke request format", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -199,7 +203,6 @@ describe("Mock ChatGoogle", () => { }); test("1. Response format", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -224,7 +227,6 @@ describe("Mock ChatGoogle", () => { }); test("1. Invoke response format", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -252,7 +254,6 @@ describe("Mock ChatGoogle", () => { // SystemMessages will be turned into the human request with the prompt // from the system message and a faked ai response saying "Ok". test("1. System request format old model", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -293,7 +294,6 @@ describe("Mock ChatGoogle", () => { }); test("1. System request format convert true", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -334,7 +334,6 @@ describe("Mock ChatGoogle", () => { }); test("1. System request format convert false", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -373,7 +372,6 @@ describe("Mock ChatGoogle", () => { }); test("1. System request format new model", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -412,7 +410,6 @@ describe("Mock ChatGoogle", () => { }); test("1. System request - multiple", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -444,7 +441,6 @@ describe("Mock ChatGoogle", () => { }); test("1. System request - not first", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -472,8 +468,7 @@ describe("Mock ChatGoogle", () => { expect(caught).toBeTruthy(); }); - test("2. Response format - safety", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any + test("2. Safety - settings", async () => { const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -483,6 +478,12 @@ describe("Mock ChatGoogle", () => { }; const model = new ChatGoogle({ authOptions, + safetySettings: [ + { + category: GoogleAISafetyCategory.Harassment, + threshold: GoogleAISafetyThreshold.Most, + }, + ], }); const messages: BaseMessageLike[] = [ new HumanMessage("Flip a coin and tell me H for heads and T for tails"), @@ -492,25 +493,88 @@ describe("Mock ChatGoogle", () => { let caught = false; try { await model.invoke(messages); + } catch (xx: any) { + caught = true; + } + + const settings = record?.opts?.data?.safetySettings; + expect(settings).toBeDefined(); + expect(Array.isArray(settings)).toEqual(true); + expect(settings).toHaveLength(1); + expect(settings[0].category).toEqual("HARM_CATEGORY_HARASSMENT"); + expect(settings[0].threshold).toEqual("BLOCK_LOW_AND_ABOVE"); - // eslint-disable-next-line @typescript-eslint/no-explicit-any + expect(caught).toEqual(true); + }); + + test("2. Safety - default", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-2-mock.json", + }; + const model = new ChatGoogle({ + authOptions, + }); + const messages: BaseMessageLike[] = [ + new HumanMessage("Flip a coin and tell me H for heads and T for tails"), + new AIMessage("H"), + new HumanMessage("Flip it again"), + ]; + let caught = false; + try { + await model.invoke(messages); } catch (xx: any) { caught = true; expect(xx).toBeInstanceOf(GoogleAISafetyError); - const result = xx?.reply.generations[0].message; + const result = xx?.reply.generations[0]; + expect(result).toBeUndefined(); + } + + expect(caught).toEqual(true); + }); + + test("2. Safety - safety handler", async () => { + const safetyHandler: GoogleAISafetyHandler = new MessageGeminiSafetyHandler( + { + msg: "I'm sorry, Dave, but I can't do that.", + } + ); + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "chat-2-mock.json", + }; + const model = new ChatGoogle({ + authOptions, + safetyHandler, + }); + const messages: BaseMessageLike[] = [ + new HumanMessage("Flip a coin and tell me H for heads and T for tails"), + new AIMessage("H"), + new HumanMessage("Flip it again"), + ]; + let caught = false; + try { + const result = await model.invoke(messages); expect(result._getType()).toEqual("ai"); const aiMessage = result as AIMessage; expect(aiMessage.content).toBeDefined(); - expect(aiMessage.content).toBe("T"); + expect(aiMessage.content).toBe("I'm sorry, Dave, but I can't do that."); + } catch (xx: any) { + caught = true; } - expect(caught).toEqual(true); + expect(caught).toEqual(false); }); test("3. invoke - images", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -556,7 +620,6 @@ describe("Mock ChatGoogle", () => { }); test("3. invoke - media - invalid", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -588,12 +651,11 @@ describe("Mock ChatGoogle", () => { const result = await model.invoke(messages); expect(result).toBeUndefined(); } catch (e) { - expect((e as Error).message).toEqual("Invalid media content"); + expect((e as Error).message).toMatch(/^Invalid media content/); } }); test("3. invoke - media - no manager", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -684,20 +746,14 @@ describe("Mock ChatGoogle", () => { async function store(path: string, text: string): Promise { const type = path.endsWith(".png") ? "image/png" : "text/plain"; - const blob = new MediaBlob({ - data: { - value: text, - type, - }, - path, - }); + const data = new Blob([text], { type }); + const blob = await MediaBlob.fromBlob(data, { path }); await resolver.store(blob); } await store("resolve://host/foo", "fooing"); await store("resolve://host2/bar/baz", "barbazing"); await store("resolve://host/foo/blue-box.png", "png"); - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -705,10 +761,38 @@ describe("Mock ChatGoogle", () => { projectId, resultFile: "chat-3-mock.json", }; + const callbacks: CallbackHandlerMethods[] = [ + { + handleChatModelStart( + llm: Serialized, + messages: BaseMessage[][], + runId: string, + _parentRunId?: string, + _extraParams?: Record, + _tags?: string[], + _metadata?: Record, + _runName?: string + ): any { + console.log("Chat start", llm, messages, runId); + }, + handleCustomEvent( + eventName: string, + data: any, + runId: string, + tags?: string[], + metadata?: Record + ): any { + console.log("Custom event", eventName, runId, data, tags, metadata); + }, + }, + ]; const model = new ChatGoogle({ authOptions, model: "gemini-1.5-flash", - mediaManager, + apiConfig: { + mediaManager, + }, + callbacks, }); const message: MessageContentComplex[] = [ @@ -750,7 +834,6 @@ describe("Mock ChatGoogle", () => { }); test("4. Functions Bind - Gemini format request", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -832,7 +915,6 @@ describe("Mock ChatGoogle", () => { }); test("4. Functions withStructuredOutput - Gemini format request", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -904,7 +986,6 @@ describe("Mock ChatGoogle", () => { }); test("4. Functions - results", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -965,7 +1046,6 @@ describe("Mock ChatGoogle", () => { }); test("5. Functions - function reply", async () => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any const record: Record = {}; const projectId = mockId(); const authOptions: MockClientAuthInfo = { @@ -1027,7 +1107,60 @@ describe("Mock ChatGoogle", () => { }); }); -// eslint-disable-next-line @typescript-eslint/no-explicit-any +describe("Mock ChatGoogle - Anthropic", () => { + test("1. Invoke request format", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "claude-chat-1-mock.json", + }; + const model = new ChatGoogle({ + model: "claude-3-5-sonnet@20240620", + platformType: "gcp", + authOptions, + }); + const messages: BaseMessageLike[] = [new HumanMessage("What is 1+1?")]; + await model.invoke(messages); + + console.log("record", record); + expect(record.opts).toBeDefined(); + expect(record.opts.data).toBeDefined(); + const { data } = record.opts; + expect(data.messages).toBeDefined(); + expect(data.messages.length).toEqual(1); + expect(data.messages[0].role).toEqual("user"); + expect(data.messages[0].content).toBeDefined(); + expect(data.messages[0].content.length).toBeGreaterThanOrEqual(1); + expect(data.messages[0].content[0].text).toBeDefined(); + expect(data.system).not.toBeDefined(); + }); + + test("1. Invoke response format", async () => { + const record: Record = {}; + const projectId = mockId(); + const authOptions: MockClientAuthInfo = { + record, + projectId, + resultFile: "claude-chat-1-mock.json", + }; + const model = new ChatGoogle({ + model: "claude-3-5-sonnet@20240620", + platformType: "gcp", + authOptions, + }); + const messages: BaseMessageLike[] = [new HumanMessage("What is 1+1?")]; + const result = await model.invoke(messages); + + expect(result._getType()).toEqual("ai"); + const aiMessage = result as AIMessage; + expect(aiMessage.content).toBeDefined(); + expect(aiMessage.content).toBe( + "1 + 1 = 2\n\nThis is one of the most basic arithmetic equations. It represents the addition of two units, resulting in a sum of two." + ); + }); +}); function extractKeys(obj: Record, keys: string[] = []) { for (const key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) { diff --git a/libs/langchain-google-common/src/tests/data/chat-2-mock.json b/libs/langchain-google-common/src/tests/data/chat-2-mock.json index 9ee0bf4564d8..406c22609a76 100644 --- a/libs/langchain-google-common/src/tests/data/chat-2-mock.json +++ b/libs/langchain-google-common/src/tests/data/chat-2-mock.json @@ -1,14 +1,6 @@ { "candidates": [ { - "content": { - "parts": [ - { - "text": "T" - } - ], - "role": "model" - }, "finishReason": "SAFETY", "index": 0, "safetyRatings": [ diff --git a/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.json b/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.json new file mode 100644 index 000000000000..d465fe45392d --- /dev/null +++ b/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.json @@ -0,0 +1,18 @@ +{ + "id": "msg_vrtx_01AGfmYa73qH7wpmFsVFr4rq", + "type": "message", + "role": "assistant", + "model": "claude-3-5-sonnet-20240620", + "content": [ + { + "type": "text", + "text": "1 + 1 = 2\n\nThis is one of the most basic arithmetic equations. It represents the addition of two units, resulting in a sum of two." + } + ], + "stop_reason": "end_turn", + "stop_sequence": null, + "usage": { + "input_tokens": 16, + "output_tokens": 39 + } +} diff --git a/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.sse b/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.sse new file mode 100644 index 000000000000..4213b5548378 --- /dev/null +++ b/libs/langchain-google-common/src/tests/data/claude-chat-1-mock.sse @@ -0,0 +1,267 @@ +event: message_start +data: {"type":"message_start","message":{"id":"msg_vrtx_01JLACAmH9Ke3HQEUK1Sg8iT","type":"message","role":"assistant","model":"claude-3-5-sonnet-20240620","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1}} } + +event: ping +data: {"type": "ping"} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Thank"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" you for inqu"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"iring about my well"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"-being!"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" I'm functioning"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" optim"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ally an"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d feeling"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" quite enthusi"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"astic about"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" engaging"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" in"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" conversation"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" an"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d assisting with"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" any tasks"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" or"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" queries you might"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" have. As"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" an"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" AI,"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" I don"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"'t experience emotions or"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" physical"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" sensations in"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" the way"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" humans do, but"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" I can"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" say"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" that my"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" systems"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" are operating"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" at"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" peak"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" efficiency. I'm"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" eager"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to learn,"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" explore"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" ideas"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":", and tackle"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" intellectual"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" challenges. The"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" vast"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" repository"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" of knowledge"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" at"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" my disposal is"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" pr"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"imed an"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d ready to be put"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" to use in whatever"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" manner"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" you"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" see"}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" fit. Whether"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" you"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"'re"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" looking for in"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"-depth analysis,"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" creative brainstor"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ming, or simply"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" a friendly chat"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":", I'm here"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" an"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d fully"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" prepare"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d to dive"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" into"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" our"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" interaction"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" with"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" gu"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"sto. Is"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" there any"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" particular"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" subject"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" or"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" task"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" you'"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"d like to discuss or"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" work"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" on today"} } + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"?"} } + +event: content_block_stop +data: {"type":"content_block_stop","index":0 } + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":165} } + +event: message_stop +data: {"type":"message_stop" } + diff --git a/libs/langchain-google-common/src/tests/utils.test.ts b/libs/langchain-google-common/src/tests/utils.test.ts index 547392c397e7..70dacca0b8ee 100644 --- a/libs/langchain-google-common/src/tests/utils.test.ts +++ b/libs/langchain-google-common/src/tests/utils.test.ts @@ -13,7 +13,11 @@ import { ReadThroughBlobStore, SimpleWebBlobStore, } from "../experimental/utils/media_core.js"; -import { ReadableJsonStream } from "../utils/stream.js"; +import { + ReadableJsonStream, + ReadableSseJsonStream, + ReadableSseStream, +} from "../utils/stream.js"; describe("zodToGeminiParameters", () => { test("can convert zod schema to gemini schema", () => { @@ -420,52 +424,112 @@ function toUint8Array(data: string): Uint8Array { return new TextEncoder().encode(data); } -test("ReadableJsonStream can handle stream", async () => { - const data = [ - toUint8Array("["), - toUint8Array('{"i": 1}'), - toUint8Array('{"i'), - toUint8Array('": 2}'), - toUint8Array("]"), - ]; +describe("streaming", () => { + test("ReadableJsonStream can handle stream", async () => { + const data = [ + toUint8Array("["), + toUint8Array('{"i": 1}'), + toUint8Array('{"i'), + toUint8Array('": 2}'), + toUint8Array("]"), + ]; + + const source = new ReadableStream({ + start(controller) { + data.forEach((chunk) => controller.enqueue(chunk)); + controller.close(); + }, + }); + const stream = new ReadableJsonStream(source); + expect(await stream.nextChunk()).toEqual({ i: 1 }); + expect(await stream.nextChunk()).toEqual({ i: 2 }); + expect(await stream.nextChunk()).toBeNull(); + expect(stream.streamDone).toEqual(true); + }); - const source = new ReadableStream({ - start(controller) { - data.forEach((chunk) => controller.enqueue(chunk)); - controller.close(); - }, + test("ReadableJsonStream can handle multibyte stream", async () => { + const data = [ + toUint8Array("["), + toUint8Array('{"i": 1, "msg":"helloπŸ‘‹"}'), + toUint8Array('{"i": 2,'), + toUint8Array('"msg":"こん'), + new Uint8Array([0xe3]), // 1st byte of "に" + new Uint8Array([0x81, 0xab]), // 2-3rd bytes of "に" + toUint8Array("けは"), + new Uint8Array([0xf0, 0x9f]), // first half bytes of "πŸ‘‹" + new Uint8Array([0x91, 0x8b]), // second half bytes of "πŸ‘‹" + toUint8Array('"}'), + toUint8Array("]"), + ]; + + const source = new ReadableStream({ + start(controller) { + data.forEach((chunk) => controller.enqueue(chunk)); + controller.close(); + }, + }); + const stream = new ReadableJsonStream(source); + expect(await stream.nextChunk()).toEqual({ i: 1, msg: "helloπŸ‘‹" }); + expect(await stream.nextChunk()).toEqual({ i: 2, msg: "γ“γ‚“γ«γ‘γ―πŸ‘‹" }); + expect(await stream.nextChunk()).toBeNull(); + expect(stream.streamDone).toEqual(true); }); - const stream = new ReadableJsonStream(source); - expect(await stream.nextChunk()).toEqual({ i: 1 }); - expect(await stream.nextChunk()).toEqual({ i: 2 }); - expect(await stream.nextChunk()).toBeNull(); - expect(stream.streamDone).toEqual(true); -}); -test("ReadableJsonStream can handle multibyte stream", async () => { - const data = [ - toUint8Array("["), - toUint8Array('{"i": 1, "msg":"helloπŸ‘‹"}'), - toUint8Array('{"i": 2,'), - toUint8Array('"msg":"こん'), - new Uint8Array([0xe3]), // 1st byte of "に" - new Uint8Array([0x81, 0xab]), // 2-3rd bytes of "に" - toUint8Array("けは"), - new Uint8Array([0xf0, 0x9f]), // first half bytes of "πŸ‘‹" - new Uint8Array([0x91, 0x8b]), // second half bytes of "πŸ‘‹" - toUint8Array('"}'), - toUint8Array("]"), + const eventData: string[] = [ + "event: ping\n", + 'data: {"type": "ping"}\n', + "\n", + "event: pong\n", + 'data: {"type": "pong", "value": "ping-pong"}\n', + "\n", + "\n", ]; - const source = new ReadableStream({ - start(controller) { - data.forEach((chunk) => controller.enqueue(chunk)); - controller.close(); - }, + test("SseStream", async () => { + const source = new ReadableStream({ + start(controller) { + eventData.forEach((chunk) => controller.enqueue(toUint8Array(chunk))); + controller.close(); + }, + }); + + let chunk; + const stream = new ReadableSseStream(source); + + chunk = await stream.nextChunk(); + expect(chunk.event).toEqual("ping"); + expect(chunk.data).toEqual('{"type": "ping"}'); + + chunk = await stream.nextChunk(); + expect(chunk.event).toEqual("pong"); + + chunk = await stream.nextChunk(); + expect(chunk).toBeNull(); + + expect(stream.streamDone).toEqual(true); + }); + + test("SseJsonStream", async () => { + const source = new ReadableStream({ + start(controller) { + eventData.forEach((chunk) => controller.enqueue(toUint8Array(chunk))); + controller.close(); + }, + }); + + let chunk; + const stream = new ReadableSseJsonStream(source); + + chunk = await stream.nextChunk(); + expect(chunk.type).toEqual("ping"); + + chunk = await stream.nextChunk(); + expect(chunk.type).toEqual("pong"); + expect(chunk.value).toEqual("ping-pong"); + + chunk = await stream.nextChunk(); + expect(chunk).toBeNull(); + + expect(stream.streamDone).toEqual(true); }); - const stream = new ReadableJsonStream(source); - expect(await stream.nextChunk()).toEqual({ i: 1, msg: "helloπŸ‘‹" }); - expect(await stream.nextChunk()).toEqual({ i: 2, msg: "γ“γ‚“γ«γ‘γ―πŸ‘‹" }); - expect(await stream.nextChunk()).toBeNull(); - expect(stream.streamDone).toEqual(true); }); diff --git a/libs/langchain-google-common/src/types-anthropic.ts b/libs/langchain-google-common/src/types-anthropic.ts new file mode 100644 index 000000000000..a4c182e09f39 --- /dev/null +++ b/libs/langchain-google-common/src/types-anthropic.ts @@ -0,0 +1,237 @@ +export interface AnthropicCacheControl { + type: "ephemeral" | string; +} + +interface AnthropicMessageContentBase { + type: string; + cache_control?: AnthropicCacheControl | null; +} + +export interface AnthropicMessageContentText + extends AnthropicMessageContentBase { + type: "text"; + text: string; +} + +export interface AnthropicMessageContentImage + extends AnthropicMessageContentBase { + type: "image"; + source: { + type: "base64" | string; + media_type: string; + data: string; + }; +} + +// TODO: Define this +export type AnthropicMessageContentToolUseInput = object; + +export interface AnthropicMessageContentToolUse + extends AnthropicMessageContentBase { + type: "tool_use"; + id: string; + name: string; + input: AnthropicMessageContentToolUseInput; +} + +export type AnthropicMessageContentToolResultContent = + | AnthropicMessageContentText + | AnthropicMessageContentImage; + +export interface AnthropicMessageContentToolResult + extends AnthropicMessageContentBase { + type: "tool_result"; + tool_use_id: string; + is_error?: boolean; + content: string | AnthropicMessageContentToolResultContent[]; +} + +export type AnthropicMessageContent = + | AnthropicMessageContentText + | AnthropicMessageContentImage + | AnthropicMessageContentToolUse + | AnthropicMessageContentToolResult; + +export interface AnthropicMessage { + role: string; + content: string | AnthropicMessageContent[]; +} + +export interface AnthropicMetadata { + user_id?: string | null; +} + +interface AnthropicToolChoiceBase { + type: string; +} + +export interface AnthropicToolChoiceAuto extends AnthropicToolChoiceBase { + type: "auto"; +} + +export interface AnthropicToolChoiceAny extends AnthropicToolChoiceBase { + type: "any"; +} + +export interface AnthropicToolChoiceTool extends AnthropicToolChoiceBase { + type: "tool"; + name: string; +} + +export type AnthropicToolChoice = + | AnthropicToolChoiceAuto + | AnthropicToolChoiceAny + | AnthropicToolChoiceTool; + +// TODO: Define this +export type AnthropicToolInputSchema = object; + +export interface AnthropicTool { + type?: string; // Just available on tools 20241022 and later? + name: string; + description?: string; + cache_control?: AnthropicCacheControl; + input_schema: AnthropicToolInputSchema; +} + +export interface AnthropicRequest { + anthropic_version: string; + messages: AnthropicMessage[]; + system?: string; + stream?: boolean; + max_tokens: number; + temperature?: number; + top_k?: number; + top_p?: number; + stop_sequences?: string[]; + metadata?: AnthropicMetadata; + tool_choice?: AnthropicToolChoice; + tools?: AnthropicTool[]; +} + +export type AnthropicRequestSettings = Pick< + AnthropicRequest, + "max_tokens" | "temperature" | "top_k" | "top_p" | "stop_sequences" | "stream" +>; + +export interface AnthropicContentText { + type: "text"; + text: string; +} + +export interface AnthropicContentToolUse { + type: "tool_use"; + id: string; + name: string; + input: object; +} + +export type AnthropicContent = AnthropicContentText | AnthropicContentToolUse; + +export interface AnthropicUsage { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens: number | null; + cache_creation_output_tokens: number | null; +} + +export type AnthropicResponseData = + | AnthropicResponseMessage + | AnthropicStreamBaseEvent; + +export interface AnthropicResponseMessage { + id: string; + type: string; + role: string; + content: AnthropicContent[]; + model: string; + stop_reason: string | null; + stop_sequence: string | null; + usage: AnthropicUsage; +} + +export interface AnthropicAPIConfig { + version?: string; +} + +export type AnthropicStreamEventType = + | "message_start" + | "content_block_start" + | "content_block_delta" + | "content_block_stop" + | "message_delta" + | "message_stop" + | "ping" + | "error"; + +export type AnthropicStreamDeltaType = "text_delta" | "input_json_delta"; + +export interface AnthropicStreamBaseEvent { + type: AnthropicStreamEventType; +} + +export interface AnthropicStreamMessageStartEvent + extends AnthropicStreamBaseEvent { + type: "message_start"; + message: AnthropicResponseMessage; +} + +export interface AnthropicStreamContentBlockStartEvent + extends AnthropicStreamBaseEvent { + type: "content_block_start"; + index: number; + content_block: AnthropicContent; +} + +export interface AnthropicStreamBaseDelta { + type: AnthropicStreamDeltaType; +} + +export interface AnthropicStreamTextDelta extends AnthropicStreamBaseDelta { + type: "text_delta"; + text: string; +} + +export interface AnthropicStreamInputJsonDelta + extends AnthropicStreamBaseDelta { + type: "input_json_delta"; + partial_json: string; +} + +export type AnthropicStreamDelta = + | AnthropicStreamTextDelta + | AnthropicStreamInputJsonDelta; + +export interface AnthropicStreamContentBlockDeltaEvent + extends AnthropicStreamBaseEvent { + type: "content_block_delta"; + index: number; + delta: AnthropicStreamDelta; +} + +export interface AnthropicStreamContentBlockStopEvent + extends AnthropicStreamBaseEvent { + type: "content_block_stop"; + index: number; +} + +export interface AnthropicStreamMessageDeltaEvent + extends AnthropicStreamBaseEvent { + type: "message_delta"; + delta: Partial; +} + +export interface AnthropicStreamMessageStopEvent + extends AnthropicStreamBaseEvent { + type: "message_stop"; +} + +export interface AnthropicStreamPingEvent extends AnthropicStreamBaseEvent { + type: "ping"; +} + +export interface AnthropicStreamErrorEvent extends AnthropicStreamBaseEvent { + type: "error"; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + error: any; +} diff --git a/libs/langchain-google-common/src/types.ts b/libs/langchain-google-common/src/types.ts index 4fecd254693b..bb49cf2edd4f 100644 --- a/libs/langchain-google-common/src/types.ts +++ b/libs/langchain-google-common/src/types.ts @@ -3,8 +3,20 @@ import type { BaseChatModelCallOptions, BindToolsInput, } from "@langchain/core/language_models/chat_models"; +import { + BaseMessage, + BaseMessageChunk, + MessageContent, +} from "@langchain/core/messages"; +import { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"; import type { JsonStream } from "./utils/stream.js"; import { MediaManager } from "./experimental/utils/media_core.js"; +import { + AnthropicResponseData, + AnthropicAPIConfig, +} from "./types-anthropic.js"; + +export * from "./types-anthropic.js"; /** * Parameters needed to setup the client connection. @@ -45,10 +57,68 @@ export interface GoogleConnectionParams platformType?: GooglePlatformType; } +export const GoogleAISafetyCategory = { + Harassment: "HARM_CATEGORY_HARASSMENT", + HARASSMENT: "HARM_CATEGORY_HARASSMENT", + HARM_CATEGORY_HARASSMENT: "HARM_CATEGORY_HARASSMENT", + + HateSpeech: "HARM_CATEGORY_HATE_SPEECH", + HATE_SPEECH: "HARM_CATEGORY_HATE_SPEECH", + HARM_CATEGORY_HATE_SPEECH: "HARM_CATEGORY_HATE_SPEECH", + + SexuallyExplicit: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + SEXUALLY_EXPLICIT: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + HARM_CATEGORY_SEXUALLY_EXPLICIT: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + + Dangerous: "HARM_CATEGORY_DANGEROUS", + DANGEROUS: "HARM_CATEGORY_DANGEROUS", + HARM_CATEGORY_DANGEROUS: "HARM_CATEGORY_DANGEROUS", + + CivicIntegrity: "HARM_CATEGORY_CIVIC_INTEGRITY", + CIVIC_INTEGRITY: "HARM_CATEGORY_CIVIC_INTEGRITY", + HARM_CATEGORY_CIVIC_INTEGRITY: "HARM_CATEGORY_CIVIC_INTEGRITY", +} as const; + +export type GoogleAISafetyCategory = + (typeof GoogleAISafetyCategory)[keyof typeof GoogleAISafetyCategory]; + +export const GoogleAISafetyThreshold = { + None: "BLOCK_NONE", + NONE: "BLOCK_NONE", + BLOCK_NONE: "BLOCK_NONE", + + Few: "BLOCK_ONLY_HIGH", + FEW: "BLOCK_ONLY_HIGH", + BLOCK_ONLY_HIGH: "BLOCK_ONLY_HIGH", + + Some: "BLOCK_MEDIUM_AND_ABOVE", + SOME: "BLOCK_MEDIUM_AND_ABOVE", + BLOCK_MEDIUM_AND_ABOVE: "BLOCK_MEDIUM_AND_ABOVE", + + Most: "BLOCK_LOW_AND_ABOVE", + MOST: "BLOCK_LOW_AND_ABOVE", + BLOCK_LOW_AND_ABOVE: "BLOCK_LOW_AND_ABOVE", + + Off: "OFF", + OFF: "OFF", + BLOCK_OFF: "OFF", +} as const; + +export type GoogleAISafetyThreshold = + (typeof GoogleAISafetyThreshold)[keyof typeof GoogleAISafetyThreshold]; + +export const GoogleAISafetyMethod = { + Severity: "SEVERITY", + Probability: "PROBABILITY", +} as const; + +export type GoogleAISafetyMethod = + (typeof GoogleAISafetyMethod)[keyof typeof GoogleAISafetyMethod]; + export interface GoogleAISafetySetting { - category: string; - threshold: string; - method?: string; + category: GoogleAISafetyCategory | string; + threshold: GoogleAISafetyThreshold | string; + method?: GoogleAISafetyMethod | string; // Just for Vertex AI? } export type GoogleAIResponseMimeType = "text/plain" | "application/json"; @@ -149,7 +219,7 @@ export interface GoogleAIBaseLLMInput GoogleConnectionParams, GoogleAIModelParams, GoogleAISafetyParams, - GeminiAPIConfig {} + GoogleAIAPIParams {} export interface GoogleAIBaseLanguageModelCallOptions extends BaseChatModelCallOptions, @@ -314,13 +384,15 @@ export interface GenerateContentResponseData { export type GoogleLLMModelFamily = null | "palm" | "gemini"; +export type VertexModelFamily = GoogleLLMModelFamily | "claude"; + export type GoogleLLMResponseData = | JsonStream | GenerateContentResponseData | GenerateContentResponseData[]; export interface GoogleLLMResponse extends GoogleResponse { - data: GoogleLLMResponseData; + data: GoogleLLMResponseData | AnthropicResponseData; } export interface GoogleAISafetyHandler { @@ -348,6 +420,42 @@ export interface GeminiJsonSchemaDirty extends GeminiJsonSchema { additionalProperties?: boolean; } +export type GoogleAIAPI = { + messageContentToParts?: (content: MessageContent) => Promise; + + baseMessageToContent?: ( + message: BaseMessage, + prevMessage: BaseMessage | undefined, + useSystemInstruction: boolean + ) => Promise; + + responseToString: (response: GoogleLLMResponse) => string; + + responseToChatGeneration: ( + response: GoogleLLMResponse + ) => ChatGenerationChunk | null; + + chunkToString: (chunk: BaseMessageChunk) => string; + + responseToBaseMessage: (response: GoogleLLMResponse) => BaseMessage; + + responseToChatResult: (response: GoogleLLMResponse) => ChatResult; + + formatData: ( + input: unknown, + parameters: GoogleAIModelRequestParams + ) => Promise; +}; + export interface GeminiAPIConfig { + safetyHandler?: GoogleAISafetyHandler; mediaManager?: MediaManager; + useSystemInstruction?: boolean; +} + +export type GoogleAIAPIConfig = GeminiAPIConfig | AnthropicAPIConfig; + +export interface GoogleAIAPIParams { + apiName?: string; + apiConfig?: GoogleAIAPIConfig; } diff --git a/libs/langchain-google-common/src/utils/anthropic.ts b/libs/langchain-google-common/src/utils/anthropic.ts new file mode 100644 index 000000000000..72e1f9e57080 --- /dev/null +++ b/libs/langchain-google-common/src/utils/anthropic.ts @@ -0,0 +1,719 @@ +import { + ChatGeneration, + ChatGenerationChunk, + ChatResult, +} from "@langchain/core/outputs"; +import { + BaseMessage, + BaseMessageChunk, + AIMessageChunk, + MessageContentComplex, + MessageContentText, + MessageContent, + MessageContentImageUrl, + AIMessageFields, + AIMessageChunkFields, +} from "@langchain/core/messages"; +import { + ToolCall, + ToolCallChunk, + ToolMessage, +} from "@langchain/core/messages/tool"; +import { + AnthropicAPIConfig, + AnthropicContent, + AnthropicContentText, + AnthropicContentToolUse, + AnthropicMessage, + AnthropicMessageContent, + AnthropicMessageContentImage, + AnthropicMessageContentText, + AnthropicMessageContentToolResult, + AnthropicMessageContentToolResultContent, + AnthropicRequest, + AnthropicRequestSettings, + AnthropicResponseData, + AnthropicResponseMessage, + AnthropicStreamContentBlockDeltaEvent, + AnthropicStreamContentBlockStartEvent, + AnthropicStreamInputJsonDelta, + AnthropicStreamMessageDeltaEvent, + AnthropicStreamMessageStartEvent, + AnthropicStreamTextDelta, + AnthropicTool, + AnthropicToolChoice, + GeminiTool, + GoogleAIAPI, + GoogleAIModelParams, + GoogleAIModelRequestParams, + GoogleAIToolType, + GoogleLLMResponse, +} from "../types.js"; + +export function getAnthropicAPI(config?: AnthropicAPIConfig): GoogleAIAPI { + function partToString(part: AnthropicContent): string { + return "text" in part ? part.text : ""; + } + + function messageToString(message: AnthropicResponseMessage): string { + const content: AnthropicContent[] = message?.content ?? []; + const ret = content.reduce((acc, part) => { + const str = partToString(part); + return acc + str; + }, ""); + return ret; + } + + function responseToString(response: GoogleLLMResponse): string { + const data = response.data as AnthropicResponseData; + switch (data?.type) { + case "message": + return messageToString(data as AnthropicResponseMessage); + default: + throw Error(`Unknown type: ${data?.type}`); + } + } + + /** + * Normalize the AIMessageChunk. + * If the fields are just a string - use that as content. + * If the content is an array of just text fields, turn them into a string. + * @param fields + */ + function newAIMessageChunk(fields: string | AIMessageFields): AIMessageChunk { + if (typeof fields === "string") { + return new AIMessageChunk(fields); + } + const ret: AIMessageFields = { + ...fields, + }; + + if (Array.isArray(fields?.content)) { + let str: string | undefined = ""; + fields.content.forEach((val) => { + if (str !== undefined && val.type === "text") { + str = `${str}${val.text}`; + } else { + str = undefined; + } + }); + if (str) { + ret.content = str; + } + } + + return new AIMessageChunk(ret); + } + + function textContentToMessageFields( + textContent: AnthropicContentText + ): AIMessageFields { + return { + content: [textContent], + }; + } + + function toolUseContentToMessageFields( + toolUseContent: AnthropicContentToolUse + ): AIMessageFields { + const tool: ToolCall = { + id: toolUseContent.id, + name: toolUseContent.name, + type: "tool_call", + args: toolUseContent.input, + }; + return { + content: [], + tool_calls: [tool], + }; + } + + function anthropicContentToMessageFields( + anthropicContent: AnthropicContent + ): AIMessageFields | undefined { + const type = anthropicContent?.type; + switch (type) { + case "text": + return textContentToMessageFields(anthropicContent); + case "tool_use": + return toolUseContentToMessageFields(anthropicContent); + default: + return undefined; + } + } + + function contentToMessage( + anthropicContent: AnthropicContent[] + ): BaseMessageChunk { + const complexContent: MessageContentComplex[] = []; + const toolCalls: ToolCall[] = []; + anthropicContent.forEach((ac) => { + const messageFields = anthropicContentToMessageFields(ac); + if (messageFields?.content) { + complexContent.push( + ...(messageFields.content as MessageContentComplex[]) + ); + } + if (messageFields?.tool_calls) { + toolCalls.push(...messageFields.tool_calls); + } + }); + + const ret: AIMessageFields = { + content: complexContent, + tool_calls: toolCalls, + }; + return newAIMessageChunk(ret); + } + + function messageToGenerationInfo(message: AnthropicResponseMessage) { + const usage = message?.usage; + const usageMetadata: Record = { + input_tokens: usage?.input_tokens ?? 0, + output_tokens: usage?.output_tokens ?? 0, + total_tokens: (usage?.input_tokens ?? 0) + (usage?.output_tokens ?? 0), + }; + return { + usage_metadata: usageMetadata, + finish_reason: message.stop_reason, + }; + } + + function messageToChatGeneration( + responseMessage: AnthropicResponseMessage + ): ChatGenerationChunk { + const content: AnthropicContent[] = responseMessage?.content ?? []; + const text = messageToString(responseMessage); + const message = contentToMessage(content); + const generationInfo = messageToGenerationInfo(responseMessage); + return new ChatGenerationChunk({ + text, + message, + generationInfo, + }); + } + + function messageStartToChatGeneration( + event: AnthropicStreamMessageStartEvent + ): ChatGenerationChunk { + const responseMessage = event.message; + return messageToChatGeneration(responseMessage); + } + + function messageDeltaToChatGeneration( + event: AnthropicStreamMessageDeltaEvent + ): ChatGenerationChunk { + const responseMessage = event.delta; + return messageToChatGeneration(responseMessage as AnthropicResponseMessage); + } + + function contentBlockStartTextToChatGeneration( + event: AnthropicStreamContentBlockStartEvent + ): ChatGenerationChunk | null { + const content = event.content_block; + const message = contentToMessage([content]); + if (!message) { + return null; + } + + const text = "text" in content ? content.text : ""; + return new ChatGenerationChunk({ + message, + text, + }); + } + + function contentBlockStartToolUseToChatGeneration( + event: AnthropicStreamContentBlockStartEvent + ): ChatGenerationChunk | null { + const contentBlock = event.content_block as AnthropicContentToolUse; + const text: string = ""; + const toolChunk: ToolCallChunk = { + type: "tool_call_chunk", + index: event.index, + name: contentBlock.name, + id: contentBlock.id, + }; + if ( + typeof contentBlock.input === "object" && + Object.keys(contentBlock.input).length > 0 + ) { + toolChunk.args = JSON.stringify(contentBlock.input); + } + const toolChunks: ToolCallChunk[] = [toolChunk]; + + const content: MessageContentComplex[] = [ + { + index: event.index, + ...contentBlock, + }, + ]; + const messageFields: AIMessageChunkFields = { + content, + tool_call_chunks: toolChunks, + }; + const message = newAIMessageChunk(messageFields); + return new ChatGenerationChunk({ + message, + text, + }); + } + + function contentBlockStartToChatGeneration( + event: AnthropicStreamContentBlockStartEvent + ): ChatGenerationChunk | null { + switch (event.content_block.type) { + case "text": + return contentBlockStartTextToChatGeneration(event); + case "tool_use": + return contentBlockStartToolUseToChatGeneration(event); + default: + console.warn( + `Unexpected start content_block type: ${JSON.stringify(event)}` + ); + return null; + } + } + + function contentBlockDeltaTextToChatGeneration( + event: AnthropicStreamContentBlockDeltaEvent + ): ChatGenerationChunk { + const delta = event.delta as AnthropicStreamTextDelta; + const text = delta?.text; + const message = newAIMessageChunk(text); + return new ChatGenerationChunk({ + message, + text, + }); + } + + function contentBlockDeltaInputJsonDeltaToChatGeneration( + event: AnthropicStreamContentBlockDeltaEvent + ): ChatGenerationChunk { + const delta = event.delta as AnthropicStreamInputJsonDelta; + const text: string = ""; + const toolChunks: ToolCallChunk[] = [ + { + index: event.index, + args: delta.partial_json, + }, + ]; + const content: MessageContentComplex[] = [ + { + index: event.index, + ...delta, + }, + ]; + const messageFields: AIMessageChunkFields = { + content, + tool_call_chunks: toolChunks, + }; + const message = newAIMessageChunk(messageFields); + return new ChatGenerationChunk({ + message, + text, + }); + } + + function contentBlockDeltaToChatGeneration( + event: AnthropicStreamContentBlockDeltaEvent + ): ChatGenerationChunk | null { + switch (event.delta.type) { + case "text_delta": + return contentBlockDeltaTextToChatGeneration(event); + case "input_json_delta": + return contentBlockDeltaInputJsonDeltaToChatGeneration(event); + default: + console.warn( + `Unexpected delta content_block type: ${JSON.stringify(event)}` + ); + return null; + } + } + + function responseToChatGeneration( + response: GoogleLLMResponse + ): ChatGenerationChunk | null { + const data = response.data as AnthropicResponseData; + switch (data.type) { + case "message": + return messageToChatGeneration(data as AnthropicResponseMessage); + case "message_start": + return messageStartToChatGeneration( + data as AnthropicStreamMessageStartEvent + ); + case "message_delta": + return messageDeltaToChatGeneration( + data as AnthropicStreamMessageDeltaEvent + ); + case "content_block_start": + return contentBlockStartToChatGeneration( + data as AnthropicStreamContentBlockStartEvent + ); + case "content_block_delta": + return contentBlockDeltaToChatGeneration( + data as AnthropicStreamContentBlockDeltaEvent + ); + + case "ping": + case "message_stop": + case "content_block_stop": + // These are ignorable + return null; + + case "error": + throw new Error( + `Error while streaming results: ${JSON.stringify(data)}` + ); + + default: + // We don't know what type this is, but Anthropic may have added + // new ones without telling us. Don't error, but don't use them. + console.warn("Unknown data for responseToChatGeneration", data); + // throw new Error(`Unknown response type: ${data.type}`); + return null; + } + } + + function chunkToString(chunk: BaseMessageChunk): string { + if (chunk === null) { + return ""; + } else if (typeof chunk.content === "string") { + return chunk.content; + } else if (chunk.content.length === 0) { + return ""; + } else if (chunk.content[0].type === "text") { + return chunk.content[0].text; + } else { + throw new Error(`Unexpected chunk: ${chunk}`); + } + } + + function responseToBaseMessage(response: GoogleLLMResponse): BaseMessage { + const data = response.data as AnthropicResponseMessage; + const content: AnthropicContent[] = data?.content ?? []; + return contentToMessage(content); + } + + function responseToChatResult(response: GoogleLLMResponse): ChatResult { + const message = response.data as AnthropicResponseMessage; + const generations: ChatGeneration[] = []; + const gen = responseToChatGeneration(response); + if (gen) { + generations.push(gen); + } + const llmOutput = messageToGenerationInfo(message); + return { + generations, + llmOutput, + }; + } + + function formatAnthropicVersion(): string { + return config?.version ?? "vertex-2023-10-16"; + } + + function textContentToAnthropicContent( + content: MessageContentText + ): AnthropicMessageContentText { + return content; + } + + function extractMimeType( + str: string + ): { media_type: string; data: string } | null { + if (str.startsWith("data:")) { + return { + media_type: str.split(":")[1].split(";")[0], + data: str.split(",")[1], + }; + } + return null; + } + + function imageContentToAnthropicContent( + content: MessageContentImageUrl + ): AnthropicMessageContentImage | undefined { + const dataUrl = content.image_url; + const url = typeof dataUrl === "string" ? dataUrl : dataUrl?.url; + const urlInfo = extractMimeType(url); + + if (!urlInfo) { + return undefined; + } + + return { + type: "image", + source: { + type: "base64", + ...urlInfo, + }, + }; + } + + function contentComplexToAnthropicContent( + content: MessageContentComplex + ): AnthropicMessageContent | undefined { + const type = content?.type; + switch (type) { + case "text": + return textContentToAnthropicContent(content as MessageContentText); + case "image_url": + return imageContentToAnthropicContent( + content as MessageContentImageUrl + ); + default: + console.warn(`Unexpected content type: ${type}`); + return undefined; + } + } + + function contentToAnthropicContent( + content: MessageContent + ): AnthropicMessageContent[] { + const ret: AnthropicMessageContent[] = []; + + const ca = + typeof content === "string" ? [{ type: "text", text: content }] : content; + ca.forEach((complex) => { + const ac = contentComplexToAnthropicContent(complex); + if (ac) { + ret.push(ac); + } + }); + + return ret; + } + + function baseRoleToAnthropicMessage( + base: BaseMessage, + role: string + ): AnthropicMessage { + const content = contentToAnthropicContent(base.content); + return { + role, + content, + }; + } + + function toolMessageToAnthropicMessage(base: ToolMessage): AnthropicMessage { + const role = "user"; + const toolUseId = base.tool_call_id; + const toolContent = contentToAnthropicContent( + base.content + ) as AnthropicMessageContentToolResultContent[]; + const content: AnthropicMessageContentToolResult[] = [ + { + type: "tool_result", + tool_use_id: toolUseId, + content: toolContent, + }, + ]; + return { + role, + content, + }; + } + + function baseToAnthropicMessage( + base: BaseMessage + ): AnthropicMessage | undefined { + const type = base.getType(); + switch (type) { + case "human": + return baseRoleToAnthropicMessage(base, "user"); + case "ai": + return baseRoleToAnthropicMessage(base, "assistant"); + case "tool": + return toolMessageToAnthropicMessage(base as ToolMessage); + default: + return undefined; + } + } + + function formatMessages(input: BaseMessage[]): AnthropicMessage[] { + const ret: AnthropicMessage[] = []; + + input.forEach((baseMessage) => { + const anthropicMessage = baseToAnthropicMessage(baseMessage); + if (anthropicMessage) { + ret.push(anthropicMessage); + } + }); + + return ret; + } + + function formatSettings( + parameters: GoogleAIModelRequestParams + ): AnthropicRequestSettings { + const ret: AnthropicRequestSettings = { + stream: parameters?.streaming ?? false, + max_tokens: parameters?.maxOutputTokens ?? 8192, + }; + + if (parameters.topP) { + ret.top_p = parameters.topP; + } + if (parameters.topK) { + ret.top_k = parameters.topK; + } + if (parameters.temperature) { + ret.temperature = parameters.temperature; + } + if (parameters.stopSequences) { + ret.stop_sequences = parameters.stopSequences; + } + + return ret; + } + + function contentComplexArrayToText( + contentArray: MessageContentComplex[] + ): string { + let ret = ""; + + contentArray.forEach((content) => { + const contentType = content?.type; + if (contentType === "text") { + const textContent = content as MessageContentText; + ret = `${ret}\n${textContent.text}`; + } + }); + + return ret; + } + + function formatSystem(input: BaseMessage[]): string { + let ret = ""; + + input.forEach((message) => { + if (message._getType() === "system") { + const content = message?.content; + const contentString = + typeof content === "string" + ? (content as string) + : contentComplexArrayToText(content as MessageContentComplex[]); + ret = `${ret}\n${contentString}`; + } + }); + + return ret; + } + + function formatGeminiTool(tool: GeminiTool): AnthropicTool[] { + if (Object.hasOwn(tool, "functionDeclarations")) { + const funcs = tool?.functionDeclarations ?? []; + return funcs.map((func) => { + const inputSchema = func.parameters!; + return { + // type: "tool", // This may only be valid for models 20241022+ + name: func.name, + description: func.description, + input_schema: inputSchema, + }; + }); + } else { + console.warn( + `Unable to format GeminiTool: ${JSON.stringify(tool, null, 1)}` + ); + return []; + } + } + + function formatTool(tool: GoogleAIToolType): AnthropicTool[] { + if (Object.hasOwn(tool, "name")) { + return [tool as AnthropicTool]; + } else { + return formatGeminiTool(tool as GeminiTool); + } + } + + function formatTools( + parameters: GoogleAIModelRequestParams + ): AnthropicTool[] { + const tools: GoogleAIToolType[] = parameters?.tools ?? []; + const ret: AnthropicTool[] = []; + tools.forEach((tool) => { + const anthropicTools = formatTool(tool); + anthropicTools.forEach((anthropicTool) => { + if (anthropicTool) { + ret.push(anthropicTool); + } + }); + }); + return ret; + } + + function formatToolChoice( + parameters: GoogleAIModelRequestParams + ): AnthropicToolChoice | undefined { + const choice = parameters?.tool_choice; + if (!choice) { + return undefined; + } else if (typeof choice === "object") { + return choice as AnthropicToolChoice; + } else { + switch (choice) { + case "any": + case "auto": + return { + type: choice, + }; + case "none": + return undefined; + default: + return { + type: "tool", + name: choice, + }; + } + } + } + + async function formatData( + input: unknown, + parameters: GoogleAIModelRequestParams + ): Promise { + const typedInput = input as BaseMessage[]; + const anthropicVersion = formatAnthropicVersion(); + const messages = formatMessages(typedInput); + const settings = formatSettings(parameters); + const system = formatSystem(typedInput); + const tools = formatTools(parameters); + const toolChoice = formatToolChoice(parameters); + const ret: AnthropicRequest = { + anthropic_version: anthropicVersion, + messages, + ...settings, + }; + if (tools && tools.length && parameters?.tool_choice !== "none") { + ret.tools = tools; + } + if (toolChoice) { + ret.tool_choice = toolChoice; + } + if (system?.length) { + ret.system = system; + } + + return ret; + } + + return { + responseToString, + responseToChatGeneration, + chunkToString, + responseToBaseMessage, + responseToChatResult, + formatData, + }; +} + +export function validateClaudeParams(_params: GoogleAIModelParams): void { + // FIXME - validate the parameters +} + +export function isModelClaude(modelName: string): boolean { + return modelName.toLowerCase().startsWith("claude"); +} diff --git a/libs/langchain-google-common/src/utils/common.ts b/libs/langchain-google-common/src/utils/common.ts index b3aa2cba7b4b..bf8ddb228382 100644 --- a/libs/langchain-google-common/src/utils/common.ts +++ b/libs/langchain-google-common/src/utils/common.ts @@ -9,12 +9,13 @@ import type { GoogleAIModelParams, GoogleAIModelRequestParams, GoogleAIToolType, - GoogleLLMModelFamily, + VertexModelFamily, } from "../types.js"; import { jsonSchemaToGeminiParameters, zodToGeminiParameters, } from "./zod_to_gemini_parameters.js"; +import { isModelClaude, validateClaudeParams } from "./anthropic.js"; export function copyAIModelParams( params: GoogleAIModelParams | undefined, @@ -143,16 +144,33 @@ export function copyAIModelParamsInto( export function modelToFamily( modelName: string | undefined -): GoogleLLMModelFamily { +): VertexModelFamily { if (!modelName) { return null; } else if (isModelGemini(modelName)) { return "gemini"; + } else if (isModelClaude(modelName)) { + return "claude"; } else { return null; } } +export function modelToPublisher(modelName: string | undefined): string { + const family = modelToFamily(modelName); + switch (family) { + case "gemini": + case "palm": + return "google"; + + case "claude": + return "anthropic"; + + default: + return "unknown"; + } +} + export function validateModelParams( params: GoogleAIModelParams | undefined ): void { @@ -161,6 +179,10 @@ export function validateModelParams( switch (modelToFamily(model)) { case "gemini": return validateGeminiParams(testParams); + + case "claude": + return validateClaudeParams(testParams); + default: throw new Error( `Unable to verify model params: ${JSON.stringify(params)}` diff --git a/libs/langchain-google-common/src/utils/gemini.ts b/libs/langchain-google-common/src/utils/gemini.ts index 472f4c5725d8..cc8e994efec6 100644 --- a/libs/langchain-google-common/src/utils/gemini.ts +++ b/libs/langchain-google-common/src/utils/gemini.ts @@ -21,6 +21,8 @@ import { ChatResult, } from "@langchain/core/outputs"; import { ToolCallChunk } from "@langchain/core/messages/tool"; +import { StructuredToolParams } from "@langchain/core/tools"; +import { isLangChainTool } from "@langchain/core/utils/function_calling"; import type { GoogleLLMResponse, GoogleAIModelParams, @@ -33,10 +35,21 @@ import type { GenerateContentResponseData, GoogleAISafetyHandler, GeminiPartFunctionCall, + GoogleAIAPI, GeminiAPIConfig, } from "../types.js"; import { GoogleAISafetyError } from "./safety.js"; import { MediaBlob } from "../experimental/utils/media_core.js"; +import { + GeminiFunctionDeclaration, + GeminiGenerationConfig, + GeminiRequest, + GeminiSafetySetting, + GeminiTool, + GoogleAIModelRequestParams, + GoogleAIToolType, +} from "../types.js"; +import { zodToGeminiParameters } from "./zod_to_gemini_parameters.js"; export interface FunctionCall { name: string; @@ -60,6 +73,128 @@ export interface ToolCallRaw { function: FunctionCallRaw; } +export interface DefaultGeminiSafetySettings { + errorFinish?: string[]; +} + +export class DefaultGeminiSafetyHandler implements GoogleAISafetyHandler { + errorFinish = ["SAFETY", "RECITATION", "OTHER"]; + + constructor(settings?: DefaultGeminiSafetySettings) { + this.errorFinish = settings?.errorFinish ?? this.errorFinish; + } + + handleDataPromptFeedback( + response: GoogleLLMResponse, + data: GenerateContentResponseData + ): GenerateContentResponseData { + // Check to see if our prompt was blocked in the first place + const promptFeedback = data?.promptFeedback; + const blockReason = promptFeedback?.blockReason; + if (blockReason) { + throw new GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`); + } + return data; + } + + handleDataFinishReason( + response: GoogleLLMResponse, + data: GenerateContentResponseData + ): GenerateContentResponseData { + const firstCandidate = data?.candidates?.[0]; + const finishReason = firstCandidate?.finishReason; + if (this.errorFinish.includes(finishReason)) { + throw new GoogleAISafetyError(response, `Finish reason: ${finishReason}`); + } + return data; + } + + handleData( + response: GoogleLLMResponse, + data: GenerateContentResponseData + ): GenerateContentResponseData { + let ret = data; + ret = this.handleDataPromptFeedback(response, ret); + ret = this.handleDataFinishReason(response, ret); + return ret; + } + + handle(response: GoogleLLMResponse): GoogleLLMResponse { + let newdata; + + if ("nextChunk" in response.data) { + // TODO: This is a stream. How to handle? + newdata = response.data; + } else if (Array.isArray(response.data)) { + // If it is an array, try to handle every item in the array + try { + newdata = response.data.map((item) => this.handleData(response, item)); + } catch (xx) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (xx instanceof GoogleAISafetyError) { + throw new GoogleAISafetyError(response, xx.message); + } else { + throw xx; + } + } + } else { + const data = response.data as GenerateContentResponseData; + newdata = this.handleData(response, data); + } + + return { + ...response, + data: newdata, + }; + } +} + +export interface MessageGeminiSafetySettings + extends DefaultGeminiSafetySettings { + msg?: string; + forceNewMessage?: boolean; +} + +export class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler { + msg: string = ""; + + forceNewMessage = false; + + constructor(settings?: MessageGeminiSafetySettings) { + super(settings); + this.msg = settings?.msg ?? this.msg; + this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage; + } + + setMessage(data: GenerateContentResponseData): GenerateContentResponseData { + const ret = data; + if ( + this.forceNewMessage || + !data?.candidates?.[0]?.content?.parts?.length + ) { + ret.candidates = data.candidates ?? []; + ret.candidates[0] = data.candidates[0] ?? {}; + ret.candidates[0].content = data.candidates[0].content ?? {}; + ret.candidates[0].content = { + role: "model", + parts: [{ text: this.msg }], + }; + } + return ret; + } + + handleData( + response: GoogleLLMResponse, + data: GenerateContentResponseData + ): GenerateContentResponseData { + try { + return super.handleData(response, data); + } catch (xx) { + return this.setMessage(data); + } + } +} + const extractMimeType = ( str: string ): { mimeType: string; data: string } | null => { @@ -72,7 +207,7 @@ const extractMimeType = ( return null; }; -export function getGeminiAPI(config?: GeminiAPIConfig) { +export function getGeminiAPI(config?: GeminiAPIConfig): GoogleAIAPI { function messageContentText( content: MessageContentText ): GeminiPartText | null { @@ -153,7 +288,9 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { } } - throw new Error("Invalid media content"); + throw new Error( + `Invalid media content: ${JSON.stringify(content, null, 1)}` + ); } async function messageContentComplexToPart( @@ -175,7 +312,7 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { return await messageContentMedia(content); default: throw new Error( - `Unsupported type received while converting message to message parts` + `Unsupported type "${content.type}" received while converting message to message parts: ${content}` ); } throw new Error( @@ -282,10 +419,9 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { } async function systemMessageToContent( - message: SystemMessage, - useSystemInstruction: boolean + message: SystemMessage ): Promise { - return useSystemInstruction + return config?.useSystemInstruction ? roleMessageToContent("system", message) : [ ...(await roleMessageToContent("user", message)), @@ -349,16 +485,12 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { async function baseMessageToContent( message: BaseMessage, - prevMessage: BaseMessage | undefined, - useSystemInstruction: boolean + prevMessage: BaseMessage | undefined ): Promise { const type = message._getType(); switch (type) { case "system": - return systemMessageToContent( - message as SystemMessage, - useSystemInstruction - ); + return systemMessageToContent(message as SystemMessage); case "human": return roleMessageToContent("user", message); case "ai": @@ -519,9 +651,10 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { function safeResponseTo( response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler, responseTo: (response: GoogleLLMResponse) => RetType ): RetType { + const safetyHandler = + config?.safetyHandler ?? new DefaultGeminiSafetyHandler(); try { const safeResponse = safetyHandler.handle(response); return responseTo(safeResponse); @@ -535,11 +668,8 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { } } - function safeResponseToString( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler - ): string { - return safeResponseTo(response, safetyHandler, responseToString); + function safeResponseToString(response: GoogleLLMResponse): string { + return safeResponseTo(response, responseToString); } function responseToGenerationInfo(response: GoogleLLMResponse) { @@ -575,10 +705,9 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { } function safeResponseToChatGeneration( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler + response: GoogleLLMResponse ): ChatGenerationChunk { - return safeResponseTo(response, safetyHandler, responseToChatGeneration); + return safeResponseTo(response, responseToChatGeneration); } function chunkToString(chunk: BaseMessageChunk): string { @@ -724,11 +853,8 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { return new AIMessage(fields); } - function safeResponseToBaseMessage( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler - ): BaseMessage { - return safeResponseTo(response, safetyHandler, responseToBaseMessage); + function safeResponseToBaseMessage(response: GoogleLLMResponse): BaseMessage { + return safeResponseTo(response, responseToBaseMessage); } function responseToChatResult(response: GoogleLLMResponse): ChatResult { @@ -739,167 +865,269 @@ export function getGeminiAPI(config?: GeminiAPIConfig) { }; } - function safeResponseToChatResult( - response: GoogleLLMResponse, - safetyHandler: GoogleAISafetyHandler - ): ChatResult { - return safeResponseTo(response, safetyHandler, responseToChatResult); - } - - return { - messageContentToParts, - baseMessageToContent, - safeResponseToString, - safeResponseToChatGeneration, - chunkToString, - safeResponseToBaseMessage, - safeResponseToChatResult, - }; -} - -export function validateGeminiParams(params: GoogleAIModelParams): void { - if (params.maxOutputTokens && params.maxOutputTokens < 0) { - throw new Error("`maxOutputTokens` must be a positive integer"); - } - - if ( - params.temperature && - (params.temperature < 0 || params.temperature > 2) - ) { - throw new Error("`temperature` must be in the range of [0.0,2.0]"); + function safeResponseToChatResult(response: GoogleLLMResponse): ChatResult { + return safeResponseTo(response, responseToChatResult); } - if (params.topP && (params.topP < 0 || params.topP > 1)) { - throw new Error("`topP` must be in the range of [0.0,1.0]"); + function inputType( + input: MessageContent | BaseMessage[] + ): "MessageContent" | "BaseMessageArray" { + if (typeof input === "string") { + return "MessageContent"; + } else { + const firstItem: BaseMessage | MessageContentComplex = input[0]; + if (Object.hasOwn(firstItem, "content")) { + return "BaseMessageArray"; + } else { + return "MessageContent"; + } + } } - if (params.topK && params.topK < 0) { - throw new Error("`topK` must be a positive integer"); + async function formatMessageContents( + input: MessageContent, + _parameters: GoogleAIModelParams + ): Promise { + const parts = await messageContentToParts!(input); + const contents: GeminiContent[] = [ + { + role: "user", // Required by Vertex AI + parts, + }, + ]; + return contents; } -} -export function isModelGemini(modelName: string): boolean { - return modelName.toLowerCase().startsWith("gemini"); -} + async function formatBaseMessageContents( + input: BaseMessage[], + _parameters: GoogleAIModelParams + ): Promise { + const inputPromises: Promise[] = input.map((msg, i) => + baseMessageToContent!(msg, input[i - 1]) + ); + const inputs = await Promise.all(inputPromises); -export interface DefaultGeminiSafetySettings { - errorFinish?: string[]; -} + return inputs.reduce((acc, cur) => { + // Filter out the system content + if (cur.every((content) => content.role === "system")) { + return acc; + } -export class DefaultGeminiSafetyHandler implements GoogleAISafetyHandler { - errorFinish = ["SAFETY", "RECITATION", "OTHER"]; + // Combine adjacent function messages + if ( + cur[0]?.role === "function" && + acc.length > 0 && + acc[acc.length - 1].role === "function" + ) { + acc[acc.length - 1].parts = [ + ...acc[acc.length - 1].parts, + ...cur[0].parts, + ]; + } else { + acc.push(...cur); + } - constructor(settings?: DefaultGeminiSafetySettings) { - this.errorFinish = settings?.errorFinish ?? this.errorFinish; + return acc; + }, [] as GeminiContent[]); } - handleDataPromptFeedback( - response: GoogleLLMResponse, - data: GenerateContentResponseData - ): GenerateContentResponseData { - // Check to see if our prompt was blocked in the first place - const promptFeedback = data?.promptFeedback; - const blockReason = promptFeedback?.blockReason; - if (blockReason) { - throw new GoogleAISafetyError(response, `Prompt blocked: ${blockReason}`); + async function formatContents( + input: MessageContent | BaseMessage[], + parameters: GoogleAIModelRequestParams + ): Promise { + const it = inputType(input); + switch (it) { + case "MessageContent": + return formatMessageContents(input as MessageContent, parameters); + case "BaseMessageArray": + return formatBaseMessageContents(input as BaseMessage[], parameters); + default: + throw new Error(`Unknown input type "${it}": ${input}`); } - return data; } - handleDataFinishReason( - response: GoogleLLMResponse, - data: GenerateContentResponseData - ): GenerateContentResponseData { - const firstCandidate = data?.candidates?.[0]; - const finishReason = firstCandidate?.finishReason; - if (this.errorFinish.includes(finishReason)) { - throw new GoogleAISafetyError(response, `Finish reason: ${finishReason}`); - } - return data; + function formatGenerationConfig( + parameters: GoogleAIModelRequestParams + ): GeminiGenerationConfig { + return { + temperature: parameters.temperature, + topK: parameters.topK, + topP: parameters.topP, + maxOutputTokens: parameters.maxOutputTokens, + stopSequences: parameters.stopSequences, + responseMimeType: parameters.responseMimeType, + }; } - handleData( - response: GoogleLLMResponse, - data: GenerateContentResponseData - ): GenerateContentResponseData { - let ret = data; - ret = this.handleDataPromptFeedback(response, ret); - ret = this.handleDataFinishReason(response, ret); - return ret; + function formatSafetySettings( + parameters: GoogleAIModelRequestParams + ): GeminiSafetySetting[] { + return parameters.safetySettings ?? []; } - handle(response: GoogleLLMResponse): GoogleLLMResponse { - let newdata; - - if ("nextChunk" in response.data) { - // TODO: This is a stream. How to handle? - newdata = response.data; - } else if (Array.isArray(response.data)) { - // If it is an array, try to handle every item in the array - try { - newdata = response.data.map((item) => this.handleData(response, item)); - } catch (xx) { - // eslint-disable-next-line no-instanceof/no-instanceof - if (xx instanceof GoogleAISafetyError) { - throw new GoogleAISafetyError(response, xx.message); + async function formatBaseMessageSystemInstruction( + input: BaseMessage[] + ): Promise { + let ret = {} as GeminiContent; + for (let index = 0; index < input.length; index += 1) { + const message = input[index]; + if (message._getType() === "system") { + // For system types, we only want it if it is the first message, + // if it appears anywhere else, it should be an error. + if (index === 0) { + // eslint-disable-next-line prefer-destructuring + ret = (await baseMessageToContent!(message, undefined))[0]; } else { - throw xx; + throw new Error( + "System messages are only permitted as the first passed message." + ); } } - } else { - const data = response.data as GenerateContentResponseData; - newdata = this.handleData(response, data); } + return ret; + } + + async function formatSystemInstruction( + input: MessageContent | BaseMessage[] + ): Promise { + if (!config?.useSystemInstruction) { + return {} as GeminiContent; + } + + const it = inputType(input); + switch (it) { + case "BaseMessageArray": + return formatBaseMessageSystemInstruction(input as BaseMessage[]); + default: + return {} as GeminiContent; + } + } + + function structuredToolToFunctionDeclaration( + tool: StructuredToolParams + ): GeminiFunctionDeclaration { + const jsonSchema = zodToGeminiParameters(tool.schema); return { - ...response, - data: newdata, + name: tool.name, + description: tool.description ?? `A function available to call.`, + parameters: jsonSchema, }; } -} -export interface MessageGeminiSafetySettings - extends DefaultGeminiSafetySettings { - msg?: string; - forceNewMessage?: boolean; -} + function structuredToolsToGeminiTools( + tools: StructuredToolParams[] + ): GeminiTool[] { + return [ + { + functionDeclarations: tools.map(structuredToolToFunctionDeclaration), + }, + ]; + } -export class MessageGeminiSafetyHandler extends DefaultGeminiSafetyHandler { - msg: string = ""; + function formatTools(parameters: GoogleAIModelRequestParams): GeminiTool[] { + const tools: GoogleAIToolType[] | undefined = parameters?.tools; + if (!tools || tools.length === 0) { + return []; + } - forceNewMessage = false; + if (tools.every(isLangChainTool)) { + return structuredToolsToGeminiTools(tools); + } else { + if ( + tools.length === 1 && + (!("functionDeclarations" in tools[0]) || + !tools[0].functionDeclarations?.length) + ) { + return []; + } + return tools as GeminiTool[]; + } + } - constructor(settings?: MessageGeminiSafetySettings) { - super(settings); - this.msg = settings?.msg ?? this.msg; - this.forceNewMessage = settings?.forceNewMessage ?? this.forceNewMessage; + function formatToolConfig( + parameters: GoogleAIModelRequestParams + ): GeminiRequest["toolConfig"] | undefined { + if (!parameters.tool_choice || typeof parameters.tool_choice !== "string") { + return undefined; + } + + return { + functionCallingConfig: { + mode: parameters.tool_choice as "auto" | "any" | "none", + allowedFunctionNames: parameters.allowed_function_names, + }, + }; } - setMessage(data: GenerateContentResponseData): GenerateContentResponseData { - const ret = data; + async function formatData( + input: unknown, + parameters: GoogleAIModelRequestParams + ): Promise { + const typedInput = input as MessageContent | BaseMessage[]; + const contents = await formatContents(typedInput, parameters); + const generationConfig = formatGenerationConfig(parameters); + const tools = formatTools(parameters); + const toolConfig = formatToolConfig(parameters); + const safetySettings = formatSafetySettings(parameters); + const systemInstruction = await formatSystemInstruction(typedInput); + + const ret: GeminiRequest = { + contents, + generationConfig, + }; + if (tools && tools.length) { + ret.tools = tools; + } + if (toolConfig) { + ret.toolConfig = toolConfig; + } + if (safetySettings && safetySettings.length) { + ret.safetySettings = safetySettings; + } if ( - this.forceNewMessage || - !data?.candidates?.[0]?.content?.parts?.length + systemInstruction?.role && + systemInstruction?.parts && + systemInstruction?.parts?.length ) { - ret.candidates = data.candidates ?? []; - ret.candidates[0] = data.candidates[0] ?? {}; - ret.candidates[0].content = data.candidates[0].content ?? {}; - ret.candidates[0].content = { - role: "model", - parts: [{ text: this.msg }], - }; + ret.systemInstruction = systemInstruction; } return ret; } - handleData( - response: GoogleLLMResponse, - data: GenerateContentResponseData - ): GenerateContentResponseData { - try { - return super.handleData(response, data); - } catch (xx) { - return this.setMessage(data); - } + return { + messageContentToParts, + baseMessageToContent, + responseToString: safeResponseToString, + responseToChatGeneration: safeResponseToChatGeneration, + chunkToString, + responseToBaseMessage: safeResponseToBaseMessage, + responseToChatResult: safeResponseToChatResult, + formatData, + }; +} + +export function validateGeminiParams(params: GoogleAIModelParams): void { + if (params.maxOutputTokens && params.maxOutputTokens < 0) { + throw new Error("`maxOutputTokens` must be a positive integer"); + } + + if ( + params.temperature && + (params.temperature < 0 || params.temperature > 2) + ) { + throw new Error("`temperature` must be in the range of [0.0,2.0]"); + } + + if (params.topP && (params.topP < 0 || params.topP > 1)) { + throw new Error("`topP` must be in the range of [0.0,1.0]"); + } + + if (params.topK && params.topK < 0) { + throw new Error("`topK` must be a positive integer"); } } + +export function isModelGemini(modelName: string): boolean { + return modelName.toLowerCase().startsWith("gemini"); +} diff --git a/libs/langchain-google-common/src/utils/stream.ts b/libs/langchain-google-common/src/utils/stream.ts index 226ac49dca10..2a61446864f3 100644 --- a/libs/langchain-google-common/src/utils/stream.ts +++ b/libs/langchain-google-common/src/utils/stream.ts @@ -1,5 +1,34 @@ import { GenerationChunk } from "@langchain/core/outputs"; +export interface AbstractStream { + /** + * Add more text to the buffer + * @param data + */ + appendBuffer(data: string): void; + + /** + * Indicate that there is no more text to be added to the buffer + * (ie - our source material is done) + */ + closeBuffer(): void; + /** + * Get the next chunk that is coming from the stream. + * This chunk may be null, usually indicating the last chunk in the stream. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + nextChunk(): Promise; + + /** + * Is the stream done? + * A stream is only done if all of the following are true: + * - There is no more data to be added to the text buffer + * - There is no more data in the text buffer + * - There are no chunks that are waiting to be consumed + */ + get streamDone(): boolean; +} + export function complexValue(value: unknown): unknown { if (value === null || typeof value === "undefined") { // I dunno what to put here. An error, probably @@ -68,8 +97,7 @@ export function simpleValue(val: unknown): unknown { return val; } } - -export class JsonStream { +export class JsonStream implements AbstractStream { _buffer = ""; _bufferOpen = true; @@ -247,11 +275,13 @@ export class ComplexJsonStream extends JsonStream { } } -export class ReadableJsonStream extends JsonStream { +export class ReadableAbstractStream implements AbstractStream { + private baseStream: AbstractStream; + decoder: TextDecoder; - constructor(body: ReadableStream | null) { - super(); + constructor(baseStream: AbstractStream, body: ReadableStream | null) { + this.baseStream = baseStream; this.decoder = new TextDecoder("utf-8"); if (body) { void this.run(body); @@ -260,6 +290,23 @@ export class ReadableJsonStream extends JsonStream { } } + appendBuffer(data: string): void { + return this.baseStream.appendBuffer(data); + } + + closeBuffer(): void { + return this.baseStream.closeBuffer(); + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + nextChunk(): Promise { + return this.baseStream.nextChunk(); + } + + get streamDone(): boolean { + return this.baseStream.streamDone; + } + async run(body: ReadableStream) { const reader = body.getReader(); let isDone = false; @@ -275,3 +322,147 @@ export class ReadableJsonStream extends JsonStream { } } } + +export class ReadableJsonStream extends ReadableAbstractStream { + constructor(body: ReadableStream | null) { + super(new JsonStream(), body); + } +} + +export class SseStream implements AbstractStream { + _buffer = ""; + + _bufferOpen = true; + + appendBuffer(data: string): void { + this._buffer += data; + this._parseBuffer(); + } + + closeBuffer(): void { + this._bufferOpen = false; + this._parseBuffer(); + } + + /** + * Attempt to load an entire event. + * For each entire event we load, + * send them to be handled. + */ + _parseBuffer(): void { + const events = this._buffer.split(/\n\n/); + this._buffer = events.pop() ?? ""; + events.forEach((event) => this._handleEvent(event.trim())); + + if (!this._bufferOpen) { + // No more data will be added, and we have parsed + // everything. So dump the rest. + this._handleEvent(null); + this._buffer = ""; + } + } + + /** + * Given an event string, get all the fields + * in the event. It is assumed there is one field + * per line, but that field names can be duplicated, + * indicating to append the new value to the previous value + * @param event + */ + _parseEvent(event: string | null): Record | null { + if (!event || event.trim() === "") { + return null; + } + const ret: Record = {}; + + const lines = event.split(/\n/); + lines.forEach((line) => { + const match = line.match(/^([^:]+): \s*(.+)\n*$/); + if (match && match.length === 3) { + const key = match[1]; + const val = match[2]; + const cur = ret[key] ?? ""; + ret[key] = `${cur}${val}`; + } + }); + + return ret; + } + + // Set up a potential Promise that the handler can resolve. + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkResolution: (chunk: any) => void; + + // If there is no Promise (it is null), the handler must add it to the queue + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkPending: Promise | null = null; + + // A queue that will collect chunks while there is no Promise + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _chunkQueue: any[] = []; + + _handleEvent(event: string | null): void { + const chunk = this._parseEvent(event); + if (this._chunkPending) { + this._chunkResolution(chunk); + this._chunkPending = null; + } else { + this._chunkQueue.push(chunk); + } + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async nextChunk(): Promise { + if (this._chunkQueue.length > 0) { + // If there is data in the queue, return the next queue chunk + return this._chunkQueue.shift() as Record; + } else { + // Otherwise, set up a promise that handleChunk will cause to be resolved + this._chunkPending = new Promise((resolve) => { + this._chunkResolution = resolve; + }); + return this._chunkPending; + } + } + + get streamDone(): boolean { + return ( + !this._bufferOpen && + this._buffer.length === 0 && + this._chunkQueue.length === 0 && + this._chunkPending === null + ); + } +} + +export class ReadableSseStream extends ReadableAbstractStream { + constructor(body: ReadableStream | null) { + super(new SseStream(), body); + } +} + +export class SseJsonStream extends SseStream { + _jsonAttribute: string = "data"; + + constructor(jsonAttribute?: string) { + super(); + this._jsonAttribute = jsonAttribute ?? this._jsonAttribute; + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async nextChunk(): Promise { + const eventRecord = (await super.nextChunk()) as Record; + const json = eventRecord?.[this._jsonAttribute]; + if (!json) { + return null; + } else { + return JSON.parse(json); + } + } +} + +export class ReadableSseJsonStream extends ReadableAbstractStream { + constructor(body: ReadableStream | null) { + super(new SseJsonStream(), body); + } +} diff --git a/libs/langchain-google-gauth/src/auth.ts b/libs/langchain-google-gauth/src/auth.ts index 21093bcbce42..bb8053b9c521 100644 --- a/libs/langchain-google-gauth/src/auth.ts +++ b/libs/langchain-google-gauth/src/auth.ts @@ -1,16 +1,21 @@ import { Readable } from "stream"; import { + AbstractStream, ensureAuthOptionScopes, GoogleAbstractedClient, GoogleAbstractedClientOps, GoogleConnectionParams, JsonStream, + SseJsonStream, + SseStream, } from "@langchain/google-common"; import { GoogleAuth, GoogleAuthOptions } from "google-auth-library"; -export class NodeJsonStream extends JsonStream { - constructor(data: Readable) { - super(); +export class NodeAbstractStream implements AbstractStream { + private baseStream: AbstractStream; + + constructor(baseStream: AbstractStream, data: Readable) { + this.baseStream = baseStream; const decoder = new TextDecoder("utf-8"); data.on("data", (data) => { const text = decoder.decode(data, { stream: true }); @@ -22,6 +27,41 @@ export class NodeJsonStream extends JsonStream { this.closeBuffer(); }); } + + appendBuffer(data: string): void { + return this.baseStream.appendBuffer(data); + } + + closeBuffer(): void { + return this.baseStream.closeBuffer(); + } + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + nextChunk(): Promise { + return this.baseStream.nextChunk(); + } + + get streamDone(): boolean { + return this.baseStream.streamDone; + } +} + +export class NodeJsonStream extends NodeAbstractStream { + constructor(data: Readable) { + super(new JsonStream(), data); + } +} + +export class NodeSseStream extends NodeAbstractStream { + constructor(data: Readable) { + super(new SseStream(), data); + } +} + +export class NodeSseJsonStream extends NodeAbstractStream { + constructor(data: Readable) { + super(new SseJsonStream(), data); + } } export class GAuthClient implements GoogleAbstractedClient { @@ -47,12 +87,21 @@ export class GAuthClient implements GoogleAbstractedClient { async request(opts: GoogleAbstractedClientOps): Promise { try { const ret = await this.gauth.request(opts); - return opts.responseType !== "stream" - ? ret - : { - ...ret, - data: new NodeJsonStream(ret.data), - }; + const [contentType] = ret?.headers?.["content-type"]?.split(/;/) ?? [""]; + if (opts.responseType !== "stream") { + return ret; + } else if (contentType === "text/event-stream") { + return { + ...ret, + data: new NodeSseJsonStream(ret.data), + }; + } else { + return { + ...ret, + data: new NodeJsonStream(ret.data), + }; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (xx: any) { console.error("call to gauth.request", JSON.stringify(xx, null, 2)); diff --git a/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts index 2bc8f483b13b..0284189ebec4 100644 --- a/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-gauth/src/tests/chat_models.int.test.ts @@ -258,7 +258,7 @@ describe("GAuth Chat", () => { actionIfBlobMissing: undefined, }, }); - const canonicalStore = new BlobStoreGoogleCloudStorage({ + const backingStore = new BlobStoreGoogleCloudStorage({ uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), defaultStoreOptions: { actionIfInvalid: "prefixPath", @@ -266,7 +266,7 @@ describe("GAuth Chat", () => { }); const blobStore = new ReadThroughBlobStore({ baseStore: aliasStore, - backingStore: canonicalStore, + backingStore, }); const resolver = new SimpleWebBlobStore(); const mediaManager = new MediaManager({ @@ -275,7 +275,9 @@ describe("GAuth Chat", () => { }); const model = new ChatGoogle({ modelName: "gemini-1.5-flash", - mediaManager, + apiConfig: { + mediaManager, + }, }); const message: MessageContentComplex[] = [ @@ -285,7 +287,7 @@ describe("GAuth Chat", () => { }, { type: "media", - fileUri: "https://js.langchain.com/img/brand/wordmark.png", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", }, ]; diff --git a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts index dcac30321b53..a3b8bbe4b2d8 100644 --- a/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts @@ -1,4 +1,4 @@ -import { test } from "@jest/globals"; +import { expect, test } from "@jest/globals"; import fs from "fs/promises"; import { BaseLanguageModelInput } from "@langchain/core/language_models/base"; import { ChatPromptValue } from "@langchain/core/prompt_values"; @@ -34,12 +34,44 @@ import { MessagesPlaceholder, } from "@langchain/core/prompts"; import { InMemoryStore } from "@langchain/core/stores"; +import { BaseCallbackHandler } from "@langchain/core/callbacks/base"; +import { + GoogleRequestLogger, + GoogleRequestRecorder, +} from "@langchain/google-common"; import { GeminiTool } from "../types.js"; import { ChatVertexAI } from "../chat_models.js"; -describe("GAuth Chat", () => { +const weatherTool = tool((_) => "no-op", { + name: "get_weather", + description: + "Get the weather of a specific location and return the temperature in Celsius.", + schema: z.object({ + location: z.string().describe("The name of city to get the weather for."), + }), +}); + +const calculatorTool = tool((_) => "no-op", { + name: "calculator", + description: "Calculate the result of a math expression.", + schema: z.object({ + expression: z.string().describe("The math expression to calculate."), + }), +}); + +describe("GAuth Gemini Chat", () => { + let recorder: GoogleRequestRecorder; + let callbacks: BaseCallbackHandler[]; + + beforeEach(() => { + recorder = new GoogleRequestRecorder(); + callbacks = [recorder, new GoogleRequestLogger()]; + }); + test("invoke", async () => { - const model = new ChatVertexAI(); + const model = new ChatVertexAI({ + callbacks, + }); const res = await model.invoke("What is 1 + 1?"); expect(res).toBeDefined(); expect(res._getType()).toEqual("ai"); @@ -75,7 +107,9 @@ describe("GAuth Chat", () => { }); test("stream", async () => { - const model = new ChatVertexAI(); + const model = new ChatVertexAI({ + callbacks, + }); const input: BaseLanguageModelInput = new ChatPromptValue([ new SystemMessage( "You will reply to all requests to flip a coin with either H, indicating heads, or T, indicating tails." @@ -225,7 +259,7 @@ describe("GAuth Chat", () => { actionIfBlobMissing: undefined, }, }); - const canonicalStore = new BlobStoreGoogleCloudStorage({ + const backingStore = new BlobStoreGoogleCloudStorage({ uriPrefix: new GoogleCloudStorageUri("gs://test-langchainjs/mediatest/"), defaultStoreOptions: { actionIfInvalid: "prefixPath", @@ -233,7 +267,7 @@ describe("GAuth Chat", () => { }); const blobStore = new ReadThroughBlobStore({ baseStore: aliasStore, - backingStore: canonicalStore, + backingStore, }); const resolver = new SimpleWebBlobStore(); const mediaManager = new MediaManager({ @@ -242,7 +276,9 @@ describe("GAuth Chat", () => { }); const model = new ChatGoogle({ modelName: "gemini-1.5-flash", - mediaManager, + apiConfig: { + mediaManager, + }, }); const message: MessageContentComplex[] = [ @@ -252,7 +288,7 @@ describe("GAuth Chat", () => { }, { type: "media", - fileUri: "https://js.langchain.com/img/brand/wordmark.png", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", }, ]; @@ -279,208 +315,305 @@ describe("GAuth Chat", () => { throw e; } }); -}); -test("Stream token count usage_metadata", async () => { - const model = new ChatVertexAI({ - temperature: 0, - maxOutputTokens: 10, - }); - let res: AIMessageChunk | null = null; - for await (const chunk of await model.stream( - "Why is the sky blue? Be concise." - )) { - if (!res) { - res = chunk; - } else { - res = res.concat(chunk); + test("Stream token count usage_metadata", async () => { + const model = new ChatVertexAI({ + temperature: 0, + maxOutputTokens: 10, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } } - } - // console.log(res); - expect(res?.usage_metadata).toBeDefined(); - if (!res?.usage_metadata) { - return; - } - expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.total_tokens).toBe( - res.usage_metadata.input_tokens + res.usage_metadata.output_tokens - ); -}); - -test("streamUsage excludes token usage", async () => { - const model = new ChatVertexAI({ - temperature: 0, - streamUsage: false, - }); - let res: AIMessageChunk | null = null; - for await (const chunk of await model.stream( - "Why is the sky blue? Be concise." - )) { - if (!res) { - res = chunk; - } else { - res = res.concat(chunk); + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; } - } - // console.log(res); - expect(res?.usage_metadata).not.toBeDefined(); -}); + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); + }); -test("Invoke token count usage_metadata", async () => { - const model = new ChatVertexAI({ - temperature: 0, - maxOutputTokens: 10, + test("streamUsage excludes token usage", async () => { + const model = new ChatVertexAI({ + temperature: 0, + streamUsage: false, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + // console.log(res); + expect(res?.usage_metadata).not.toBeDefined(); }); - const res = await model.invoke("Why is the sky blue? Be concise."); - // console.log(res); - expect(res?.usage_metadata).toBeDefined(); - if (!res?.usage_metadata) { - return; - } - expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.total_tokens).toBe( - res.usage_metadata.input_tokens + res.usage_metadata.output_tokens - ); -}); -test("Streaming true constructor param will stream", async () => { - const modelWithStreaming = new ChatVertexAI({ - maxOutputTokens: 50, - streaming: true, + test("Invoke token count usage_metadata", async () => { + const model = new ChatVertexAI({ + temperature: 0, + maxOutputTokens: 10, + }); + const res = await model.invoke("Why is the sky blue? Be concise."); + // console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(1); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); }); - let totalTokenCount = 0; - let tokensString = ""; - const result = await modelWithStreaming.invoke("What is 1 + 1?", { - callbacks: [ - { - handleLLMNewToken: (tok) => { - totalTokenCount += 1; - tokensString += tok; + test("Streaming true constructor param will stream", async () => { + const modelWithStreaming = new ChatVertexAI({ + maxOutputTokens: 50, + streaming: true, + }); + + let totalTokenCount = 0; + let tokensString = ""; + const result = await modelWithStreaming.invoke("What is 1 + 1?", { + callbacks: [ + { + handleLLMNewToken: (tok) => { + totalTokenCount += 1; + tokensString += tok; + }, }, - }, - ], + ], + }); + + expect(result).toBeDefined(); + expect(result.content).toBe(tokensString); + + expect(totalTokenCount).toBeGreaterThan(1); }); - expect(result).toBeDefined(); - expect(result.content).toBe(tokensString); + test("Can force a model to invoke a tool", async () => { + const model = new ChatVertexAI({ + model: "gemini-1.5-pro", + }); + const modelWithTools = model.bind({ + tools: [calculatorTool, weatherTool], + tool_choice: "calculator", + }); - expect(totalTokenCount).toBeGreaterThan(1); -}); + const result = await modelWithTools.invoke( + "Whats the weather like in paris today? What's 1836 plus 7262?" + ); -test("Can force a model to invoke a tool", async () => { - const model = new ChatVertexAI({ - model: "gemini-1.5-pro", - }); - const weatherTool = tool((_) => "no-op", { - name: "get_weather", - description: - "Get the weather of a specific location and return the temperature in Celsius.", - schema: z.object({ - location: z.string().describe("The name of city to get the weather for."), - }), - }); - const calculatorTool = tool((_) => "no-op", { - name: "calculator", - description: "Calculate the result of a math expression.", - schema: z.object({ - expression: z.string().describe("The math expression to calculate."), - }), - }); - const modelWithTools = model.bind({ - tools: [calculatorTool, weatherTool], - tool_choice: "calculator", + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + if (!result.tool_calls?.[0]) return; + expect(result.tool_calls?.[0].name).toBe("calculator"); + expect(result.tool_calls?.[0].args).toHaveProperty("expression"); }); - const result = await modelWithTools.invoke( - "Whats the weather like in paris today? What's 1836 plus 7262?" - ); + test("ChatGoogleGenerativeAI can stream tools", async () => { + const model = new ChatVertexAI({}); - expect(result.tool_calls).toHaveLength(1); - expect(result.tool_calls?.[0]).toBeDefined(); - if (!result.tool_calls?.[0]) return; - expect(result.tool_calls?.[0].name).toBe("calculator"); - expect(result.tool_calls?.[0].args).toHaveProperty("expression"); -}); + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z.string().describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } -test("ChatGoogleGenerativeAI can stream tools", async () => { - const model = new ChatVertexAI({}); + expect(finalChunk).toBeDefined(); + if (!finalChunk) return; - const weatherTool = tool( - (_) => "The weather in San Francisco today is 18 degrees and sunny.", - { - name: "current_weather_tool", - description: "Get the current weather for a given location.", - schema: z.object({ - location: z.string().describe("The location to get the weather for."), - }), + const toolCalls = finalChunk.tool_calls; + expect(toolCalls).toBeDefined(); + if (!toolCalls) { + throw new Error("tool_calls not in response"); } - ); - - const modelWithTools = model.bindTools([weatherTool]); - const stream = await modelWithTools.stream( - "Whats the weather like today in San Francisco?" - ); - let finalChunk: AIMessageChunk | undefined; - for await (const chunk of stream) { - finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + expect(toolCalls.length).toBe(1); + expect(toolCalls[0].name).toBe("current_weather_tool"); + expect(toolCalls[0].args).toHaveProperty("location"); + }); + + async function fileToBase64(filePath: string): Promise { + const fileData = await fs.readFile(filePath); + const base64String = Buffer.from(fileData).toString("base64"); + return base64String; } - expect(finalChunk).toBeDefined(); - if (!finalChunk) return; + test("Gemini can understand audio", async () => { + // Update this with the correct path to an audio file on your machine. + const audioPath = + "../langchain-google-genai/src/tests/data/gettysburg10.wav"; + const audioMimeType = "audio/wav"; - const toolCalls = finalChunk.tool_calls; - expect(toolCalls).toBeDefined(); - if (!toolCalls) { - throw new Error("tool_calls not in response"); - } - expect(toolCalls.length).toBe(1); - expect(toolCalls[0].name).toBe("current_weather_tool"); - expect(toolCalls[0].args).toHaveProperty("location"); + const model = new ChatVertexAI({ + model: "gemini-1.5-flash", + temperature: 0, + maxRetries: 0, + }); + + const audioBase64 = await fileToBase64(audioPath); + + const prompt = ChatPromptTemplate.fromMessages([ + new MessagesPlaceholder("audio"), + ]); + + const chain = prompt.pipe(model); + const response = await chain.invoke({ + audio: new HumanMessage({ + content: [ + { + type: "media", + mimeType: audioMimeType, + data: audioBase64, + }, + { + type: "text", + text: "Summarize the content in this audio. ALso, what is the speaker's tone?", + }, + ], + }), + }); + + expect(typeof response.content).toBe("string"); + expect((response.content as string).length).toBeGreaterThan(15); + }); }); -async function fileToBase64(filePath: string): Promise { - const fileData = await fs.readFile(filePath); - const base64String = Buffer.from(fileData).toString("base64"); - return base64String; -} - -test("Gemini can understand audio", async () => { - // Update this with the correct path to an audio file on your machine. - const audioPath = "../langchain-google-genai/src/tests/data/gettysburg10.wav"; - const audioMimeType = "audio/wav"; - - const model = new ChatVertexAI({ - model: "gemini-1.5-flash", - temperature: 0, - maxRetries: 0, +describe("GAuth Anthropic Chat", () => { + let recorder: GoogleRequestRecorder; + let callbacks: BaseCallbackHandler[]; + + // const modelName: string = "claude-3-5-sonnet@20240620"; + // const modelName: string = "claude-3-sonnet@20240229"; + const modelName: string = "claude-3-5-sonnet-v2@20241022"; + + beforeEach(() => { + recorder = new GoogleRequestRecorder(); + callbacks = [recorder, new GoogleRequestLogger()]; }); - const audioBase64 = await fileToBase64(audioPath); + test("invoke", async () => { + const model = new ChatVertexAI({ + modelName, + callbacks, + }); + const res = await model.invoke("What is 1 + 1?"); + expect(res).toBeDefined(); + expect(res._getType()).toEqual("ai"); - const prompt = ChatPromptTemplate.fromMessages([ - new MessagesPlaceholder("audio"), - ]); + const aiMessage = res as AIMessageChunk; + expect(aiMessage.content).toBeDefined(); - const chain = prompt.pipe(model); - const response = await chain.invoke({ - audio: new HumanMessage({ - content: [ - { - type: "media", - mimeType: audioMimeType, - data: audioBase64, - }, - { - type: "text", - text: "Summarize the content in this audio. ALso, what is the speaker's tone?", - }, - ], - }), + expect(typeof aiMessage.content).toBe("string"); + const text = aiMessage.content as string; + expect(text).toMatch(/(1 + 1 (equals|is|=) )?2.? ?/); + + const connection = recorder?.request?.connection; + expect(connection?.url).toEqual( + `https://us-east5-aiplatform.googleapis.com/v1/projects/test-vertex-ai-382612/locations/us-east5/publishers/anthropic/models/${modelName}:rawPredict` + ); + + console.log(JSON.stringify(aiMessage, null, 1)); + console.log(aiMessage.lc_kwargs); }); - expect(typeof response.content).toBe("string"); - expect((response.content as string).length).toBeGreaterThan(15); + test("stream", async () => { + const model = new ChatVertexAI({ + modelName, + callbacks, + }); + const stream = await model.stream("How are you today? Be verbose."); + const chunks = []; + for await (const chunk of stream) { + console.log(chunk); + chunks.push(chunk); + } + expect(chunks.length).toBeGreaterThan(1); + }); + + test("tool invocation", async () => { + const model = new ChatVertexAI({ + modelName, + callbacks, + }); + const modelWithTools = model.bind({ + tools: [weatherTool], + }); + + const result = await modelWithTools.invoke( + "Whats the weather like in paris today?" + ); + + const request = recorder?.request ?? {}; + const data = request?.data; + expect(data).toHaveProperty("tools"); + expect(data.tools).toHaveLength(1); + + expect(result.tool_calls).toHaveLength(1); + expect(result.tool_calls?.[0]).toBeDefined(); + expect(result.tool_calls?.[0].name).toBe("get_weather"); + expect(result.tool_calls?.[0].args).toHaveProperty("location"); + }); + + test("stream tools", async () => { + const model = new ChatVertexAI({ + modelName, + callbacks, + }); + + const weatherTool = tool( + (_) => "The weather in San Francisco today is 18 degrees and sunny.", + { + name: "current_weather_tool", + description: "Get the current weather for a given location.", + schema: z.object({ + location: z.string().describe("The location to get the weather for."), + }), + } + ); + + const modelWithTools = model.bindTools([weatherTool]); + const stream = await modelWithTools.stream( + "Whats the weather like today in San Francisco?" + ); + let finalChunk: AIMessageChunk | undefined; + for await (const chunk of stream) { + finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk); + } + + expect(finalChunk).toBeDefined(); + const toolCalls = finalChunk?.tool_calls; + expect(toolCalls).toBeDefined(); + expect(toolCalls?.length).toBe(1); + expect(toolCalls?.[0].name).toBe("current_weather_tool"); + expect(toolCalls?.[0].args).toHaveProperty("location"); + }); }); diff --git a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts index 106b685bfee8..0e10359599b3 100644 --- a/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts +++ b/libs/langchain-google-webauth/src/tests/chat_models.int.test.ts @@ -204,7 +204,9 @@ describe("Google APIKey Chat", () => { const model = new ChatGoogle({ modelName: "gemini-1.5-flash", apiVersion: "v1beta", - mediaManager, + apiConfig: { + mediaManager, + }, }); const message: MessageContentComplex[] = [ @@ -214,7 +216,7 @@ describe("Google APIKey Chat", () => { }, { type: "media", - fileUri: "https://js.langchain.com/img/brand/wordmark.png", + fileUri: "https://js.langchain.com/v0.2/img/brand/wordmark.png", }, ];