From 818c888e162a456767135e343c426d6479daa8e2 Mon Sep 17 00:00:00 2001 From: FilipZmijewski Date: Thu, 30 Jan 2025 19:30:43 +0100 Subject: [PATCH 1/3] feature: Add chat deployment to chat class (#41) * Rename auth method in docs * Rename auth method in docs * Add deployment chat to chat class * Upadate Watsonx sdk * Rework interfaces in llms as well * Bump watsonx-ai sdk version * Remove unused code * Add fake auth --- libs/langchain-community/package.json | 2 +- .../src/chat_models/ibm.ts | 179 ++++++++++++------ .../src/chat_models/tests/ibm.int.test.ts | 35 +++- .../src/chat_models/tests/ibm.test.ts | 57 +++++- libs/langchain-community/src/llms/ibm.ts | 111 +++++++---- .../src/llms/tests/ibm.test.ts | 15 +- libs/langchain-community/src/types/ibm.ts | 22 ++- libs/langchain-community/src/utils/ibm.ts | 10 +- yarn.lock | 13 +- 9 files changed, 324 insertions(+), 120 deletions(-) diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 6b59f23ee454..8467575c9c46 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -79,7 +79,7 @@ "@gradientai/nodejs-sdk": "^1.2.0", "@huggingface/inference": "^2.6.4", "@huggingface/transformers": "^3.2.3", - "@ibm-cloud/watsonx-ai": "^1.3.0", + "@ibm-cloud/watsonx-ai": "^1.4.0", "@jest/globals": "^29.5.0", "@lancedb/lancedb": "^0.13.0", "@langchain/core": "workspace:*", diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index 992419649fb1..17e80922a8db 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -33,6 +33,7 @@ import { } from "@langchain/core/outputs"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { + DeploymentsTextChatParams, RequestCallbacks, TextChatMessagesTextChatMessageAssistant, TextChatParameterTools, @@ -65,7 +66,13 @@ import { import { isZodSchema } from "@langchain/core/utils/types"; import { zodToJsonSchema } from "zod-to-json-schema"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; -import { WatsonxAuth, WatsonxParams } from "../types/ibm.js"; +import { + Neverify, + WatsonxAuth, + WatsonxChatBasicOptions, + WatsonxDeployedParams, + WatsonxParams, +} from "../types/ibm.js"; import { _convertToolCallIdToMistralCompatible, authenticateAndSetInstance, @@ -80,16 +87,24 @@ export interface WatsonxDeltaStream { } export interface WatsonxCallParams - extends Partial> { - maxRetries?: number; - watsonxCallbacks?: RequestCallbacks; -} + extends Partial< + Omit + > {} + +export interface WatsonxCallDeployedParams extends DeploymentsTextChatParams {} + export interface WatsonxCallOptionsChat extends Omit, - WatsonxCallParams { + WatsonxCallParams, + WatsonxChatBasicOptions { promptIndex?: number; tool_choice?: TextChatParameterTools | string | "auto" | "any"; - watsonxCallbacks?: RequestCallbacks; +} + +export interface WatsonxCallOptionsDeployedChat + extends WatsonxCallDeployedParams, + WatsonxChatBasicOptions { + promptIndex?: number; } type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; @@ -97,10 +112,18 @@ type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; export interface ChatWatsonxInput extends BaseChatModelParams, WatsonxParams, - WatsonxCallParams { - streaming?: boolean; -} + WatsonxCallParams, + Neverify {} + +export interface ChatWatsonxDeployedInput + extends BaseChatModelParams, + WatsonxDeployedParams, + Neverify {} +export type ChatWatsonxConstructor = BaseChatModelParams & + Partial & + WatsonxDeployedParams & + WatsonxCallParams; function _convertToValidToolId(model: string, tool_call_id: string) { if (model.startsWith("mistralai")) return _convertToolCallIdToMistralCompatible(tool_call_id); @@ -335,10 +358,12 @@ function _convertToolChoiceToWatsonxToolChoice( } export class ChatWatsonx< - CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat + CallOptions extends WatsonxCallOptionsChat = + | WatsonxCallOptionsChat + | WatsonxCallOptionsDeployedChat > extends BaseChatModel - implements ChatWatsonxInput + implements ChatWatsonxConstructor { static lc_name() { return "ChatWatsonx"; @@ -380,8 +405,8 @@ export class ChatWatsonx< ls_provider: "watsonx", ls_model_name: this.model, ls_model_type: "chat", - ls_temperature: params.temperature ?? undefined, - ls_max_tokens: params.maxTokens ?? undefined, + ls_temperature: params?.temperature ?? undefined, + ls_max_tokens: params?.maxTokens ?? undefined, }; } @@ -399,6 +424,8 @@ export class ChatWatsonx< projectId?: string; + idOrName?: string; + frequencyPenalty?: number; logprobs?: boolean; @@ -425,37 +452,44 @@ export class ChatWatsonx< watsonxCallbacks?: RequestCallbacks; - constructor(fields: ChatWatsonxInput & WatsonxAuth) { + constructor( + fields: (ChatWatsonxInput | ChatWatsonxDeployedInput) & WatsonxAuth + ) { super(fields); if ( - (fields.projectId && fields.spaceId) || - (fields.idOrName && fields.projectId) || - (fields.spaceId && fields.idOrName) + ("projectId" in fields && "spaceId" in fields) || + ("projectId" in fields && "idOrName" in fields) || + ("spaceId" in fields && "idOrName" in fields) ) throw new Error("Maximum 1 id type can be specified per instance"); - if (!fields.projectId && !fields.spaceId && !fields.idOrName) + if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields)) throw new Error( "No id specified! At least id of 1 type has to be specified" ); - this.projectId = fields?.projectId; - this.spaceId = fields?.spaceId; - this.temperature = fields?.temperature; - this.maxRetries = fields?.maxRetries || this.maxRetries; - this.maxConcurrency = fields?.maxConcurrency; - this.frequencyPenalty = fields?.frequencyPenalty; - this.topLogprobs = fields?.topLogprobs; - this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.presencePenalty = fields?.presencePenalty; - this.topP = fields?.topP; - this.timeLimit = fields?.timeLimit; - this.responseFormat = fields?.responseFormat ?? this.responseFormat; + + if ("model" in fields) { + this.projectId = fields?.projectId; + this.spaceId = fields?.spaceId; + this.temperature = fields?.temperature; + this.maxRetries = fields?.maxRetries || this.maxRetries; + this.maxConcurrency = fields?.maxConcurrency; + this.frequencyPenalty = fields?.frequencyPenalty; + this.topLogprobs = fields?.topLogprobs; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.presencePenalty = fields?.presencePenalty; + this.topP = fields?.topP; + this.timeLimit = fields?.timeLimit; + this.responseFormat = fields?.responseFormat ?? this.responseFormat; + this.streaming = fields?.streaming ?? this.streaming; + this.n = fields?.n ?? this.n; + this.model = fields?.model ?? this.model; + } else this.idOrName = fields?.idOrName; + + this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks; this.serviceUrl = fields?.serviceUrl; - this.streaming = fields?.streaming ?? this.streaming; - this.n = fields?.n ?? this.n; - this.model = fields?.model ?? this.model; this.version = fields?.version ?? this.version; - this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks; + const { watsonxAIApikey, watsonxAIAuthType, @@ -486,6 +520,11 @@ export class ChatWatsonx< } invocationParams(options: this["ParsedCallOptions"]) { + const { signal, promptIndex, ...rest } = options; + if (this.idOrName && Object.keys(rest).length > 0) + throw new Error("Options cannot be provided to a deployed model"); + if (this.idOrName) return undefined; + const params = { maxTokens: options.maxTokens ?? this.maxTokens, temperature: options?.temperature ?? this.temperature, @@ -521,10 +560,16 @@ export class ChatWatsonx< } as CallOptions); } - scopeId() { + scopeId(): + | { idOrName: string } + | { projectId: string; modelId: string } + | { spaceId: string; modelId: string } { if (this.projectId) return { projectId: this.projectId, modelId: this.model }; - else return { spaceId: this.spaceId, modelId: this.model }; + else if (this.spaceId) + return { spaceId: this.spaceId, modelId: this.model }; + else if (this.idOrName) return { idOrName: this.idOrName }; + else throw new Error("No scope id provided"); } async completionWithRetry( @@ -595,23 +640,30 @@ export class ChatWatsonx< .map(([_, value]) => value); return { generations, llmOutput: { tokenUsage } }; } else { - const params = { - ...this.invocationParams(options), - ...this.scopeId(), - }; + const params = this.invocationParams(options); + const scopeId = this.scopeId(); const watsonxCallbacks = this.invocationCallbacks(options); const watsonxMessages = _convertMessagesToWatsonxMessages( messages, this.model ); const callback = () => - this.service.textChat( - { - ...params, - messages: watsonxMessages, - }, - watsonxCallbacks - ); + "idOrName" in scopeId + ? this.service.deploymentsTextChat( + { + ...scopeId, + messages: watsonxMessages, + }, + watsonxCallbacks + ) + : this.service.textChat( + { + ...params, + ...scopeId, + messages: watsonxMessages, + }, + watsonxCallbacks + ); const { result } = await this.completionWithRetry(callback, options); const generations: ChatGeneration[] = []; for (const part of result.choices) { @@ -646,21 +698,33 @@ export class ChatWatsonx< options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - const params = { ...this.invocationParams(options), ...this.scopeId() }; + const params = this.invocationParams(options); + const scopeId = this.scopeId(); const watsonxMessages = _convertMessagesToWatsonxMessages( messages, this.model ); const watsonxCallbacks = this.invocationCallbacks(options); const callback = () => - this.service.textChatStream( - { - ...params, - messages: watsonxMessages, - returnObject: true, - }, - watsonxCallbacks - ); + "idOrName" in scopeId + ? this.service.deploymentsTextChatStream( + { + ...scopeId, + messages: watsonxMessages, + returnObject: true, + }, + watsonxCallbacks + ) + : this.service.textChatStream( + { + ...params, + ...scopeId, + messages: watsonxMessages, + returnObject: true, + }, + watsonxCallbacks + ); + const stream = await this.completionWithRetry(callback, options); let defaultRole; let usage: TextChatUsage | undefined; @@ -707,7 +771,6 @@ export class ChatWatsonx< if (message === null || (!delta.content && !delta.tool_calls)) { continue; } - const generationChunk = new ChatGenerationChunk({ message, text: delta.content ?? "", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts index ae47345a1add..1cdc836ba9c8 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts @@ -16,7 +16,7 @@ import { ChatWatsonx } from "../ibm.js"; describe("Tests for chat", () => { describe("Test ChatWatsonx invoke and generate", () => { - test("Basic invoke", async () => { + test("Basic invoke with projectId", async () => { const service = new ChatWatsonx({ model: "mistralai/mistral-large", version: "2024-05-31", @@ -26,6 +26,37 @@ describe("Tests for chat", () => { const res = await service.invoke("Print hello world"); expect(res).toBeInstanceOf(AIMessage); }); + test("Basic invoke with spaceId", async () => { + const service = new ChatWatsonx({ + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + spaceId: process.env.WATSONX_AI_SPACE_ID ?? "testString", + }); + const res = await service.invoke("Print hello world"); + expect(res).toBeInstanceOf(AIMessage); + }); + test("Basic invoke with idOrName", async () => { + const service = new ChatWatsonx({ + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString", + }); + const res = await service.invoke("Print hello world"); + expect(res).toBeInstanceOf(AIMessage); + }); + test("Invalide invoke with idOrName and options as second argument", async () => { + const service = new ChatWatsonx({ + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", + idOrName: process.env.WATSONX_AI_ID_OR_NAME ?? "testString", + }); + await expect(() => + service.invoke("Print hello world", { + maxTokens: 100, + }) + ).rejects.toThrow("Options cannot be provided to a deployed model"); + }); test("Basic generate", async () => { const service = new ChatWatsonx({ model: "mistralai/mistral-large", @@ -710,7 +741,7 @@ describe("Tests for chat", () => { test("Schema with zod and stream", async () => { const service = new ChatWatsonx({ - model: "mistralai/mistral-large", + model: "meta-llama/llama-3-1-70b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.test.ts index f52a689f6755..b35b59d8ccbd 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.test.ts @@ -1,7 +1,12 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-explicit-any */ import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; -import { ChatWatsonx, ChatWatsonxInput, WatsonxCallParams } from "../ibm.js"; +import { + ChatWatsonx, + ChatWatsonxConstructor, + ChatWatsonxInput, + WatsonxCallParams, +} from "../ibm.js"; import { authenticateAndSetInstance } from "../../utils/ibm.js"; const fakeAuthProp = { @@ -13,7 +18,7 @@ export function getKey(key: K): K { } export const testProperties = ( instance: ChatWatsonx, - testProps: ChatWatsonxInput, + testProps: ChatWatsonxConstructor, notExTestProps?: { [key: string]: any } ) => { const checkProperty = ( @@ -24,13 +29,19 @@ export const testProperties = ( Object.keys(testProps).forEach((key) => { const keys = getKey(key); type Type = Pick; - if (typeof testProps[key as keyof T] === "object") - checkProperty(testProps[key as keyof T], instance[key], existing); + checkProperty( + testProps[key as keyof T], + instance[key as keyof typeof instance], + existing + ); else { if (existing) - expect(instance[key as keyof T]).toBe(testProps[key as keyof T]); - else if (instance) expect(instance[key as keyof T]).toBeUndefined(); + expect(instance[key as keyof typeof instance]).toBe( + testProps[key as keyof T] + ); + else if (instance) + expect(instance[key as keyof typeof instance]).toBeUndefined(); } }); }; @@ -62,6 +73,40 @@ describe("LLM unit tests", () => { testProperties(instance, testProps); }); + test("Authenticate with projectId", async () => { + const testProps = { + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + + test("Authenticate with spaceId", async () => { + const testProps = { + model: "mistralai/mistral-large", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + spaceId: process.env.WATSONX_AI_SPACE_ID || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + + test("Authenticate with idOrName", async () => { + const testProps = { + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + idOrName: process.env.WATSONX_AI_ID_OR_NAME || "testString", + }; + const instance = new ChatWatsonx({ ...testProps, ...fakeAuthProp }); + testProperties(instance, testProps); + }); + test("Test methods after init", () => { const testProps: ChatWatsonxInput = { model: "mistralai/mistral-large", diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts index 75e65fd6873d..97fb287a982a 100644 --- a/libs/langchain-community/src/llms/ibm.ts +++ b/libs/langchain-community/src/llms/ibm.ts @@ -3,7 +3,6 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseLLM, BaseLLMParams } from "@langchain/core/language_models/llms"; import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; import { - DeploymentTextGenProperties, RequestCallbacks, ReturnOptionProperties, TextGenLengthPenalty, @@ -21,9 +20,11 @@ import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { authenticateAndSetInstance } from "../utils/ibm.js"; import { GenerationInfo, + Neverify, ResponseChunk, TokenUsage, WatsonxAuth, + WatsonxDeployedParams, WatsonxParams, } from "../types/ibm.js"; @@ -31,15 +32,7 @@ import { * Input to LLM class. */ -export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions { - maxRetries?: number; - parameters?: Partial; - idOrName?: string; - watsonxCallbacks?: RequestCallbacks; -} - -export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams { - streaming?: boolean; +export interface WatsonxLLMParams { maxNewTokens?: number; decodingMethod?: TextGenParameters.Constants.DecodingMethod | string; lengthPenalty?: TextGenLengthPenalty; @@ -54,9 +47,36 @@ export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams { truncateInpuTokens?: number; returnOptions?: ReturnOptionProperties; includeStopSequence?: boolean; +} + +export interface WatsonxDeploymentLLMParams { + idOrName: string; +} + +export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions { + maxRetries?: number; + parameters?: Partial; watsonxCallbacks?: RequestCallbacks; } +export interface WatsonxInputLLM + extends WatsonxParams, + BaseLLMParams, + WatsonxLLMParams, + Neverify {} + +export interface WatsonxDeployedInputLLM + extends WatsonxDeployedParams, + BaseLLMParams, + Neverify { + model?: never; +} + +export type WatsonxLLMConstructor = BaseLLMParams & + WatsonxLLMParams & + Partial & + WatsonxDeployedParams; + /** * Integration with an LLM. */ @@ -64,7 +84,7 @@ export class WatsonxLLM< CallOptions extends WatsonxCallOptionsLLM = WatsonxCallOptionsLLM > extends BaseLLM - implements WatsonxInputLLM + implements WatsonxLLMConstructor { // Used for tracing, replace with the same name as your class static lc_name() { @@ -123,43 +143,51 @@ export class WatsonxLLM< private service: WatsonXAI; - constructor(fields: WatsonxInputLLM & WatsonxAuth) { + constructor( + fields: (WatsonxInputLLM | WatsonxDeployedInputLLM) & WatsonxAuth + ) { super(fields); - this.model = fields.model ?? this.model; - this.version = fields.version; - this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens; - this.serviceUrl = fields.serviceUrl; - this.decodingMethod = fields.decodingMethod; - this.lengthPenalty = fields.lengthPenalty; - this.minNewTokens = fields.minNewTokens; - this.randomSeed = fields.randomSeed; - this.stopSequence = fields.stopSequence; - this.temperature = fields.temperature; - this.timeLimit = fields.timeLimit; - this.topK = fields.topK; - this.topP = fields.topP; - this.repetitionPenalty = fields.repetitionPenalty; - this.truncateInpuTokens = fields.truncateInpuTokens; - this.returnOptions = fields.returnOptions; - this.includeStopSequence = fields.includeStopSequence; + + if (fields.model) { + this.model = fields.model ?? this.model; + this.version = fields.version; + this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens; + this.serviceUrl = fields.serviceUrl; + this.decodingMethod = fields.decodingMethod; + this.lengthPenalty = fields.lengthPenalty; + this.minNewTokens = fields.minNewTokens; + this.randomSeed = fields.randomSeed; + this.stopSequence = fields.stopSequence; + this.temperature = fields.temperature; + this.timeLimit = fields.timeLimit; + this.topK = fields.topK; + this.topP = fields.topP; + this.repetitionPenalty = fields.repetitionPenalty; + this.truncateInpuTokens = fields.truncateInpuTokens; + this.returnOptions = fields.returnOptions; + this.includeStopSequence = fields.includeStopSequence; + this.projectId = fields?.projectId; + this.spaceId = fields?.spaceId; + } else { + this.idOrName = fields?.idOrName; + } + this.maxRetries = fields.maxRetries || this.maxRetries; this.maxConcurrency = fields.maxConcurrency; this.streaming = fields.streaming || this.streaming; this.watsonxCallbacks = fields.watsonxCallbacks || this.watsonxCallbacks; + if ( - (fields.projectId && fields.spaceId) || - (fields.idOrName && fields.projectId) || - (fields.spaceId && fields.idOrName) + ("projectId" in fields && "spaceId" in fields) || + ("projectId" in fields && "idOrName" in fields) || + ("spaceId" in fields && "idOrName" in fields) ) throw new Error("Maximum 1 id type can be specified per instance"); - if (!fields.projectId && !fields.spaceId && !fields.idOrName) + if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields)) throw new Error( "No id specified! At least id of 1 type has to be specified" ); - this.projectId = fields?.projectId; - this.spaceId = fields?.spaceId; - this.idOrName = fields?.idOrName; this.serviceUrl = fields?.serviceUrl; const { @@ -215,11 +243,12 @@ export class WatsonxLLM< }; } - invocationParams( - options: this["ParsedCallOptions"] - ): TextGenParameters | DeploymentTextGenProperties { + invocationParams(options: this["ParsedCallOptions"]) { const { parameters } = options; - + const { signal, ...rest } = options; + if (this.idOrName && Object.keys(rest).length > 0) + throw new Error("Options cannot be provided to a deployed model"); + if (this.idOrName) return undefined; return { max_new_tokens: parameters?.maxNewTokens ?? this.maxNewTokens, decoding_method: parameters?.decodingMethod ?? this.decodingMethod, @@ -293,7 +322,7 @@ export class WatsonxLLM< ...requestOptions } = options; const tokenUsage = { generated_token_count: 0, input_token_count: 0 }; - const idOrName = options?.idOrName ?? this.idOrName; + const idOrName = this.idOrName; const parameters = this.invocationParams(options); const watsonxCallbacks = this.invocationCallbacks(options); if (stream) { diff --git a/libs/langchain-community/src/llms/tests/ibm.test.ts b/libs/langchain-community/src/llms/tests/ibm.test.ts index 6237cb1d14c1..0669af2f811b 100644 --- a/libs/langchain-community/src/llms/tests/ibm.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.test.ts @@ -1,7 +1,7 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-explicit-any */ import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; -import { WatsonxLLM, WatsonxInputLLM } from "../ibm.js"; +import { WatsonxLLM, WatsonxInputLLM, WatsonxLLMConstructor } from "../ibm.js"; import { authenticateAndSetInstance } from "../../utils/ibm.js"; import { WatsonxEmbeddings } from "../../embeddings/ibm.js"; @@ -14,7 +14,7 @@ export function getKey(key: K): K { } export const testProperties = ( instance: WatsonxLLM | WatsonxEmbeddings, - testProps: WatsonxInputLLM, + testProps: WatsonxLLMConstructor, notExTestProps?: { [key: string]: any } ) => { const checkProperty = ( @@ -63,6 +63,17 @@ describe("LLM unit tests", () => { testProperties(instance, testProps); }); + test("Test basic properties after init", async () => { + const testProps = { + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + idOrName: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const instance = new WatsonxLLM({ ...testProps, ...fakeAuthProp }); + + testProperties(instance, testProps); + }); + test("Test methods after init", () => { const testProps: WatsonxInputLLM = { model: "ibm/granite-13b-chat-v2", diff --git a/libs/langchain-community/src/types/ibm.ts b/libs/langchain-community/src/types/ibm.ts index ee5db8532036..f5d4b72de7b4 100644 --- a/libs/langchain-community/src/types/ibm.ts +++ b/libs/langchain-community/src/types/ibm.ts @@ -1,3 +1,5 @@ +import { RequestCallbacks } from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; + export interface TokenUsage { generated_token_count: number; input_token_count: number; @@ -17,13 +19,27 @@ export interface WatsonxInit { version: string; } -export interface WatsonxParams extends WatsonxInit { +export interface WatsonxChatBasicOptions { + maxConcurrency?: number; + maxRetries?: number; + streaming?: boolean; + watsonxCallbacks?: RequestCallbacks; +} + +export interface WatsonxParams extends WatsonxInit, WatsonxChatBasicOptions { model: string; spaceId?: string; projectId?: string; +} + +export type Neverify = { + [K in keyof T]?: never; +}; + +export interface WatsonxDeployedParams + extends WatsonxInit, + WatsonxChatBasicOptions { idOrName?: string; - maxConcurrency?: number; - maxRetries?: number; } export interface GenerationInfo { diff --git a/libs/langchain-community/src/utils/ibm.ts b/libs/langchain-community/src/utils/ibm.ts index ccbe1204ef60..8786a0263198 100644 --- a/libs/langchain-community/src/utils/ibm.ts +++ b/libs/langchain-community/src/utils/ibm.ts @@ -184,10 +184,18 @@ export class WatsonxToolsOutputParser< const tool = message.tool_calls; return tool; }); + if (tools[0] === undefined) { - if (this.latestCorrect) tools.push(this.latestCorrect); + if (this.latestCorrect) { + tools.push(this.latestCorrect); + } else { + const toolCall: ToolCall = { name: "", args: {} }; + tools.push(toolCall); + } } + const [tool] = tools; + tool.name = ""; this.latestCorrect = tool; return tool.args as T; } diff --git a/yarn.lock b/yarn.lock index fefd13294652..fe891b2b0113 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10669,14 +10669,15 @@ __metadata: languageName: node linkType: hard -"@ibm-cloud/watsonx-ai@npm:^1.3.0": - version: 1.3.0 - resolution: "@ibm-cloud/watsonx-ai@npm:1.3.0" +"@ibm-cloud/watsonx-ai@npm:^1.4.0": + version: 1.4.0 + resolution: "@ibm-cloud/watsonx-ai@npm:1.4.0" dependencies: + "@langchain/textsplitters": ^0.1.0 "@types/node": ^18.0.0 extend: 3.0.2 ibm-cloud-sdk-core: ^5.0.2 - checksum: 6a2127391ca70005b942d3c4ab1abc738946c42bbf3ee0f8eb6f778434b5f8806d622f1f36446f00b9fb82dc2c8aea3526426ec46cc53fa8a075ba7a294da096 + checksum: 5250816f9ad93839cf26e3788eeace8155721765c39c65547eff8ebbd5fc8a0dfa107f6e799593f1209f4b3489be24aa674aa92b7ecbc5fc2bd29390a28e84ff languageName: node linkType: hard @@ -11899,7 +11900,7 @@ __metadata: "@gradientai/nodejs-sdk": ^1.2.0 "@huggingface/inference": ^2.6.4 "@huggingface/transformers": ^3.2.3 - "@ibm-cloud/watsonx-ai": ^1.3.0 + "@ibm-cloud/watsonx-ai": ^1.4.0 "@jest/globals": ^29.5.0 "@lancedb/lancedb": ^0.13.0 "@langchain/core": "workspace:*" @@ -13237,7 +13238,7 @@ __metadata: languageName: unknown linkType: soft -"@langchain/textsplitters@>=0.0.0 <0.2.0, @langchain/textsplitters@workspace:*, @langchain/textsplitters@workspace:libs/langchain-textsplitters": +"@langchain/textsplitters@>=0.0.0 <0.2.0, @langchain/textsplitters@^0.1.0, @langchain/textsplitters@workspace:*, @langchain/textsplitters@workspace:libs/langchain-textsplitters": version: 0.0.0-use.local resolution: "@langchain/textsplitters@workspace:libs/langchain-textsplitters" dependencies: From f818281aa450006a292b8c8d57537e23d98c6e2d Mon Sep 17 00:00:00 2001 From: FilipZmijewski Date: Tue, 4 Feb 2025 13:28:46 +0100 Subject: [PATCH 2/3] fix: Add fixes regarding PR comments, change models in tests (#42) --- .../src/chat_models/ibm.ts | 19 +++--- .../tests/ibm.standard.int.test.ts | 9 +++ libs/langchain-community/src/llms/ibm.ts | 9 +-- .../src/llms/tests/ibm.int.test.ts | 62 ++++++++++--------- 4 files changed, 55 insertions(+), 44 deletions(-) diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index 17e80922a8db..3bb4f4a0adf4 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -150,7 +150,7 @@ function _convertToolToWatsonxTool( function _convertMessagesToWatsonxMessages( messages: BaseMessage[], - model: string + model?: string ): TextChatResultMessage[] { const getRole = (role: MessageType) => { switch (role) { @@ -174,7 +174,7 @@ function _convertMessagesToWatsonxMessages( return message.tool_calls .map((toolCall) => ({ ...toolCall, - id: _convertToValidToolId(model, toolCall.id ?? ""), + id: _convertToValidToolId(model ?? "", toolCall.id ?? ""), })) .map(convertLangChainToolCallToOpenAI) as TextChatToolCall[]; } @@ -189,7 +189,7 @@ function _convertMessagesToWatsonxMessages( role: getRole(message._getType()), content, name: message.name, - tool_call_id: _convertToValidToolId(model, message.tool_call_id), + tool_call_id: _convertToValidToolId(model ?? "", message.tool_call_id), }; } @@ -252,7 +252,7 @@ function _watsonxResponseToChatMessage( function _convertDeltaToMessageChunk( delta: WatsonxDeltaStream, rawData: TextChatResponse, - model: string, + model?: string, usage?: TextChatUsage, defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role ) { @@ -268,7 +268,7 @@ function _convertDeltaToMessageChunk( } => ({ ...toolCall, index, - id: _convertToValidToolId(model, toolCall.id), + id: _convertToValidToolId(model ?? "", toolCall.id), type: "function", }) ) @@ -321,7 +321,7 @@ function _convertDeltaToMessageChunk( return new ToolMessageChunk({ content, additional_kwargs, - tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id), + tool_call_id: _convertToValidToolId(model ?? "", rawToolCalls?.[0].id), }); } else if (role === "function") { return new FunctionMessageChunk({ @@ -410,7 +410,7 @@ export class ChatWatsonx< }; } - model: string; + model?: string; version = "2024-05-31"; @@ -523,7 +523,6 @@ export class ChatWatsonx< const { signal, promptIndex, ...rest } = options; if (this.idOrName && Object.keys(rest).length > 0) throw new Error("Options cannot be provided to a deployed model"); - if (this.idOrName) return undefined; const params = { maxTokens: options.maxTokens ?? this.maxTokens, @@ -564,9 +563,9 @@ export class ChatWatsonx< | { idOrName: string } | { projectId: string; modelId: string } | { spaceId: string; modelId: string } { - if (this.projectId) + if (this.projectId && this.model) return { projectId: this.projectId, modelId: this.model }; - else if (this.spaceId) + else if (this.spaceId && this.model) return { spaceId: this.spaceId, modelId: this.model }; else if (this.idOrName) return { idOrName: this.idOrName }; else throw new Error("No scope id provided"); diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts index 68b967d972b7..0a247720cbf2 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts @@ -34,6 +34,15 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests< }, }); } + + async testInvokeMoreComplexTools() { + this.skipTestMessage( + "testInvokeMoreComplexTools", + "ChatWatsonx", + "Watsonx does not support tool schemas which contain object with unknown/any parameters." + + "Watsonx only supports objects in schemas when the parameters are defined." + ); + } } const testClass = new ChatWatsonxStandardIntegrationTests(); diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts index 97fb287a982a..8bf94854800d 100644 --- a/libs/langchain-community/src/llms/ibm.ts +++ b/libs/langchain-community/src/llms/ibm.ts @@ -454,10 +454,11 @@ export class WatsonxLLM< geneartionsArray[completion].stop_reason = chunk?.generationInfo?.stop_reason; geneartionsArray[completion].text += chunk.text; - void runManager?.handleLLMNewToken(chunk.text, { - prompt: promptIdx, - completion: 0, - }); + if (chunk.text) + void runManager?.handleLLMNewToken(chunk.text, { + prompt: promptIdx, + completion: 0, + }); } return geneartionsArray.map((item) => { diff --git a/libs/langchain-community/src/llms/tests/ibm.int.test.ts b/libs/langchain-community/src/llms/tests/ibm.int.test.ts index f6a6e9e1ddcc..75af46c694bf 100644 --- a/libs/langchain-community/src/llms/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.int.test.ts @@ -11,7 +11,7 @@ describe("Text generation", () => { describe("Test invoke method", () => { test("Correct value", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -21,7 +21,7 @@ describe("Text generation", () => { test("Overwritte params", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -33,7 +33,7 @@ describe("Text generation", () => { test("Invalid projectId", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: "Test wrong value", @@ -43,7 +43,7 @@ describe("Text generation", () => { test("Invalid credentials", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: "Test wrong value", @@ -56,7 +56,7 @@ describe("Text generation", () => { test("Wrong value", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -67,7 +67,7 @@ describe("Text generation", () => { test("Stop", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -79,7 +79,7 @@ describe("Text generation", () => { test("Stop with timeout", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: "sdadasdas" as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -94,7 +94,7 @@ describe("Text generation", () => { test("Signal in call options", async () => { const watsonXInstance = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -119,7 +119,7 @@ describe("Text generation", () => { test("Concurenccy", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", maxConcurrency: 1, version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, @@ -139,7 +139,7 @@ describe("Text generation", () => { input_token_count: 0, }; const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", maxNewTokens: 1, maxConcurrency: 1, @@ -171,7 +171,7 @@ describe("Text generation", () => { let streamedText = ""; let usedTokens = 0; const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -198,7 +198,7 @@ describe("Text generation", () => { describe("Test generate methods", () => { test("Basic usage", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -213,7 +213,7 @@ describe("Text generation", () => { test("Stop", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -221,12 +221,14 @@ describe("Text generation", () => { }); const res = await model.generate( - ["Print hello world!", "Print hello world hello!"], + [ + "Print hello world in JavaScript!!", + "Print hello world twice in Python!", + ], { - stop: ["Hello"], + stop: ["hello"], } ); - expect( res.generations .map((generation) => generation.map((item) => item.text)) @@ -239,7 +241,7 @@ describe("Text generation", () => { const nrNewTokens = [0, 0, 0]; const completions = ["", "", ""]; const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -270,7 +272,7 @@ describe("Text generation", () => { test("Prompt value", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -290,7 +292,7 @@ describe("Text generation", () => { let countedTokens = 0; let streamedText = ""; const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -313,15 +315,15 @@ describe("Text generation", () => { test("Stop", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, maxNewTokens: 100, }); - const stream = await model.stream("Print hello world!", { - stop: ["Hello"], + const stream = await model.stream("Print hello world in JavaScript!", { + stop: ["hello"], }); const chunks = []; for await (const chunk of stream) { @@ -332,7 +334,7 @@ describe("Text generation", () => { test("Timeout", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -354,7 +356,7 @@ describe("Text generation", () => { test("Signal in call options", async () => { const model = new WatsonxLLM({ - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -384,7 +386,7 @@ describe("Text generation", () => { describe("Test getNumToken method", () => { test("Passing correct value", async () => { const testProps: WatsonxInputLLM = { - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -402,7 +404,7 @@ describe("Text generation", () => { test("Passing wrong value", async () => { const testProps: WatsonxInputLLM = { - model: "ibm/granite-13b-chat-v2", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -425,7 +427,7 @@ describe("Text generation", () => { test("Single request callback", async () => { let callbackFlag = false; const service = new WatsonxLLM({ - model: "mistralai/mistral-large", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -445,7 +447,7 @@ describe("Text generation", () => { test("Single response callback", async () => { let callbackFlag = false; const service = new WatsonxLLM({ - model: "mistralai/mistral-large", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -467,7 +469,7 @@ describe("Text generation", () => { let callbackFlagReq = false; let callbackFlagRes = false; const service = new WatsonxLLM({ - model: "mistralai/mistral-large", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -495,7 +497,7 @@ describe("Text generation", () => { let langchainCallback = false; const service = new WatsonxLLM({ - model: "mistralai/mistral-large", + model: "ibm/granite-3-8b-instruct", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", From 0dc5cacdfc82f03220078f8c9ba08e788de3eb9a Mon Sep 17 00:00:00 2001 From: FilipZmijewski Date: Tue, 4 Feb 2025 13:58:41 +0100 Subject: [PATCH 3/3] fix: Remove optional checker --- libs/langchain-community/src/chat_models/ibm.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index 3bb4f4a0adf4..7db2c9cec1c6 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -405,8 +405,8 @@ export class ChatWatsonx< ls_provider: "watsonx", ls_model_name: this.model, ls_model_type: "chat", - ls_temperature: params?.temperature ?? undefined, - ls_max_tokens: params?.maxTokens ?? undefined, + ls_temperature: params.temperature ?? undefined, + ls_max_tokens: params.maxTokens ?? undefined, }; }