diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index 1a1fdd36e12b..4c14e304a9dd 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -66,6 +66,7 @@ import { } from "@langchain/core/runnables"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ToolCallChunk } from "@langchain/core/messages/tool"; +import { _convertToolCallIdToMistralCompatible } from "./utils.js"; interface TokenUsage { completionTokens?: number; @@ -199,7 +200,10 @@ function convertMessagesToMistralMessages( const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => { if (isAIMessage(message) && !!message.tool_calls?.length) { return message.tool_calls - .map((toolCall) => ({ ...toolCall, id: toolCall.id })) + .map((toolCall) => ({ + ...toolCall, + id: _convertToolCallIdToMistralCompatible(toolCall.id ?? ""), + })) .map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[]; } if (!message.additional_kwargs.tool_calls?.length) { @@ -208,7 +212,7 @@ function convertMessagesToMistralMessages( const toolCalls: Omit[] = message.additional_kwargs.tool_calls; return toolCalls?.map((toolCall) => ({ - id: toolCall.id, + id: _convertToolCallIdToMistralCompatible(toolCall.id), type: "function", function: toolCall.function, })); @@ -217,6 +221,17 @@ function convertMessagesToMistralMessages( return messages.map((message) => { const toolCalls = getTools(message); const content = toolCalls === undefined ? getContent(message.content) : ""; + if ("tool_call_id" in message && typeof message.tool_call_id === "string") { + return { + role: getRole(message._getType()), + content, + name: message.name, + tool_call_id: _convertToolCallIdToMistralCompatible( + message.tool_call_id + ), + }; + } + return { role: getRole(message._getType()), content, diff --git a/libs/langchain-mistralai/src/tests/chat_models.standard.test.ts b/libs/langchain-mistralai/src/tests/chat_models.standard.test.ts index c925f3dc6d47..d77997e8e258 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.standard.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.standard.test.ts @@ -24,7 +24,7 @@ class ChatMistralAIStandardUnitTests extends ChatModelUnitTests< expectedLsParams(): Partial { console.warn( - "Overriding testStandardParams. ChatCloudflareWorkersAI does not support stop sequences." + "Overriding testStandardParams. ChatMistralAI does not support stop sequences." ); return { ls_provider: "string", diff --git a/libs/langchain-mistralai/src/tests/chat_models.test.ts b/libs/langchain-mistralai/src/tests/chat_models.test.ts new file mode 100644 index 000000000000..b1dedfb0ba56 --- /dev/null +++ b/libs/langchain-mistralai/src/tests/chat_models.test.ts @@ -0,0 +1,29 @@ +import { + _isValidMistralToolCallId, + _convertToolCallIdToMistralCompatible, +} from "../utils.js"; + +describe("Mistral Tool Call ID Conversion", () => { + test("valid and invalid Mistral tool call IDs", () => { + expect(_isValidMistralToolCallId("ssAbar4Dr")).toBe(true); + expect(_isValidMistralToolCallId("abc123")).toBe(false); + expect(_isValidMistralToolCallId("call_JIIjI55tTipFFzpcP8re3BpM")).toBe( + false + ); + }); + + test("tool call ID conversion", () => { + const resultMap: Record = { + ssAbar4Dr: "ssAbar4Dr", + abc123: "0001yoN1K", + call_JIIjI55tTipFFzpcP8re3BpM: "0001sqrj5", + 12345: "00003akVR", + }; + + for (const [inputId, expectedOutput] of Object.entries(resultMap)) { + const convertedId = _convertToolCallIdToMistralCompatible(inputId); + expect(convertedId).toBe(expectedOutput); + expect(_isValidMistralToolCallId(convertedId)).toBe(true); + } + }); +}); diff --git a/libs/langchain-mistralai/src/utils.ts b/libs/langchain-mistralai/src/utils.ts new file mode 100644 index 000000000000..193efb570555 --- /dev/null +++ b/libs/langchain-mistralai/src/utils.ts @@ -0,0 +1,46 @@ +// Mistral enforces a specific pattern for tool call IDs +const TOOL_CALL_ID_PATTERN = /^[a-zA-Z0-9]{9}$/; + +export function _isValidMistralToolCallId(toolCallId: string): boolean { + return TOOL_CALL_ID_PATTERN.test(toolCallId); +} + +function _base62Encode(num: number): string { + let numCopy = num; + const base62 = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + if (numCopy === 0) return base62[0]; + const arr: string[] = []; + const base = base62.length; + while (numCopy) { + arr.push(base62[numCopy % base]); + numCopy = Math.floor(numCopy / base); + } + return arr.reverse().join(""); +} + +function _simpleHash(str: string): number { + let hash = 0; + for (let i = 0; i < str.length; i += 1) { + const char = str.charCodeAt(i); + hash = (hash << 5) - hash + char; + hash &= hash; // Convert to 32-bit integer + } + return Math.abs(hash); +} + +export function _convertToolCallIdToMistralCompatible( + toolCallId: string +): string { + if (_isValidMistralToolCallId(toolCallId)) { + return toolCallId; + } else { + const hash = _simpleHash(toolCallId); + const base62Str = _base62Encode(hash); + if (base62Str.length >= 9) { + return base62Str.slice(0, 9); + } else { + return base62Str.padStart(9, "0"); + } + } +}