Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[community]: Add chat deployment to IBM chat class #7633

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"@gradientai/nodejs-sdk": "^1.2.0",
"@huggingface/inference": "^2.6.4",
"@huggingface/transformers": "^3.2.3",
"@ibm-cloud/watsonx-ai": "^1.3.0",
"@ibm-cloud/watsonx-ai": "^1.4.0",
"@jest/globals": "^29.5.0",
"@lancedb/lancedb": "^0.13.0",
"@langchain/core": "workspace:*",
Expand Down
190 changes: 126 additions & 64 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
} from "@langchain/core/outputs";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import {
DeploymentsTextChatParams,
RequestCallbacks,
TextChatMessagesTextChatMessageAssistant,
TextChatParameterTools,
Expand Down Expand Up @@ -65,7 +66,13 @@ import {
import { isZodSchema } from "@langchain/core/utils/types";
import { zodToJsonSchema } from "zod-to-json-schema";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { WatsonxAuth, WatsonxParams } from "../types/ibm.js";
import {
Neverify,
WatsonxAuth,
WatsonxChatBasicOptions,
WatsonxDeployedParams,
WatsonxParams,
} from "../types/ibm.js";
import {
_convertToolCallIdToMistralCompatible,
authenticateAndSetInstance,
Expand All @@ -80,27 +87,43 @@ export interface WatsonxDeltaStream {
}

export interface WatsonxCallParams
extends Partial<Omit<TextChatParams, "modelId" | "toolChoice">> {
maxRetries?: number;
watsonxCallbacks?: RequestCallbacks;
}
extends Partial<
Omit<TextChatParams, "modelId" | "toolChoice" | "messages" | "headers">
> {}

export interface WatsonxCallDeployedParams extends DeploymentsTextChatParams {}

export interface WatsonxCallOptionsChat
extends Omit<BaseChatModelCallOptions, "stop">,
WatsonxCallParams {
WatsonxCallParams,
WatsonxChatBasicOptions {
promptIndex?: number;
tool_choice?: TextChatParameterTools | string | "auto" | "any";
watsonxCallbacks?: RequestCallbacks;
}

export interface WatsonxCallOptionsDeployedChat
extends WatsonxCallDeployedParams,
WatsonxChatBasicOptions {
promptIndex?: number;
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;

export interface ChatWatsonxInput
extends BaseChatModelParams,
WatsonxParams,
WatsonxCallParams {
streaming?: boolean;
}
WatsonxCallParams,
Neverify<DeploymentsTextChatParams> {}

export interface ChatWatsonxDeployedInput
extends BaseChatModelParams,
WatsonxDeployedParams,
Neverify<TextChatParams> {}

export type ChatWatsonxConstructor = BaseChatModelParams &
Partial<WatsonxParams> &
WatsonxDeployedParams &
WatsonxCallParams;
function _convertToValidToolId(model: string, tool_call_id: string) {
if (model.startsWith("mistralai"))
return _convertToolCallIdToMistralCompatible(tool_call_id);
Expand All @@ -127,7 +150,7 @@ function _convertToolToWatsonxTool(

function _convertMessagesToWatsonxMessages(
messages: BaseMessage[],
model: string
model?: string
): TextChatResultMessage[] {
const getRole = (role: MessageType) => {
switch (role) {
Expand All @@ -151,7 +174,7 @@ function _convertMessagesToWatsonxMessages(
return message.tool_calls
.map((toolCall) => ({
...toolCall,
id: _convertToValidToolId(model, toolCall.id ?? ""),
id: _convertToValidToolId(model ?? "", toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as TextChatToolCall[];
}
Expand All @@ -166,7 +189,7 @@ function _convertMessagesToWatsonxMessages(
role: getRole(message._getType()),
content,
name: message.name,
tool_call_id: _convertToValidToolId(model, message.tool_call_id),
tool_call_id: _convertToValidToolId(model ?? "", message.tool_call_id),
};
}

Expand Down Expand Up @@ -229,7 +252,7 @@ function _watsonxResponseToChatMessage(
function _convertDeltaToMessageChunk(
delta: WatsonxDeltaStream,
rawData: TextChatResponse,
model: string,
model?: string,
usage?: TextChatUsage,
defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role
) {
Expand All @@ -245,7 +268,7 @@ function _convertDeltaToMessageChunk(
} => ({
...toolCall,
index,
id: _convertToValidToolId(model, toolCall.id),
id: _convertToValidToolId(model ?? "", toolCall.id),
type: "function",
})
)
Expand Down Expand Up @@ -298,7 +321,7 @@ function _convertDeltaToMessageChunk(
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id),
tool_call_id: _convertToValidToolId(model ?? "", rawToolCalls?.[0].id),
});
} else if (role === "function") {
return new FunctionMessageChunk({
Expand Down Expand Up @@ -335,10 +358,12 @@ function _convertToolChoiceToWatsonxToolChoice(
}

export class ChatWatsonx<
CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat
CallOptions extends WatsonxCallOptionsChat =
| WatsonxCallOptionsChat
| WatsonxCallOptionsDeployedChat
>
extends BaseChatModel<CallOptions>
implements ChatWatsonxInput
implements ChatWatsonxConstructor
{
static lc_name() {
return "ChatWatsonx";
Expand Down Expand Up @@ -385,7 +410,7 @@ export class ChatWatsonx<
};
}

model: string;
model?: string;

version = "2024-05-31";

Expand All @@ -399,6 +424,8 @@ export class ChatWatsonx<

projectId?: string;

idOrName?: string;

frequencyPenalty?: number;

logprobs?: boolean;
Expand All @@ -425,37 +452,44 @@ export class ChatWatsonx<

watsonxCallbacks?: RequestCallbacks;

constructor(fields: ChatWatsonxInput & WatsonxAuth) {
constructor(
fields: (ChatWatsonxInput | ChatWatsonxDeployedInput) & WatsonxAuth
) {
super(fields);
if (
(fields.projectId && fields.spaceId) ||
(fields.idOrName && fields.projectId) ||
(fields.spaceId && fields.idOrName)
("projectId" in fields && "spaceId" in fields) ||
("projectId" in fields && "idOrName" in fields) ||
("spaceId" in fields && "idOrName" in fields)
)
throw new Error("Maximum 1 id type can be specified per instance");

if (!fields.projectId && !fields.spaceId && !fields.idOrName)
if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields))
throw new Error(
"No id specified! At least id of 1 type has to be specified"
);
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;

if ("model" in fields) {
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
} else this.idOrName = fields?.idOrName;
FilipZmijewski marked this conversation as resolved.
Show resolved Hide resolved

this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;
this.serviceUrl = fields?.serviceUrl;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
this.version = fields?.version ?? this.version;
this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;

const {
watsonxAIApikey,
watsonxAIAuthType,
Expand Down Expand Up @@ -486,6 +520,10 @@ export class ChatWatsonx<
}

invocationParams(options: this["ParsedCallOptions"]) {
const { signal, promptIndex, ...rest } = options;
if (this.idOrName && Object.keys(rest).length > 0)
throw new Error("Options cannot be provided to a deployed model");

const params = {
maxTokens: options.maxTokens ?? this.maxTokens,
temperature: options?.temperature ?? this.temperature,
Expand Down Expand Up @@ -521,10 +559,16 @@ export class ChatWatsonx<
} as CallOptions);
}

scopeId() {
if (this.projectId)
scopeId():
| { idOrName: string }
| { projectId: string; modelId: string }
| { spaceId: string; modelId: string } {
if (this.projectId && this.model)
return { projectId: this.projectId, modelId: this.model };
else return { spaceId: this.spaceId, modelId: this.model };
else if (this.spaceId && this.model)
return { spaceId: this.spaceId, modelId: this.model };
else if (this.idOrName) return { idOrName: this.idOrName };
else throw new Error("No scope id provided");
}

async completionWithRetry<T>(
Expand Down Expand Up @@ -595,23 +639,30 @@ export class ChatWatsonx<
.map(([_, value]) => value);
return { generations, llmOutput: { tokenUsage } };
} else {
const params = {
...this.invocationParams(options),
...this.scopeId(),
};
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxCallbacks = this.invocationCallbacks(options);
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const callback = () =>
this.service.textChat(
{
...params,
messages: watsonxMessages,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChat(
{
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
)
: this.service.textChat(
{
...params,
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
);
const { result } = await this.completionWithRetry(callback, options);
const generations: ChatGeneration[] = [];
for (const part of result.choices) {
Expand Down Expand Up @@ -646,21 +697,33 @@ export class ChatWatsonx<
options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const params = { ...this.invocationParams(options), ...this.scopeId() };
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const watsonxCallbacks = this.invocationCallbacks(options);
const callback = () =>
this.service.textChatStream(
{
...params,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChatStream(
{
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
)
: this.service.textChatStream(
{
...params,
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);

const stream = await this.completionWithRetry(callback, options);
let defaultRole;
let usage: TextChatUsage | undefined;
Expand Down Expand Up @@ -707,7 +770,6 @@ export class ChatWatsonx<
if (message === null || (!delta.content && !delta.tool_calls)) {
continue;
}

const generationChunk = new ChatGenerationChunk({
message,
text: delta.content ?? "",
Expand Down
Loading