From ea0c5a8be4f47ce8ad6a02f49fbe45468f243b79 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 18 Nov 2024 15:53:46 -0800 Subject: [PATCH 01/27] fix(openai): Support o1 streaming (#7229) --- libs/langchain-openai/src/chat_models.ts | 14 ---------- .../src/tests/chat_models.int.test.ts | 28 +++++++++++++++++++ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 1db33c8728abe..8d2145fc8be74 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -14,7 +14,6 @@ import { ToolMessageChunk, OpenAIToolCall, isAIMessage, - convertToChunk, UsageMetadata, } from "@langchain/core/messages"; import { @@ -1360,19 +1359,6 @@ export class ChatOpenAI< options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - if (this.model.includes("o1-")) { - console.warn( - "[WARNING]: OpenAI o1 models do not yet support token-level streaming. Streaming will yield single chunk." - ); - const result = await this._generate(messages, options, runManager); - const messageChunk = convertToChunk(result.generations[0].message); - yield new ChatGenerationChunk({ - message: messageChunk, - text: - typeof messageChunk.content === "string" ? messageChunk.content : "", - }); - return; - } const messagesMapped: OpenAICompletionParam[] = _convertMessagesToOpenAIParams(messages); const params = { diff --git a/libs/langchain-openai/src/tests/chat_models.int.test.ts b/libs/langchain-openai/src/tests/chat_models.int.test.ts index a5bed0811e617..be5f8c7d7a903 100644 --- a/libs/langchain-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.int.test.ts @@ -1166,3 +1166,31 @@ describe("Audio output", () => { ).toBeGreaterThan(1); }); }); + +test("Can stream o1 requests", async () => { + const model = new ChatOpenAI({ + model: "o1-mini", + }); + const stream = await model.stream( + "Write me a very simple hello world program in Python. Ensure it is wrapped in a function called 'hello_world' and has descriptive comments." + ); + let finalMsg: AIMessageChunk | undefined; + let numChunks = 0; + for await (const chunk of stream) { + finalMsg = finalMsg ? concat(finalMsg, chunk) : chunk; + numChunks += 1; + } + + expect(finalMsg).toBeTruthy(); + if (!finalMsg) { + throw new Error("No final message found"); + } + if (typeof finalMsg.content === "string") { + expect(finalMsg.content.length).toBeGreaterThan(10); + } else { + expect(finalMsg.content.length).toBeGreaterThanOrEqual(1); + } + + // A + expect(numChunks).toBeGreaterThan(3); +}); From 029240abf4f60ff123af76d386c9f2f476ffca85 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 18 Nov 2024 15:56:54 -0800 Subject: [PATCH 02/27] chore(openai): Release 0.3.14 (#7230) --- libs/langchain-openai/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-openai/package.json b/libs/langchain-openai/package.json index aeed4d5af3028..c0b4e1956e091 100644 --- a/libs/langchain-openai/package.json +++ b/libs/langchain-openai/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/openai", - "version": "0.3.13", + "version": "0.3.14", "description": "OpenAI integrations for LangChain.js", "type": "module", "engines": { From b9b414a668d1cc816fa83d2bcdd996880d39e197 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 20 Nov 2024 14:34:19 -0500 Subject: [PATCH 03/27] docs: Add missing streaming concept (#7228) --- docs/core_docs/docs/concepts/index.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/core_docs/docs/concepts/index.mdx b/docs/core_docs/docs/concepts/index.mdx index 709e31ef9af16..4bda066c00bd0 100644 --- a/docs/core_docs/docs/concepts/index.mdx +++ b/docs/core_docs/docs/concepts/index.mdx @@ -22,6 +22,7 @@ The conceptual guide does not cover step-by-step instructions or specific implem - **[Memory](https://langchain-ai.github.io/langgraphjs/concepts/memory/)**: Information about a conversation that is persisted so that it can be used in future conversations. - **[Multimodality](/docs/concepts/multimodality)**: The ability to work with data that comes in different forms, such as text, audio, images, and video. - **[Runnable interface](/docs/concepts/runnables)**: The base abstraction that many LangChain components and the LangChain Expression Language are built on. +- **[Streaming](/docs/concepts/streaming)**: LangChain streaming APIs for surfacing results as they are generated. - **[LangChain Expression Language (LCEL)](/docs/concepts/lcel)**: A syntax for orchestrating LangChain components. Most useful for simpler applications. - **[Document loaders](/docs/concepts/document_loaders)**: Load a source as a list of documents. - **[Retrieval](/docs/concepts/retrieval)**: Information retrieval systems can retrieve structured or unstructured data from a datasource in response to a query. From d2e1f4f2763e7cfae822fe52e4e0e03acb466e4f Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 21 Nov 2024 18:01:35 -0800 Subject: [PATCH 04/27] fix(ci): Update release script to be more lenient (#7234) --- release_workspace.js | 309 +++++++++++++++++++++++++++++++------------ 1 file changed, 225 insertions(+), 84 deletions(-) diff --git a/release_workspace.js b/release_workspace.js index 3fe8d3b994a83..9e127961a1e74 100644 --- a/release_workspace.js +++ b/release_workspace.js @@ -4,20 +4,21 @@ const fs = require("fs"); const path = require("path"); const { spawn } = require("child_process"); const readline = require("readline"); -const semver = require('semver') +const semver = require("semver"); -const PRIMARY_PROJECTS = ["langchain", "@langchain/core", "@langchain/community"]; const RELEASE_BRANCH = "release"; const MAIN_BRANCH = "main"; /** * Get the version of a workspace inside a directory. - * - * @param {string} workspaceDirectory + * + * @param {string} workspaceDirectory * @returns {string} The version of the workspace in the input directory. */ function getWorkspaceVersion(workspaceDirectory) { - const pkgJsonFile = fs.readFileSync(path.join(process.cwd(), workspaceDirectory, "package.json")); + const pkgJsonFile = fs.readFileSync( + path.join(process.cwd(), workspaceDirectory, "package.json") + ); const parsedJSONFile = JSON.parse(pkgJsonFile); return parsedJSONFile.version; } @@ -26,29 +27,45 @@ function getWorkspaceVersion(workspaceDirectory) { * Finds all workspaces in the monorepo and returns an array of objects. * Each object in the return value contains the relative path to the workspace * directory, along with the full package.json file contents. - * + * * @returns {Array<{ dir: string, packageJSON: Record}>} */ function getAllWorkspaces() { - const possibleWorkspaceDirectories = ["./libs/*", "./langchain", "./langchain-core"]; - const allWorkspaces = possibleWorkspaceDirectories.flatMap((workspaceDirectory) => { - if (workspaceDirectory.endsWith("*")) { - // List all folders inside directory, require, and return the package.json. - const allDirs = fs.readdirSync(path.join(process.cwd(), workspaceDirectory.replace("*", ""))); - const subDirs = allDirs.map((dir) => { - return { - dir: `${workspaceDirectory.replace("*", "")}${dir}`, - packageJSON: require(path.join(process.cwd(), `${workspaceDirectory.replace("*", "")}${dir}`, "package.json")) - } - }); - return subDirs; + const possibleWorkspaceDirectories = [ + "./libs/*", + "./langchain", + "./langchain-core", + ]; + const allWorkspaces = possibleWorkspaceDirectories.flatMap( + (workspaceDirectory) => { + if (workspaceDirectory.endsWith("*")) { + // List all folders inside directory, require, and return the package.json. + const allDirs = fs.readdirSync( + path.join(process.cwd(), workspaceDirectory.replace("*", "")) + ); + const subDirs = allDirs.map((dir) => { + return { + dir: `${workspaceDirectory.replace("*", "")}${dir}`, + packageJSON: require(path.join( + process.cwd(), + `${workspaceDirectory.replace("*", "")}${dir}`, + "package.json" + )), + }; + }); + return subDirs; + } + const packageJSON = require(path.join( + process.cwd(), + workspaceDirectory, + "package.json" + )); + return { + dir: workspaceDirectory, + packageJSON, + }; } - const packageJSON = require(path.join(process.cwd(), workspaceDirectory, "package.json")); - return { - dir: workspaceDirectory, - packageJSON, - }; - }); + ); return allWorkspaces; } @@ -56,18 +73,24 @@ function getAllWorkspaces() { * Writes the JSON file with the updated dependency version. Accounts * for version prefixes, eg ~, ^, >, <, >=, <=, ||, *. Also skips * versions which are "latest" or "workspace:*". - * - * @param {Array} workspaces - * @param {"dependencies" | "devDependencies" | "peerDependencies"} dependencyType - * @param {string} workspaceName - * @param {string} newVersion + * + * @param {Array} workspaces + * @param {"dependencies" | "devDependencies" | "peerDependencies"} dependencyType + * @param {string} workspaceName + * @param {string} newVersion */ -function updateDependencies(workspaces, dependencyType, workspaceName, newVersion) { +function updateDependencies( + workspaces, + dependencyType, + workspaceName, + newVersion +) { const versionPrefixes = ["~", "^", ">", "<", ">=", "<=", "||", "*"]; const skipVersions = ["latest", "workspace:*"]; workspaces.forEach((workspace) => { - const currentVersion = workspace.packageJSON[dependencyType]?.[workspaceName]; + const currentVersion = + workspace.packageJSON[dependencyType]?.[workspaceName]; if (currentVersion) { const prefix = versionPrefixes.find((p) => currentVersion.startsWith(p)); const shouldSkip = skipVersions.some((v) => currentVersion === v); @@ -75,7 +98,10 @@ function updateDependencies(workspaces, dependencyType, workspaceName, newVersio if (!shouldSkip) { const versionToUpdate = prefix ? `${prefix}${newVersion}` : newVersion; workspace.packageJSON[dependencyType][workspaceName] = versionToUpdate; - fs.writeFileSync(path.join(workspace.dir, "package.json"), JSON.stringify(workspace.packageJSON, null, 2) + "\n"); + fs.writeFileSync( + path.join(workspace.dir, "package.json"), + JSON.stringify(workspace.packageJSON, null, 2) + "\n" + ); } } }); @@ -85,7 +111,7 @@ function updateDependencies(workspaces, dependencyType, workspaceName, newVersio * Runs `release-it` with args in the input package directory, * passing the new version as an argument, along with other * release-it args. - * + * * @param {string} packageDirectory The directory to run yarn release in. * @param {string} npm2FACode The 2FA code for NPM. * @param {string | undefined} tag An optional tag to publish to. @@ -95,22 +121,43 @@ async function runYarnRelease(packageDirectory, npm2FACode, tag) { return new Promise((resolve, reject) => { const workingDirectory = path.join(process.cwd(), packageDirectory); const tagArg = tag ? `--npm.tag=${tag}` : ""; - const args = ["release-it", `--npm.otp=${npm2FACode}`, tagArg, "--config", ".release-it.json"]; - + const args = [ + "release-it", + `--npm.otp=${npm2FACode}`, + tagArg, + "--config", + ".release-it.json", + ]; + console.log(`Running command: "yarn ${args.join(" ")}"`); - const yarnReleaseProcess = spawn("yarn", args, { stdio: "inherit", cwd: workingDirectory }); + const yarnReleaseProcess = spawn("yarn", args, { cwd: workingDirectory }); + + let stdout = ""; + let stderr = ""; + + yarnReleaseProcess.stdout.on("data", (data) => { + stdout += data; + // Still show output in real-time + process.stdout.write(data); + }); + + yarnReleaseProcess.stderr.on("data", (data) => { + stderr += data; + // Still show errors in real-time + process.stderr.write(data); + }); yarnReleaseProcess.on("close", (code) => { if (code === 0) { resolve(); } else { - reject(`Process exited with code ${code}`); + reject(`Process exited with code ${code}.\nError: ${stderr}`); } }); yarnReleaseProcess.on("error", (err) => { - reject(err); + reject(`Failed to start process: ${err.message}\nError: ${stderr}`); }); }); } @@ -119,7 +166,7 @@ async function runYarnRelease(packageDirectory, npm2FACode, tag) { * Finds all `package.json`'s which contain the input workspace as a dependency. * Then, updates the dependency to the new version, runs yarn install and * commits the changes. - * + * * @param {string} workspaceName The name of the workspace to bump dependencies for. * @param {string} workspaceDirectory The path to the workspace directory. * @param {Array<{ dir: string, packageJSON: Record}>} allWorkspaces @@ -127,7 +174,13 @@ async function runYarnRelease(packageDirectory, npm2FACode, tag) { * @param {string} preReleaseVersion The version of the workspace before it was released. * @returns {void} */ -function bumpDeps(workspaceName, workspaceDirectory, allWorkspaces, tag, preReleaseVersion) { +function bumpDeps( + workspaceName, + workspaceDirectory, + allWorkspaces, + tag, + preReleaseVersion +) { // Read workspace file, get version (edited by release-it), and bump pkgs to that version. let updatedWorkspaceVersion = getWorkspaceVersion(workspaceDirectory); if (!semver.valid(updatedWorkspaceVersion)) { @@ -138,11 +191,15 @@ function bumpDeps(workspaceName, workspaceDirectory, allWorkspaces, tag, preRele // If the updated version is not greater than the pre-release version, // the branch is out of sync. Pull from github and check again. if (!semver.gt(updatedWorkspaceVersion, preReleaseVersion)) { - console.log("Updated version is not greater than the pre-release version. Pulling from github and checking again."); + console.log( + "Updated version is not greater than the pre-release version. Pulling from github and checking again." + ); execSync(`git pull origin ${RELEASE_BRANCH}`); updatedWorkspaceVersion = getWorkspaceVersion(workspaceDirectory); if (!semver.gt(updatedWorkspaceVersion, preReleaseVersion)) { - console.warn(`Workspace version has not changed in repo. Version in repo: ${updatedWorkspaceVersion}. Exiting.`); + console.warn( + `Workspace version has not changed in repo. Version in repo: ${updatedWorkspaceVersion}. Exiting.` + ); process.exit(0); } } @@ -161,81 +218,138 @@ function bumpDeps(workspaceName, workspaceDirectory, allWorkspaces, tag, preRele console.log(`Checking out new branch: ${newBranchName}`); execSync(`git checkout -b ${newBranchName}`); - const allWorkspacesWhichDependOn = allWorkspaces.filter(({ packageJSON }) => + const allWorkspacesWhichDependOn = allWorkspaces.filter(({ packageJSON }) => Object.keys(packageJSON.dependencies ?? {}).includes(workspaceName) ); - const allWorkspacesWhichDevDependOn = allWorkspaces.filter(({ packageJSON }) => - Object.keys(packageJSON.devDependencies ?? {}).includes(workspaceName) + const allWorkspacesWhichDevDependOn = allWorkspaces.filter( + ({ packageJSON }) => + Object.keys(packageJSON.devDependencies ?? {}).includes(workspaceName) ); - const allWorkspacesWhichPeerDependOn = allWorkspaces.filter(({ packageJSON }) => - Object.keys(packageJSON.peerDependencies ?? {}).includes(workspaceName) + const allWorkspacesWhichPeerDependOn = allWorkspaces.filter( + ({ packageJSON }) => + Object.keys(packageJSON.peerDependencies ?? {}).includes(workspaceName) ); // For console log, get all workspaces which depend and filter out duplicates. - const allWhichDependOn = new Set([ - ...allWorkspacesWhichDependOn, - ...allWorkspacesWhichDevDependOn, - ...allWorkspacesWhichPeerDependOn, - ].map(({ packageJSON }) => packageJSON.name)); + const allWhichDependOn = new Set( + [ + ...allWorkspacesWhichDependOn, + ...allWorkspacesWhichDevDependOn, + ...allWorkspacesWhichPeerDependOn, + ].map(({ packageJSON }) => packageJSON.name) + ); if (allWhichDependOn.size !== 0) { - console.log(`Found ${[...allWhichDependOn].length} workspaces which depend on ${workspaceName}. + console.log(`Found ${ + [...allWhichDependOn].length + } workspaces which depend on ${workspaceName}. Workspaces: - ${[...allWhichDependOn].map((name) => name).join("\n- ")} `); // Update packages which depend on the input workspace. - updateDependencies(allWorkspacesWhichDependOn, "dependencies", workspaceName, updatedWorkspaceVersion); - updateDependencies(allWorkspacesWhichDevDependOn, "devDependencies", workspaceName, updatedWorkspaceVersion); - updateDependencies(allWorkspacesWhichPeerDependOn, "peerDependencies", workspaceName, updatedWorkspaceVersion); + updateDependencies( + allWorkspacesWhichDependOn, + "dependencies", + workspaceName, + updatedWorkspaceVersion + ); + updateDependencies( + allWorkspacesWhichDevDependOn, + "devDependencies", + workspaceName, + updatedWorkspaceVersion + ); + updateDependencies( + allWorkspacesWhichPeerDependOn, + "peerDependencies", + workspaceName, + updatedWorkspaceVersion + ); console.log("Updated package.json's! Running yarn install."); try { execSync(`yarn install`); } catch (_) { - console.log("Yarn install failed. Likely because NPM has not finished publishing the new version. Continuing.") + console.log( + "Yarn install failed. Likely because NPM has not finished publishing the new version. Continuing." + ); } // Add all current changes, commit, push and log branch URL. console.log("Adding and committing all changes."); execSync(`git add -A`); - execSync(`git commit -m "all[minor]: bump deps on ${workspaceName} to ${versionString}"`); + execSync( + `git commit -m "all[minor]: bump deps on ${workspaceName} to ${versionString}"` + ); console.log("Pushing changes."); execSync(`git push -u origin ${newBranchName}`); - console.log("🔗 Open %s and merge the bump-deps PR.", `\x1b[34mhttps://github.com/langchain-ai/langchainjs/compare/${newBranchName}?expand=1\x1b[0m`); + console.log( + "🔗 Open %s and merge the bump-deps PR.", + `\x1b[34mhttps://github.com/langchain-ai/langchainjs/compare/${newBranchName}?expand=1\x1b[0m` + ); } else { console.log(`No workspaces depend on ${workspaceName}.`); } } +/** + * Create a commit message for the input workspace and version. + * + * @param {string} workspaceName + * @param {string} version + */ +function createCommitMessage(workspaceName, version) { + return `release(${workspaceName}): ${version}`; +} + +/** + * Commits all changes and pushes to the current branch. + * + * @param {string} workspaceName The name of the workspace being released + * @param {string} version The new version being released + * @returns {void} + */ +function commitAndPushChanges(workspaceName, version) { + console.log("Committing and pushing changes..."); + const commitMsg = createCommitMessage(workspaceName, version); + execSync("git add -A"); + execSync(`git commit -m "${commitMsg}"`); + // Pushes to the current branch + execSync("git push -u origin $(git rev-parse --abbrev-ref HEAD)"); + console.log("Successfully committed and pushed changes."); +} + /** * Verifies the current branch is main, then checks out a new release branch * and pushes an empty commit. - * + * * @returns {void} * @throws {Error} If the current branch is not main. */ function checkoutReleaseBranch() { const currentBranch = execSync("git branch --show-current").toString().trim(); - if (currentBranch === MAIN_BRANCH) { + if (currentBranch === MAIN_BRANCH || currentBranch === RELEASE_BRANCH) { console.log(`Checking out '${RELEASE_BRANCH}' branch.`); execSync(`git checkout -B ${RELEASE_BRANCH}`); execSync(`git push -u origin ${RELEASE_BRANCH}`); } else { - throw new Error(`Current branch is not ${MAIN_BRANCH}. Current branch: ${currentBranch}`); + throw new Error( + `Current branch is not ${MAIN_BRANCH} or ${RELEASE_BRANCH}. Current branch: ${currentBranch}` + ); } } /** * Prompts the user for input and returns the input. This is used * for requesting an OTP from the user for NPM 2FA. - * + * * @param {string} question The question to log to the users terminal. * @returns {Promise} The user input. */ async function getUserInput(question) { const rl = readline.createInterface({ input: process.stdin, - output: process.stdout + output: process.stdout, }); return new Promise((resolve) => { @@ -246,13 +360,32 @@ async function getUserInput(question) { }); } +/** + * Checks if there are any uncommitted changes in the git repository + * + * @returns {boolean} True if there are uncommitted changes, false otherwise + */ +function hasUncommittedChanges() { + try { + // This command returns empty string if no changes, or a string with changes if there are any + const output = execSync("git status --porcelain").toString(); + return output.length > 0; + } catch (error) { + console.error("Error checking git status:", error); + // If we can't check, better to assume there are changes + return true; + } +} async function main() { const program = new Command(); program .description("Release a new workspace version to NPM.") .option("--workspace ", "Workspace name, eg @langchain/core") - .option("--bump-deps", "Whether or not to bump other workspaces that depend on this one.") + .option( + "--bump-deps", + "Whether or not to bump other workspaces that depend on this one." + ) .option("--tag ", "Optionally specify a tag to publish to."); program.parse(); @@ -265,10 +398,18 @@ async function main() { throw new Error("--workspace is a required flag."); } + if (hasUncommittedChanges()) { + console.warn( + "[WARNING]: You have uncommitted changes. These will be included in the release commit." + ); + } + // Find the workspace package.json's. const allWorkspaces = getAllWorkspaces(); - const matchingWorkspace = allWorkspaces.find(({ packageJSON }) => packageJSON.name === options.workspace); - + const matchingWorkspace = allWorkspaces.find( + ({ packageJSON }) => packageJSON.name === options.workspace + ); + if (!matchingWorkspace) { throw new Error(`Could not find workspace ${options.workspace}`); } @@ -278,30 +419,30 @@ async function main() { // Run build, lint, tests console.log("Running build, lint, and tests."); - execSync(`yarn turbo:command run --filter ${options.workspace} build lint test --concurrency 1`); + execSync( + `yarn turbo:command run --filter ${options.workspace} build lint test --concurrency 1` + ); console.log("Successfully ran build, lint, and tests."); - // Only run export tests for primary projects. - if (PRIMARY_PROJECTS.includes(options.workspace.trim())) { - // Run export tests. - // LangChain must be built before running export tests. - console.log("Building 'langchain' and running export tests."); - execSync(`yarn run turbo:command build --filter=langchain`); - execSync(`yarn run test:exports:docker`); - console.log("Successfully built langchain, and tested exports."); - } else { - console.log("Skipping export tests for non primary project."); - } - - const npm2FACode = await getUserInput("Please enter your NPM 2FA authentication code:"); + const npm2FACode = await getUserInput( + "Please enter your NPM 2FA authentication code:" + ); const preReleaseVersion = getWorkspaceVersion(matchingWorkspace.dir); // Run `release-it` on workspace await runYarnRelease(matchingWorkspace.dir, npm2FACode, options.tag); - + + if (hasUncommittedChanges()) { + const updatedVersion = getWorkspaceVersion(matchingWorkspace.dir); + commitAndPushChanges(options.workspace, updatedVersion); + } + // Log release branch URL - console.log("🔗 Open %s and merge the release PR.", `\x1b[34mhttps://github.com/langchain-ai/langchainjs/compare/release?expand=1\x1b[0m`); + console.log( + "🔗 Open %s and merge the release PR.", + `\x1b[34mhttps://github.com/langchain-ai/langchainjs/compare/release?expand=1\x1b[0m` + ); // If `bump-deps` flag is set, find all workspaces which depend on the input workspace. // Then, update their package.json to use the new version of the input workspace. @@ -315,6 +456,6 @@ async function main() { preReleaseVersion ); } -}; +} main(); From 78a69951d4626892fed2c035693f10e069c4f989 Mon Sep 17 00:00:00 2001 From: CarterMorris <114012427+CarterMorris@users.noreply.github.com> Date: Thu, 21 Nov 2024 23:39:41 -0500 Subject: [PATCH 05/27] feat(mistral): Mistral 1.3.1 migration (#7218) Co-authored-by: Ashtian Co-authored-by: BaNg-W Co-authored-by: CarterMorris Co-authored-by: BaNg-W Co-authored-by: BaNg-W <114012080+BaNg-W@users.noreply.github.com> Co-authored-by: jacoblee93 --- .../docs/integrations/chat/mistral.ipynb | 117 +++- .../docs/integrations/llms/mistral.ipynb | 105 +++- .../text_embedding/mistralai.ipynb | 105 +++- libs/langchain-mistralai/package.json | 6 +- libs/langchain-mistralai/src/chat_models.ts | 503 ++++++++++++------ libs/langchain-mistralai/src/embeddings.ts | 168 +++++- libs/langchain-mistralai/src/llms.ts | 294 ++++++++-- .../src/tests/chat_models.int.test.ts | 437 ++++++++++++--- .../tests/chat_models.standard.int.test.ts | 6 +- .../src/tests/chat_models.test.ts | 1 + .../src/tests/embeddings.int.test.ts | 194 +++++++ .../src/tests/llms.int.test.ts | 196 ++++++- libs/langchain-mistralai/src/utils.ts | 40 ++ yarn.lock | 18 +- 14 files changed, 1866 insertions(+), 324 deletions(-) diff --git a/docs/core_docs/docs/integrations/chat/mistral.ipynb b/docs/core_docs/docs/integrations/chat/mistral.ipynb index ccb2c6590f908..679bde45a9257 100644 --- a/docs/core_docs/docs/integrations/chat/mistral.ipynb +++ b/docs/core_docs/docs/integrations/chat/mistral.ipynb @@ -38,7 +38,7 @@ "\n", "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", - "| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | \n", + "| ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | \n", "\n", "## Setup\n", "\n", @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae", "metadata": {}, "outputs": [], @@ -323,6 +323,117 @@ "console.log(calcToolRes.tool_calls);" ] }, + { + "cell_type": "markdown", + "id": "85dcbecc", + "metadata": {}, + "source": [ + "## Hooks\n", + "\n", + "Mistral AI supports custom hooks for three events: beforeRequest, requestError, and reponse. Examples of the function signature for each hook type can be seen below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74b8b855", + "metadata": {}, + "outputs": [], + "source": [ + "const beforeRequestHook = (req: Request): Request | void | Promise => {\n", + " // Code to run before a request is processed by Mistral\n", + "};\n", + "\n", + "const requestErrorHook = (err: unknown, req: Request): void | Promise => {\n", + " // Code to run when an error occurs as Mistral is processing a request\n", + "};\n", + "\n", + "const responseHook = (res: Response, req: Request): void | Promise => {\n", + " // Code to run before Mistral sends a successful response\n", + "};" + ] + }, + { + "cell_type": "markdown", + "id": "930df6c4", + "metadata": {}, + "source": [ + "To add these hooks to the chat model, either pass them as arguments and they are automatically added:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b8084f6", + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const modelWithHooks = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " beforeRequestHooks: [ beforeRequestHook ],\n", + " requestErrorHooks: [ requestErrorHook ],\n", + " responseHooks: [ responseHook ],\n", + " // other params...\n", + "});" + ] + }, + { + "cell_type": "markdown", + "id": "cc9478f3", + "metadata": {}, + "source": [ + "Or assign and add them manually after instantiation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daa70dc3", + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const model = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " // other params...\n", + "});\n", + "\n", + "model.beforeRequestHooks = [ ...model.beforeRequestHooks, beforeRequestHook ];\n", + "model.requestErrorHooks = [ ...model.requestErrorHooks, requestErrorHook ];\n", + "model.responseHooks = [ ...model.responseHooks, responseHook ];\n", + "\n", + "model.addAllHooksToHttpClient();" + ] + }, + { + "cell_type": "markdown", + "id": "389f5159", + "metadata": {}, + "source": [ + "The method addAllHooksToHttpClient clears all currently added hooks before assigning the entire updated hook lists to avoid hook duplication.\n", + "\n", + "Hooks can be removed one at a time, or all hooks can be cleared from the model at once." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a56b64bb", + "metadata": {}, + "outputs": [], + "source": [ + "model.removeHookFromHttpClient(beforeRequestHook);\n", + "\n", + "model.removeAllHooksFromHttpClient();" + ] + }, { "cell_type": "markdown", "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", @@ -354,4 +465,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/core_docs/docs/integrations/llms/mistral.ipynb b/docs/core_docs/docs/integrations/llms/mistral.ipynb index cb7cf11f6d015..0e59324552914 100644 --- a/docs/core_docs/docs/integrations/llms/mistral.ipynb +++ b/docs/core_docs/docs/integrations/llms/mistral.ipynb @@ -271,6 +271,109 @@ "console.log(customOutputParser(resWithParser));" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hooks\n", + "\n", + "Mistral AI supports custom hooks for three events: beforeRequest, requestError, and reponse. Examples of the function signature for each hook type can be seen below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "const beforeRequestHook = (req: Request): Request | void | Promise => {\n", + " // Code to run before a request is processed by Mistral\n", + "};\n", + "\n", + "const requestErrorHook = (err: unknown, req: Request): void | Promise => {\n", + " // Code to run when an error occurs as Mistral is processing a request\n", + "};\n", + "\n", + "const responseHook = (res: Response, req: Request): void | Promise => {\n", + " // Code to run before Mistral sends a successful response\n", + "};" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To add these hooks to the chat model, either pass them as arguments and they are automatically added:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const modelWithHooks = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " beforeRequestHooks: [ beforeRequestHook ],\n", + " requestErrorHooks: [ requestErrorHook ],\n", + " responseHooks: [ responseHook ],\n", + " // other params...\n", + "});" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or assign and add them manually after instantiation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const model = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " // other params...\n", + "});\n", + "\n", + "model.beforeRequestHooks = [ ...model.beforeRequestHooks, beforeRequestHook ];\n", + "model.requestErrorHooks = [ ...model.requestErrorHooks, requestErrorHook ];\n", + "model.responseHooks = [ ...model.responseHooks, responseHook ];\n", + "\n", + "model.addAllHooksToHttpClient();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The method addAllHooksToHttpClient clears all currently added hooks before assigning the entire updated hook lists to avoid hook duplication.\n", + "\n", + "Hooks can be removed one at a time, or all hooks can be cleared from the model at once." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.removeHookFromHttpClient(beforeRequestHook);\n", + "\n", + "model.removeAllHooksFromHttpClient();" + ] + }, { "cell_type": "markdown", "id": "e9bdfcef", @@ -307,4 +410,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/core_docs/docs/integrations/text_embedding/mistralai.ipynb b/docs/core_docs/docs/integrations/text_embedding/mistralai.ipynb index 97660a2b1d908..b272f32088438 100644 --- a/docs/core_docs/docs/integrations/text_embedding/mistralai.ipynb +++ b/docs/core_docs/docs/integrations/text_embedding/mistralai.ipynb @@ -310,6 +310,109 @@ "console.log(vectors[1].slice(0, 100));" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Hooks\n", + "\n", + "Mistral AI supports custom hooks for three events: beforeRequest, requestError, and reponse. Examples of the function signature for each hook type can be seen below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "const beforeRequestHook = (req: Request): Request | void | Promise => {\n", + " // Code to run before a request is processed by Mistral\n", + "};\n", + "\n", + "const requestErrorHook = (err: unknown, req: Request): void | Promise => {\n", + " // Code to run when an error occurs as Mistral is processing a request\n", + "};\n", + "\n", + "const responseHook = (res: Response, req: Request): void | Promise => {\n", + " // Code to run before Mistral sends a successful response\n", + "};" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To add these hooks to the chat model, either pass them as arguments and they are automatically added:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const modelWithHooks = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " beforeRequestHooks: [ beforeRequestHook ],\n", + " requestErrorHooks: [ requestErrorHook ],\n", + " responseHooks: [ responseHook ],\n", + " // other params...\n", + "});" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or assign and add them manually after instantiation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import { ChatMistralAI } from \"@langchain/mistralai\" \n", + "\n", + "const model = new ChatMistralAI({\n", + " model: \"mistral-large-latest\",\n", + " temperature: 0,\n", + " maxRetries: 2,\n", + " // other params...\n", + "});\n", + "\n", + "model.beforeRequestHooks = [ ...model.beforeRequestHooks, beforeRequestHook ];\n", + "model.requestErrorHooks = [ ...model.requestErrorHooks, requestErrorHook ];\n", + "model.responseHooks = [ ...model.responseHooks, responseHook ];\n", + "\n", + "model.addAllHooksToHttpClient();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The method addAllHooksToHttpClient clears all currently added hooks before assigning the entire updated hook lists to avoid hook duplication.\n", + "\n", + "Hooks can be removed one at a time, or all hooks can be cleared from the model at once." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.removeHookFromHttpClient(beforeRequestHook);\n", + "\n", + "model.removeAllHooksFromHttpClient();" + ] + }, { "cell_type": "markdown", "id": "8938e581", @@ -341,4 +444,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/libs/langchain-mistralai/package.json b/libs/langchain-mistralai/package.json index 35d65923078fd..7876ea753710f 100644 --- a/libs/langchain-mistralai/package.json +++ b/libs/langchain-mistralai/package.json @@ -35,13 +35,13 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@mistralai/mistralai": "^0.4.0", + "@mistralai/mistralai": "^1.3.1", "uuid": "^10.0.0", - "zod": "^3.22.4", + "zod": "^3.23.8", "zod-to-json-schema": "^3.22.4" }, "peerDependencies": { - "@langchain/core": ">=0.2.21 <0.4.0" + "@langchain/core": ">=0.3.7 <0.4.0" }, "devDependencies": { "@jest/globals": "^29.5.0", diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts index 4d9e1a447fd9a..97a195da1ed7e 100644 --- a/libs/langchain-mistralai/src/chat_models.ts +++ b/libs/langchain-mistralai/src/chat_models.ts @@ -1,19 +1,28 @@ import { v4 as uuidv4 } from "uuid"; +import { Mistral as MistralClient } from "@mistralai/mistralai"; import { - ChatCompletionResponse, - Function as MistralAIFunction, - ToolCalls as MistralAIToolCalls, - ResponseFormat, - ChatCompletionResponseChunk, - ChatRequest, - Tool as MistralAITool, - Message as MistralAIMessage, - TokenUsage as MistralAITokenUsage, -} from "@mistralai/mistralai"; + ChatCompletionRequest as MistralAIChatCompletionRequest, + ChatCompletionRequestToolChoice as MistralAIToolChoice, + Messages as MistralAIMessage, +} from "@mistralai/mistralai/models/components/chatcompletionrequest.js"; +import { ContentChunk as MistralAIContentChunk } from "@mistralai/mistralai/models/components/contentchunk.js"; +import { Tool as MistralAITool } from "@mistralai/mistralai/models/components/tool.js"; +import { ToolCall as MistralAIToolCall } from "@mistralai/mistralai/models/components/toolcall.js"; +import { ChatCompletionStreamRequest as MistralAIChatCompletionStreamRequest } from "@mistralai/mistralai/models/components/chatcompletionstreamrequest.js"; +import { UsageInfo as MistralAITokenUsage } from "@mistralai/mistralai/models/components/usageinfo.js"; +import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js"; +import { ChatCompletionResponse as MistralAIChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js"; import { + type BeforeRequestHook, + type RequestErrorHook, + type ResponseHook, + HTTPClient as MistralAIHTTPClient, +} from "@mistralai/mistralai/lib/http.js"; +import { + BaseMessage, MessageType, - type BaseMessage, MessageContent, + MessageContentComplex, AIMessage, HumanMessage, HumanMessageChunk, @@ -21,13 +30,11 @@ import { ToolMessageChunk, ChatMessageChunk, FunctionMessageChunk, - OpenAIToolCall, isAIMessage, } from "@langchain/core/messages"; import type { BaseLanguageModelInput, BaseLanguageModelCallOptions, - StructuredOutputMethodParams, StructuredOutputMethodOptions, FunctionDefinition, } from "@langchain/core/language_models/base"; @@ -44,6 +51,7 @@ import { ChatGenerationChunk, ChatResult, } from "@langchain/core/outputs"; +import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { z } from "zod"; @@ -65,7 +73,10 @@ import { } from "@langchain/core/runnables"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ToolCallChunk } from "@langchain/core/messages/tool"; -import { _convertToolCallIdToMistralCompatible } from "./utils.js"; +import { + _convertToolCallIdToMistralCompatible, + _mistralContentChunkToMessageContentComplex, +} from "./utils.js"; interface TokenUsage { completionTokens?: number; @@ -73,14 +84,7 @@ interface TokenUsage { totalTokens?: number; } -export type MistralAIToolChoice = "auto" | "any" | "none"; - -type MistralAIToolInput = { type: string; function: MistralAIFunction }; - -type ChatMistralAIToolType = - | MistralAIToolInput - | MistralAITool - | BindToolsInput; +type ChatMistralAIToolType = MistralAIToolCall | MistralAITool | BindToolsInput; export interface ChatMistralAICallOptions extends Omit { @@ -110,6 +114,7 @@ export interface ChatMistralAIInput /** * The name of the model to use. * Alias for `model` + * @deprecated Use `model` instead. * @default {"mistral-small-latest"} */ modelName?: string; @@ -119,9 +124,14 @@ export interface ChatMistralAIInput */ model?: string; /** - * Override the default endpoint. + * Override the default server URL used by the Mistral SDK. + * @deprecated use serverURL instead */ endpoint?: string; + /** + * Override the default server URL used by the Mistral SDK. + */ + serverURL?: string; /** * What sampling temperature to use, between 0.0 and 2.0. * Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -165,6 +175,42 @@ export interface ChatMistralAIInput * The seed to use for random sampling. If set, different calls will generate deterministic results. */ seed?: number; + /** + * A list of custom hooks that must follow (req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created. + */ + beforeRequestHooks?: BeforeRequestHook[]; + /** + * A list of custom hooks that must follow (err: unknown, req: Request) => Awaitable. + * They are automatically added when a ChatMistralAI instance is created. + */ + requestErrorHooks?: RequestErrorHook[]; + /** + * A list of custom hooks that must follow (res: Response, req: Request) => Awaitable. + * They are automatically added when a ChatMistralAI instance is created. + */ + responseHooks?: ResponseHook[]; + /** + * Custom HTTP client to manage API requests. + * Allows users to add custom fetch implementations, hooks, as well as error and response processing. + */ + httpClient?: MistralAIHTTPClient; + /** + * Determines how much the model penalizes the repetition of words or phrases. A higher presence + * penalty encourages the model to use a wider variety of words and phrases, making the output + * more diverse and creative. + */ + presencePenalty?: number; + /** + * Penalizes the repetition of words based on their frequency in the generated text. A higher + * frequency penalty discourages the model from repeating words that have already appeared frequently + * in the output, promoting diversity and reducing repetition. + */ + frequencyPenalty?: number; + /** + * Number of completions to return for each request, input tokens are only billed once. + */ + numCompletions?: number; } function convertMessagesToMistralMessages( @@ -187,12 +233,69 @@ function convertMessagesToMistralMessages( } }; - const getContent = (content: MessageContent): string => { + const getContent = ( + content: MessageContent, + type: MessageType + ): string | MistralAIContentChunk[] => { + const _generateContentChunk = ( + complex: MessageContentComplex, + role: string + ): MistralAIContentChunk => { + if ( + complex.type === "image_url" && + (role === "user" || role === "assistant") + ) { + return { + type: complex.type, + imageUrl: complex?.image_url, + }; + } + + if (complex.type === "text") { + return { + type: complex.type, + text: complex?.text, + }; + } + + throw new Error( + `ChatMistralAI only supports messages of "image_url" for roles "user" and "assistant", and "text" for all others.\n\nReceived: ${JSON.stringify( + content, + null, + 2 + )}` + ); + }; + if (typeof content === "string") { return content; } + + if (Array.isArray(content)) { + const mistralRole = getRole(type); + // Mistral "assistant" and "user" roles can support Mistral ContentChunks + // Mistral "system" role can support Mistral TextChunks + const newContent: MistralAIContentChunk[] = []; + content.forEach((messageContentComplex) => { + // Mistral content chunks only support type "text" and "image_url" + if ( + messageContentComplex.type === "text" || + messageContentComplex.type === "image_url" + ) { + newContent.push( + _generateContentChunk(messageContentComplex, mistralRole) + ); + } else { + throw new Error( + `Mistral only supports types "text" or "image_url" for complex message types.` + ); + } + }); + return newContent; + } + throw new Error( - `ChatMistralAI does not support non text message content. Received: ${JSON.stringify( + `Message content must be a string or an array.\n\nReceived: ${JSON.stringify( content, null, 2 @@ -200,61 +303,66 @@ function convertMessagesToMistralMessages( ); }; - const getTools = (message: BaseMessage): MistralAIToolCalls[] | undefined => { + const getTools = (message: BaseMessage): MistralAIToolCall[] | undefined => { if (isAIMessage(message) && !!message.tool_calls?.length) { return message.tool_calls .map((toolCall) => ({ ...toolCall, id: _convertToolCallIdToMistralCompatible(toolCall.id ?? ""), })) - .map(convertLangChainToolCallToOpenAI) as MistralAIToolCalls[]; + .map(convertLangChainToolCallToOpenAI) as MistralAIToolCall[]; } - if (!message.additional_kwargs.tool_calls?.length) { - return undefined; - } - const toolCalls: Omit[] = - message.additional_kwargs.tool_calls; - return toolCalls?.map((toolCall) => ({ - id: _convertToolCallIdToMistralCompatible(toolCall.id), - type: "function", - function: toolCall.function, - })); + return undefined; }; return messages.map((message) => { const toolCalls = getTools(message); - const content = toolCalls === undefined ? getContent(message.content) : ""; + const content = getContent(message.content, message.getType()); if ("tool_call_id" in message && typeof message.tool_call_id === "string") { return { - role: getRole(message._getType()), + role: getRole(message.getType()), content, name: message.name, - tool_call_id: _convertToolCallIdToMistralCompatible( - message.tool_call_id - ), + toolCallId: _convertToolCallIdToMistralCompatible(message.tool_call_id), }; + // Mistral "assistant" role can only support either content or tool calls but not both + } else if (isAIMessage(message)) { + if (toolCalls === undefined) { + return { + role: getRole(message.getType()), + content, + }; + } else { + return { + role: getRole(message.getType()), + toolCalls, + }; + } } return { - role: getRole(message._getType()), + role: getRole(message.getType()), content, - tool_calls: toolCalls, }; }) as MistralAIMessage[]; } function mistralAIResponseToChatMessage( - choice: ChatCompletionResponse["choices"][0], + choice: NonNullable[0], usage?: MistralAITokenUsage ): BaseMessage { const { message } = choice; - // MistralAI SDK does not include tool_calls in the non + if (message === undefined) { + throw new Error("No message found in response"); + } + // MistralAI SDK does not include toolCalls in the non // streaming return type, so we need to extract it like this // to satisfy typescript. - let rawToolCalls: MistralAIToolCalls[] = []; - if ("tool_calls" in message && Array.isArray(message.tool_calls)) { - rawToolCalls = message.tool_calls as MistralAIToolCalls[]; + let rawToolCalls: MistralAIToolCall[] = []; + if ("toolCalls" in message && Array.isArray(message.toolCalls)) { + rawToolCalls = message.toolCalls; } + const content = _mistralContentChunkToMessageContentComplex(message.content); switch (message.role) { case "assistant": { const toolCalls = []; @@ -272,48 +380,41 @@ function mistralAIResponseToChatMessage( } } return new AIMessage({ - content: message.content ?? "", + content, tool_calls: toolCalls, invalid_tool_calls: invalidToolCalls, - additional_kwargs: { - tool_calls: rawToolCalls.length - ? rawToolCalls.map((toolCall) => ({ - ...toolCall, - type: "function", - })) - : undefined, - }, + additional_kwargs: {}, usage_metadata: usage ? { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - total_tokens: usage.total_tokens, + input_tokens: usage.promptTokens, + output_tokens: usage.completionTokens, + total_tokens: usage.totalTokens, } : undefined, }); } default: - return new HumanMessage(message.content ?? ""); + return new HumanMessage({ content }); } } function _convertDeltaToMessageChunk( delta: { - role?: string | undefined; - content?: string | undefined; - tool_calls?: MistralAIToolCalls[] | undefined; + role?: string | null | undefined; + content?: string | MistralAIContentChunk[] | null | undefined; + toolCalls?: MistralAIToolCall[] | null | undefined; }, usage?: MistralAITokenUsage | null ) { - if (!delta.content && !delta.tool_calls) { + if (!delta.content && !delta.toolCalls) { if (usage) { return new AIMessageChunk({ content: "", usage_metadata: usage ? { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - total_tokens: usage.total_tokens, + input_tokens: usage.promptTokens, + output_tokens: usage.completionTokens, + total_tokens: usage.totalTokens, } : undefined, }); @@ -323,9 +424,9 @@ function _convertDeltaToMessageChunk( // Our merge additional kwargs util function will throw unless there // is an index key in each tool object (as seen in OpenAI's) so we // need to insert it here. - const rawToolCallChunksWithIndex = delta.tool_calls?.length - ? delta.tool_calls?.map( - (toolCall, index): OpenAIToolCall => ({ + const rawToolCallChunksWithIndex = delta.toolCalls?.length + ? delta.toolCalls?.map( + (toolCall, index): MistralAIToolCall & { index: number } => ({ ...toolCall, index, id: toolCall.id ?? uuidv4().replace(/-/g, ""), @@ -338,17 +439,20 @@ function _convertDeltaToMessageChunk( if (delta.role) { role = delta.role; } - const content = delta.content ?? ""; + const content = _mistralContentChunkToMessageContentComplex(delta.content); + let additional_kwargs; const toolCallChunks: ToolCallChunk[] = []; if (rawToolCallChunksWithIndex !== undefined) { - additional_kwargs = { - tool_calls: rawToolCallChunksWithIndex, - }; for (const rawToolCallChunk of rawToolCallChunksWithIndex) { + const rawArgs = rawToolCallChunk.function?.arguments; + const args = + rawArgs === undefined || typeof rawArgs === "string" + ? rawArgs + : JSON.stringify(rawArgs); toolCallChunks.push({ name: rawToolCallChunk.function?.name, - args: rawToolCallChunk.function?.arguments, + args, id: rawToolCallChunk.id, index: rawToolCallChunk.index, type: "tool_call_chunk", @@ -367,9 +471,9 @@ function _convertDeltaToMessageChunk( additional_kwargs, usage_metadata: usage ? { - input_tokens: usage.prompt_tokens, - output_tokens: usage.completion_tokens, - total_tokens: usage.total_tokens, + input_tokens: usage.promptTokens, + output_tokens: usage.completionTokens, + total_tokens: usage.totalTokens, } : undefined, }); @@ -748,13 +852,16 @@ export class ChatMistralAI< lc_namespace = ["langchain", "chat_models", "mistralai"]; - modelName = "mistral-small-latest"; - model = "mistral-small-latest"; apiKey: string; - endpoint?: string; + /** + * @deprecated use serverURL instead + */ + endpoint: string; + + serverURL?: string; temperature = 0.7; @@ -775,10 +882,26 @@ export class ChatMistralAI< seed?: number; + maxRetries?: number; + lc_serializable = true; streamUsage = true; + beforeRequestHooks?: Array; + + requestErrorHooks?: Array; + + responseHooks?: Array; + + httpClient?: MistralAIHTTPClient; + + presencePenalty?: number; + + frequencyPenalty?: number; + + numCompletions?: number; + constructor(fields?: ChatMistralAIInput) { super(fields ?? {}); const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY"); @@ -789,17 +912,26 @@ export class ChatMistralAI< } this.apiKey = apiKey; this.streaming = fields?.streaming ?? this.streaming; - this.endpoint = fields?.endpoint; + this.serverURL = fields?.serverURL ?? this.serverURL; this.temperature = fields?.temperature ?? this.temperature; this.topP = fields?.topP ?? this.topP; this.maxTokens = fields?.maxTokens ?? this.maxTokens; - this.safeMode = fields?.safeMode ?? this.safeMode; this.safePrompt = fields?.safePrompt ?? this.safePrompt; this.randomSeed = fields?.seed ?? fields?.randomSeed ?? this.seed; this.seed = this.randomSeed; - this.modelName = fields?.model ?? fields?.modelName ?? this.model; - this.model = this.modelName; + this.maxRetries = fields?.maxRetries; + this.httpClient = fields?.httpClient; + this.model = fields?.model ?? fields?.modelName ?? this.model; this.streamUsage = fields?.streamUsage ?? this.streamUsage; + this.beforeRequestHooks = + fields?.beforeRequestHooks ?? this.beforeRequestHooks; + this.requestErrorHooks = + fields?.requestErrorHooks ?? this.requestErrorHooks; + this.responseHooks = fields?.responseHooks ?? this.responseHooks; + this.presencePenalty = fields?.presencePenalty ?? this.presencePenalty; + this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty; + this.numCompletions = fields?.numCompletions ?? this.numCompletions; + this.addAllHooksToHttpClient(); } get lc_secrets(): { [key: string]: string } | undefined { @@ -834,22 +966,27 @@ export class ChatMistralAI< */ invocationParams( options?: this["ParsedCallOptions"] - ): Omit { + ): Omit< + MistralAIChatCompletionRequest | MistralAIChatCompletionStreamRequest, + "messages" + > { const { response_format, tools, tool_choice } = options ?? {}; const mistralAITools: Array | undefined = tools?.length ? _convertToolToMistralTool(tools) : undefined; - const params: Omit = { + const params: Omit = { model: this.model, tools: mistralAITools, temperature: this.temperature, maxTokens: this.maxTokens, topP: this.topP, randomSeed: this.seed, - safeMode: this.safeMode, safePrompt: this.safePrompt, toolChoice: tool_choice, - responseFormat: response_format as ResponseFormat, + responseFormat: response_format, + presencePenalty: this.presencePenalty, + frequencyPenalty: this.frequencyPenalty, + n: this.numCompletions, }; return params; } @@ -867,41 +1004,55 @@ export class ChatMistralAI< /** * Calls the MistralAI API with retry logic in case of failures. * @param {ChatRequest} input The input to send to the MistralAI API. - * @returns {Promise>} The response from the MistralAI API. + * @returns {Promise>} The response from the MistralAI API. */ async completionWithRetry( - input: ChatRequest, + input: MistralAIChatCompletionStreamRequest, streaming: true - ): Promise>; + ): Promise>; async completionWithRetry( - input: ChatRequest, + input: MistralAIChatCompletionRequest, streaming: false - ): Promise; + ): Promise; async completionWithRetry( - input: ChatRequest, + input: + | MistralAIChatCompletionRequest + | MistralAIChatCompletionStreamRequest, streaming: boolean ): Promise< - ChatCompletionResponse | AsyncGenerator + | MistralAIChatCompletionResponse + | AsyncIterable > { - const { MistralClient } = await this.imports(); - const client = new MistralClient(this.apiKey, this.endpoint); + const caller = new AsyncCaller({ + maxRetries: this.maxRetries, + }); + const client = new MistralClient({ + apiKey: this.apiKey, + serverURL: this.serverURL, + // If httpClient exists, pass it into constructor + ...(this.httpClient ? { httpClient: this.httpClient } : {}), + }); - return this.caller.call(async () => { + return caller.call(async () => { try { let res: - | ChatCompletionResponse - | AsyncGenerator; + | MistralAIChatCompletionResponse + | AsyncIterable; if (streaming) { - res = client.chatStream(input); + res = await client.chat.stream(input); } else { - res = await client.chat(input); + res = await client.chat.complete(input); } return res; // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (e: any) { - if (e.message?.includes("status: 400")) { + if ( + e.message?.includes("status: 400") || + e.message?.toLowerCase().includes("status 400") || + e.message?.includes("validation failed") + ) { e.status = 400; } throw e; @@ -925,7 +1076,7 @@ export class ChatMistralAI< // Enable streaming for signal controller or timeout due // to SDK limitations on canceling requests. - const shouldStream = !!options.signal ?? !!options.timeout; + const shouldStream = options.signal ?? !!options.timeout; // Handle streaming if (this.streaming || shouldStream) { @@ -950,11 +1101,8 @@ export class ChatMistralAI< // Not streaming, so we can just call the API once. const response = await this.completionWithRetry(input, false); - const { - completion_tokens: completionTokens, - prompt_tokens: promptTokens, - total_tokens: totalTokens, - } = response?.usage ?? {}; + const { completionTokens, promptTokens, totalTokens } = + response?.usage ?? {}; if (completionTokens) { tokenUsage.completionTokens = @@ -977,13 +1125,16 @@ export class ChatMistralAI< if (!("message" in part)) { throw new Error("No message found in the choice."); } - const text = part.message?.content ?? ""; + let text = part.message?.content ?? ""; + if (Array.isArray(text)) { + text = text[0].type === "text" ? text[0].text : ""; + } const generation: ChatGeneration = { text, message: mistralAIResponseToChatMessage(part, response?.usage), }; - if (part.finish_reason) { - generation.generationInfo = { finish_reason: part.finish_reason }; + if (part.finishReason) { + generation.generationInfo = { finishReason: part.finishReason }; } generations.push(generation); } @@ -1006,7 +1157,7 @@ export class ChatMistralAI< }; const streamIterable = await this.completionWithRetry(input, true); - for await (const data of streamIterable) { + for await (const { data } of streamIterable) { if (options.signal?.aborted) { throw new Error("AbortError"); } @@ -1032,9 +1183,13 @@ export class ChatMistralAI< // Do not yield a chunk if the message is empty continue; } + let text = delta.content ?? ""; + if (Array.isArray(text)) { + text = text[0].type === "text" ? text[0].text : ""; + } const generationChunk = new ChatGenerationChunk({ message, - text: delta.content ?? "", + text, generationInfo: newTokenIndices, }); yield generationChunk; @@ -1050,6 +1205,79 @@ export class ChatMistralAI< } } + addAllHooksToHttpClient() { + try { + // To prevent duplicate hooks + this.removeAllHooksFromHttpClient(); + + // If the user wants to use hooks, but hasn't created an HTTPClient yet + const hasHooks = [ + this.beforeRequestHooks, + this.requestErrorHooks, + this.responseHooks, + ].some((hook) => hook && hook.length > 0); + if (hasHooks && !this.httpClient) { + this.httpClient = new MistralAIHTTPClient(); + } + + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.addHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.addHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.addHook("response", hook); + } + } + } catch { + throw new Error("Error in adding all hooks"); + } + } + + removeAllHooksFromHttpClient() { + try { + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.removeHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.removeHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.removeHook("response", hook); + } + } + } catch { + throw new Error("Error in removing hooks"); + } + } + + removeHookFromHttpClient( + hook: BeforeRequestHook | RequestErrorHook | ResponseHook + ) { + try { + this.httpClient?.removeHook("beforeRequest", hook as BeforeRequestHook); + this.httpClient?.removeHook("requestError", hook as RequestErrorHook); + this.httpClient?.removeHook("response", hook as ResponseHook); + } catch { + throw new Error("Error in removing hook"); + } + } + /** @ignore */ _combineLLMOutput() { return []; @@ -1060,7 +1288,6 @@ export class ChatMistralAI< RunOutput extends Record = Record >( outputSchema: - | StructuredOutputMethodParams | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, @@ -1072,7 +1299,6 @@ export class ChatMistralAI< RunOutput extends Record = Record >( outputSchema: - | StructuredOutputMethodParams | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, @@ -1084,7 +1310,6 @@ export class ChatMistralAI< RunOutput extends Record = Record >( outputSchema: - | StructuredOutputMethodParams | z.ZodType // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, @@ -1096,21 +1321,11 @@ export class ChatMistralAI< { raw: BaseMessage; parsed: RunOutput } > { // eslint-disable-next-line @typescript-eslint/no-explicit-any - let schema: z.ZodType | Record; - let name; - let method; - let includeRaw; - if (isStructuredOutputMethodParams(outputSchema)) { - schema = outputSchema.schema; - name = outputSchema.name; - method = outputSchema.method; - includeRaw = outputSchema.includeRaw; - } else { - schema = outputSchema; - name = config?.name; - method = config?.method; - includeRaw = config?.includeRaw; - } + const schema: z.ZodType | Record = outputSchema; + const name = config?.name; + const method = config?.method; + const includeRaw = config?.includeRaw; + let llm: Runnable; let outputParser: BaseLLMOutputParser; @@ -1205,12 +1420,6 @@ export class ChatMistralAI< parsedWithFallback, ]); } - - /** @ignore */ - private async imports() { - const { default: MistralClient } = await import("@mistralai/mistralai"); - return { MistralClient }; - } } function isZodSchema< @@ -1223,15 +1432,3 @@ function isZodSchema< // Check for a characteristic method of Zod schemas return typeof (input as z.ZodType)?.parse === "function"; } - -function isStructuredOutputMethodParams( - x: unknown - // eslint-disable-next-line @typescript-eslint/no-explicit-any -): x is StructuredOutputMethodParams> { - return ( - x !== undefined && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - typeof (x as StructuredOutputMethodParams>).schema === - "object" - ); -} diff --git a/libs/langchain-mistralai/src/embeddings.ts b/libs/langchain-mistralai/src/embeddings.ts index f750f53cf5a83..c21d1d4947c7b 100644 --- a/libs/langchain-mistralai/src/embeddings.ts +++ b/libs/langchain-mistralai/src/embeddings.ts @@ -1,7 +1,14 @@ import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; import { chunkArray } from "@langchain/core/utils/chunk_array"; -import { EmbeddingResponse } from "@mistralai/mistralai"; +import { EmbeddingRequest as MistralAIEmbeddingsRequest } from "@mistralai/mistralai/src/models/components/embeddingrequest.js"; +import { EmbeddingResponse as MistralAIEmbeddingsResponse } from "@mistralai/mistralai/src/models/components/embeddingresponse.js"; +import { + BeforeRequestHook, + RequestErrorHook, + ResponseHook, + HTTPClient as MistralAIHTTPClient, +} from "@mistralai/mistralai/lib/http.js"; /** * Interface for MistralAIEmbeddings parameters. Extends EmbeddingsParams and @@ -30,9 +37,14 @@ export interface MistralAIEmbeddingsParams extends EmbeddingsParams { */ encodingFormat?: string; /** - * Override the default endpoint. + * Override the default server URL used by the Mistral SDK. + * @deprecated use serverURL instead */ endpoint?: string; + /** + * Override the default server URL used by the Mistral SDK. + */ + serverURL?: string; /** * The maximum number of documents to embed in a single request. * @default {512} @@ -44,6 +56,26 @@ export interface MistralAIEmbeddingsParams extends EmbeddingsParams { * @default {true} */ stripNewLines?: boolean; + /** + * A list of custom hooks that must follow (req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + beforeRequestHooks?: BeforeRequestHook[]; + /** + * A list of custom hooks that must follow (err: unknown, req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + requestErrorHooks?: RequestErrorHook[]; + /** + * A list of custom hooks that must follow (res: Response, req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + responseHooks?: ResponseHook[]; + /** + * Optional custom HTTP client to manage API requests + * Allows users to add custom fetch implementations, hooks, as well as error and response processing. + */ + httpClient?: MistralAIHTTPClient; } /** @@ -65,7 +97,20 @@ export class MistralAIEmbeddings apiKey: string; - endpoint?: string; + /** + * @deprecated use serverURL instead + */ + endpoint: string; + + serverURL?: string; + + beforeRequestHooks?: Array; + + requestErrorHooks?: Array; + + responseHooks?: Array; + + httpClient?: MistralAIHTTPClient; constructor(fields?: Partial) { super(fields ?? {}); @@ -74,12 +119,19 @@ export class MistralAIEmbeddings throw new Error("API key missing for MistralAI, but it is required."); } this.apiKey = apiKey; - this.endpoint = fields?.endpoint; + this.serverURL = fields?.serverURL ?? this.serverURL; this.modelName = fields?.model ?? fields?.modelName ?? this.model; this.model = this.modelName; this.encodingFormat = fields?.encodingFormat ?? this.encodingFormat; this.batchSize = fields?.batchSize ?? this.batchSize; this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; + this.beforeRequestHooks = + fields?.beforeRequestHooks ?? this.beforeRequestHooks; + this.requestErrorHooks = + fields?.requestErrorHooks ?? this.requestErrorHooks; + this.responseHooks = fields?.responseHooks ?? this.responseHooks; + this.httpClient = fields?.httpClient ?? this.httpClient; + this.addAllHooksToHttpClient(); } /** @@ -105,7 +157,7 @@ export class MistralAIEmbeddings const batch = batches[i]; const { data: batchResponse } = batchResponses[i]; for (let j = 0; j < batch.length; j += 1) { - embeddings.push(batchResponse[j].embedding); + embeddings.push(batchResponse[j].embedding ?? []); } } return embeddings; @@ -121,33 +173,113 @@ export class MistralAIEmbeddings const { data } = await this.embeddingWithRetry( this.stripNewLines ? text.replace(/\n/g, " ") : text ); - return data[0].embedding; + return data[0].embedding ?? []; } /** * Private method to make a request to the MistralAI API to generate * embeddings. Handles the retry logic and returns the response from the * API. - * @param {string | Array} input Text to send to the MistralAI API. - * @returns {Promise} Promise that resolves to the response from the API. + * @param {string | Array} inputs Text to send to the MistralAI API. + * @returns {Promise} Promise that resolves to the response from the API. */ private async embeddingWithRetry( - input: string | Array - ): Promise { - const { MistralClient } = await this.imports(); - const client = new MistralClient(this.apiKey, this.endpoint); + inputs: string | Array + ): Promise { + const { Mistral } = await this.imports(); + const client = new Mistral({ + apiKey: this.apiKey, + serverURL: this.serverURL, + // If httpClient exists, pass it into constructor + ...(this.httpClient ? { httpClient: this.httpClient } : {}), + }); + const embeddingsRequest: MistralAIEmbeddingsRequest = { + model: this.model, + inputs, + encodingFormat: this.encodingFormat, + }; return this.caller.call(async () => { - const res = await client.embeddings({ - model: this.model, - input, - }); + const res = await client.embeddings.create(embeddingsRequest); return res; }); } + addAllHooksToHttpClient() { + try { + // To prevent duplicate hooks + this.removeAllHooksFromHttpClient(); + + // If the user wants to use hooks, but hasn't created an HTTPClient yet + const hasHooks = [ + this.beforeRequestHooks, + this.requestErrorHooks, + this.responseHooks, + ].some((hook) => hook && hook.length > 0); + if (hasHooks && !this.httpClient) { + this.httpClient = new MistralAIHTTPClient(); + } + + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.addHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.addHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.addHook("response", hook); + } + } + } catch { + throw new Error("Error in adding all hooks"); + } + } + + removeAllHooksFromHttpClient() { + try { + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.removeHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.removeHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.removeHook("response", hook); + } + } + } catch { + throw new Error("Error in removing hooks"); + } + } + + removeHookFromHttpClient( + hook: BeforeRequestHook | RequestErrorHook | ResponseHook + ) { + try { + this.httpClient?.removeHook("beforeRequest", hook as BeforeRequestHook); + this.httpClient?.removeHook("requestError", hook as RequestErrorHook); + this.httpClient?.removeHook("response", hook as ResponseHook); + } catch { + throw new Error("Error in removing hook"); + } + } + /** @ignore */ private async imports() { - const { default: MistralClient } = await import("@mistralai/mistralai"); - return { MistralClient }; + const { Mistral } = await import("@mistralai/mistralai"); + return { Mistral }; } } diff --git a/libs/langchain-mistralai/src/llms.ts b/libs/langchain-mistralai/src/llms.ts index 2a232a9c0449b..859369f4f0ab6 100644 --- a/libs/langchain-mistralai/src/llms.ts +++ b/libs/langchain-mistralai/src/llms.ts @@ -2,12 +2,18 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseLLMParams, LLM } from "@langchain/core/language_models/llms"; import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base"; import { GenerationChunk, LLMResult } from "@langchain/core/outputs"; +import { FIMCompletionRequest as MistralAIFIMCompletionRequest } from "@mistralai/mistralai/models/components/fimcompletionrequest.js"; +import { FIMCompletionStreamRequest as MistralAIFIMCompletionStreamRequest } from "@mistralai/mistralai/models/components/fimcompletionstreamrequest.js"; +import { FIMCompletionResponse as MistralAIFIMCompletionResponse } from "@mistralai/mistralai/models/components/fimcompletionresponse.js"; +import { ChatCompletionChoice as MistralAIChatCompletionChoice } from "@mistralai/mistralai/models/components/chatcompletionchoice.js"; +import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js"; +import { CompletionChunk as MistralAICompetionChunk } from "@mistralai/mistralai/models/components/completionchunk.js"; import { - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseChunk, - type CompletionRequest, -} from "@mistralai/mistralai"; + BeforeRequestHook, + RequestErrorHook, + ResponseHook, + HTTPClient as MistralAIHTTPClient, +} from "@mistralai/mistralai/lib/http.js"; import { getEnvironmentVariable } from "@langchain/core/utils/env"; import { chunkArray } from "@langchain/core/utils/chunk_array"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; @@ -34,9 +40,14 @@ export interface MistralAIInput extends BaseLLMParams { */ apiKey?: string; /** - * Override the default endpoint. + * Override the default server URL used by the Mistral SDK. + * @deprecated use serverURL instead */ endpoint?: string; + /** + * Override the default server URL used by the Mistral SDK. + */ + serverURL?: string; /** * What sampling temperature to use, between 0.0 and 2.0. * Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. @@ -69,6 +80,26 @@ export interface MistralAIInput extends BaseLLMParams { * Batch size to use when passing multiple documents to generate */ batchSize?: number; + /** + * A list of custom hooks that must follow (req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + beforeRequestHooks?: BeforeRequestHook[]; + /** + * A list of custom hooks that must follow (err: unknown, req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + requestErrorHooks?: RequestErrorHook[]; + /** + * A list of custom hooks that must follow (res: Response, req: Request) => Awaitable + * They are automatically added when a ChatMistralAI instance is created + */ + responseHooks?: ResponseHook[]; + /** + * Optional custom HTTP client to manage API requests + * Allows users to add custom fetch implementations, hooks, as well as error and response processing. + */ + httpClient?: MistralAIHTTPClient; } /** @@ -102,12 +133,25 @@ export class MistralAI apiKey: string; - endpoint?: string; + /** + * @deprecated use serverURL instead + */ + endpoint: string; + + serverURL?: string; maxRetries?: number; maxConcurrency?: number; + beforeRequestHooks?: Array; + + requestErrorHooks?: Array; + + responseHooks?: Array; + + httpClient?: MistralAIHTTPClient; + constructor(fields?: MistralAIInput) { super(fields ?? {}); @@ -118,9 +162,15 @@ export class MistralAI this.randomSeed = fields?.randomSeed ?? this.randomSeed; this.batchSize = fields?.batchSize ?? this.batchSize; this.streaming = fields?.streaming ?? this.streaming; - this.endpoint = fields?.endpoint; + this.serverURL = fields?.serverURL ?? this.serverURL; this.maxRetries = fields?.maxRetries; this.maxConcurrency = fields?.maxConcurrency; + this.beforeRequestHooks = + fields?.beforeRequestHooks ?? this.beforeRequestHooks; + this.requestErrorHooks = + fields?.requestErrorHooks ?? this.requestErrorHooks; + this.responseHooks = fields?.responseHooks ?? this.responseHooks; + this.httpClient = fields?.httpClient ?? this.httpClient; const apiKey = fields?.apiKey ?? getEnvironmentVariable("MISTRAL_API_KEY"); if (!apiKey) { @@ -130,6 +180,8 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA ); } this.apiKey = apiKey; + + this.addAllHooksToHttpClient(); } get lc_secrets(): { [key: string]: string } | undefined { @@ -150,7 +202,10 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA invocationParams( options: this["ParsedCallOptions"] - ): Omit { + ): Omit< + MistralAIFIMCompletionRequest | MistralAIFIMCompletionStreamRequest, + "prompt" + > { return { model: this.model, suffix: options.suffix, @@ -177,7 +232,11 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA prompt, }; const result = await this.completionWithRetry(params, options, false); - return result.choices[0].message.content ?? ""; + let content = result?.choices?.[0].message.content ?? ""; + if (Array.isArray(content)) { + content = content[0].type === "text" ? content[0].text : ""; + } + return content; } async _generate( @@ -186,7 +245,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA runManager?: CallbackManagerForLLMRun ): Promise { const subPrompts = chunkArray(prompts, this.batchSize); - const choices: ChatCompletionResponseChoice[][] = []; + const choices: MistralAIChatCompletionChoice[][] = []; const params = this.invocationParams(options); @@ -194,14 +253,14 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA const data = await (async () => { if (this.streaming) { const responseData: Array< - { choices: ChatCompletionResponseChoice[] } & Partial< - Omit + { choices: MistralAIChatCompletionChoice[] } & Partial< + Omit > > = []; for (let x = 0; x < subPrompts[i].length; x += 1) { - const choices: ChatCompletionResponseChoice[] = []; + const choices: MistralAIChatCompletionChoice[] = []; let response: - | Omit + | Omit | undefined; const stream = await this.completionWithRetry( { @@ -211,35 +270,52 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA options, true ); - for await (const message of stream) { + for await (const { data } of stream) { // on the first message set the response properties if (!response) { response = { - id: message.id, + id: data.id, object: "chat.completion", - created: message.created, - model: message.model, + created: data.created, + model: data.model, }; } // on all messages, update choice - for (const part of message.choices) { + for (const part of data.choices) { + let content = part.delta.content ?? ""; + // Convert MistralContentChunk data into a string + if (Array.isArray(content)) { + let strContent = ""; + for (const contentChunk of content) { + if (contentChunk.type === "text") { + strContent += contentChunk.text; + } else if (contentChunk.type === "image_url") { + const imageURL = + typeof contentChunk.imageUrl === "string" + ? contentChunk.imageUrl + : contentChunk.imageUrl.url; + strContent += imageURL; + } + } + content = strContent; + } if (!choices[part.index]) { choices[part.index] = { index: part.index, message: { - role: part.delta.role ?? "assistant", - content: part.delta.content ?? "", - tool_calls: null, + role: "assistant", + content, + toolCalls: null, }, - finish_reason: part.finish_reason, + finishReason: part.finishReason ?? "length", }; } else { const choice = choices[part.index]; - choice.message.content += part.delta.content ?? ""; - choice.finish_reason = part.finish_reason; + choice.message.content += content; + choice.finishReason = part.finishReason ?? "length"; } - void runManager?.handleLLMNewToken(part.delta.content ?? "", { + void runManager?.handleLLMNewToken(content, { prompt: part.index, completion: part.index, }); @@ -255,7 +331,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA } return responseData; } else { - const responseData: Array = []; + const responseData: Array = []; for (let x = 0; x < subPrompts[i].length; x += 1) { const res = await this.completionWithRetry( { @@ -271,16 +347,22 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA } })(); - choices.push(...data.map((d) => d.choices)); + choices.push(...data.map((d) => d.choices ?? [])); } const generations = choices.map((promptChoices) => - promptChoices.map((choice) => ({ - text: choice.message.content ?? "", - generationInfo: { - finishReason: choice.finish_reason, - }, - })) + promptChoices.map((choice) => { + let text = choice.message?.content ?? ""; + if (Array.isArray(text)) { + text = text[0].type === "text" ? text[0].text : ""; + } + return { + text, + generationInfo: { + finishReason: choice.finishReason, + }, + }; + }) ); return { generations, @@ -288,45 +370,63 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA } async completionWithRetry( - request: CompletionRequest, + request: MistralAIFIMCompletionRequest, options: this["ParsedCallOptions"], stream: false - ): Promise; + ): Promise; async completionWithRetry( - request: CompletionRequest, + request: MistralAIFIMCompletionStreamRequest, options: this["ParsedCallOptions"], stream: true - ): Promise>; + ): Promise>; async completionWithRetry( - request: CompletionRequest, + request: + | MistralAIFIMCompletionRequest + | MistralAIFIMCompletionStreamRequest, options: this["ParsedCallOptions"], stream: boolean ): Promise< - | ChatCompletionResponse - | AsyncGenerator + MistralAIFIMCompletionResponse | AsyncIterable > { - const { MistralClient } = await this.imports(); + const { Mistral } = await this.imports(); const caller = new AsyncCaller({ maxConcurrency: options.maxConcurrency || this.maxConcurrency, maxRetries: this.maxRetries, }); - const client = new MistralClient( - this.apiKey, - this.endpoint, - this.maxRetries, - options.timeout - ); + const client = new Mistral({ + apiKey: this.apiKey, + serverURL: this.serverURL, + timeoutMs: options.timeout, + // If httpClient exists, pass it into constructor + ...(this.httpClient ? { httpClient: this.httpClient } : {}), + }); return caller.callWithOptions( { signal: options.signal, }, async () => { - if (stream) { - return client.completionStream(request); - } else { - return client.completion(request); + try { + let res: + | MistralAIFIMCompletionResponse + | AsyncIterable; + if (stream) { + res = await client.fim.stream(request); + } else { + res = await client.fim.complete(request); + } + return res; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if ( + e.message?.includes("status: 400") || + e.message?.toLowerCase().includes("status 400") || + e.message?.includes("validation failed") + ) { + e.status = 400; + } + throw e; } } ); @@ -342,15 +442,20 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA prompt, }; const stream = await this.completionWithRetry(params, options, true); - for await (const data of stream) { + for await (const message of stream) { + const { data } = message; const choice = data?.choices[0]; if (!choice) { continue; } + let text = choice.delta.content ?? ""; + if (Array.isArray(text)) { + text = text[0].type === "text" ? text[0].text : ""; + } const chunk = new GenerationChunk({ - text: choice.delta.content ?? "", + text, generationInfo: { - finishReason: choice.finish_reason, + finishReason: choice.finishReason, tokenUsage: data.usage, }, }); @@ -363,9 +468,82 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA } } + addAllHooksToHttpClient() { + try { + // To prevent duplicate hooks + this.removeAllHooksFromHttpClient(); + + // If the user wants to use hooks, but hasn't created an HTTPClient yet + const hasHooks = [ + this.beforeRequestHooks, + this.requestErrorHooks, + this.responseHooks, + ].some((hook) => hook && hook.length > 0); + if (hasHooks && !this.httpClient) { + this.httpClient = new MistralAIHTTPClient(); + } + + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.addHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.addHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.addHook("response", hook); + } + } + } catch { + throw new Error("Error in adding all hooks"); + } + } + + removeAllHooksFromHttpClient() { + try { + if (this.beforeRequestHooks) { + for (const hook of this.beforeRequestHooks) { + this.httpClient?.removeHook("beforeRequest", hook); + } + } + + if (this.requestErrorHooks) { + for (const hook of this.requestErrorHooks) { + this.httpClient?.removeHook("requestError", hook); + } + } + + if (this.responseHooks) { + for (const hook of this.responseHooks) { + this.httpClient?.removeHook("response", hook); + } + } + } catch { + throw new Error("Error in removing hooks"); + } + } + + removeHookFromHttpClient( + hook: BeforeRequestHook | RequestErrorHook | ResponseHook + ) { + try { + this.httpClient?.removeHook("beforeRequest", hook as BeforeRequestHook); + this.httpClient?.removeHook("requestError", hook as RequestErrorHook); + this.httpClient?.removeHook("response", hook as ResponseHook); + } catch { + throw new Error("Error in removing hook"); + } + } + /** @ignore */ private async imports() { - const { default: MistralClient } = await import("@mistralai/mistralai"); - return { MistralClient }; + const { Mistral } = await import("@mistralai/mistralai"); + return { Mistral }; } } diff --git a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts index 39e88c34a0db9..c50c8261dda73 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts @@ -5,14 +5,17 @@ import { z } from "zod"; import { AIMessage, AIMessageChunk, - BaseMessage, HumanMessage, + SystemMessage, ToolMessage, } from "@langchain/core/messages"; +import { ContentChunk as MistralAIContentChunk } from "@mistralai/mistralai/models/components/contentchunk.js"; +import { HTTPClient } from "@mistralai/mistralai/lib/http.js"; import { zodToJsonSchema } from "zod-to-json-schema"; import { ChatMistralAI } from "../chat_models.js"; +import { _mistralContentChunkToMessageContentComplex } from "../utils.js"; -test("Test ChatMistralAI can invoke", async () => { +test("Test ChatMistralAI can invoke hello", async () => { const model = new ChatMistralAI({ model: "mistral-tiny", }); @@ -80,19 +83,13 @@ test("Can call tools using structured tools", async () => { const chain = prompt.pipe(model); const response = await chain.invoke({}); - expect("tool_calls" in response.additional_kwargs).toBe(true); + expect("tool_calls" in response).toBe(true); // console.log(response.additional_kwargs.tool_calls?.[0]); - expect(response.additional_kwargs.tool_calls?.[0].function.name).toBe( - "calculator" - ); - expect( - JSON.parse( - response.additional_kwargs.tool_calls?.[0].function.arguments ?? "{}" - ).calculator - ).toBeDefined(); + expect(response.tool_calls?.[0].name).toBe("calculator"); + expect(response.tool_calls?.[0].args?.calculator).toBeDefined(); }); -test("Can call tools", async () => { +test("Can call tools using raw tools", async () => { const tools = [ { type: "function", @@ -130,20 +127,8 @@ test("Can call tools", async () => { const response = await chain.invoke({}); // console.log(response); expect(response.tool_calls?.length).toEqual(1); - expect(response.tool_calls?.[0].args).toEqual( - JSON.parse( - response.additional_kwargs.tool_calls?.[0].function.arguments ?? "{}" - ) - ); - expect("tool_calls" in response.additional_kwargs).toBe(true); - expect(response.additional_kwargs.tool_calls?.[0].function.name).toBe( - "calculator" - ); - expect( - JSON.parse( - response.additional_kwargs.tool_calls?.[0].function.arguments ?? "{}" - ).calculator - ).toBeDefined(); + expect(response.tool_calls?.[0].name).toBe("calculator"); + expect(response.tool_calls?.[0].args?.calculator).toBeDefined(); }); test("Can call .stream with tool calling", async () => { @@ -179,7 +164,7 @@ test("Can call .stream with tool calling", async () => { const chain = prompt.pipe(model); const response = await chain.stream({}); - let finalRes: BaseMessage | null = null; + let finalRes: AIMessageChunk | null = null; for await (const chunk of response) { // console.log(chunk); finalRes = chunk; @@ -188,16 +173,10 @@ test("Can call .stream with tool calling", async () => { throw new Error("No final response found"); } - expect("tool_calls" in finalRes.additional_kwargs).toBe(true); + expect("tool_calls" in finalRes).toBe(true); // console.log(finalRes.additional_kwargs.tool_calls?.[0]); - expect(finalRes.additional_kwargs.tool_calls?.[0].function.name).toBe( - "calculator" - ); - expect( - JSON.parse( - finalRes.additional_kwargs.tool_calls?.[0].function.arguments ?? "{}" - ).calculator - ).toBeDefined(); + expect(finalRes.tool_calls?.[0].name).toBe("calculator"); + expect(finalRes.tool_calls?.[0].args.calculator).toBeDefined(); }); test("Can use json mode response format", async () => { @@ -302,7 +281,7 @@ test("Can stream and concat responses for a complex tool", async () => { const chain = prompt.pipe(model); const response = await chain.stream({}); - let finalRes: BaseMessage[] = []; + let finalRes: AIMessageChunk[] = []; for await (const chunk of response) { // console.log(chunk); finalRes = finalRes.concat(chunk); @@ -311,11 +290,10 @@ test("Can stream and concat responses for a complex tool", async () => { throw new Error("No final response found"); } - expect(finalRes[0].additional_kwargs.tool_calls?.[0]).toBeDefined(); - const toolCall = finalRes[0].additional_kwargs.tool_calls?.[0]; - expect(toolCall?.function.name).toBe("person_traits"); - const args = JSON.parse(toolCall?.function.arguments ?? "{}"); - const { person } = args; + expect(finalRes[0].tool_calls?.[0]).toBeDefined(); + const toolCall = finalRes[0].tool_calls?.[0]; + expect(toolCall?.name).toBe("person_traits"); + const person = toolCall?.args?.person; expect(person).toBeDefined(); expect(person.name).toBeDefined(); expect(person.age).toBeDefined(); @@ -406,7 +384,7 @@ describe("withStructuredOutput", () => { ]); const chain = prompt.pipe(modelWithStructuredOutput); const result = await chain.invoke({}); - // console.log(result); + console.log(result); expect("operation" in result).toBe(true); expect("number1" in result).toBe(true); expect("number2" in result).toBe(true); @@ -609,28 +587,11 @@ describe("withStructuredOutput", () => { throw new Error("raw not in result"); } const { raw } = result as { raw: AIMessage }; - expect(raw.additional_kwargs.tool_calls?.length).toBeGreaterThan(0); - expect(raw.additional_kwargs.tool_calls?.[0].function.name).toBe( - "calculator" - ); - expect( - "operation" in - JSON.parse( - raw.additional_kwargs.tool_calls?.[0].function.arguments ?? "" - ) - ).toBe(true); - expect( - "number1" in - JSON.parse( - raw.additional_kwargs.tool_calls?.[0].function.arguments ?? "" - ) - ).toBe(true); - expect( - "number2" in - JSON.parse( - raw.additional_kwargs.tool_calls?.[0].function.arguments ?? "" - ) - ).toBe(true); + expect(raw.tool_calls?.length).toBeGreaterThan(0); + expect(raw.tool_calls?.[0].name).toBe("calculator"); + expect("operation" in (raw.tool_calls?.[0]?.args ?? {})).toBe(true); + expect("number1" in (raw.tool_calls?.[0]?.args ?? {})).toBe(true); + expect("number2" in (raw.tool_calls?.[0]?.args ?? {})).toBe(true); }); }); @@ -812,7 +773,7 @@ describe("codestral-latest", () => { expect(fullMessage.toLowerCase()).toContain("world"); }); - test("Can call tools using structured tools codestral-latest", async () => { + test("Can call tools using codestral-latest structured tools", async () => { class CodeSandbox extends StructuredTool { name = "code_sandbox"; @@ -850,16 +811,15 @@ describe("codestral-latest", () => { "Write a function that takes in a single argument and logs it to the console. Ensure the code is in Python.", }); // console.log(response); - expect("tool_calls" in response.additional_kwargs).toBe(true); - // console.log(response.additional_kwargs.tool_calls?.[0]); - if (!response.additional_kwargs.tool_calls?.[0]) { + expect("tool_calls" in response).toBe(true); + // console.log(response.tool_calls?.[0]); + if (!response.tool_calls?.[0]) { throw new Error("No tool call found"); } - const sandboxTool = response.additional_kwargs.tool_calls[0]; - expect(sandboxTool.function.name).toBe("code_sandbox"); - const parsedArgs = JSON.parse(sandboxTool.function.arguments); - expect(parsedArgs.code).toBeDefined(); - // console.log(parsedArgs.code); + const sandboxTool = response.tool_calls[0]; + expect(sandboxTool.name).toBe("code_sandbox"); + expect(sandboxTool.args?.code).toBeDefined(); + // console.log(sandboxTool.args?.code); }); }); @@ -957,3 +917,332 @@ test("withStructuredOutput will always force tool usage", async () => { const castMessage = response.raw as AIMessage; expect(castMessage.tool_calls).toHaveLength(1); }); + +test("Test ChatMistralAI can invoke with MessageContent input types", async () => { + const model = new ChatMistralAI({ + model: "pixtral-12b-2409", + }); + const messagesListContent = [ + new SystemMessage({ + content: "List the top 5 countries in Europe with the highest GDP", + }), + new HumanMessage({ + content: [ + { + type: "text", + text: "Here is an infographic with European GPDs", + }, + { + type: "image_url", + image_url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + ], + }), + ]; + const response = await model.invoke(messagesListContent); + console.log("response", response); + expect(response.content.length).toBeGreaterThan(1); +}); + +test("Mistral ContentChunk to MessageContentComplex conversion", () => { + const mistralMessages = [ + { + type: "text", + text: "Test message", + }, + { + type: "image_url", + imageUrl: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + { + type: "image_url", + imageUrl: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + detail: "high", + }, + }, + { + type: "image_url", + imageUrl: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + detail: "medium", + }, + }, + { + type: "image_url", + imageUrl: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + }, + ] as MistralAIContentChunk[]; + + expect(_mistralContentChunkToMessageContentComplex(mistralMessages)).toEqual([ + { + type: "text", + text: "Test message", + }, + { + type: "image_url", + image_url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + { + type: "image_url", + image_url: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + detail: "high", + }, + }, + { + type: "image_url", + image_url: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + }, + { + type: "image_url", + image_url: { + url: "https://mistral.ai/images/news/pixtral-12b/gdp.png", + }, + }, + ]); +}); + +test("Test ChatMistralAI can register BeforeRequestHook function", async () => { + const model = new ChatMistralAI({ + model: "mistral-tiny", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test ChatMistralAI can register RequestErrorHook function", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new ChatMistralAI({ + model: "mistral-tiny", + httpClient: customHttpClient, + maxRetries: 0, + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const RequestErrorHook = (): void => { + addCount(); + console.log("In request error hook"); + }; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await prompt.pipe(model).invoke({ + input: "Hello", + }); + } catch (e: unknown) { + // Intended error, do not rethrow + } + + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test ChatMistralAI can register ResponseHook function", async () => { + const model = new ChatMistralAI({ + model: "mistral-tiny", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const ResponseHook = (): void => { + addCount(); + }; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test ChatMistralAI can register multiple hook functions with success", async () => { + const model = new ChatMistralAI({ + model: "mistral-tiny", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test ChatMistralAI can register multiple hook functions with error", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new ChatMistralAI({ + model: "mistral-tiny", + httpClient: customHttpClient, + maxRetries: 0, + }); + + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const RequestErrorHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await prompt.pipe(model).invoke({ + input: "Hello", + }); + } catch (e: unknown) { + // Intended error, do not rethrow + } + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test ChatMistralAI can remove hook", async () => { + const model = new ChatMistralAI({ + model: "mistral-tiny", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(1); + + model.removeHookFromHttpClient(beforeRequestHook); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test ChatMistralAI can remove all hooks", async () => { + const model = new ChatMistralAI({ + model: "mistral-tiny", + }); + const prompt = ChatPromptTemplate.fromMessages([ + ["system", "You are a helpful assistant"], + ["human", "{input}"], + ]); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(2); + + model.removeAllHooksFromHttpClient(); + + await prompt.pipe(model).invoke({ + input: "Hello", + }); + // console.log(count); + expect(count).toEqual(2); +}); diff --git a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts index a17ff1085925a..985cc5650b8d9 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts @@ -24,11 +24,11 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests< }); } - async testCacheComplexMessageTypes() { + async testToolMessageHistoriesListContent() { this.skipTestMessage( - "testCacheComplexMessageTypes", + "testToolMessageHistoriesListContent", "ChatMistralAI", - "Complex message types not properly implemented" + "tool_use message blocks not supported" ); } } diff --git a/libs/langchain-mistralai/src/tests/chat_models.test.ts b/libs/langchain-mistralai/src/tests/chat_models.test.ts index 5577c0f1be918..1d1a894ce1969 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.test.ts @@ -2,6 +2,7 @@ import { ChatMistralAI } from "../chat_models.js"; import { _isValidMistralToolCallId, _convertToolCallIdToMistralCompatible, + _mistralContentChunkToMessageContentComplex, } from "../utils.js"; describe("Mistral Tool Call ID Conversion", () => { diff --git a/libs/langchain-mistralai/src/tests/embeddings.int.test.ts b/libs/langchain-mistralai/src/tests/embeddings.int.test.ts index 103218b5de332..3f2805bbe51e6 100644 --- a/libs/langchain-mistralai/src/tests/embeddings.int.test.ts +++ b/libs/langchain-mistralai/src/tests/embeddings.int.test.ts @@ -1,4 +1,5 @@ import { test } from "@jest/globals"; +import { HTTPClient } from "@mistralai/mistralai/lib/http.js"; import { MistralAIEmbeddings } from "../embeddings.js"; test("Test MistralAIEmbeddings can embed query", async () => { @@ -21,3 +22,196 @@ test("Test MistralAIEmbeddings can embed documents", async () => { expect(embeddings[0].length).toBe(1024); expect(embeddings[1].length).toBe(1024); }); + +test("Test MistralAIEmbeddings can register BeforeRequestHook function", async () => { + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAIEmbeddings can register RequestErrorHook function", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + httpClient: customHttpClient, + maxRetries: 0, + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const RequestErrorHook = (): void => { + addCount(); + console.log("In request error hook"); + }; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await model.embedQuery("Hello"); + } catch (e: unknown) { + // Intended error, do not rethrow + } + + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAIEmbeddings can register ResponseHook function", async () => { + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const ResponseHook = (): void => { + addCount(); + }; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAIEmbeddings can register multiple hook functions with success", async () => { + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test MistralAIEmbeddings can register multiple hook functions with error", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + httpClient: customHttpClient, + maxRetries: 0, + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const RequestErrorHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await model.embedQuery("Hello"); + } catch (e: unknown) { + // Intended error, do not rethrow + } + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test MistralAIEmbeddings can remove hook", async () => { + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(1); + + model.removeHookFromHttpClient(beforeRequestHook); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAIEmbeddings can remove all hooks", async () => { + const model = new MistralAIEmbeddings({ + model: "mistral-embed", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(2); + + model.removeAllHooksFromHttpClient(); + + await model.embedQuery("Hello"); + // console.log(count); + expect(count).toEqual(2); +}); diff --git a/libs/langchain-mistralai/src/tests/llms.int.test.ts b/libs/langchain-mistralai/src/tests/llms.int.test.ts index 6fc263a0eb000..5317b2d760ce3 100644 --- a/libs/langchain-mistralai/src/tests/llms.int.test.ts +++ b/libs/langchain-mistralai/src/tests/llms.int.test.ts @@ -2,12 +2,13 @@ import { test, expect } from "@jest/globals"; import { CallbackManager } from "@langchain/core/callbacks/manager"; +import { HTTPClient } from "@mistralai/mistralai/lib/http.js"; import { MistralAI } from "../llms.js"; // Save the original value of the 'LANGCHAIN_CALLBACKS_BACKGROUND' environment variable const originalBackground = process.env.LANGCHAIN_CALLBACKS_BACKGROUND; -test("Test MistralAI", async () => { +test("Test MistralAI default", async () => { const model = new MistralAI({ maxTokens: 5, model: "codestral-latest", @@ -173,3 +174,196 @@ test("Test MistralAI stream method with early break", async () => { } expect(i).toBeGreaterThan(5); }); + +test("Test MistralAI can register BeforeRequestHook function", async () => { + const model = new MistralAI({ + model: "codestral-latest", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: ."); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAI can register RequestErrorHook function", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new MistralAI({ + model: "codestral-latest", + httpClient: customHttpClient, + maxRetries: 0, + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const RequestErrorHook = (): void => { + addCount(); + console.log("In request error hook"); + }; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await model.invoke("Log 'Hello world' to the console in javascript: ."); + } catch (e: unknown) { + // Intended error, do not rethrow + } + + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAI can register ResponseHook function", async () => { + const model = new MistralAI({ + model: "codestral-latest", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const ResponseHook = (): void => { + addCount(); + }; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: ."); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAI can register multiple hook functions with success", async () => { + const model = new MistralAI({ + model: "codestral-latest", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: "); + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test MistralAI can register multiple hook functions with error", async () => { + const fetcher = (): Promise => + Promise.reject(new Error("Intended fetcher error")); + const customHttpClient = new HTTPClient({ fetcher }); + + const model = new MistralAI({ + model: "codestral-latest", + httpClient: customHttpClient, + maxRetries: 0, + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const RequestErrorHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.requestErrorHooks = [RequestErrorHook]; + model.addAllHooksToHttpClient(); + + try { + await model.invoke("Log 'Hello world' to the console in javascript: "); + } catch (e: unknown) { + // Intended error, do not rethrow + } + // console.log(count); + expect(count).toEqual(2); +}); + +test("Test MistralAI can remove hook", async () => { + const model = new MistralAI({ + model: "codestral-latest", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.addAllHooksToHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: "); + // console.log(count); + expect(count).toEqual(1); + + model.removeHookFromHttpClient(beforeRequestHook); + + await model.invoke("Log 'Hello world' to the console in javascript: "); + // console.log(count); + expect(count).toEqual(1); +}); + +test("Test MistralAI can remove all hooks", async () => { + const model = new MistralAI({ + model: "codestral-latest", + }); + + let count = 0; + const addCount = () => { + count += 1; + }; + + const beforeRequestHook = (): void => { + addCount(); + }; + const ResponseHook = (): void => { + addCount(); + }; + model.beforeRequestHooks = [beforeRequestHook]; + model.responseHooks = [ResponseHook]; + model.addAllHooksToHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: "); + // console.log(count); + expect(count).toEqual(2); + + model.removeAllHooksFromHttpClient(); + + await model.invoke("Log 'Hello world' to the console in javascript: "); + // console.log(count); + expect(count).toEqual(2); +}); diff --git a/libs/langchain-mistralai/src/utils.ts b/libs/langchain-mistralai/src/utils.ts index 193efb570555d..9c94b6558602b 100644 --- a/libs/langchain-mistralai/src/utils.ts +++ b/libs/langchain-mistralai/src/utils.ts @@ -1,3 +1,6 @@ +import { ContentChunk as MistralAIContentChunk } from "@mistralai/mistralai/models/components/contentchunk.js"; +import { MessageContentComplex } from "@langchain/core/messages"; + // Mistral enforces a specific pattern for tool call IDs const TOOL_CALL_ID_PATTERN = /^[a-zA-Z0-9]{9}$/; @@ -44,3 +47,40 @@ export function _convertToolCallIdToMistralCompatible( } } } + +export function _mistralContentChunkToMessageContentComplex( + content: string | MistralAIContentChunk[] | null | undefined +): string | MessageContentComplex[] { + if (!content) { + return ""; + } + if (typeof content === "string") { + return content; + } + return content.map((contentChunk) => { + // Only Mistral ImageURLChunks need conversion to MessageContentComplex + if (contentChunk.type === "image_url") { + if ( + typeof contentChunk.imageUrl !== "string" && + contentChunk.imageUrl?.detail + ) { + const { detail } = contentChunk.imageUrl; + // Mistral detail can be any string, but MessageContentComplex only supports + // detail to be "high", "auto", or "low" + if (detail !== "high" && detail !== "auto" && detail !== "low") { + return { + type: contentChunk.type, + image_url: { + url: contentChunk.imageUrl.url, + }, + }; + } + } + return { + type: contentChunk.type, + image_url: contentChunk.imageUrl, + }; + } + return contentChunk; + }); +} diff --git a/yarn.lock b/yarn.lock index ee400067548ae..5fdd9da4ebeb4 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12375,7 +12375,7 @@ __metadata: "@langchain/core": "workspace:*" "@langchain/scripts": ">=0.1.0 <0.2.0" "@langchain/standard-tests": 0.0.0 - "@mistralai/mistralai": ^0.4.0 + "@mistralai/mistralai": ^1.3.1 "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 "@tsconfig/recommended": ^1.0.3 @@ -12397,10 +12397,10 @@ __metadata: ts-jest: ^29.1.0 typescript: <5.2.0 uuid: ^10.0.0 - zod: ^3.22.4 + zod: ^3.23.8 zod-to-json-schema: ^3.22.4 peerDependencies: - "@langchain/core": ">=0.2.21 <0.4.0" + "@langchain/core": ">=0.3.7 <0.4.0" languageName: unknown linkType: soft @@ -13150,12 +13150,12 @@ __metadata: languageName: node linkType: hard -"@mistralai/mistralai@npm:^0.4.0": - version: 0.4.0 - resolution: "@mistralai/mistralai@npm:0.4.0" - dependencies: - node-fetch: ^2.6.7 - checksum: 1b03fc0b55164c02e5fb29fb2d09ebe4ad44346fc313f7fb3ab09e48f73f975763d1ac9654098d433ea17d7caa20654b2b15510822276acc9fa46db461a254a6 +"@mistralai/mistralai@npm:^1.3.1": + version: 1.3.1 + resolution: "@mistralai/mistralai@npm:1.3.1" + peerDependencies: + zod: ">= 3" + checksum: 9e31a2f760706a9f54347ba2cb2b7784d4f93eb4ff5d87cc7cfac9b7a1a1816f21da2328f5f5e13c11ed8953f1d71f2a2e09d12123ac17d171c189d21b87a977 languageName: node linkType: hard From a3cd5935e4b52335e29270f39ece7d26d023a29c Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 21 Nov 2024 20:52:47 -0800 Subject: [PATCH 06/27] chore(mistral): Release 0.2.0 (#7239) --- libs/langchain-mistralai/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-mistralai/package.json b/libs/langchain-mistralai/package.json index 7876ea753710f..d2235871d8c93 100644 --- a/libs/langchain-mistralai/package.json +++ b/libs/langchain-mistralai/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/mistralai", - "version": "0.1.1", + "version": "0.2.0", "description": "MistralAI integration for LangChain.js", "type": "module", "engines": { From 9b70c5e16ace0237736125bf69023a296457b35a Mon Sep 17 00:00:00 2001 From: crisjy Date: Fri, 22 Nov 2024 13:14:42 +0800 Subject: [PATCH 07/27] feat(azurecosmosdb): Vector Store Add DiskANN index for CosmosDB (#7225) Co-authored-by: jacoblee93 --- .../src/azure_cosmosdb_mongodb.ts | 70 ++++++++++++++----- 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_mongodb.ts b/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_mongodb.ts index 805d7417e9c05..b21fc4b85c791 100644 --- a/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_mongodb.ts +++ b/libs/langchain-azure-cosmosdb/src/azure_cosmosdb_mongodb.ts @@ -34,7 +34,7 @@ export type AzureCosmosDBMongoDBIndexOptions = { /** Skips automatic index creation. */ readonly skipCreate?: boolean; - readonly indexType?: "ivf" | "hnsw"; + readonly indexType?: "ivf" | "hnsw" | "diskann"; /** Number of clusters that the inverted file (IVF) index uses to group the vector data. */ readonly numLists?: number; /** Number of dimensions for vector similarity. */ @@ -45,6 +45,12 @@ export type AzureCosmosDBMongoDBIndexOptions = { readonly m?: number; /** The size of the dynamic candidate list for constructing the graph with the HNSW index. */ readonly efConstruction?: number; + /** Max number of neighbors withe the Diskann idnex */ + readonly maxDegree?: number; + /** L value for index building withe the Diskann idnex */ + readonly lBuild?: number; + /** L value for index searching withe the Diskann idnex */ + readonly lSearch?: number; }; /** Azure Cosmos DB for MongoDB vCore delete Parameters. */ @@ -234,7 +240,7 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { */ async createIndex( dimensions: number | undefined = undefined, - indexType: "ivf" | "hnsw" = "ivf", + indexType: "ivf" | "hnsw" | "diskann" = "ivf", similarity: AzureCosmosDBMongoDBSimilarityType = AzureCosmosDBMongoDBSimilarityType.COS ): Promise { await this.connectPromise; @@ -246,23 +252,36 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { vectorLength = queryEmbedding.length; } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const cosmosSearchOptions: any = { + kind: "", + similarity, + dimensions: vectorLength, + }; + + if (indexType === "hnsw") { + cosmosSearchOptions.kind = "vector-hnsw"; + cosmosSearchOptions.m = this.indexOptions.m ?? 16; + cosmosSearchOptions.efConstruction = + this.indexOptions.efConstruction ?? 200; + } else if (indexType === "diskann") { + cosmosSearchOptions.kind = "vector-diskann"; + cosmosSearchOptions.maxDegree = this.indexOptions.maxDegree ?? 40; + cosmosSearchOptions.lBuild = this.indexOptions.lBuild ?? 50; + cosmosSearchOptions.lSearch = this.indexOptions.lSearch ?? 40; + /** Default to IVF index */ + } else { + cosmosSearchOptions.kind = "vector-ivf"; + cosmosSearchOptions.numLists = this.indexOptions.numLists ?? 100; + } + const createIndexCommands = { createIndexes: this.collection.collectionName, indexes: [ { name: this.indexName, key: { [this.embeddingKey]: "cosmosSearch" }, - cosmosSearchOptions: { - kind: indexType === "hnsw" ? "vector-hnsw" : "vector-ivf", - ...(indexType === "hnsw" - ? { - m: this.indexOptions.m ?? 16, - efConstruction: this.indexOptions.efConstruction ?? 200, - } - : { numLists: this.indexOptions.numLists ?? 100 }), - similarity, - dimensions: vectorLength, - }, + cosmosSearchOptions, }, ], }; @@ -357,7 +376,8 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { */ async similaritySearchVectorWithScore( queryVector: number[], - k = 4 + k: number, + indexType?: "ivf" | "hnsw" | "diskann" ): Promise<[Document, number][]> { await this.initialize(); @@ -367,7 +387,10 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { cosmosSearch: { vector: queryVector, path: this.embeddingKey, - k, + k: k ?? 4, + ...(indexType === "diskann" + ? { lSearch: this.indexOptions.lSearch ?? 40 } + : {}), }, returnStoredSource: true, }, @@ -406,13 +429,26 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { async maxMarginalRelevanceSearch( query: string, options: MaxMarginalRelevanceSearchOptions + ): Promise; + + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions, + indexType: "ivf" | "hnsw" | "diskann" + ): Promise; + + async maxMarginalRelevanceSearch( + query: string, + options: MaxMarginalRelevanceSearchOptions, + indexType?: "ivf" | "hnsw" | "diskann" ): Promise { const { k, fetchK = 20, lambda = 0.5 } = options; const queryEmbedding = await this.embeddings.embedQuery(query); const docs = await this.similaritySearchVectorWithScore( queryEmbedding, - fetchK + fetchK, + indexType ); const embeddingList = docs.map((doc) => doc[0].metadata[this.embeddingKey]); @@ -449,7 +485,7 @@ export class AzureCosmosDBMongoDBVectorStore extends VectorStore { // Unless skipCreate is set, create the index // This operation is no-op if the index already exists if (!this.indexOptions.skipCreate) { - const indexType = this.indexOptions.indexType === "hnsw" ? "hnsw" : "ivf"; + const indexType = this.indexOptions.indexType || "ivf"; await this.createIndex( this.indexOptions.dimensions, indexType, From fc6b925b52cf9f3228c4577c8c2358188914a453 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Thu, 21 Nov 2024 21:20:14 -0800 Subject: [PATCH 08/27] chore(azure-cosmosdb): Release 0.2.3 (#7241) --- libs/langchain-azure-cosmosdb/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-azure-cosmosdb/package.json b/libs/langchain-azure-cosmosdb/package.json index df04442e5ed73..7cd8a1cd41018 100644 --- a/libs/langchain-azure-cosmosdb/package.json +++ b/libs/langchain-azure-cosmosdb/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/azure-cosmosdb", - "version": "0.2.2", + "version": "0.2.3", "description": "Azure CosmosDB integration for LangChain.js", "type": "module", "engines": { From 7fd5667bc3599c22ece7af47b969c42160f4bd72 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Fri, 22 Nov 2024 11:47:18 -0800 Subject: [PATCH 09/27] fix(community): bedrock parsing array content/tool blocks (#7244) --- .../src/utils/bedrock/anthropic.ts | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/libs/langchain-community/src/utils/bedrock/anthropic.ts b/libs/langchain-community/src/utils/bedrock/anthropic.ts index 3f440bd2b0142..e8c2ab16306ba 100644 --- a/libs/langchain-community/src/utils/bedrock/anthropic.ts +++ b/libs/langchain-community/src/utils/bedrock/anthropic.ts @@ -121,7 +121,7 @@ function _formatContent(content: MessageContent) { if (typeof content === "string") { return content; } else { - const contentBlocks = content.map((contentPart) => { + const contentBlocks = content.flatMap((contentPart) => { if (contentPart.type === "image_url") { let source; if (typeof contentPart.image_url === "string") { @@ -133,7 +133,13 @@ function _formatContent(content: MessageContent) { type: "image" as const, // Explicitly setting the type as "image" source, }; - } else if (contentPart.type === "text") { + } else if ( + contentPart.type === "text" || + contentPart.type === "text_delta" + ) { + if (contentPart.text === "") { + return []; + } // Assuming contentPart is of type MessageContentText here return { type: "text" as const, // Explicitly setting the type as "text" @@ -148,6 +154,8 @@ function _formatContent(content: MessageContent) { ...contentPart, // eslint-disable-next-line @typescript-eslint/no-explicit-any } as any; + } else if (contentPart.type === "input_json_delta") { + return []; } else { throw new Error("Unsupported message content format"); } @@ -204,21 +212,20 @@ export function formatMessagesForAnthropic(messages: BaseMessage[]): { }; } } else { - const { content } = message; - const hasMismatchedToolCalls = !message.tool_calls.every((toolCall) => - content.find( - (contentPart) => - contentPart.type === "tool_use" && contentPart.id === toolCall.id - ) - ); - if (hasMismatchedToolCalls) { - console.warn( - `The "tool_calls" field on a message is only respected if content is a string.` + const formattedContent = _formatContent(message.content); + if (Array.isArray(formattedContent)) { + const formattedToolsContent = message.tool_calls.map( + _convertLangChainToolCallToAnthropic ); + return { + role, + content: [...formattedContent, ...formattedToolsContent], + }; } + return { role, - content: _formatContent(message.content), + content: formattedContent, }; } } else { From cb2c42c4d49a9492ba023dd2d954769579dc9cba Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Fri, 22 Nov 2024 15:04:47 -0800 Subject: [PATCH 10/27] fix(ci): Fix release workspace (#7245) --- release_workspace.js | 131 +++++++++++++++++++++++++++++-------------- 1 file changed, 90 insertions(+), 41 deletions(-) diff --git a/release_workspace.js b/release_workspace.js index 9e127961a1e74..419cbd5fa9d63 100644 --- a/release_workspace.js +++ b/release_workspace.js @@ -9,6 +9,26 @@ const semver = require("semver"); const RELEASE_BRANCH = "release"; const MAIN_BRANCH = "main"; +/** + * Handles execSync errors and logs them in a readable format. + * @param {string} command + * @param {{ doNotExit?: boolean }} [options] - Optional configuration + * @param {boolean} [options.doNotExit] - Whether or not to exit the process on error + */ +function execSyncWithErrorHandling(command, options = {}) { + try { + execSync( + command, + { stdio: "inherit" } // This will stream output in real-time + ); + } catch (error) { + console.error(error.message); + if (!options.doNotExit) { + process.exit(1); + } + } +} + /** * Get the version of a workspace inside a directory. * @@ -131,33 +151,22 @@ async function runYarnRelease(packageDirectory, npm2FACode, tag) { console.log(`Running command: "yarn ${args.join(" ")}"`); - const yarnReleaseProcess = spawn("yarn", args, { cwd: workingDirectory }); - - let stdout = ""; - let stderr = ""; - - yarnReleaseProcess.stdout.on("data", (data) => { - stdout += data; - // Still show output in real-time - process.stdout.write(data); - }); - - yarnReleaseProcess.stderr.on("data", (data) => { - stderr += data; - // Still show errors in real-time - process.stderr.write(data); + // Use 'inherit' for stdio to allow direct CLI interaction + const yarnReleaseProcess = spawn("yarn", args, { + stdio: "inherit", + cwd: workingDirectory, }); yarnReleaseProcess.on("close", (code) => { if (code === 0) { resolve(); } else { - reject(`Process exited with code ${code}.\nError: ${stderr}`); + reject(`Process exited with code ${code}`); } }); yarnReleaseProcess.on("error", (err) => { - reject(`Failed to start process: ${err.message}\nError: ${stderr}`); + reject(`Failed to start process: ${err.message}`); }); }); } @@ -194,7 +203,7 @@ function bumpDeps( console.log( "Updated version is not greater than the pre-release version. Pulling from github and checking again." ); - execSync(`git pull origin ${RELEASE_BRANCH}`); + execSyncWithErrorHandling(`git pull origin ${RELEASE_BRANCH}`); updatedWorkspaceVersion = getWorkspaceVersion(workspaceDirectory); if (!semver.gt(updatedWorkspaceVersion, preReleaseVersion)) { console.warn( @@ -213,10 +222,10 @@ function bumpDeps( versionString = `${updatedWorkspaceVersion}-${tag}`; } - execSync(`git checkout ${MAIN_BRANCH}`); + execSyncWithErrorHandling(`git checkout ${MAIN_BRANCH}`); const newBranchName = `bump-${workspaceName}-to-${versionString}`; console.log(`Checking out new branch: ${newBranchName}`); - execSync(`git checkout -b ${newBranchName}`); + execSyncWithErrorHandling(`git checkout -b ${newBranchName}`); const allWorkspacesWhichDependOn = allWorkspaces.filter(({ packageJSON }) => Object.keys(packageJSON.dependencies ?? {}).includes(workspaceName) @@ -268,7 +277,7 @@ Workspaces: console.log("Updated package.json's! Running yarn install."); try { - execSync(`yarn install`); + execSyncWithErrorHandling(`yarn install`); } catch (_) { console.log( "Yarn install failed. Likely because NPM has not finished publishing the new version. Continuing." @@ -277,12 +286,12 @@ Workspaces: // Add all current changes, commit, push and log branch URL. console.log("Adding and committing all changes."); - execSync(`git add -A`); - execSync( + execSyncWithErrorHandling(`git add -A`); + execSyncWithErrorHandling( `git commit -m "all[minor]: bump deps on ${workspaceName} to ${versionString}"` ); console.log("Pushing changes."); - execSync(`git push -u origin ${newBranchName}`); + execSyncWithErrorHandling(`git push -u origin ${newBranchName}`); console.log( "🔗 Open %s and merge the bump-deps PR.", `\x1b[34mhttps://github.com/langchain-ai/langchainjs/compare/${newBranchName}?expand=1\x1b[0m` @@ -299,7 +308,8 @@ Workspaces: * @param {string} version */ function createCommitMessage(workspaceName, version) { - return `release(${workspaceName}): ${version}`; + const cleanedWorkspaceName = workspaceName.replace("@langchain/", ""); + return `release(${cleanedWorkspaceName}): ${version}`; } /** @@ -307,15 +317,28 @@ function createCommitMessage(workspaceName, version) { * * @param {string} workspaceName The name of the workspace being released * @param {string} version The new version being released + * @param {boolean} onlyPush Whether or not to only push the changes, and not commit * @returns {void} */ -function commitAndPushChanges(workspaceName, version) { - console.log("Committing and pushing changes..."); - const commitMsg = createCommitMessage(workspaceName, version); - execSync("git add -A"); - execSync(`git commit -m "${commitMsg}"`); +function commitAndPushChanges(workspaceName, version, onlyPush) { + if (!onlyPush) { + console.log("Committing changes..."); + const commitMsg = createCommitMessage(workspaceName, version); + try { + execSyncWithErrorHandling("git add -A", { doNotExit: true }); + execSyncWithErrorHandling(`git commit -m "${commitMsg}"`, { + doNotExit: true, + }); + } catch (_) { + // No-op. Likely erroring because there are no unstaged changes. + } + } + + console.log("Pushing changes..."); // Pushes to the current branch - execSync("git push -u origin $(git rev-parse --abbrev-ref HEAD)"); + execSyncWithErrorHandling( + "git push -u origin $(git rev-parse --abbrev-ref HEAD)" + ); console.log("Successfully committed and pushed changes."); } @@ -330,8 +353,8 @@ function checkoutReleaseBranch() { const currentBranch = execSync("git branch --show-current").toString().trim(); if (currentBranch === MAIN_BRANCH || currentBranch === RELEASE_BRANCH) { console.log(`Checking out '${RELEASE_BRANCH}' branch.`); - execSync(`git checkout -B ${RELEASE_BRANCH}`); - execSync(`git push -u origin ${RELEASE_BRANCH}`); + execSyncWithErrorHandling(`git checkout -B ${RELEASE_BRANCH}`); + execSyncWithErrorHandling(`git push -u origin ${RELEASE_BRANCH}`); } else { throw new Error( `Current branch is not ${MAIN_BRANCH} or ${RELEASE_BRANCH}. Current branch: ${currentBranch}` @@ -361,15 +384,34 @@ async function getUserInput(question) { } /** - * Checks if there are any uncommitted changes in the git repository + * Checks if there are any uncommitted changes in the git repository. * * @returns {boolean} True if there are uncommitted changes, false otherwise */ function hasUncommittedChanges() { try { - // This command returns empty string if no changes, or a string with changes if there are any - const output = execSync("git status --porcelain").toString(); - return output.length > 0; + // Check for uncommitted changes (both staged and unstaged) + const uncommittedOutput = execSync("git status --porcelain").toString(); + + return uncommittedOutput.length > 0; + } catch (error) { + console.error("Error checking git status:", error); + // If we can't check, better to assume there are changes + return true; + } +} + +/** + * Checks if there are any staged commits in the git repository. + * + * @returns {boolean} True if there are staged changes, false otherwise + */ +function hasStagedChanges() { + try { + // Check for staged but unpushed changes + const unPushedOutput = execSync("git log '@{u}..'").toString(); + + return unPushedOutput.length > 0; } catch (error) { console.error("Error checking git status:", error); // If we can't check, better to assume there are changes @@ -419,7 +461,7 @@ async function main() { // Run build, lint, tests console.log("Running build, lint, and tests."); - execSync( + execSyncWithErrorHandling( `yarn turbo:command run --filter ${options.workspace} build lint test --concurrency 1` ); console.log("Successfully ran build, lint, and tests."); @@ -433,9 +475,13 @@ async function main() { // Run `release-it` on workspace await runYarnRelease(matchingWorkspace.dir, npm2FACode, options.tag); - if (hasUncommittedChanges()) { + const hasStaged = hasStagedChanges(); + const hasUnCommitted = hasUncommittedChanges(); + if (hasStaged || hasUnCommitted) { const updatedVersion = getWorkspaceVersion(matchingWorkspace.dir); - commitAndPushChanges(options.workspace, updatedVersion); + // Only push and do not commit if there are staged changes and no uncommitted changes + const onlyPush = hasStaged && !hasUnCommitted; + commitAndPushChanges(options.workspace, updatedVersion, onlyPush); } // Log release branch URL @@ -458,4 +504,7 @@ async function main() { } } -main(); +main().catch((error) => { + console.error(error); + process.exit(1); +}); From 4d2bd6363790753a58aae4afd1ff7e4092871812 Mon Sep 17 00:00:00 2001 From: Felipe Martins Diel <41558831+felipediel@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:15:32 -0300 Subject: [PATCH 11/27] feat(community): Incorporate BM25 score in the results (#7236) Co-authored-by: Jacob Lee --- .../src/retrievers/bm25.ts | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/libs/langchain-community/src/retrievers/bm25.ts b/libs/langchain-community/src/retrievers/bm25.ts index dfc04709cba19..716027cf684bc 100644 --- a/libs/langchain-community/src/retrievers/bm25.ts +++ b/libs/langchain-community/src/retrievers/bm25.ts @@ -6,6 +6,7 @@ import { BM25 } from "../utils/@furkantoprak/bm25/BM25.js"; export type BM25RetrieverOptions = { docs: Document[]; k: number; + includeScore?: boolean; } & BaseRetrieverInput; /** @@ -14,6 +15,8 @@ export type BM25RetrieverOptions = { * The k parameter determines the number of documents to return for each query. */ export class BM25Retriever extends BaseRetriever { + includeScore = false; + static lc_name() { return "BM25Retriever"; } @@ -35,6 +38,7 @@ export class BM25Retriever extends BaseRetriever { super(options); this.docs = options.docs; this.k = options.k; + this.includeScore = options.includeScore ?? this.includeScore; } private preprocessFunc(text: string): string[] { @@ -53,6 +57,19 @@ export class BM25Retriever extends BaseRetriever { scoredDocs.sort((a, b) => b.score - a.score); - return scoredDocs.slice(0, this.k).map((item) => item.document); + return scoredDocs.slice(0, this.k).map((item) => { + if (this.includeScore) { + return new Document({ + ...(item.document.id && { id: item.document.id }), + pageContent: item.document.pageContent, + metadata: { + bm25Score: item.score, + ...item.document.metadata, + }, + }); + } else { + return item.document; + } + }); } } From de7bcda54a32f97438f3bdd3dbc347fbcd10da12 Mon Sep 17 00:00:00 2001 From: Filip Michalsky <31483888+filip-michalsky@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:28:50 -0500 Subject: [PATCH 12/27] feat(community): Stagehand tools integration (#7177) Co-authored-by: jacoblee93 --- .../docs/integrations/tools/stagehand.mdx | 139 +++++++ examples/package.json | 1 + .../src/agents/stagehand_ai_web_browser.ts | 87 ++++ libs/langchain-community/langchain.config.js | 2 + libs/langchain-community/package.json | 5 + .../src/agents/toolkits/stagehand.ts | 177 ++++++++ .../toolkits/tests/stagehand.int.test.ts | 242 +++++++++++ yarn.lock | 381 ++++++++++++++++++ 8 files changed, 1034 insertions(+) create mode 100644 docs/core_docs/docs/integrations/tools/stagehand.mdx create mode 100644 examples/src/agents/stagehand_ai_web_browser.ts create mode 100644 libs/langchain-community/src/agents/toolkits/stagehand.ts create mode 100644 libs/langchain-community/src/agents/toolkits/tests/stagehand.int.test.ts diff --git a/docs/core_docs/docs/integrations/tools/stagehand.mdx b/docs/core_docs/docs/integrations/tools/stagehand.mdx new file mode 100644 index 0000000000000..389181067593c --- /dev/null +++ b/docs/core_docs/docs/integrations/tools/stagehand.mdx @@ -0,0 +1,139 @@ +--- +sidebar_label: Stagehand AI Web Automation Toolkit +hide_table_of_contents: true +--- + +import CodeBlock from "@theme/CodeBlock"; +import Example from "@examples/agents/stagehand_ai_web_browser.ts"; +import { StagehandToolkit } from "@langchain/community/agents/toolkits/stagehand"; + +# Stagehand Toolkit + +The Stagehand Toolkit equips your AI agent with the following capabilities: + +- **navigate()**: Navigate to a specific URL. +- **act()**: Perform browser automation actions like clicking, typing, and navigation. +- **extract()**: Extract structured data from web pages using Zod schemas. +- **observe()**: Get a list of possible actions and elements on the current page. + +## Setup + +1. Install the required packages: + +```bash +npm install @langchain/langgraph @langchain/community @langchain/core +``` + +2. Create a Stagehand Instance + If you plan to run the browser locally, you'll also need to install Playwright's browser dependencies. + +```bash +npx playwright install +``` + +3. Set up your model provider credentials: + +For OpenAI: + +```bash +export OPENAI_API_KEY="your-openai-api-key" +``` + +For Anthropic: + +```bash +export ANTHROPIC_API_KEY="your-anthropic-api-key" +``` + +## Usage, Standalone, Local Browser + +```typescript +import { StagehandToolkit } from "langchain/community/agents/toolkits/stagehand"; +import { ChatOpenAI } from "@langchain/openai"; +import { Stagehand } from "@browserbasehq/stagehand"; + +// Specify your Browserbase credentials. +process.env.BROWSERBASE_API_KEY = ""; +process.env.BROWSERBASE_PROJECT_ID = ""; + +// Specify OpenAI API key. +process.env.OPENAI_API_KEY = ""; + +const stagehand = new Stagehand({ + env: "LOCAL", + headless: false, + verbose: 2, + debugDom: true, + enableCaching: false, +}); + +// Create a Stagehand Toolkit with all the available actions from the Stagehand. +const stagehandToolkit = await StagehandToolkit.fromStagehand(stagehand); + +const navigateTool = stagehandToolkit.tools.find( + (t) => t.name === "stagehand_navigate" +); +if (!navigateTool) { + throw new Error("Navigate tool not found"); +} +await navigateTool.invoke("https://www.google.com"); + +const actionTool = stagehandToolkit.tools.find( + (t) => t.name === "stagehand_act" +); +if (!actionTool) { + throw new Error("Action tool not found"); +} +await actionTool.invoke('Search for "OpenAI"'); + +const observeTool = stagehandToolkit.tools.find( + (t) => t.name === "stagehand_observe" +); +if (!observeTool) { + throw new Error("Observe tool not found"); +} +const result = await observeTool.invoke( + "What actions can be performed on the current page?" +); +const observations = JSON.parse(result); + +// Handle observations as needed +console.log(observations); + +const currentUrl = stagehand.page.url(); +expect(currentUrl).toContain("google.com/search?q=OpenAI"); +``` + +## Usage with LangGraph Agents + +{Example} + +## Usage on Browserbase - remote headless browser + +If you want to run the browser remotely, you can use the Browserbase platform. + +You need to set the `BROWSERBASE_API_KEY` environment variable to your Browserbase API key. + +```bash +export BROWSERBASE_API_KEY="your-browserbase-api-key" +``` + +You also need to set `BROWSERBASE_PROJECT_ID` to your Browserbase project ID. + +```bash +export BROWSERBASE_PROJECT_ID="your-browserbase-project-id" +``` + +Then initialize the Stagehand instance with the `BROWSERBASE` environment. + +```typescript +const stagehand = new Stagehand({ + env: "BROWSERBASE", +}); +``` + +## Related + +- Tool [conceptual guide](/docs/concepts/tools) +- Tool [how-to guides](/docs/how_to/#tools) +- [Stagehand Documentation](https://github.com/browserbase/stagehand#readme) diff --git a/examples/package.json b/examples/package.json index 75ab6e9eb4d65..792c866804fb5 100644 --- a/examples/package.json +++ b/examples/package.json @@ -25,6 +25,7 @@ "license": "MIT", "dependencies": { "@azure/identity": "^4.2.1", + "@browserbasehq/stagehand": "^1.3.0", "@clickhouse/client": "^0.2.5", "@elastic/elasticsearch": "^8.4.0", "@faker-js/faker": "^8.4.1", diff --git a/examples/src/agents/stagehand_ai_web_browser.ts b/examples/src/agents/stagehand_ai_web_browser.ts new file mode 100644 index 0000000000000..7797baa080f1e --- /dev/null +++ b/examples/src/agents/stagehand_ai_web_browser.ts @@ -0,0 +1,87 @@ +import { Stagehand } from "@browserbasehq/stagehand"; +import { + StagehandActTool, + StagehandNavigateTool, +} from "@langchain/community/agents/toolkits/stagehand"; +import { ChatOpenAI } from "@langchain/openai"; +import { createReactAgent } from "@langchain/langgraph/prebuilt"; + +async function main() { + // Initialize Stagehand once and pass it to the tools + const stagehand = new Stagehand({ + env: "LOCAL", + enableCaching: true, + }); + + const actTool = new StagehandActTool(stagehand); + const navigateTool = new StagehandNavigateTool(stagehand); + + // Initialize the model + const model = new ChatOpenAI({ + modelName: "gpt-4", + temperature: 0, + }); + + // Create the agent using langgraph + const agent = createReactAgent({ + llm: model, + tools: [actTool, navigateTool], + }); + + // Execute the agent using streams + const inputs1 = { + messages: [ + { + role: "user", + content: "Navigate to https://www.google.com", + }, + ], + }; + + const stream1 = await agent.stream(inputs1, { + streamMode: "values", + }); + + for await (const { messages } of stream1) { + const msg = + messages && messages.length > 0 + ? messages[messages.length - 1] + : undefined; + if (msg?.content) { + console.log(msg.content); + } else if (msg?.tool_calls && msg.tool_calls.length > 0) { + console.log(msg.tool_calls); + } else { + console.log(msg); + } + } + + const inputs2 = { + messages: [ + { + role: "user", + content: "Search for 'OpenAI'", + }, + ], + }; + + const stream2 = await agent.stream(inputs2, { + streamMode: "values", + }); + + for await (const { messages } of stream2) { + const msg = + messages && messages.length > 0 + ? messages[messages.length - 1] + : undefined; + if (msg?.content) { + console.log(msg.content); + } else if (msg?.tool_calls && msg.tool_calls.length > 0) { + console.log(msg.tool_calls); + } else { + console.log(msg); + } + } +} + +main(); diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index 631c2a12879e9..7bd552b4dcb3b 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -65,6 +65,7 @@ export const config = { "agents/toolkits/aws_sfn": "agents/toolkits/aws_sfn", "agents/toolkits/base": "agents/toolkits/base", "agents/toolkits/connery": "agents/toolkits/connery/index", + "agents/toolkits/stagehand": "agents/toolkits/stagehand", // embeddings "embeddings/alibaba_tongyi": "embeddings/alibaba_tongyi", "embeddings/baidu_qianfan": "embeddings/baidu_qianfan", @@ -336,6 +337,7 @@ export const config = { "tools/gmail", "tools/google_calendar", "agents/toolkits/aws_sfn", + "agents/toolkits/stagehand", "callbacks/handlers/llmonitor", "callbacks/handlers/lunary", "callbacks/handlers/upstash_ratelimit", diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 8de0a386cb340..7758dd43a8b36 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -62,6 +62,7 @@ "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.15.0", "@browserbasehq/sdk": "^1.1.5", + "@browserbasehq/stagehand": "^1.0.0", "@clickhouse/client": "^0.2.5", "@cloudflare/ai": "1.0.12", "@cloudflare/workers-types": "^4.20230922.0", @@ -92,6 +93,7 @@ "@notionhq/client": "^2.2.10", "@opensearch-project/opensearch": "^2.2.0", "@planetscale/database": "^1.8.0", + "@playwright/test": "^1.48.2", "@premai/prem-sdk": "^0.3.25", "@qdrant/js-client-rest": "^1.8.2", "@raycast/api": "^1.83.1", @@ -190,6 +192,7 @@ "node-llama-cpp": "3.1.1", "notion-to-md": "^3.1.0", "officeparser": "^4.0.4", + "openai": "*", "pdf-parse": "1.1.1", "pg": "^8.11.0", "pg-copy-streams": "^6.0.5", @@ -231,6 +234,7 @@ "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.15.0", "@browserbasehq/sdk": "*", + "@browserbasehq/stagehand": "^1.0.0", "@clickhouse/client": "^0.2.5", "@cloudflare/ai": "*", "@datastax/astra-db-ts": "^1.0.0", @@ -319,6 +323,7 @@ "neo4j-driver": "*", "notion-to-md": "^3.1.0", "officeparser": "^4.0.4", + "openai": "*", "pdf-parse": "1.1.1", "pg": "^8.11.0", "pg-copy-streams": "^6.0.5", diff --git a/libs/langchain-community/src/agents/toolkits/stagehand.ts b/libs/langchain-community/src/agents/toolkits/stagehand.ts new file mode 100644 index 0000000000000..763560cf1b4b4 --- /dev/null +++ b/libs/langchain-community/src/agents/toolkits/stagehand.ts @@ -0,0 +1,177 @@ +import { + Tool, + BaseToolkit as Toolkit, + ToolInterface, + StructuredTool, +} from "@langchain/core/tools"; +import { Stagehand } from "@browserbasehq/stagehand"; +import { AnyZodObject, z } from "zod"; + +// Documentation is here: +// https://js.langchain.com/docs/integrations/tools/stagehand + +abstract class StagehandToolBase extends Tool { + protected stagehand?: Stagehand; + + private localStagehand?: Stagehand; + + constructor(stagehandInstance?: Stagehand) { + super(); + this.stagehand = stagehandInstance; + } + + protected async getStagehand(): Promise { + if (this.stagehand) return this.stagehand; + + if (!this.localStagehand) { + this.localStagehand = new Stagehand({ + env: "LOCAL", + enableCaching: true, + }); + await this.localStagehand.init(); + } + return this.localStagehand; + } +} + +function isErrorWithMessage(error: unknown): error is { message: string } { + return ( + typeof error === "object" && + error !== null && + "message" in error && + typeof (error as { message: unknown }).message === "string" + ); +} + +export class StagehandNavigateTool extends StagehandToolBase { + name = "stagehand_navigate"; + + description = + "Use this tool to navigate to a specific URL using Stagehand. The input should be a valid URL as a string."; + + async _call(input: string): Promise { + const stagehand = await this.getStagehand(); + try { + await stagehand.page.goto(input); + return `Successfully navigated to ${input}.`; + } catch (error: unknown) { + const message = isErrorWithMessage(error) ? error.message : String(error); + return `Failed to navigate: ${message}`; + } + } +} + +export class StagehandActTool extends StagehandToolBase { + name = "stagehand_act"; + + description = + "Use this tool to perform an action on the current web page using Stagehand. The input should be a string describing the action to perform."; + + async _call(input: string): Promise { + const stagehand = await this.getStagehand(); + const result = await stagehand.act({ action: input }); + if (result.success) { + return `Action performed successfully: ${result.message}`; + } else { + return `Failed to perform action: ${result.message}`; + } + } +} + +export class StagehandExtractTool extends StructuredTool { + name = "stagehand_extract"; + + description = + "Use this tool to extract structured information from the current web page using Stagehand. The input should include an 'instruction' string and a 'schema' object representing the extraction schema in JSON Schema format."; + + // Define the input schema for the tool + schema = z.object({ + instruction: z.string().describe("Instruction on what to extract"), + schema: z + .record(z.any()) + .describe("Extraction schema in JSON Schema format"), + }); + + private stagehand?: Stagehand; + + constructor(stagehandInstance?: Stagehand) { + super(); + this.stagehand = stagehandInstance; + } + + async _call(input: { + instruction: string; + schema: AnyZodObject; + }): Promise { + const stagehand = await this.getStagehand(); + const { instruction, schema } = input; + + try { + const result = await stagehand.extract({ + instruction, + schema, // Assuming Stagehand accepts the schema in JSON Schema format + }); + return JSON.stringify(result); + } catch (error: unknown) { + const message = isErrorWithMessage(error) ? error.message : String(error); + return `Failed to extract information: ${message}`; + } + } + + protected async getStagehand(): Promise { + if (this.stagehand) return this.stagehand; + + // Initialize local Stagehand instance if not provided + this.stagehand = new Stagehand({ + env: "LOCAL", + enableCaching: true, + }); + await this.stagehand.init(); + return this.stagehand; + } +} + +export class StagehandObserveTool extends StagehandToolBase { + name = "stagehand_observe"; + + description = + "Use this tool to observe the current web page and retrieve possible actions using Stagehand. The input can be an optional instruction string."; + + async _call(input: string): Promise { + const stagehand = await this.getStagehand(); + const instruction = input || undefined; + + try { + const result = await stagehand.observe({ instruction }); + return JSON.stringify(result); + } catch (error: unknown) { + const message = isErrorWithMessage(error) ? error.message : String(error); + return `Failed to observe: ${message}`; + } + } +} + +export class StagehandToolkit extends Toolkit { + tools: ToolInterface[]; + + stagehand?: Stagehand; + + constructor(stagehand?: Stagehand) { + super(); + this.stagehand = stagehand; + this.tools = this.initializeTools(); + } + + private initializeTools(): ToolInterface[] { + return [ + new StagehandNavigateTool(this.stagehand), + new StagehandActTool(this.stagehand), + new StagehandExtractTool(this.stagehand), + new StagehandObserveTool(this.stagehand), + ]; + } + + static async fromStagehand(stagehand: Stagehand): Promise { + return new StagehandToolkit(stagehand); + } +} diff --git a/libs/langchain-community/src/agents/toolkits/tests/stagehand.int.test.ts b/libs/langchain-community/src/agents/toolkits/tests/stagehand.int.test.ts new file mode 100644 index 0000000000000..67bc045385b8b --- /dev/null +++ b/libs/langchain-community/src/agents/toolkits/tests/stagehand.int.test.ts @@ -0,0 +1,242 @@ +import { expect, describe, test, beforeEach, afterEach } from "@jest/globals"; +import { Stagehand } from "@browserbasehq/stagehand"; +import { z } from "zod"; +import { ChatOpenAI } from "@langchain/openai"; +// import { createReactAgent } from "@langchain/langgraph/prebuilt"; +import { StagehandToolkit } from "../stagehand.js"; + +describe("StagehandToolkit Integration Tests", () => { + let stagehand: Stagehand; + let toolkit: StagehandToolkit; + + beforeEach(async () => { + stagehand = new Stagehand({ + env: "LOCAL", + headless: false, + verbose: 2, + debugDom: true, + enableCaching: false, + }); + await stagehand.init({ modelName: "gpt-4o-mini" }); + toolkit = await StagehandToolkit.fromStagehand(stagehand); + }); + + afterEach(async () => { + await stagehand.context.close().catch(() => {}); + }); + + test("should perform basic navigation and search", async () => { + const navigateTool = toolkit.tools.find( + (t) => t.name === "stagehand_navigate" + ); + if (!navigateTool) { + throw new Error("Navigate tool not found"); + } + await navigateTool.invoke("https://www.google.com"); + + const actionTool = toolkit.tools.find((t) => t.name === "stagehand_act"); + if (!actionTool) { + throw new Error("Action tool not found"); + } + await actionTool.invoke('Search for "OpenAI"'); + + const currentUrl = stagehand.page.url(); + expect(currentUrl).toContain("google.com/search?q=OpenAI"); + }); + + test("should extract structured data from webpage", async () => { + const navigateTool = toolkit.tools.find( + (t) => t.name === "stagehand_navigate" + ); + if (!navigateTool) { + throw new Error("Navigate tool not found"); + } + await navigateTool.invoke( + "https://github.com/facebook/react/graphs/contributors" + ); + + const extractTool = toolkit.tools.find( + (t) => t.name === "stagehand_extract" + ); + if (!extractTool) { + throw new Error("Extract tool not found"); + } + const input = { + instruction: "extract the top contributor", + schema: z.object({ + username: z.string(), + url: z.string(), + }), + }; + const result = await extractTool.invoke(input); + const parsedResult = JSON.parse(result); + const { username, url } = parsedResult; + expect(username).toBeDefined(); + expect(url).toBeDefined(); + }); + + test("should handle tab navigation", async () => { + const navigateTool = toolkit.tools.find( + (t) => t.name === "stagehand_navigate" + ); + if (!navigateTool) { + throw new Error("Navigate tool not found"); + } + await navigateTool.invoke("https://www.google.com/"); + + const actionTool = toolkit.tools.find((t) => t.name === "stagehand_act"); + if (!actionTool) { + throw new Error("Action tool not found"); + } + await actionTool.invoke("click on the about page"); + + const currentUrl = stagehand.page.url(); + expect(currentUrl).toContain("about"); + }); + + test("should use observe tool to get page information", async () => { + await stagehand.page.goto("https://github.com/browserbase/stagehand"); + + const observeTool = toolkit.tools.find( + (t) => t.name === "stagehand_observe" + ); + if (!observeTool) { + throw new Error("Observe tool not found"); + } + const result = await observeTool.invoke( + "What actions can be performed on the repository page?" + ); + + const observations = JSON.parse(result); + + expect(Array.isArray(observations)).toBe(true); + expect(observations.length).toBeGreaterThan(0); + expect(observations[0]).toHaveProperty("description"); + expect(observations[0]).toHaveProperty("selector"); + expect(typeof observations[0].description).toBe("string"); + expect(typeof observations[0].selector).toBe("string"); + }); + + test("should perform navigation and search using llm with tools", async () => { + const llm = new ChatOpenAI({ temperature: 0 }); + + if (!llm.bindTools) { + throw new Error("Language model does not support tools."); + } + + // Bind tools to the LLM + const llmWithTools = llm.bindTools(toolkit.tools); + + // Execute queries atomically + const result = await llmWithTools.invoke( + "Navigate to https://www.google.com" + ); + + expect(result.tool_calls).toBeDefined(); + expect(result.tool_calls?.length).toBe(1); + const toolCall = result.tool_calls?.[0]; + expect(toolCall?.name).toBe("stagehand_navigate"); + + const navigateTool = toolkit.tools.find( + (t) => t.name === "stagehand_navigate" + ); + if (!navigateTool) { + throw new Error("Navigate tool not found"); + } + const navigateResult = await navigateTool?.invoke(toolCall?.args?.input); + expect(navigateResult).toContain("Successfully navigated"); + + const result2 = await llmWithTools.invoke('Search for "OpenAI"'); + expect(result2.tool_calls).toBeDefined(); + expect(result2.tool_calls?.length).toBe(1); + const actionToolCall = result2.tool_calls?.[0]; + expect(actionToolCall?.name).toBe("stagehand_act"); + expect(actionToolCall?.args?.input).toBe("search for OpenAI"); + + const actionTool = toolkit.tools.find((t) => t.name === "stagehand_act"); + if (!actionTool) { + throw new Error("Action tool not found"); + } + const actionResult = await actionTool.invoke(actionToolCall?.args?.input); + expect(actionResult).toContain("successfully"); + + // Verify the current URL + const currentUrl = stagehand.page.url(); + expect(currentUrl).toContain("google.com/search?q=OpenAI"); + }); + + // test("should work with langgraph", async () => { + // const actTool = toolkit.tools.find((t) => t.name === "stagehand_act"); + // const navigateTool = toolkit.tools.find( + // (t) => t.name === "stagehand_navigate" + // ); + // if (!actTool || !navigateTool) { + // throw new Error("Required tools not found"); + // } + // const tools = [actTool, navigateTool]; + + // const model = new ChatOpenAI({ + // modelName: "gpt-4", + // temperature: 0, + // }); + + // const agent = createReactAgent({ + // llm: model, + // tools, + // }); + // // Navigate to Google + // const inputs1 = { + // messages: [ + // { + // role: "user", + // content: "Navigate to https://www.google.com", + // }, + // ], + // }; + + // const stream1 = await agent.stream(inputs1, { + // streamMode: "values", + // }); + + // for await (const { messages } of stream1) { + // const msg = + // messages && messages.length > 0 + // ? messages[messages.length - 1] + // : undefined; + // if (msg?.content) { + // console.log(msg.content); + // } else if (msg?.tool_calls && msg.tool_calls.length > 0) { + // console.log(msg.tool_calls); + // } else { + // console.log(msg); + // } + // } + + // // Click through to careers page and search + // const inputs2 = { + // messages: [ + // { + // role: "user", + // content: "Click on the About page", + // }, + // ], + // }; + + // const stream2 = await agent.stream(inputs2, { + // streamMode: "values", + // }); + // for await (const { messages } of stream2) { + // const msg = messages ? messages[messages.length - 1] : undefined; + // if (msg?.content) { + // console.log(msg.content); + // } else if (msg?.tool_calls && msg.tool_calls.length > 0) { + // console.log(msg.tool_calls); + // } else { + // console.log(msg); + // } + // } + + // const currentUrl = stagehand.page.url(); + // expect(currentUrl).toContain("about"); + // }); +}); diff --git a/yarn.lock b/yarn.lock index 5fdd9da4ebeb4..637c23e532c4d 100644 --- a/yarn.lock +++ b/yarn.lock @@ -8565,6 +8565,41 @@ __metadata: languageName: node linkType: hard +"@browserbasehq/sdk@npm:^2.0.0": + version: 2.0.0 + resolution: "@browserbasehq/sdk@npm:2.0.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + checksum: f3ef62ff6817e5095ba1d2477b3ffcbfd7accf9cc1692b8d047803d4c854ad68521724c12af7df584589b9c04eb2fe95ec6f0a20114d1515363b814aa9d8b34e + languageName: node + linkType: hard + +"@browserbasehq/stagehand@npm:^1.0.0, @browserbasehq/stagehand@npm:^1.3.0": + version: 1.3.0 + resolution: "@browserbasehq/stagehand@npm:1.3.0" + dependencies: + "@anthropic-ai/sdk": ^0.27.3 + "@browserbasehq/sdk": ^2.0.0 + anthropic: ^0.0.0 + anthropic-ai: ^0.0.10 + sharp: ^0.33.5 + zod-to-json-schema: ^3.23.3 + peerDependencies: + "@playwright/test": ^1.42.1 + deepmerge: ^4.3.1 + dotenv: ^16.4.5 + openai: ^4.62.1 + zod: ^3.23.8 + checksum: 16962b3a95af92f3d435b5ceca84a5f0728334c5b3ac327f862b09501a5ecc6465d305ce20856fdc5d83606c26f884b69676347e4524a38a5e20795ee3d4d30e + languageName: node + linkType: hard + "@chainsafe/is-ip@npm:^2.0.1": version: 2.0.2 resolution: "@chainsafe/is-ip@npm:2.0.2" @@ -9401,6 +9436,15 @@ __metadata: languageName: node linkType: hard +"@emnapi/runtime@npm:^1.2.0": + version: 1.3.1 + resolution: "@emnapi/runtime@npm:1.3.1" + dependencies: + tslib: ^2.4.0 + checksum: 9a16ae7905a9c0e8956cf1854ef74e5087fbf36739abdba7aa6b308485aafdc993da07c19d7af104cd5f8e425121120852851bb3a0f78e2160e420a36d47f42f + languageName: node + linkType: hard + "@esbuild-kit/cjs-loader@npm:^2.4.2": version: 2.4.2 resolution: "@esbuild-kit/cjs-loader@npm:2.4.2" @@ -10504,6 +10548,181 @@ __metadata: languageName: node linkType: hard +"@img/sharp-darwin-arm64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-darwin-arm64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-darwin-arm64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-darwin-arm64": + optional: true + conditions: os=darwin & cpu=arm64 + languageName: node + linkType: hard + +"@img/sharp-darwin-x64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-darwin-x64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-darwin-x64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-darwin-x64": + optional: true + conditions: os=darwin & cpu=x64 + languageName: node + linkType: hard + +"@img/sharp-libvips-darwin-arm64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-darwin-arm64@npm:1.0.4" + conditions: os=darwin & cpu=arm64 + languageName: node + linkType: hard + +"@img/sharp-libvips-darwin-x64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-darwin-x64@npm:1.0.4" + conditions: os=darwin & cpu=x64 + languageName: node + linkType: hard + +"@img/sharp-libvips-linux-arm64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-linux-arm64@npm:1.0.4" + conditions: os=linux & cpu=arm64 & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-libvips-linux-arm@npm:1.0.5": + version: 1.0.5 + resolution: "@img/sharp-libvips-linux-arm@npm:1.0.5" + conditions: os=linux & cpu=arm & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-libvips-linux-s390x@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-linux-s390x@npm:1.0.4" + conditions: os=linux & cpu=s390x & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-libvips-linux-x64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-linux-x64@npm:1.0.4" + conditions: os=linux & cpu=x64 & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-libvips-linuxmusl-arm64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-linuxmusl-arm64@npm:1.0.4" + conditions: os=linux & cpu=arm64 & libc=musl + languageName: node + linkType: hard + +"@img/sharp-libvips-linuxmusl-x64@npm:1.0.4": + version: 1.0.4 + resolution: "@img/sharp-libvips-linuxmusl-x64@npm:1.0.4" + conditions: os=linux & cpu=x64 & libc=musl + languageName: node + linkType: hard + +"@img/sharp-linux-arm64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linux-arm64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linux-arm64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-linux-arm64": + optional: true + conditions: os=linux & cpu=arm64 & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-linux-arm@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linux-arm@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linux-arm": 1.0.5 + dependenciesMeta: + "@img/sharp-libvips-linux-arm": + optional: true + conditions: os=linux & cpu=arm & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-linux-s390x@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linux-s390x@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linux-s390x": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-linux-s390x": + optional: true + conditions: os=linux & cpu=s390x & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-linux-x64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linux-x64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linux-x64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-linux-x64": + optional: true + conditions: os=linux & cpu=x64 & libc=glibc + languageName: node + linkType: hard + +"@img/sharp-linuxmusl-arm64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linuxmusl-arm64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linuxmusl-arm64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-linuxmusl-arm64": + optional: true + conditions: os=linux & cpu=arm64 & libc=musl + languageName: node + linkType: hard + +"@img/sharp-linuxmusl-x64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-linuxmusl-x64@npm:0.33.5" + dependencies: + "@img/sharp-libvips-linuxmusl-x64": 1.0.4 + dependenciesMeta: + "@img/sharp-libvips-linuxmusl-x64": + optional: true + conditions: os=linux & cpu=x64 & libc=musl + languageName: node + linkType: hard + +"@img/sharp-wasm32@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-wasm32@npm:0.33.5" + dependencies: + "@emnapi/runtime": ^1.2.0 + conditions: cpu=wasm32 + languageName: node + linkType: hard + +"@img/sharp-win32-ia32@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-win32-ia32@npm:0.33.5" + conditions: os=win32 & cpu=ia32 + languageName: node + linkType: hard + +"@img/sharp-win32-x64@npm:0.33.5": + version: 0.33.5 + resolution: "@img/sharp-win32-x64@npm:0.33.5" + conditions: os=win32 & cpu=x64 + languageName: node + linkType: hard + "@inquirer/figures@npm:^1.0.3": version: 1.0.5 resolution: "@inquirer/figures@npm:1.0.5" @@ -11484,6 +11703,7 @@ __metadata: "@azure/search-documents": ^12.0.0 "@azure/storage-blob": ^12.15.0 "@browserbasehq/sdk": ^1.1.5 + "@browserbasehq/stagehand": ^1.0.0 "@clickhouse/client": ^0.2.5 "@cloudflare/ai": 1.0.12 "@cloudflare/workers-types": ^4.20230922.0 @@ -11515,6 +11735,7 @@ __metadata: "@notionhq/client": ^2.2.10 "@opensearch-project/opensearch": ^2.2.0 "@planetscale/database": ^1.8.0 + "@playwright/test": ^1.48.2 "@premai/prem-sdk": ^0.3.25 "@qdrant/js-client-rest": ^1.8.2 "@raycast/api": ^1.83.1 @@ -11619,6 +11840,7 @@ __metadata: node-llama-cpp: 3.1.1 notion-to-md: ^3.1.0 officeparser: ^4.0.4 + openai: "*" pdf-parse: 1.1.1 pg: ^8.11.0 pg-copy-streams: ^6.0.5 @@ -11662,6 +11884,7 @@ __metadata: "@azure/search-documents": ^12.0.0 "@azure/storage-blob": ^12.15.0 "@browserbasehq/sdk": "*" + "@browserbasehq/stagehand": ^1.0.0 "@clickhouse/client": ^0.2.5 "@cloudflare/ai": "*" "@datastax/astra-db-ts": ^1.0.0 @@ -11750,6 +11973,7 @@ __metadata: neo4j-driver: "*" notion-to-md: ^3.1.0 officeparser: ^4.0.4 + openai: "*" pdf-parse: 1.1.1 pg: ^8.11.0 pg-copy-streams: ^6.0.5 @@ -14136,6 +14360,17 @@ __metadata: languageName: node linkType: hard +"@playwright/test@npm:^1.48.2": + version: 1.49.0 + resolution: "@playwright/test@npm:1.49.0" + dependencies: + playwright: 1.49.0 + bin: + playwright: cli.js + checksum: f8477aa61d59fd22c6161c48221ab246340dc37bbe2804e1a7d1be8cbd0fd861747fcb7ca559f4bc7328226ff2c90ccb7efa588a7d7d7829f3e57902b28fe39a + languageName: node + linkType: hard + "@pnpm/config.env-replace@npm:^1.0.0": version: 1.0.0 resolution: "@pnpm/config.env-replace@npm:1.0.0" @@ -21058,6 +21293,20 @@ __metadata: languageName: node linkType: hard +"anthropic-ai@npm:^0.0.10": + version: 0.0.10 + resolution: "anthropic-ai@npm:0.0.10" + checksum: aee9204f298475c6ab7c34c9cada7ba1dbe0337ce331aea4cd4feffd34e3f491ff3cd64b572ceb5a81b6304c3fbb783b4e7c3a1d8837ce0ad1dca80796dca8f5 + languageName: node + linkType: hard + +"anthropic@npm:^0.0.0": + version: 0.0.0 + resolution: "anthropic@npm:0.0.0" + checksum: 67e0d17815883edd7acf967848a4c5d15d36f827c7cc0191c84381d8c04838f2f22b363c299b92bd6d50ff398bfa417279554398b197a72b301ebc4b1888df08 + languageName: node + linkType: hard + "any-promise@npm:^1.0.0": version: 1.3.0 resolution: "any-promise@npm:1.3.0" @@ -25260,6 +25509,13 @@ __metadata: languageName: node linkType: hard +"detect-libc@npm:^2.0.3": + version: 2.0.3 + resolution: "detect-libc@npm:2.0.3" + checksum: 2ba6a939ae55f189aea996ac67afceb650413c7a34726ee92c40fb0deb2400d57ef94631a8a3f052055eea7efb0f99a9b5e6ce923415daa3e68221f963cfc27d + languageName: node + linkType: hard + "detect-newline@npm:^3.0.0": version: 3.1.0 resolution: "detect-newline@npm:3.1.0" @@ -27339,6 +27595,7 @@ __metadata: resolution: "examples@workspace:examples" dependencies: "@azure/identity": ^4.2.1 + "@browserbasehq/stagehand": ^1.3.0 "@clickhouse/client": ^0.2.5 "@elastic/elasticsearch": ^8.4.0 "@faker-js/faker": ^8.4.1 @@ -35905,6 +36162,28 @@ __metadata: languageName: node linkType: hard +"openai@npm:*": + version: 4.73.0 + resolution: "openai@npm:4.73.0" + dependencies: + "@types/node": ^18.11.18 + "@types/node-fetch": ^2.6.4 + abort-controller: ^3.0.0 + agentkeepalive: ^4.2.1 + form-data-encoder: 1.7.2 + formdata-node: ^4.3.2 + node-fetch: ^2.6.7 + peerDependencies: + zod: ^3.23.8 + peerDependenciesMeta: + zod: + optional: true + bin: + openai: bin/cli + checksum: 499f20048a15e777a943c8eefc2ae1e1195a88ec2a61337d98233ae7f2c2cba3ff4f473e08bce25af38d24da3a3da2b3db077919ac104ff2449e246e895b5b62 + languageName: node + linkType: hard + "openai@npm:^4.32.1": version: 4.47.1 resolution: "openai@npm:4.47.1" @@ -36959,6 +37238,30 @@ __metadata: languageName: node linkType: hard +"playwright-core@npm:1.49.0": + version: 1.49.0 + resolution: "playwright-core@npm:1.49.0" + bin: + playwright-core: cli.js + checksum: d8423ad0cab2e672856529bf6b98b406e7e605da098b847b9b54ee8ebd8d716ed8880a9afff4b38f0a2e3f59b95661c74589116ce3ff2b5e0ae3561507086c94 + languageName: node + linkType: hard + +"playwright@npm:1.49.0": + version: 1.49.0 + resolution: "playwright@npm:1.49.0" + dependencies: + fsevents: 2.3.2 + playwright-core: 1.49.0 + dependenciesMeta: + fsevents: + optional: true + bin: + playwright: cli.js + checksum: f1bfb2fff65cad2ce996edab74ec231dfd21aeb5961554b765ce1eaec27efb87eaba37b00e91ecd27727b82861e5d8c230abe4960e93f6ada8be5ad1020df306 + languageName: node + linkType: hard + "playwright@npm:^1.32.1": version: 1.32.1 resolution: "playwright@npm:1.32.1" @@ -40006,6 +40309,75 @@ __metadata: languageName: node linkType: hard +"sharp@npm:^0.33.5": + version: 0.33.5 + resolution: "sharp@npm:0.33.5" + dependencies: + "@img/sharp-darwin-arm64": 0.33.5 + "@img/sharp-darwin-x64": 0.33.5 + "@img/sharp-libvips-darwin-arm64": 1.0.4 + "@img/sharp-libvips-darwin-x64": 1.0.4 + "@img/sharp-libvips-linux-arm": 1.0.5 + "@img/sharp-libvips-linux-arm64": 1.0.4 + "@img/sharp-libvips-linux-s390x": 1.0.4 + "@img/sharp-libvips-linux-x64": 1.0.4 + "@img/sharp-libvips-linuxmusl-arm64": 1.0.4 + "@img/sharp-libvips-linuxmusl-x64": 1.0.4 + "@img/sharp-linux-arm": 0.33.5 + "@img/sharp-linux-arm64": 0.33.5 + "@img/sharp-linux-s390x": 0.33.5 + "@img/sharp-linux-x64": 0.33.5 + "@img/sharp-linuxmusl-arm64": 0.33.5 + "@img/sharp-linuxmusl-x64": 0.33.5 + "@img/sharp-wasm32": 0.33.5 + "@img/sharp-win32-ia32": 0.33.5 + "@img/sharp-win32-x64": 0.33.5 + color: ^4.2.3 + detect-libc: ^2.0.3 + semver: ^7.6.3 + dependenciesMeta: + "@img/sharp-darwin-arm64": + optional: true + "@img/sharp-darwin-x64": + optional: true + "@img/sharp-libvips-darwin-arm64": + optional: true + "@img/sharp-libvips-darwin-x64": + optional: true + "@img/sharp-libvips-linux-arm": + optional: true + "@img/sharp-libvips-linux-arm64": + optional: true + "@img/sharp-libvips-linux-s390x": + optional: true + "@img/sharp-libvips-linux-x64": + optional: true + "@img/sharp-libvips-linuxmusl-arm64": + optional: true + "@img/sharp-libvips-linuxmusl-x64": + optional: true + "@img/sharp-linux-arm": + optional: true + "@img/sharp-linux-arm64": + optional: true + "@img/sharp-linux-s390x": + optional: true + "@img/sharp-linux-x64": + optional: true + "@img/sharp-linuxmusl-arm64": + optional: true + "@img/sharp-linuxmusl-x64": + optional: true + "@img/sharp-wasm32": + optional: true + "@img/sharp-win32-ia32": + optional: true + "@img/sharp-win32-x64": + optional: true + checksum: 04beae89910ac65c5f145f88de162e8466bec67705f497ace128de849c24d168993e016f33a343a1f3c30b25d2a90c3e62b017a9a0d25452371556f6cd2471e4 + languageName: node + linkType: hard + "shebang-command@npm:^2.0.0": version: 2.0.0 resolution: "shebang-command@npm:2.0.0" @@ -44368,6 +44740,15 @@ __metadata: languageName: node linkType: hard +"zod-to-json-schema@npm:^3.23.3": + version: 3.23.5 + resolution: "zod-to-json-schema@npm:3.23.5" + peerDependencies: + zod: ^3.23.3 + checksum: 3ac37128d1b989b027e55684201e1da90237f0955dc9bb40da013bc60f2ed23c57026df2fdc14da187be4d53873daad08210807c28c0dde4375c9df0e5fe7901 + languageName: node + linkType: hard + "zod@npm:3.23.8": version: 3.23.8 resolution: "zod@npm:3.23.8" From f3d504a2f7876d1a6c0bf8c24cd1fb0932153182 Mon Sep 17 00:00:00 2001 From: Audrey Sage Lorberfeld Date: Mon, 25 Nov 2024 09:06:37 -0800 Subject: [PATCH 13/27] Update Pinecone indexing example to show use of deletionProtection (#7249) --- .../vector_stores/pinecone/index_docs.ts | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/src/indexes/vector_stores/pinecone/index_docs.ts b/examples/src/indexes/vector_stores/pinecone/index_docs.ts index 31d2a584b2d03..745b24be524ab 100644 --- a/examples/src/indexes/vector_stores/pinecone/index_docs.ts +++ b/examples/src/indexes/vector_stores/pinecone/index_docs.ts @@ -11,6 +11,23 @@ import { PineconeStore } from "@langchain/pinecone"; const pinecone = new Pinecone(); +// If index already exists: +// const pineconeIndex = pinecone.Index(process.env.PINECONE_INDEX!); + +// If index does not exist, create it: +await pinecone.createIndex({ + name: process.env.PINECONE_INDEX!, + dimension: 1536, + metric: "cosine", + spec: { + serverless: { + cloud: "aws", + region: "us-east-1", + }, + }, + deletionProtection: "disabled", // Note: deletion protection disabled https://docs.pinecone.io/guides/indexes/prevent-index-deletion#disable-deletion-protection +}); + const pineconeIndex = pinecone.Index(process.env.PINECONE_INDEX!); const docs = [ From 2d6cf2737f618e2e206de4b7a2b425ee235d95c4 Mon Sep 17 00:00:00 2001 From: FilipZmijewski Date: Mon, 25 Nov 2024 18:28:31 +0100 Subject: [PATCH 14/27] feat(community): Add rerank solution to existing IBM community implementation (#7200) Co-authored-by: jacoblee93 --- docs/core_docs/.gitignore | 88 ++-- .../document_compressors/ibm.ipynb | 385 ++++++++++++++++++ .../document_compressors/mixedbread_ai.mdx | 2 +- docs/core_docs/sidebars.js | 16 + libs/langchain-community/.gitignore | 4 + libs/langchain-community/langchain.config.js | 4 + libs/langchain-community/package.json | 13 + .../src/document_compressors/ibm.ts | 168 ++++++++ .../tests/ibm.int.test.ts | 80 ++++ .../document_compressors/tests/ibm.test.ts | 159 ++++++++ .../src/load/import_constants.ts | 1 + 11 files changed, 878 insertions(+), 42 deletions(-) create mode 100644 docs/core_docs/docs/integrations/document_compressors/ibm.ipynb create mode 100644 libs/langchain-community/src/document_compressors/ibm.ts create mode 100644 libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts create mode 100644 libs/langchain-community/src/document_compressors/tests/ibm.test.ts diff --git a/docs/core_docs/.gitignore b/docs/core_docs/.gitignore index 099af2f2d6427..8f648f622a5eb 100644 --- a/docs/core_docs/.gitignore +++ b/docs/core_docs/.gitignore @@ -218,14 +218,28 @@ docs/how_to/agent_executor.md docs/how_to/agent_executor.mdx docs/concepts/t.md docs/concepts/t.mdx +docs/troubleshooting/errors/INVALID_TOOL_RESULTS.md +docs/troubleshooting/errors/INVALID_TOOL_RESULTS.mdx docs/versions/migrating_memory/conversation_summary_memory.md docs/versions/migrating_memory/conversation_summary_memory.mdx docs/versions/migrating_memory/conversation_buffer_window_memory.md docs/versions/migrating_memory/conversation_buffer_window_memory.mdx docs/versions/migrating_memory/chat_history.md docs/versions/migrating_memory/chat_history.mdx -docs/troubleshooting/errors/INVALID_TOOL_RESULTS.md -docs/troubleshooting/errors/INVALID_TOOL_RESULTS.mdx +docs/integrations/tools/tavily_search.md +docs/integrations/tools/tavily_search.mdx +docs/integrations/tools/serpapi.md +docs/integrations/tools/serpapi.mdx +docs/integrations/tools/exa_search.md +docs/integrations/tools/exa_search.mdx +docs/integrations/tools/duckduckgo_search.md +docs/integrations/tools/duckduckgo_search.mdx +docs/integrations/toolkits/vectorstore.md +docs/integrations/toolkits/vectorstore.mdx +docs/integrations/toolkits/sql.md +docs/integrations/toolkits/sql.mdx +docs/integrations/toolkits/openapi.md +docs/integrations/toolkits/openapi.mdx docs/integrations/vectorstores/weaviate.md docs/integrations/vectorstores/weaviate.mdx docs/integrations/vectorstores/upstash.md @@ -252,22 +266,24 @@ docs/integrations/vectorstores/elasticsearch.md docs/integrations/vectorstores/elasticsearch.mdx docs/integrations/vectorstores/chroma.md docs/integrations/vectorstores/chroma.mdx -docs/integrations/tools/tavily_search.md -docs/integrations/tools/tavily_search.mdx -docs/integrations/tools/serpapi.md -docs/integrations/tools/serpapi.mdx -docs/integrations/tools/exa_search.md -docs/integrations/tools/exa_search.mdx -docs/integrations/tools/duckduckgo_search.md -docs/integrations/tools/duckduckgo_search.mdx -docs/integrations/toolkits/vectorstore.md -docs/integrations/toolkits/vectorstore.mdx -docs/integrations/toolkits/sql.md -docs/integrations/toolkits/sql.mdx -docs/integrations/toolkits/openapi.md -docs/integrations/toolkits/openapi.mdx +docs/integrations/stores/in_memory.md +docs/integrations/stores/in_memory.mdx +docs/integrations/stores/file_system.md +docs/integrations/stores/file_system.mdx +docs/integrations/retrievers/tavily.md +docs/integrations/retrievers/tavily.mdx +docs/integrations/retrievers/kendra-retriever.md +docs/integrations/retrievers/kendra-retriever.mdx +docs/integrations/retrievers/exa.md +docs/integrations/retrievers/exa.mdx +docs/integrations/retrievers/bm25.md +docs/integrations/retrievers/bm25.mdx +docs/integrations/retrievers/bedrock-knowledge-bases.md +docs/integrations/retrievers/bedrock-knowledge-bases.mdx docs/integrations/text_embedding/togetherai.md docs/integrations/text_embedding/togetherai.mdx +docs/integrations/text_embedding/pinecone.md +docs/integrations/text_embedding/pinecone.mdx docs/integrations/text_embedding/openai.md docs/integrations/text_embedding/openai.mdx docs/integrations/text_embedding/ollama.md @@ -290,20 +306,6 @@ docs/integrations/text_embedding/bedrock.md docs/integrations/text_embedding/bedrock.mdx docs/integrations/text_embedding/azure_openai.md docs/integrations/text_embedding/azure_openai.mdx -docs/integrations/stores/in_memory.md -docs/integrations/stores/in_memory.mdx -docs/integrations/stores/file_system.md -docs/integrations/stores/file_system.mdx -docs/integrations/retrievers/tavily.md -docs/integrations/retrievers/tavily.mdx -docs/integrations/retrievers/kendra-retriever.md -docs/integrations/retrievers/kendra-retriever.mdx -docs/integrations/retrievers/exa.md -docs/integrations/retrievers/exa.mdx -docs/integrations/retrievers/bm25.md -docs/integrations/retrievers/bm25.mdx -docs/integrations/retrievers/bedrock-knowledge-bases.md -docs/integrations/retrievers/bedrock-knowledge-bases.mdx docs/integrations/llms/together.md docs/integrations/llms/together.mdx docs/integrations/llms/openai.md @@ -328,6 +330,10 @@ docs/integrations/llms/azure.md docs/integrations/llms/azure.mdx docs/integrations/llms/arcjet.md docs/integrations/llms/arcjet.mdx +docs/integrations/document_compressors/ibm.md +docs/integrations/document_compressors/ibm.mdx +docs/integrations/chat/xai.md +docs/integrations/chat/xai.mdx docs/integrations/chat/togetherai.md docs/integrations/chat/togetherai.mdx docs/integrations/chat/openai.md @@ -376,6 +382,16 @@ docs/integrations/retrievers/self_query/hnswlib.md docs/integrations/retrievers/self_query/hnswlib.mdx docs/integrations/retrievers/self_query/chroma.md docs/integrations/retrievers/self_query/chroma.mdx +docs/integrations/document_loaders/file_loaders/unstructured.md +docs/integrations/document_loaders/file_loaders/unstructured.mdx +docs/integrations/document_loaders/file_loaders/text.md +docs/integrations/document_loaders/file_loaders/text.mdx +docs/integrations/document_loaders/file_loaders/pdf.md +docs/integrations/document_loaders/file_loaders/pdf.mdx +docs/integrations/document_loaders/file_loaders/directory.md +docs/integrations/document_loaders/file_loaders/directory.mdx +docs/integrations/document_loaders/file_loaders/csv.md +docs/integrations/document_loaders/file_loaders/csv.mdx docs/integrations/document_loaders/web_loaders/web_puppeteer.md docs/integrations/document_loaders/web_loaders/web_puppeteer.mdx docs/integrations/document_loaders/web_loaders/web_cheerio.md @@ -387,14 +403,4 @@ docs/integrations/document_loaders/web_loaders/pdf.mdx docs/integrations/document_loaders/web_loaders/langsmith.md docs/integrations/document_loaders/web_loaders/langsmith.mdx docs/integrations/document_loaders/web_loaders/firecrawl.md -docs/integrations/document_loaders/web_loaders/firecrawl.mdx -docs/integrations/document_loaders/file_loaders/unstructured.md -docs/integrations/document_loaders/file_loaders/unstructured.mdx -docs/integrations/document_loaders/file_loaders/text.md -docs/integrations/document_loaders/file_loaders/text.mdx -docs/integrations/document_loaders/file_loaders/pdf.md -docs/integrations/document_loaders/file_loaders/pdf.mdx -docs/integrations/document_loaders/file_loaders/directory.md -docs/integrations/document_loaders/file_loaders/directory.mdx -docs/integrations/document_loaders/file_loaders/csv.md -docs/integrations/document_loaders/file_loaders/csv.mdx \ No newline at end of file +docs/integrations/document_loaders/web_loaders/firecrawl.mdx \ No newline at end of file diff --git a/docs/core_docs/docs/integrations/document_compressors/ibm.ipynb b/docs/core_docs/docs/integrations/document_compressors/ibm.ipynb new file mode 100644 index 0000000000000..9645a4e126a3c --- /dev/null +++ b/docs/core_docs/docs/integrations/document_compressors/ibm.ipynb @@ -0,0 +1,385 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "afaf8039", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: IBM watsonx.ai\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "e49f1e0d", + "metadata": {}, + "source": [ + "# IBM watsonx.ai\n", + "\n", + "## Overview\n", + "\n", + "This will help you getting started with the [Watsonx document compressor](/docs/concepts/#document_compressors). For detailed documentation of all Watsonx document compressor features and configurations head to the [API reference]https://api.js.langchain.com/modules/_langchain_community.document_compressors_ibm.WatsonxRerank.html).\n", + "\n", + "### Integration details\n", + "\n", + "| Class | Package | [PY support](https://python.langchain.com/docs/integrations/llms/ibm_watsonx/) | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: |\n", + "| [`WatsonxRerank`](https://api.js.langchain.com/modules/_langchain_community.document_compressors_ibm.WatsonxRerank.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community) | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", + "\n", + "## Setup\n", + "\n", + "To access IBM WatsonxAI models you'll need to create an IBM watsonx.ai account, get an API key or any other type of credentials, and install the `@langchain/community` integration package.\n", + "\n", + "### Credentials\n", + "\n", + "Head to [IBM Cloud](https://cloud.ibm.com/login) to sign up to IBM watsonx.ai and generate an API key or provide any other authentication form as presented below.\n", + "\n", + "#### IAM authentication\n", + "\n", + "```bash\n", + "export WATSONX_AI_AUTH_TYPE=iam\n", + "export WATSONX_AI_APIKEY=\n", + "```\n", + "\n", + "#### Bearer token authentication\n", + "\n", + "```bash\n", + "export WATSONX_AI_AUTH_TYPE=bearertoken\n", + "export WATSONX_AI_BEARER_TOKEN=\n", + "```\n", + "\n", + "#### CP4D authentication\n", + "\n", + "```bash\n", + "export WATSONX_AI_AUTH_TYPE=cp4d\n", + "export WATSONX_AI_USERNAME=\n", + "export WATSONX_AI_PASSWORD=\n", + "export WATSONX_AI_URL=\n", + "```\n", + "\n", + "Once these are placed in your environment variables and object is initialized authentication will proceed automatically.\n", + "\n", + "Authentication can also be accomplished by passing these values as parameters to a new instance.\n", + "\n", + "## IAM authentication\n", + "\n", + "```typescript\n", + "import { WatsonxLLM } from \"@langchain/community/llms/ibm\";\n", + "\n", + "const props = {\n", + " version: \"YYYY-MM-DD\",\n", + " serviceUrl: \"\",\n", + " projectId: \"\",\n", + " watsonxAIAuthType: \"iam\",\n", + " watsonxAIApikey: \"\",\n", + "};\n", + "const instance = new WatsonxLLM(props);\n", + "```\n", + "\n", + "## Bearer token authentication\n", + "\n", + "```typescript\n", + "import { WatsonxLLM } from \"@langchain/community/llms/ibm\";\n", + "\n", + "const props = {\n", + " version: \"YYYY-MM-DD\",\n", + " serviceUrl: \"\",\n", + " projectId: \"\",\n", + " watsonxAIAuthType: \"bearertoken\",\n", + " watsonxAIBearerToken: \"\",\n", + "};\n", + "const instance = new WatsonxLLM(props);\n", + "```\n", + "\n", + "### CP4D authentication\n", + "\n", + "```typescript\n", + "import { WatsonxLLM } from \"@langchain/community/llms/ibm\";\n", + "\n", + "const props = {\n", + " version: \"YYYY-MM-DD\",\n", + " serviceUrl: \"\",\n", + " projectId: \"\",\n", + " watsonxAIAuthType: \"cp4d\",\n", + " watsonxAIUsername: \"\",\n", + " watsonxAIPassword: \"\",\n", + " watsonxAIUrl: \"\",\n", + "};\n", + "const instance = new WatsonxLLM(props);\n", + "```\n", + "\n", + "If you want to get automated tracing from individual queries, you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:\n", + "\n", + "```typescript\n", + "// process.env.LANGSMITH_API_KEY = \"\";\n", + "// process.env.LANGSMITH_TRACING = \"true\";\n", + "```\n", + "\n", + "### Installation\n", + "\n", + "This document compressor lives in the `@langchain/community` package:\n", + "\n", + "```{=mdx}\n", + "import IntegrationInstallTooltip from \"@mdx_components/integration_install_tooltip.mdx\";\n", + "import Npm2Yarn from \"@theme/Npm2Yarn\";\n", + "\n", + "\n", + "\n", + "\n", + " @langchain/community @langchain/core\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "a38cde65-254d-4219-a441-068766c0d4b5", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our compressor:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70cc8e65-2a02-408a-bbc6-8ef649057d82", + "metadata": {}, + "outputs": [], + "source": [ + "import { WatsonxRerank } from \"@langchain/community/document_compressors/ibm\";\n", + "\n", + "const watsonxRerank = new WatsonxRerank({\n", + " version: \"2024-05-31\",\n", + " serviceUrl: process.env.WATSONX_AI_SERVICE_URL,\n", + " projectId: process.env.WATSONX_AI_PROJECT_ID,\n", + " model: \"ibm/slate-125m-english-rtrvr\",\n", + "});" + ] + }, + { + "cell_type": "markdown", + "id": "5c5f2839-4020-424e-9fc9-07777eede442", + "metadata": {}, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "markdown", + "id": "b195b0c7", + "metadata": {}, + "source": [ + "First, set up a basic RAG ingest pipeline with embeddings, a text splitter and a vector store. We'll use this to and rerank some documents regarding the selected query:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "51a60dbe-9f2e-4e04-bb62-23968f17164a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " Document {\n", + " pageContent: 'And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.',\n", + " metadata: { loc: [Object] },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. \\n' +\n", + " '\\n' +\n", + " 'I’ve worked on these issues a long time. \\n' +\n", + " '\\n' +\n", + " 'I know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety.',\n", + " metadata: { loc: [Object] },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'We are the only nation on Earth that has always turned every crisis we have faced into an opportunity. \\n' +\n", + " '\\n' +\n", + " 'The only nation that can be defined by a single word: possibilities. \\n' +\n", + " '\\n' +\n", + " 'So on this night, in our 245th year as a nation, I have come to report on the State of the Union. \\n' +\n", + " '\\n' +\n", + " 'And my report is this: the State of the Union is strong—because you, the American people, are strong.',\n", + " metadata: { loc: [Object] },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'And I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \\n' +\n", + " '\\n' +\n", + " 'Tonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world.',\n", + " metadata: { loc: [Object] },\n", + " id: undefined\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "import { readFileSync } from \"node:fs\";\n", + "import { MemoryVectorStore } from \"langchain/vectorstores/memory\";\n", + "import { WatsonxEmbeddings } from \"@langchain/community/embeddings/ibm\";\n", + "import { CharacterTextSplitter } from \"@langchain/textsplitters\";\n", + "\n", + "const embeddings = new WatsonxEmbeddings({\n", + " version: \"YYYY-MM-DD\",\n", + " serviceUrl: process.env.API_URL,\n", + " projectId: \"\",\n", + " spaceId: \"\",\n", + " model: \"ibm/slate-125m-english-rtrvr\",\n", + "});\n", + "\n", + "const textSplitter = new CharacterTextSplitter({\n", + " chunkSize: 400,\n", + " chunkOverlap: 0,\n", + "});\n", + " \n", + "const query = \"What did the president say about Ketanji Brown Jackson\";\n", + "const text = readFileSync(\"state_of_the_union.txt\", \"utf8\");\n", + "\n", + "const docs = await textSplitter.createDocuments([text]);\n", + "const vectorStore = await MemoryVectorStore.fromDocuments(docs, embeddings);\n", + "const vectorStoreRetriever = vectorStore.asRetriever();\n", + "\n", + "const result = await vectorStoreRetriever.invoke(query);\n", + "console.log(result);" + ] + }, + { + "cell_type": "markdown", + "id": "b13ebf96", + "metadata": {}, + "source": [ + "Pass selected documents to rerank and recive specific score for each" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fad30397", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " { index: 0, relevanceScore: 0.726995587348938 },\n", + " { index: 1, relevanceScore: 0.5758284330368042 },\n", + " { index: 2, relevanceScore: 0.5479092597961426 },\n", + " { index: 3, relevanceScore: 0.5468723773956299 }\n", + "]\n" + ] + } + ], + "source": [ + "import { WatsonxRerank } from \"@langchain/community/document_compressors/ibm\";\n", + "\n", + "const reranker = new WatsonxRerank({\n", + " version: \"2024-05-31\",\n", + " serviceUrl: process.env.WATSONX_AI_SERVICE_URL,\n", + " projectId: process.env.WATSONX_AI_PROJECT_ID,\n", + " model: \"ibm/slate-125m-english-rtrvr\",\n", + "});\n", + "const compressed = await reranker.rerank(result, query);\n", + "console.log(compressed);" + ] + }, + { + "cell_type": "markdown", + "id": "55f4ec18", + "metadata": {}, + "source": [ + "Or else you could have the documents returned with the result, for that use .compressDocuments() method as below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6cc39ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\n", + " Document {\n", + " pageContent: 'And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.',\n", + " metadata: { loc: [Object], relevanceScore: 0.726995587348938 },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'I spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. \\n' +\n", + " '\\n' +\n", + " 'I’ve worked on these issues a long time. \\n' +\n", + " '\\n' +\n", + " 'I know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety.',\n", + " metadata: { loc: [Object], relevanceScore: 0.5758284330368042 },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'We are the only nation on Earth that has always turned every crisis we have faced into an opportunity. \\n' +\n", + " '\\n' +\n", + " 'The only nation that can be defined by a single word: possibilities. \\n' +\n", + " '\\n' +\n", + " 'So on this night, in our 245th year as a nation, I have come to report on the State of the Union. \\n' +\n", + " '\\n' +\n", + " 'And my report is this: the State of the Union is strong—because you, the American people, are strong.',\n", + " metadata: { loc: [Object], relevanceScore: 0.5479092597961426 },\n", + " id: undefined\n", + " },\n", + " Document {\n", + " pageContent: 'And I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \\n' +\n", + " '\\n' +\n", + " 'Tonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world.',\n", + " metadata: { loc: [Object], relevanceScore: 0.5468723773956299 },\n", + " id: undefined\n", + " }\n", + "]\n" + ] + } + ], + "source": [ + "const compressedWithResults = await reranker.compressDocuments(result, query);\n", + "console.log(compressedWithResults);" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all Watsonx document compressor features and configurations head to the [API reference](https://api.js.langchain.com/modules/_langchain_community.document_compressors_ibm.WatsonxRerank.html)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "JavaScript (Node.js)", + "language": "javascript", + "name": "javascript" + }, + "language_info": { + "file_extension": ".js", + "mimetype": "application/javascript", + "name": "javascript", + "version": "20.17.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/core_docs/docs/integrations/document_compressors/mixedbread_ai.mdx b/docs/core_docs/docs/integrations/document_compressors/mixedbread_ai.mdx index e74cdf488b0d4..5ed78bdbc8645 100644 --- a/docs/core_docs/docs/integrations/document_compressors/mixedbread_ai.mdx +++ b/docs/core_docs/docs/integrations/document_compressors/mixedbread_ai.mdx @@ -51,4 +51,4 @@ console.log(result); ## Additional Resources -For more information, refer to the [Reranking API documentation](https://mixedbread.ai/docs/reranking). +For more information, refer to the [Reranking API documentation](https://www.mixedbread.ai/docs/reranking/overview). diff --git a/docs/core_docs/sidebars.js b/docs/core_docs/sidebars.js index 95bf57ec58590..978a4a1614b2c 100644 --- a/docs/core_docs/sidebars.js +++ b/docs/core_docs/sidebars.js @@ -347,6 +347,22 @@ module.exports = { slug: "integrations/document_transformers", }, }, + { + type: "category", + label: "Document rerankers", + collapsible: false, + items: [ + { + type: "autogenerated", + dirName: "integrations/document_compressors", + className: "hidden", + }, + ], + link: { + type: "generated-index", + slug: "integrations/document_compressors", + }, + }, { type: "category", label: "Model caches", diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore index b554afb6f1ec3..99f77ac328f8c 100644 --- a/libs/langchain-community/.gitignore +++ b/libs/langchain-community/.gitignore @@ -698,6 +698,10 @@ graphs/memgraph_graph.cjs graphs/memgraph_graph.js graphs/memgraph_graph.d.ts graphs/memgraph_graph.d.cts +document_compressors/ibm.cjs +document_compressors/ibm.js +document_compressors/ibm.d.ts +document_compressors/ibm.d.cts document_transformers/html_to_text.cjs document_transformers/html_to_text.js document_transformers/html_to_text.d.ts diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index 7bd552b4dcb3b..4a402c6941e8d 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -220,6 +220,8 @@ export const config = { // graphs "graphs/neo4j_graph": "graphs/neo4j_graph", "graphs/memgraph_graph": "graphs/memgraph_graph", + // document_compressors + "document_compressors/ibm": "document_compressors/ibm", // document transformers "document_transformers/html_to_text": "document_transformers/html_to_text", "document_transformers/mozilla_readability": @@ -446,6 +448,8 @@ export const config = { "cache/upstash_redis", "graphs/neo4j_graph", "graphs/memgraph_graph", + // document_compressors + "document_compressors/ibm", // document_transformers "document_transformers/html_to_text", "document_transformers/mozilla_readability", diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 7758dd43a8b36..be0bd452a0576 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -2287,6 +2287,15 @@ "import": "./graphs/memgraph_graph.js", "require": "./graphs/memgraph_graph.cjs" }, + "./document_compressors/ibm": { + "types": { + "import": "./document_compressors/ibm.d.ts", + "require": "./document_compressors/ibm.d.cts", + "default": "./document_compressors/ibm.d.ts" + }, + "import": "./document_compressors/ibm.js", + "require": "./document_compressors/ibm.cjs" + }, "./document_transformers/html_to_text": { "types": { "import": "./document_transformers/html_to_text.d.ts", @@ -3783,6 +3792,10 @@ "graphs/memgraph_graph.js", "graphs/memgraph_graph.d.ts", "graphs/memgraph_graph.d.cts", + "document_compressors/ibm.cjs", + "document_compressors/ibm.js", + "document_compressors/ibm.d.ts", + "document_compressors/ibm.d.cts", "document_transformers/html_to_text.cjs", "document_transformers/html_to_text.js", "document_transformers/html_to_text.d.ts", diff --git a/libs/langchain-community/src/document_compressors/ibm.ts b/libs/langchain-community/src/document_compressors/ibm.ts new file mode 100644 index 0000000000000..348f606854803 --- /dev/null +++ b/libs/langchain-community/src/document_compressors/ibm.ts @@ -0,0 +1,168 @@ +import { DocumentInterface } from "@langchain/core/documents"; +import { BaseDocumentCompressor } from "@langchain/core/retrievers/document_compressors"; +import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; +import { AsyncCaller } from "@langchain/core/utils/async_caller"; +import { WatsonxAuth, WatsonxParams } from "../types/ibm.js"; +import { authenticateAndSetInstance } from "../utils/ibm.js"; + +export interface WatsonxInputRerank extends Omit { + truncateInputTokens?: number; + returnOptions?: { + topN?: number; + inputs?: boolean; + }; +} +export class WatsonxRerank + extends BaseDocumentCompressor + implements WatsonxInputRerank +{ + maxRetries = 0; + + version = "2024-05-31"; + + truncateInputTokens?: number | undefined; + + returnOptions?: + | { topN?: number; inputs?: boolean; query?: boolean } + | undefined; + + model: string; + + spaceId?: string | undefined; + + projectId?: string | undefined; + + maxConcurrency?: number | undefined; + + serviceUrl: string; + + service: WatsonXAI; + + constructor(fields: WatsonxInputRerank & WatsonxAuth) { + super(); + if (fields.projectId && fields.spaceId) + throw new Error("Maximum 1 id type can be specified per instance"); + + if (!fields.projectId && !fields.spaceId) + throw new Error( + "No id specified! At least id of 1 type has to be specified" + ); + this.model = fields.model; + this.serviceUrl = fields.serviceUrl; + this.version = fields.version; + this.projectId = fields?.projectId; + this.spaceId = fields?.spaceId; + this.maxRetries = fields.maxRetries ?? this.maxRetries; + this.maxConcurrency = fields.maxConcurrency; + this.truncateInputTokens = fields.truncateInputTokens; + this.returnOptions = fields.returnOptions; + + const { + watsonxAIApikey, + watsonxAIAuthType, + watsonxAIBearerToken, + watsonxAIUsername, + watsonxAIPassword, + watsonxAIUrl, + version, + serviceUrl, + } = fields; + + const auth = authenticateAndSetInstance({ + watsonxAIApikey, + watsonxAIAuthType, + watsonxAIBearerToken, + watsonxAIUsername, + watsonxAIPassword, + watsonxAIUrl, + version, + serviceUrl, + }); + if (auth) this.service = auth; + else throw new Error("You have not provided one type of authentication"); + } + + scopeId() { + if (this.projectId) + return { projectId: this.projectId, modelId: this.model }; + else return { spaceId: this.spaceId, modelId: this.model }; + } + + invocationParams(options?: Partial) { + return { + truncate_input_tokens: + options?.truncateInputTokens ?? this.truncateInputTokens, + return_options: { + top_n: options?.returnOptions?.topN ?? this.returnOptions?.topN, + inputs: options?.returnOptions?.inputs ?? this.returnOptions?.inputs, + }, + }; + } + + async compressDocuments( + documents: DocumentInterface[], + query: string + ): Promise { + const caller = new AsyncCaller({ + maxConcurrency: this.maxConcurrency, + maxRetries: this.maxRetries, + }); + const inputs = documents.map((document) => ({ + text: document.pageContent, + })); + const { result } = await caller.call(() => + this.service.textRerank({ + ...this.scopeId(), + inputs, + query, + }) + ); + const resultDocuments = result.results.map(({ index, score }) => { + const rankedDocument = documents[index]; + rankedDocument.metadata.relevanceScore = score; + return rankedDocument; + }); + return resultDocuments; + } + + async rerank( + documents: Array< + DocumentInterface | string | Record<"pageContent", string> + >, + query: string, + options?: Partial + ): Promise> { + const inputs = documents.map((document) => { + if (typeof document === "string") { + return { text: document }; + } + return { text: document.pageContent }; + }); + + const caller = new AsyncCaller({ + maxConcurrency: this.maxConcurrency, + maxRetries: this.maxRetries, + }); + const { result } = await caller.call(() => + this.service.textRerank({ + ...this.scopeId(), + inputs, + query, + parameters: this.invocationParams(options), + }) + ); + const response = result.results.map((document) => { + return document?.input + ? { + index: document.index, + relevanceScore: document.score, + input: document?.input, + } + : { + index: document.index, + relevanceScore: document.score, + }; + }); + return response; + } +} diff --git a/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts new file mode 100644 index 0000000000000..e65ea9e1eff3a --- /dev/null +++ b/libs/langchain-community/src/document_compressors/tests/ibm.int.test.ts @@ -0,0 +1,80 @@ +/* eslint-disable no-process-env */ +import { Document } from "@langchain/core/documents"; +import { WatsonxRerank } from "../ibm.js"; + +const query = "What is the capital of the United States?"; +const docs = [ + new Document({ + pageContent: + "Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.", + }), + new Document({ + pageContent: + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.", + }), + new Document({ + pageContent: + "Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.", + }), + new Document({ + pageContent: + "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.", + }), + new Document({ + pageContent: + "Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.", + }), +]; +describe("Integration tests on WatsonxRerank", () => { + describe(".compressDocuments() method", () => { + test("Basic call", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const result = await instance.compressDocuments(docs, query); + expect(result.length).toBe(docs.length); + result.forEach((item) => + expect(typeof item.metadata.relevanceScore).toBe("number") + ); + }); + }); + + describe(".rerank() method", () => { + test("Basic call", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const result = await instance.rerank(docs, query); + expect(result.length).toBe(docs.length); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeUndefined(); + }); + }); + }); + test("Basic call with options", async () => { + const instance = new WatsonxRerank({ + model: "cross-encoder/ms-marco-minilm-l-12-v2", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + version: "2024-05-31", + projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", + }); + const result = await instance.rerank(docs, query, { + returnOptions: { + topN: 3, + inputs: true, + }, + }); + expect(result.length).toBe(3); + result.forEach((item) => { + expect(typeof item.relevanceScore).toBe("number"); + expect(item.input).toBeDefined(); + }); + }); +}); diff --git a/libs/langchain-community/src/document_compressors/tests/ibm.test.ts b/libs/langchain-community/src/document_compressors/tests/ibm.test.ts new file mode 100644 index 0000000000000..c9332bc8c1b3a --- /dev/null +++ b/libs/langchain-community/src/document_compressors/tests/ibm.test.ts @@ -0,0 +1,159 @@ +/* eslint-disable no-process-env */ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { WatsonxRerank, WatsonxInputRerank } from "../ibm.js"; + +function getKey(key: K): K { + return key; +} +const testProperties = ( + instance: WatsonxRerank, + testProps: WatsonxInputRerank, + notExTestProps?: { [key: string]: any } +) => { + const checkProperty = ( + testProps: T, + instance: T, + existing = true + ) => { + Object.keys(testProps).forEach((key) => { + const keys = getKey(key); + type Type = Pick; + + if (typeof testProps[key as keyof T] === "object") + checkProperty(testProps[key as keyof T], instance[key], existing); + else { + if (existing) + expect(instance[key as keyof T]).toBe(testProps[key as keyof T]); + else if (instance) expect(instance[key as keyof T]).toBeUndefined(); + } + }); + }; + checkProperty(testProps, instance); + if (notExTestProps) + checkProperty(notExTestProps, instance, false); +}; +const fakeAuthProp = { + watsonxAIAuthType: "iam", + watsonxAIApikey: "fake_key", +}; +describe("Embeddings unit tests", () => { + describe("Positive tests", () => { + test("Basic properties", () => { + const testProps = { + model: "cross-encoder/ms-marco-minilm-l-12-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const instance = new WatsonxRerank({ ...testProps, ...fakeAuthProp }); + testProperties(instance, testProps); + }); + + test("Basic properties", () => { + const testProps: WatsonxInputRerank = { + model: "cross-encoder/ms-marco-minilm-l-12-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + truncateInputTokens: 10, + maxConcurrency: 2, + maxRetries: 2, + returnOptions: { + topN: 5, + inputs: false, + }, + }; + const instance = new WatsonxRerank({ ...testProps, ...fakeAuthProp }); + testProperties(instance, testProps); + }); + }); + + describe("Negative tests", () => { + test("Missing id", async () => { + const testProps = { + model: "cross-encoder/ms-marco-minilm-l-12-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + }; + expect( + () => + new WatsonxRerank({ + ...testProps, + ...fakeAuthProp, + }) + ).toThrowError(); + }); + + test("Missing other props", async () => { + // @ts-expect-error Intentionally passing wrong value + const testPropsProjectId: WatsonxInputLLM = { + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + expect( + () => + new WatsonxRerank({ + ...testPropsProjectId, + }) + ).toThrowError(); + // @ts-expect-error //Intentionally passing wrong value + const testPropsServiceUrl: WatsonxInputLLM = { + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + }; + expect( + () => + new WatsonxRerank({ + ...testPropsServiceUrl, + }) + ).toThrowError(); + const testPropsVersion = { + version: "2024-05-31", + }; + expect( + () => + new WatsonxRerank({ + // @ts-expect-error Intentionally passing wrong props + testPropsVersion, + }) + ).toThrowError(); + }); + + test("Passing more than one id", async () => { + const testProps = { + model: "cross-encoder/ms-marco-minilm-l-12-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + spaceId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + expect( + () => + new WatsonxRerank({ + ...testProps, + ...fakeAuthProp, + }) + ).toThrowError(); + }); + + test("Invalid properties", () => { + const testProps = { + model: "cross-encoder/ms-marco-minilm-l-12-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", + }; + const notExTestProps = { + notExisting: 12, + notExObj: { + notExProp: 12, + }, + }; + const instance = new WatsonxRerank({ + ...testProps, + ...notExTestProps, + ...fakeAuthProp, + }); + + testProperties(instance, testProps, notExTestProps); + }); + }); +}); diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts index 65d40fc1f7d49..7be97f096fde1 100644 --- a/libs/langchain-community/src/load/import_constants.ts +++ b/libs/langchain-community/src/load/import_constants.ts @@ -111,6 +111,7 @@ export const optionalImportEntrypoints: string[] = [ "langchain_community/retrievers/zep_cloud", "langchain_community/graphs/neo4j_graph", "langchain_community/graphs/memgraph_graph", + "langchain_community/document_compressors/ibm", "langchain_community/document_transformers/html_to_text", "langchain_community/document_transformers/mozilla_readability", "langchain_community/storage/cassandra", From 3460b92609d51785d8201b7036cf52ddb8fdaa4f Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 25 Nov 2024 09:37:12 -0800 Subject: [PATCH 15/27] fix(core): Fix issue in .d.ts typing for protected type (#7259) --- langchain-core/src/output_parsers/bytes.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/langchain-core/src/output_parsers/bytes.ts b/langchain-core/src/output_parsers/bytes.ts index 222d439481351..b6ebdb4708df2 100644 --- a/langchain-core/src/output_parsers/bytes.ts +++ b/langchain-core/src/output_parsers/bytes.ts @@ -13,7 +13,9 @@ export class BytesOutputParser extends BaseTransformOutputParser { lc_serializable = true; - protected textEncoder = new TextEncoder(); + // TODO: Figure out why explicit typing is needed + // eslint-disable-next-line @typescript-eslint/no-explicit-any + protected textEncoder: any = new TextEncoder(); parse(text: string): Promise { return Promise.resolve(this.textEncoder.encode(text)); From 53d8ff5cc04cd8d295c9e21318b1784b0bece878 Mon Sep 17 00:00:00 2001 From: Alex Shan Date: Tue, 26 Nov 2024 01:46:38 +0800 Subject: [PATCH 16/27] fix(community): PrismaVectorStore handle empty array in filter (#7254) Co-authored-by: jacoblee93 --- libs/langchain-community/src/vectorstores/prisma.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/libs/langchain-community/src/vectorstores/prisma.ts b/libs/langchain-community/src/vectorstores/prisma.ts index 669fb51a52628..e61bb861eec20 100644 --- a/libs/langchain-community/src/vectorstores/prisma.ts +++ b/libs/langchain-community/src/vectorstores/prisma.ts @@ -437,6 +437,16 @@ export class PrismaVectorStore< )}` ); } + + if (value.length === 0) { + const isInOperator = OpMap[opNameKey] === OpMap.in; + + // For empty arrays: + // - IN () should return FALSE (nothing can be in an empty set) + // - NOT IN () should return TRUE (everything is not in an empty set) + return this.Prisma.sql`${!isInOperator}`; + } + return this.Prisma.sql`${colRaw} ${opRaw} (${this.Prisma.join( value )})`; From ca44bd7a14a03576aae1794cb3acc058ec3fa3c6 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 25 Nov 2024 09:57:12 -0800 Subject: [PATCH 17/27] fix(core): Move type (#7246) --- langchain-core/src/tools/index.ts | 4 +++- langchain-core/src/types/zod.ts | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) delete mode 100644 langchain-core/src/types/zod.ts diff --git a/langchain-core/src/tools/index.ts b/langchain-core/src/tools/index.ts index 844295e605f83..348e851039049 100644 --- a/langchain-core/src/tools/index.ts +++ b/langchain-core/src/tools/index.ts @@ -16,7 +16,6 @@ import { } from "../runnables/config.js"; import type { RunnableFunc, RunnableInterface } from "../runnables/base.js"; import { ToolCall, ToolMessage } from "../messages/tool.js"; -import { ZodObjectAny } from "../types/zod.js"; import { MessageContent } from "../messages/base.js"; import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js"; import { _isToolCall, ToolInputParsingException } from "./utils.js"; @@ -32,6 +31,9 @@ type ToolReturnType = any; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type ContentAndArtifact = [MessageContent, any]; +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type ZodObjectAny = z.ZodObject; + /** * Parameters for the Tool classes. */ diff --git a/langchain-core/src/types/zod.ts b/langchain-core/src/types/zod.ts deleted file mode 100644 index faaa92b3bff24..0000000000000 --- a/langchain-core/src/types/zod.ts +++ /dev/null @@ -1,4 +0,0 @@ -import type { z } from "zod"; - -// eslint-disable-next-line @typescript-eslint/no-explicit-any -export type ZodObjectAny = z.ZodObject; From d35e161aed149ec7e0ff4fb0d899fd538945ee5a Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 25 Nov 2024 10:02:13 -0800 Subject: [PATCH 18/27] chore(core): Release 0.3.19 (#7260) --- langchain-core/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain-core/package.json b/langchain-core/package.json index 8b0650ded8043..15284c5783961 100644 --- a/langchain-core/package.json +++ b/langchain-core/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/core", - "version": "0.3.18", + "version": "0.3.19", "description": "Core LangChain.js abstractions and schemas", "type": "module", "engines": { From 983f57a159f48246400a831144e0800b50717b74 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 25 Nov 2024 10:13:05 -0800 Subject: [PATCH 19/27] chore(community): Release 0.3.16 (#7261) --- libs/langchain-community/.gitignore | 4 ++++ libs/langchain-community/package.json | 15 ++++++++++++++- .../src/load/import_constants.ts | 1 + 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore index 99f77ac328f8c..e6ae5fa54a4fb 100644 --- a/libs/langchain-community/.gitignore +++ b/libs/langchain-community/.gitignore @@ -122,6 +122,10 @@ agents/toolkits/connery.cjs agents/toolkits/connery.js agents/toolkits/connery.d.ts agents/toolkits/connery.d.cts +agents/toolkits/stagehand.cjs +agents/toolkits/stagehand.js +agents/toolkits/stagehand.d.ts +agents/toolkits/stagehand.d.cts embeddings/alibaba_tongyi.cjs embeddings/alibaba_tongyi.js embeddings/alibaba_tongyi.d.ts diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index be0bd452a0576..7b826ad1e106b 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/community", - "version": "0.3.15", + "version": "0.3.16", "description": "Third-party integrations for LangChain.js", "type": "module", "engines": { @@ -991,6 +991,15 @@ "import": "./agents/toolkits/connery.js", "require": "./agents/toolkits/connery.cjs" }, + "./agents/toolkits/stagehand": { + "types": { + "import": "./agents/toolkits/stagehand.d.ts", + "require": "./agents/toolkits/stagehand.d.cts", + "default": "./agents/toolkits/stagehand.d.ts" + }, + "import": "./agents/toolkits/stagehand.js", + "require": "./agents/toolkits/stagehand.cjs" + }, "./embeddings/alibaba_tongyi": { "types": { "import": "./embeddings/alibaba_tongyi.d.ts", @@ -3216,6 +3225,10 @@ "agents/toolkits/connery.js", "agents/toolkits/connery.d.ts", "agents/toolkits/connery.d.cts", + "agents/toolkits/stagehand.cjs", + "agents/toolkits/stagehand.js", + "agents/toolkits/stagehand.d.ts", + "agents/toolkits/stagehand.d.cts", "embeddings/alibaba_tongyi.cjs", "embeddings/alibaba_tongyi.js", "embeddings/alibaba_tongyi.d.ts", diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts index 7be97f096fde1..722dd82e678b2 100644 --- a/libs/langchain-community/src/load/import_constants.ts +++ b/libs/langchain-community/src/load/import_constants.ts @@ -8,6 +8,7 @@ export const optionalImportEntrypoints: string[] = [ "langchain_community/tools/gmail", "langchain_community/tools/google_calendar", "langchain_community/agents/toolkits/aws_sfn", + "langchain_community/agents/toolkits/stagehand", "langchain_community/embeddings/bedrock", "langchain_community/embeddings/cloudflare_workersai", "langchain_community/embeddings/cohere", From 3e37940f39c40247534bb29f7463567dbde9c75f Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:42:44 -0800 Subject: [PATCH 20/27] fix(google-common): Anthropic util using getType instead of _getType (#7263) --- libs/langchain-google-common/src/utils/anthropic.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-google-common/src/utils/anthropic.ts b/libs/langchain-google-common/src/utils/anthropic.ts index 72e1f9e57080d..a5ee5004fbc55 100644 --- a/libs/langchain-google-common/src/utils/anthropic.ts +++ b/libs/langchain-google-common/src/utils/anthropic.ts @@ -518,7 +518,7 @@ export function getAnthropicAPI(config?: AnthropicAPIConfig): GoogleAIAPI { function baseToAnthropicMessage( base: BaseMessage ): AnthropicMessage | undefined { - const type = base.getType(); + const type = base._getType(); switch (type) { case "human": return baseRoleToAnthropicMessage(base, "user"); From 1888b32af0bbcb1616c5188466073a33ec038e58 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:44:52 -0800 Subject: [PATCH 21/27] release(google-common): 0.1.3 (#7265) --- libs/langchain-google-common/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-google-common/package.json b/libs/langchain-google-common/package.json index 1564f1cb9db7c..fe3fa5a001ec0 100644 --- a/libs/langchain-google-common/package.json +++ b/libs/langchain-google-common/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/google-common", - "version": "0.1.2", + "version": "0.1.3", "description": "Core types and classes for Google services.", "type": "module", "engines": { From 0a6bc45297abb3705582176c31b8815ae263a43c Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:46:42 -0800 Subject: [PATCH 22/27] fix(google-gauth/webauth): Bump Google common dep (#7266) --- libs/langchain-google-gauth/package.json | 2 +- libs/langchain-google-webauth/package.json | 2 +- yarn.lock | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain-google-gauth/package.json b/libs/langchain-google-gauth/package.json index f1accfaf6a78c..64e12ba56a379 100644 --- a/libs/langchain-google-gauth/package.json +++ b/libs/langchain-google-gauth/package.json @@ -35,7 +35,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/google-common": "~0.1.2", + "@langchain/google-common": "~0.1.3", "google-auth-library": "^8.9.0" }, "peerDependencies": { diff --git a/libs/langchain-google-webauth/package.json b/libs/langchain-google-webauth/package.json index 014ae081f3026..bf9563740e420 100644 --- a/libs/langchain-google-webauth/package.json +++ b/libs/langchain-google-webauth/package.json @@ -32,7 +32,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/google-common": "~0.1.2", + "@langchain/google-common": "~0.1.3", "web-auth-library": "^1.0.3" }, "peerDependencies": { diff --git a/yarn.lock b/yarn.lock index 637c23e532c4d..cb7d5cc268137 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12314,7 +12314,7 @@ __metadata: languageName: unknown linkType: soft -"@langchain/google-common@^0.1.0, @langchain/google-common@workspace:*, @langchain/google-common@workspace:libs/langchain-google-common, @langchain/google-common@~0.1.2": +"@langchain/google-common@^0.1.0, @langchain/google-common@workspace:*, @langchain/google-common@workspace:libs/langchain-google-common, @langchain/google-common@~0.1.3": version: 0.0.0-use.local resolution: "@langchain/google-common@workspace:libs/langchain-google-common" dependencies: @@ -12355,7 +12355,7 @@ __metadata: dependencies: "@jest/globals": ^29.5.0 "@langchain/core": "workspace:*" - "@langchain/google-common": ~0.1.2 + "@langchain/google-common": ~0.1.3 "@langchain/scripts": ">=0.1.0 <0.2.0" "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 @@ -12499,7 +12499,7 @@ __metadata: dependencies: "@jest/globals": ^29.5.0 "@langchain/core": "workspace:*" - "@langchain/google-common": ~0.1.2 + "@langchain/google-common": ~0.1.3 "@langchain/scripts": ">=0.1.0 <0.2.0" "@swc/core": ^1.3.90 "@swc/jest": ^0.2.29 From 0798ca0f28b508b14b1136a21e22c32dd8b2cfb3 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:47:50 -0800 Subject: [PATCH 23/27] release(google-webauth): 0.1.3 (#7267) --- libs/langchain-google-webauth/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-google-webauth/package.json b/libs/langchain-google-webauth/package.json index bf9563740e420..953b925c8a498 100644 --- a/libs/langchain-google-webauth/package.json +++ b/libs/langchain-google-webauth/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/google-webauth", - "version": "0.1.2", + "version": "0.1.3", "description": "Web-based authentication support for Google services", "type": "module", "engines": { From 10e4cca5aa990cf7171e3c3accb0c7336328e7f1 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:50:37 -0800 Subject: [PATCH 24/27] fix(google-vertexai/web): Bump Google auth deps (#7269) --- .../package.json | 2 +- libs/langchain-google-vertexai/package.json | 2 +- yarn.lock | 20 +++++++++++++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/libs/langchain-google-vertexai-web/package.json b/libs/langchain-google-vertexai-web/package.json index 363b3032a0a6b..85e2efc5fb83d 100644 --- a/libs/langchain-google-vertexai-web/package.json +++ b/libs/langchain-google-vertexai-web/package.json @@ -32,7 +32,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/google-webauth": "~0.1.2" + "@langchain/google-webauth": "~0.1.3" }, "peerDependencies": { "@langchain/core": ">=0.2.21 <0.4.0" diff --git a/libs/langchain-google-vertexai/package.json b/libs/langchain-google-vertexai/package.json index 9379589af4427..21b28c2c4fedb 100644 --- a/libs/langchain-google-vertexai/package.json +++ b/libs/langchain-google-vertexai/package.json @@ -32,7 +32,7 @@ "author": "LangChain", "license": "MIT", "dependencies": { - "@langchain/google-gauth": "~0.1.2" + "@langchain/google-gauth": "~0.1.3" }, "peerDependencies": { "@langchain/core": ">=0.2.21 <0.4.0" diff --git a/yarn.lock b/yarn.lock index cb7d5cc268137..3c4dd13605590 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12349,7 +12349,19 @@ __metadata: languageName: unknown linkType: soft -"@langchain/google-gauth@workspace:libs/langchain-google-gauth, @langchain/google-gauth@~0.1.2": +"@langchain/google-gauth@npm:~0.1.3": + version: 0.1.3 + resolution: "@langchain/google-gauth@npm:0.1.3" + dependencies: + "@langchain/google-common": ~0.1.3 + google-auth-library: ^8.9.0 + peerDependencies: + "@langchain/core": ">=0.2.21 <0.4.0" + checksum: ac83e180af492068de82284a396842eb9bb1e5eaa428b5270a192499da737bf192ad48a1c90eeca462e31238e37bd5698c8d071eb1d780a4f4c759270f8ab706 + languageName: node + linkType: hard + +"@langchain/google-gauth@workspace:libs/langchain-google-gauth": version: 0.0.0-use.local resolution: "@langchain/google-gauth@workspace:libs/langchain-google-gauth" dependencies: @@ -12428,7 +12440,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/core": "workspace:*" "@langchain/google-common": ^0.1.0 - "@langchain/google-webauth": ~0.1.2 + "@langchain/google-webauth": ~0.1.3 "@langchain/scripts": ">=0.1.0 <0.2.0" "@langchain/standard-tests": 0.0.0 "@swc/core": ^1.3.90 @@ -12464,7 +12476,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/core": "workspace:*" "@langchain/google-common": ^0.1.0 - "@langchain/google-gauth": ~0.1.2 + "@langchain/google-gauth": ~0.1.3 "@langchain/scripts": ">=0.1.0 <0.2.0" "@langchain/standard-tests": 0.0.0 "@swc/core": ^1.3.90 @@ -12493,7 +12505,7 @@ __metadata: languageName: unknown linkType: soft -"@langchain/google-webauth@workspace:libs/langchain-google-webauth, @langchain/google-webauth@~0.1.2": +"@langchain/google-webauth@workspace:libs/langchain-google-webauth, @langchain/google-webauth@~0.1.3": version: 0.0.0-use.local resolution: "@langchain/google-webauth@workspace:libs/langchain-google-webauth" dependencies: From 70fe615ba4cc1daabccb4140e8a7948bab456776 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:53:14 -0800 Subject: [PATCH 25/27] release(google-vertexai): 0.1.3 (#7270) --- libs/langchain-google-vertexai/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-google-vertexai/package.json b/libs/langchain-google-vertexai/package.json index 21b28c2c4fedb..58c6b3a1503f0 100644 --- a/libs/langchain-google-vertexai/package.json +++ b/libs/langchain-google-vertexai/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/google-vertexai", - "version": "0.1.2", + "version": "0.1.3", "description": "LangChain.js support for Google Vertex AI", "type": "module", "engines": { From 7da6e8f30625e51374774ccca0ee7117f019f0a0 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:55:40 -0800 Subject: [PATCH 26/27] release(google-gauth): 0.1.3 (#7271) --- libs/langchain-google-gauth/package.json | 2 +- yarn.lock | 14 +------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/libs/langchain-google-gauth/package.json b/libs/langchain-google-gauth/package.json index 64e12ba56a379..f40f0a0be0293 100644 --- a/libs/langchain-google-gauth/package.json +++ b/libs/langchain-google-gauth/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/google-gauth", - "version": "0.1.2", + "version": "0.1.3", "description": "Google auth based authentication support for Google services", "type": "module", "engines": { diff --git a/yarn.lock b/yarn.lock index 3c4dd13605590..5e6b6ed60a57e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -12349,19 +12349,7 @@ __metadata: languageName: unknown linkType: soft -"@langchain/google-gauth@npm:~0.1.3": - version: 0.1.3 - resolution: "@langchain/google-gauth@npm:0.1.3" - dependencies: - "@langchain/google-common": ~0.1.3 - google-auth-library: ^8.9.0 - peerDependencies: - "@langchain/core": ">=0.2.21 <0.4.0" - checksum: ac83e180af492068de82284a396842eb9bb1e5eaa428b5270a192499da737bf192ad48a1c90eeca462e31238e37bd5698c8d071eb1d780a4f4c759270f8ab706 - languageName: node - linkType: hard - -"@langchain/google-gauth@workspace:libs/langchain-google-gauth": +"@langchain/google-gauth@workspace:libs/langchain-google-gauth, @langchain/google-gauth@~0.1.3": version: 0.0.0-use.local resolution: "@langchain/google-gauth@workspace:libs/langchain-google-gauth" dependencies: From 983acd22b5899d7b339e7eb06e6679bdd05d4a81 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 25 Nov 2024 16:57:55 -0800 Subject: [PATCH 27/27] release(google-vertexai-web): 0.1.3 (#7272) --- libs/langchain-google-vertexai-web/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain-google-vertexai-web/package.json b/libs/langchain-google-vertexai-web/package.json index 85e2efc5fb83d..737438881e944 100644 --- a/libs/langchain-google-vertexai-web/package.json +++ b/libs/langchain-google-vertexai-web/package.json @@ -1,6 +1,6 @@ { "name": "@langchain/google-vertexai-web", - "version": "0.1.2", + "version": "0.1.3", "description": "LangChain.js support for Google Vertex AI Web", "type": "module", "engines": {