diff --git a/docs/core_docs/docs/ecosystem/langserve.mdx b/docs/core_docs/docs/ecosystem/langserve.mdx
index a2cca1db87db..bdda4894cf5e 100644
--- a/docs/core_docs/docs/ecosystem/langserve.mdx
+++ b/docs/core_docs/docs/ecosystem/langserve.mdx
@@ -41,9 +41,17 @@ import Example from "@examples/ecosystem/langsmith.ts";
{Example}
-[`streamLog`](/docs/expression_language/interface) is a lower level method for streaming chain intermediate steps as partial JSONPatch chunks.
+[`streamEvents`](/docs/expression_language/interface) allows you to stream chain intermediate steps as events such as `on_llm_start`, and `on_chain_stream`.
+See the [table here](/docs/expression_language/interface#stream-events) for a full list of events you can handle.
This method allows for a few extra options as well to only include or exclude certain named steps:
+import StreamEventsExample from "@examples/ecosystem/langsmith_stream_events.ts";
+
+{StreamEventsExample}
+
+[`streamLog`](/docs/expression_language/interface) is a lower level method for streaming chain intermediate steps as partial JSONPatch chunks.
+Like `streamEvents`, this method also allows for a few extra options as well to only include or exclude certain named steps:
+
import StreamLogExample from "@examples/ecosystem/langsmith_stream_log.ts";
{StreamLogExample}
diff --git a/docs/core_docs/docs/integrations/chat/premai.mdx b/docs/core_docs/docs/integrations/chat/premai.mdx
new file mode 100644
index 000000000000..eb7324f48b2f
--- /dev/null
+++ b/docs/core_docs/docs/integrations/chat/premai.mdx
@@ -0,0 +1,30 @@
+---
+sidebar_label: PremAI
+---
+
+import CodeBlock from "@theme/CodeBlock";
+
+# ChatPrem
+
+## Setup
+
+1. Create a Prem AI account and get your API key [here](https://app.premai.io/accounts/signup/).
+2. Export or set your API key inline. The ChatPrem class defaults to `process.env.PREM_API_KEY`.
+
+```bash
+export PREM_API_KEY=your-api-key
+```
+
+You can use models provided by Prem AI as follows:
+
+import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";
+
+
+
+```bash npm2yarn
+npm install @langchain/community
+```
+
+import PremAI from "@examples/models/chat/integration_premai.ts";
+
+{PremAI}
diff --git a/docs/core_docs/docs/integrations/text_embedding/premai.mdx b/docs/core_docs/docs/integrations/text_embedding/premai.mdx
new file mode 100644
index 000000000000..57bb2a3fbaea
--- /dev/null
+++ b/docs/core_docs/docs/integrations/text_embedding/premai.mdx
@@ -0,0 +1,28 @@
+---
+sidebar_label: Prem AI
+---
+
+# Prem AI
+
+The `PremEmbeddings` class uses the Prem AI API to generate embeddings for a given text.
+
+## Setup
+
+In order to use the Prem API you'll need an API key. You can sign up for a Prem account and create an API key [here](https://app.premai.io/accounts/signup/).
+
+You'll first need to install the [`@langchain/community`](https://www.npmjs.com/package/@langchain/community) package:
+
+import IntegrationInstallTooltip from "@mdx_components/integration_install_tooltip.mdx";
+
+
+
+```bash npm2yarn
+npm install @langchain/community
+```
+
+## Usage
+
+import CodeBlock from "@theme/CodeBlock";
+import PremExample from "@examples/embeddings/premai.ts";
+
+{PremExample}
diff --git a/docs/core_docs/docs/langgraph.mdx b/docs/core_docs/docs/langgraph.mdx
index e173a4ebda75..47dd87a930df 100644
--- a/docs/core_docs/docs/langgraph.mdx
+++ b/docs/core_docs/docs/langgraph.mdx
@@ -396,7 +396,10 @@ Let's define the nodes, as well as a function to decide how what conditional edg
```typescript
import { FunctionMessage } from "@langchain/core/messages";
import { AgentAction } from "@langchain/core/agents";
-import type { RunnableConfig } from "@langchain/core/runnables";
+import {
+ ChatPromptTemplate,
+ MessagesPlaceholder,
+} from "@langchain/core/prompts";
// Define the function that determines whether to continue or not
const shouldContinue = (state: { messages: Array }) => {
@@ -428,7 +431,7 @@ const _getAction = (state: { messages: Array }): AgentAction => {
// We construct an AgentAction from the function_call
return {
tool: lastMessage.additional_kwargs.function_call.name,
- toolInput: JSON.stringify(
+ toolInput: JSON.parse(
lastMessage.additional_kwargs.function_call.arguments
),
log: "",
@@ -436,25 +439,25 @@ const _getAction = (state: { messages: Array }): AgentAction => {
};
// Define the function that calls the model
-const callModel = async (
- state: { messages: Array },
- config?: RunnableConfig
-) => {
+const callModel = async (state: { messages: Array }) => {
const { messages } = state;
- const response = await newModel.invoke(messages, config);
+ // You can use a prompt here to tweak model behavior.
+ // You can also just pass messages to the model directly.
+ const prompt = ChatPromptTemplate.fromMessages([
+ ["system", "You are a helpful assistant."],
+ new MessagesPlaceholder("messages"),
+ ]);
+ const response = await prompt.pipe(newModel).invoke({ messages });
// We return a list, because this will get added to the existing list
return {
messages: [response],
};
};
-const callTool = async (
- state: { messages: Array },
- config?: RunnableConfig
-) => {
+const callTool = async (state: { messages: Array }) => {
const action = _getAction(state);
// We call the tool_executor and get back a response
- const response = await toolExecutor.invoke(action, config);
+ const response = await toolExecutor.invoke(action);
// We use the response to create a FunctionMessage
const functionMessage = new FunctionMessage({
content: response,
@@ -532,7 +535,7 @@ const inputs = {
const result = await app.invoke(inputs);
```
-See a LangSmith trace of this run [here](https://smith.langchain.com/public/2562d46e-da94-4c9d-9b14-3759a26aec9b/r).
+See a LangSmith trace of this run [here](https://smith.langchain.com/public/144af8a3-b496-43aa-ba9d-f0d5894196e2/r).
This may take a little bit - it's making a few calls behind the scenes.
In order to start seeing some intermediate results as they happen, we can use streaming - see below for more information on that.
@@ -555,7 +558,7 @@ for await (const output of await app.stream(inputs)) {
}
```
-See a LangSmith trace of this run [here](https://smith.langchain.com/public/9afacb13-b9dc-416e-abbe-6ed2a0811afe/r).
+See a LangSmith trace of this run [here](https://smith.langchain.com/public/968cd1bf-0db2-410f-a5b4-0e73066cf06e/r).
## Running Examples
diff --git a/examples/src/ecosystem/langsmith_stream_events.ts b/examples/src/ecosystem/langsmith_stream_events.ts
new file mode 100644
index 000000000000..a958eec5d85a
--- /dev/null
+++ b/examples/src/ecosystem/langsmith_stream_events.ts
@@ -0,0 +1,267 @@
+import { RemoteRunnable } from "@langchain/core/runnables/remote";
+
+const remoteChain = new RemoteRunnable({
+ url: "https://your_hostname.com/path",
+});
+
+const logStream = await remoteChain.streamEvents(
+ {
+ question: "What is a document loader?",
+ chat_history: [],
+ },
+ // LangChain runnable config properties
+ {
+ // Version is required for streamEvents since it's a beta API
+ version: "v1",
+ // Optional, chain specific config
+ configurable: {
+ llm: "openai_gpt_3_5_turbo",
+ },
+ metadata: {
+ conversation_id: "other_metadata",
+ },
+ },
+ // Optional additional streamLog properties for filtering outputs
+ {
+ // includeNames: [],
+ // includeTags: [],
+ // includeTypes: [],
+ // excludeNames: [],
+ // excludeTags: [],
+ // excludeTypes: [],
+ }
+);
+
+for await (const chunk of logStream) {
+ console.log(chunk);
+}
+
+/*
+ {
+ event: 'on_chain_start',
+ name: '/pirate-speak',
+ run_id: undefined,
+ tags: [],
+ metadata: {},
+ data: {
+ input: StringPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ value: null
+ }
+ }
+ }
+ {
+ event: 'on_prompt_start',
+ name: 'ChatPromptTemplate',
+ run_id: undefined,
+ tags: [ 'seq:step:1' ],
+ metadata: {},
+ data: {
+ input: StringPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ value: null
+ }
+ }
+ }
+ {
+ event: 'on_prompt_end',
+ name: 'ChatPromptTemplate',
+ run_id: undefined,
+ tags: [ 'seq:step:1' ],
+ metadata: {},
+ data: {
+ input: StringPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ value: null
+ },
+ output: ChatPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ messages: [Array]
+ }
+ }
+ }
+ {
+ event: 'on_chat_model_start',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ input: ChatPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ messages: [Array]
+ }
+ }
+ }
+ {
+ event: 'on_chat_model_stream',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: '',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chain_stream',
+ name: '/pirate-speak',
+ run_id: undefined,
+ tags: [],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: '',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chat_model_stream',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: 'Arr',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chain_stream',
+ name: '/pirate-speak',
+ run_id: undefined,
+ tags: [],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: 'Arr',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chat_model_stream',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: 'r',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chain_stream',
+ name: '/pirate-speak',
+ run_id: undefined,
+ tags: [],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: 'r',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ {
+ event: 'on_chat_model_stream',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ chunk: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: ' mate',
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+ ...
+ {
+ event: 'on_chat_model_end',
+ name: 'ChatOpenAI',
+ run_id: undefined,
+ tags: [ 'seq:step:2' ],
+ metadata: {},
+ data: {
+ input: ChatPromptValue {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ messages: [Array]
+ },
+ output: { generations: [Array], llm_output: null, run: null }
+ }
+ }
+ {
+ event: 'on_chain_end',
+ name: '/pirate-speak',
+ run_id: undefined,
+ tags: [],
+ metadata: {},
+ data: {
+ output: AIMessageChunk {
+ lc_serializable: true,
+ lc_kwargs: [Object],
+ lc_namespace: [Array],
+ content: "Arrr matey, why be ye holdin' back on me? Speak up, what be ye wantin' to know?",
+ name: undefined,
+ additional_kwargs: {},
+ response_metadata: {}
+ }
+ }
+ }
+*/
diff --git a/examples/src/ecosystem/langsmith_stream_log.ts b/examples/src/ecosystem/langsmith_stream_log.ts
index 9d3c08524372..486afdc558e0 100644
--- a/examples/src/ecosystem/langsmith_stream_log.ts
+++ b/examples/src/ecosystem/langsmith_stream_log.ts
@@ -9,7 +9,7 @@ const logStream = await remoteChain.streamLog(
{
question: "What is a document loader?",
},
- // LangChain runnable config properties
+ // LangChain runnable config properties, if supported by the chain
{
configurable: {
llm: "openai_gpt_3_5_turbo",
diff --git a/examples/src/embeddings/premai.ts b/examples/src/embeddings/premai.ts
new file mode 100644
index 000000000000..922a40a094f4
--- /dev/null
+++ b/examples/src/embeddings/premai.ts
@@ -0,0 +1,14 @@
+import { PremEmbeddings } from "@langchain/community/embeddings/premai";
+
+const embeddings = new PremEmbeddings({
+ // In Node.js defaults to process.env.PREM_API_KEY
+ apiKey: "YOUR-API-KEY",
+ // In Node.js defaults to process.env.PREM_PROJECT_ID
+ project_id: "YOUR-PROJECT_ID",
+ model: "@cf/baai/bge-small-en-v1.5", // The model to generate the embeddings
+});
+
+const res = await embeddings.embedQuery(
+ "What would be a good company name a company that makes colorful socks?"
+);
+console.log({ res });
diff --git a/examples/src/models/chat/integration_premai.ts b/examples/src/models/chat/integration_premai.ts
new file mode 100644
index 000000000000..fc57d7a3a6f6
--- /dev/null
+++ b/examples/src/models/chat/integration_premai.ts
@@ -0,0 +1,11 @@
+import { ChatPrem } from "@langchain/community/chat_models/premai";
+import { HumanMessage } from "@langchain/core/messages";
+
+const model = new ChatPrem({
+ // In Node.js defaults to process.env.PREM_API_KEY
+ apiKey: "YOUR-API-KEY",
+ // In Node.js defaults to process.env.PREM_PROJECT_ID
+ project_id: "YOUR-PROJECT_ID",
+});
+
+console.log(await model.invoke([new HumanMessage("Hello there!")]));
diff --git a/langchain-core/package.json b/langchain-core/package.json
index 8efd2384c968..5beec6355f1a 100644
--- a/langchain-core/package.json
+++ b/langchain-core/package.json
@@ -1,6 +1,6 @@
{
"name": "@langchain/core",
- "version": "0.1.49",
+ "version": "0.1.50",
"description": "Core LangChain.js abstractions and schemas",
"type": "module",
"engines": {
diff --git a/langchain-core/src/runnables/remote.ts b/langchain-core/src/runnables/remote.ts
index e5283e98d63e..c4ff3e324761 100644
--- a/langchain-core/src/runnables/remote.ts
+++ b/langchain-core/src/runnables/remote.ts
@@ -7,6 +7,7 @@ import {
LogStreamCallbackHandler,
RunLogPatch,
type LogStreamCallbackHandlerInput,
+ type StreamEvent,
} from "../tracers/log_stream.js";
import {
AIMessage,
@@ -24,12 +25,7 @@ import {
isBaseMessage,
} from "../messages/index.js";
import { GenerationChunk, ChatGenerationChunk, RUN_KEY } from "../outputs.js";
-import {
- getBytes,
- getLines,
- getMessages,
- convertEventStreamToIterableReadableDataStream,
-} from "../utils/event_source_parse.js";
+import { convertEventStreamToIterableReadableDataStream } from "../utils/event_source_parse.js";
import { IterableReadableStream } from "../utils/stream.js";
type RemoteRunnableOptions = {
@@ -305,6 +301,9 @@ export class RemoteRunnable<
config: removeCallbacks(config),
kwargs: kwargs ?? {},
});
+ if (!response.ok) {
+ throw new Error(`${response.status} Error: ${await response.text()}`);
+ }
return revive((await response.json()).output) as RunOutput;
}
@@ -345,6 +344,9 @@ export class RemoteRunnable<
.map((config) => ({ ...config, ...batchOptions })),
kwargs,
});
+ if (!response.ok) {
+ throw new Error(`${response.status} Error: ${await response.text()}`);
+ }
const body = await response.json();
if (!body.output) throw new Error("Invalid response from remote runnable");
@@ -416,23 +418,13 @@ export class RemoteRunnable<
"Could not begin remote stream. Please check the given URL and try again."
);
}
- const stream = new ReadableStream({
- async start(controller) {
- const enqueueLine = getMessages((msg) => {
- if (msg.data) controller.enqueue(deserialize(msg.data));
- });
- const onLine = (
- line: Uint8Array,
- fieldLength: number,
- flush?: boolean
- ) => {
- enqueueLine(line, fieldLength, flush);
- if (flush) controller.close();
- };
- await getBytes(body, getLines(onLine));
- },
- });
- return IterableReadableStream.fromReadableStream(stream);
+ const runnableStream = convertEventStreamToIterableReadableDataStream(body);
+ async function* wrapper(): AsyncGenerator {
+ for await (const chunk of runnableStream) {
+ yield deserialize(chunk);
+ }
+ }
+ return IterableReadableStream.fromAsyncGenerator(wrapper());
}
async *streamLog(
@@ -477,7 +469,10 @@ export class RemoteRunnable<
...camelCaseStreamOptions,
diff: false,
});
- const { body } = response;
+ const { body, ok } = response;
+ if (!ok) {
+ throw new Error(`${response.status} Error: ${await response.text()}`);
+ }
if (!body) {
throw new Error(
"Could not begin remote stream log. Please check the given URL and try again."
@@ -489,4 +484,60 @@ export class RemoteRunnable<
yield new RunLogPatch({ ops: chunk.ops });
}
}
+
+ async *streamEvents(
+ input: RunInput,
+ options: Partial & { version: "v1" },
+ streamOptions?: Omit
+ ): AsyncGenerator {
+ if (options?.version !== "v1") {
+ throw new Error(
+ `Only version "v1" of the events schema is currently supported.`
+ );
+ }
+ const [config, kwargs] =
+ this._separateRunnableConfigFromCallOptions(options);
+ // The type is in camelCase but the API only accepts snake_case.
+ const camelCaseStreamOptions = {
+ include_names: streamOptions?.includeNames,
+ include_types: streamOptions?.includeTypes,
+ include_tags: streamOptions?.includeTags,
+ exclude_names: streamOptions?.excludeNames,
+ exclude_types: streamOptions?.excludeTypes,
+ exclude_tags: streamOptions?.excludeTags,
+ };
+ const response = await this.post<{
+ input: RunInput;
+ config?: RunnableConfig;
+ kwargs?: Omit, keyof RunnableConfig>;
+ diff: false;
+ }>("/stream_events", {
+ input,
+ config: removeCallbacks(config),
+ kwargs,
+ ...camelCaseStreamOptions,
+ diff: false,
+ });
+ const { body, ok } = response;
+ if (!ok) {
+ throw new Error(`${response.status} Error: ${await response.text()}`);
+ }
+ if (!body) {
+ throw new Error(
+ "Could not begin remote stream events. Please check the given URL and try again."
+ );
+ }
+ const runnableStream = convertEventStreamToIterableReadableDataStream(body);
+ for await (const log of runnableStream) {
+ const chunk = revive(JSON.parse(log));
+ yield {
+ event: chunk.event,
+ name: chunk.name,
+ run_id: chunk.id,
+ tags: chunk.tags,
+ metadata: chunk.metadata,
+ data: chunk.data,
+ };
+ }
+ }
}
diff --git a/langchain-core/src/runnables/tests/runnable_remote.int.test.ts b/langchain-core/src/runnables/tests/runnable_remote.int.test.ts
index dfdcfd024b77..2ced65c604e2 100644
--- a/langchain-core/src/runnables/tests/runnable_remote.int.test.ts
+++ b/langchain-core/src/runnables/tests/runnable_remote.int.test.ts
@@ -2,6 +2,58 @@ import { HumanMessage } from "../../messages/index.js";
import { applyPatch } from "../../utils/json_patch.js";
import { RemoteRunnable } from "../remote.js";
+test("invoke hosted langserve", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://chat-langchain-backend.langchain.dev/chat`,
+ });
+ const result = await remote.invoke({
+ question: "What is a document loader?",
+ });
+ console.log(result);
+});
+
+test("invoke hosted langserve error handling", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://chat-langchain-backend.langchain.dev/nonexistent`,
+ });
+ await expect(async () => {
+ await remote.invoke({
+ question: "What is a document loader?",
+ });
+ }).rejects.toThrowError();
+});
+
+test("stream hosted langserve", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://chat-langchain-backend.langchain.dev/chat`,
+ });
+ const result = await remote.stream({
+ question: "What is a document loader?",
+ });
+ let totalByteSize = 0;
+ for await (const chunk of result) {
+ console.log(chunk);
+ const jsonString = JSON.stringify(chunk);
+ const byteSize = Buffer.byteLength(jsonString, "utf-8");
+ totalByteSize += byteSize;
+ }
+ console.log("totalByteSize", totalByteSize);
+});
+
+test("stream error handling hosted langserve", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://chat-langchain-backend.langchain.dev/nonexistent`,
+ });
+ await expect(async () => {
+ const result = await remote.stream({
+ question: "What is a document loader?",
+ });
+ for await (const chunk of result) {
+ console.log(chunk);
+ }
+ }).rejects.toThrowError();
+});
+
test("streamLog hosted langserve", async () => {
const remote = new RemoteRunnable({
url: `https://chat-langchain-backend.langchain.dev/chat`,
@@ -22,6 +74,21 @@ test("streamLog hosted langserve", async () => {
console.log("totalByteSize", totalByteSize);
});
+test("streamLog error handling hosted langserve", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://chat-langchain-backend.langchain.dev/nonexistent`,
+ });
+ const result = await remote.streamLog({
+ question: "What is a document loader?",
+ });
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ await expect(async () => {
+ for await (const chunk of result) {
+ console.log(chunk);
+ }
+ }).rejects.toThrowError();
+});
+
test("streamLog hosted langserve with concat syntax", async () => {
const remote = new RemoteRunnable({
url: `https://chat-langchain-backend.langchain.dev/chat`,
@@ -46,6 +113,30 @@ test("streamLog hosted langserve with concat syntax", async () => {
console.log("totalByteSize", totalByteSize);
});
+test.skip("stream events hosted langserve with concat syntax", async () => {
+ const remote = new RemoteRunnable({
+ url: `https://privateurl.com/pirate-speak/`,
+ });
+ const result = await remote.streamEvents(
+ {
+ input: "What is a document loader?",
+ chat_history: [new HumanMessage("What is a document loader?")],
+ },
+ { version: "v1" }
+ );
+ let totalByteSize = 0;
+ const state = [];
+ for await (const chunk of result) {
+ console.log(chunk);
+ state.push(chunk);
+ const jsonString = JSON.stringify(chunk);
+ const byteSize = Buffer.byteLength(jsonString, "utf-8");
+ totalByteSize += byteSize;
+ }
+ // console.log("final state", state);
+ console.log("totalByteSize", totalByteSize);
+});
+
test.skip("streamLog with raw messages", async () => {
const chain = new RemoteRunnable({
url: "https://aimor-deployment-bf1e4ebc87365334b3b8a6b175fb4151-ffoprvkqsa-uc.a.run.app/",
diff --git a/langchain-core/src/utils/event_source_parse.ts b/langchain-core/src/utils/event_source_parse.ts
index 33a5026a7a0d..953e23fb053e 100644
--- a/langchain-core/src/utils/event_source_parse.ts
+++ b/langchain-core/src/utils/event_source_parse.ts
@@ -230,12 +230,19 @@ function newMessage(): EventSourceMessage {
}
export function convertEventStreamToIterableReadableDataStream(
- stream: ReadableStream
+ stream: ReadableStream,
+ onMetadataEvent?: (e: unknown) => unknown
) {
const dataStream = new ReadableStream({
async start(controller) {
const enqueueLine = getMessages((msg) => {
- if (msg.data) controller.enqueue(msg.data);
+ if (msg.event === "error") {
+ throw new Error(msg.data ?? "Unspecified event streaming error.");
+ } else if (msg.event === "metadata") {
+ onMetadataEvent?.(msg);
+ } else {
+ if (msg.data) controller.enqueue(msg.data);
+ }
});
const onLine = (
line: Uint8Array,
diff --git a/langchain/package.json b/langchain/package.json
index 15459cbe27a0..2940ea324a6e 100644
--- a/langchain/package.json
+++ b/langchain/package.json
@@ -1,6 +1,6 @@
{
"name": "langchain",
- "version": "0.1.28",
+ "version": "0.1.29",
"description": "Typescript bindings for langchain",
"type": "module",
"engines": {
diff --git a/langchain/src/agents/agent.ts b/langchain/src/agents/agent.ts
index 0d4e50e4f085..f040c69aa126 100644
--- a/langchain/src/agents/agent.ts
+++ b/langchain/src/agents/agent.ts
@@ -13,6 +13,8 @@ import {
Runnable,
patchConfig,
type RunnableConfig,
+ RunnableSequence,
+ RunnableLike,
} from "@langchain/core/runnables";
import { LLMChain } from "../chains/llm_chain.js";
import type {
@@ -167,6 +169,43 @@ export function isRunnableAgent(x: BaseAgent) {
);
}
+// TODO: Remove in the future. Only for backwards compatibility.
+// Allows for the creation of runnables with properties that will
+// be passed to the agent executor constructor.
+export class AgentRunnableSequence<
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ RunInput = any,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ RunOutput = any
+> extends RunnableSequence {
+ streamRunnable?: boolean;
+
+ singleAction: boolean;
+
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ static fromRunnables(
+ [first, ...runnables]: [
+ RunnableLike,
+ ...RunnableLike[],
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ RunnableLike
+ ],
+ config: { singleAction: boolean; streamRunnable?: boolean; name?: string }
+ ): AgentRunnableSequence> {
+ const sequence = RunnableSequence.from(
+ [first, ...runnables],
+ config.name
+ ) as AgentRunnableSequence>;
+ sequence.singleAction = config.singleAction;
+ sequence.streamRunnable = config.streamRunnable;
+ return sequence;
+ }
+
+ static isAgentRunnableSequence(x: Runnable): x is AgentRunnableSequence {
+ return typeof (x as AgentRunnableSequence).singleAction === "boolean";
+ }
+}
+
/**
* Class representing a single-action agent powered by runnables.
* Extends the BaseSingleActionAgent class and provides methods for
@@ -202,7 +241,8 @@ export class RunnableSingleActionAgent extends BaseSingleActionAgent {
constructor(fields: RunnableSingleActionAgentInput) {
super(fields);
this.runnable = fields.runnable;
- this.defaultRunName = fields.defaultRunName ?? this.defaultRunName;
+ this.defaultRunName =
+ fields.defaultRunName ?? this.runnable.name ?? this.defaultRunName;
this.streamRunnable = fields.streamRunnable ?? this.streamRunnable;
}
diff --git a/langchain/src/agents/executor.ts b/langchain/src/agents/executor.ts
index 226ea1689194..ef26d00f9a6a 100644
--- a/langchain/src/agents/executor.ts
+++ b/langchain/src/agents/executor.ts
@@ -21,9 +21,11 @@ import { Serializable } from "@langchain/core/load/serializable";
import { SerializedLLMChain } from "../chains/serde.js";
import { StoppingMethod } from "./types.js";
import {
+ AgentRunnableSequence,
BaseMultiActionAgent,
BaseSingleActionAgent,
RunnableMultiActionAgent,
+ RunnableSingleActionAgent,
isRunnableAgent,
} from "./agent.js";
import { BaseChain, ChainInputs } from "../chains/base.js";
@@ -394,7 +396,21 @@ export class AgentExecutor extends BaseChain {
let agent: BaseSingleActionAgent | BaseMultiActionAgent;
let returnOnlyOutputs = true;
if (Runnable.isRunnable(input.agent)) {
- agent = new RunnableMultiActionAgent({ runnable: input.agent });
+ if (AgentRunnableSequence.isAgentRunnableSequence(input.agent)) {
+ if (input.agent.singleAction) {
+ agent = new RunnableSingleActionAgent({
+ runnable: input.agent,
+ streamRunnable: input.agent.streamRunnable,
+ });
+ } else {
+ agent = new RunnableMultiActionAgent({
+ runnable: input.agent,
+ streamRunnable: input.agent.streamRunnable,
+ });
+ }
+ } else {
+ agent = new RunnableMultiActionAgent({ runnable: input.agent });
+ }
// TODO: Update BaseChain implementation on breaking change
returnOnlyOutputs = false;
} else {
diff --git a/langchain/src/agents/openai_functions/index.ts b/langchain/src/agents/openai_functions/index.ts
index 65d8bd6122b3..f637e313ce71 100644
--- a/langchain/src/agents/openai_functions/index.ts
+++ b/langchain/src/agents/openai_functions/index.ts
@@ -5,11 +5,7 @@ import type {
} from "@langchain/core/language_models/base";
import type { StructuredToolInterface } from "@langchain/core/tools";
import type { BaseChatModel } from "@langchain/core/language_models/chat_models";
-import {
- Runnable,
- RunnablePassthrough,
- RunnableSequence,
-} from "@langchain/core/runnables";
+import { Runnable, RunnablePassthrough } from "@langchain/core/runnables";
import { ChatOpenAI, ChatOpenAICallOptions } from "@langchain/openai";
import type {
AgentAction,
@@ -33,7 +29,7 @@ import {
BasePromptTemplate,
} from "@langchain/core/prompts";
import { CallbackManager } from "@langchain/core/callbacks/manager";
-import { Agent, AgentArgs, RunnableSingleActionAgent } from "../agent.js";
+import { Agent, AgentArgs, AgentRunnableSequence } from "../agent.js";
import { AgentInput } from "../types.js";
import { PREFIX } from "./prompt.js";
import { LLMChain } from "../../chains/llm_chain.js";
@@ -351,18 +347,21 @@ export async function createOpenAIFunctionsAgent({
const llmWithTools = llm.bind({
functions: tools.map(convertToOpenAIFunction),
});
- const agent = RunnableSequence.from([
- RunnablePassthrough.assign({
- agent_scratchpad: (input: { steps: AgentStep[] }) =>
- formatToOpenAIFunctionMessages(input.steps),
- }),
- prompt,
- llmWithTools,
- new OpenAIFunctionsAgentOutputParser(),
- ]);
- return new RunnableSingleActionAgent({
- runnable: agent,
- defaultRunName: "OpenAIFunctionsAgent",
- streamRunnable,
- });
+ const agent = AgentRunnableSequence.fromRunnables(
+ [
+ RunnablePassthrough.assign({
+ agent_scratchpad: (input: { steps: AgentStep[] }) =>
+ formatToOpenAIFunctionMessages(input.steps),
+ }),
+ prompt,
+ llmWithTools,
+ new OpenAIFunctionsAgentOutputParser(),
+ ],
+ {
+ name: "OpenAIFunctionsAgent",
+ streamRunnable,
+ singleAction: true,
+ }
+ );
+ return agent;
}
diff --git a/langchain/src/agents/openai_tools/index.ts b/langchain/src/agents/openai_tools/index.ts
index cad545a27846..88ad1ca0391f 100644
--- a/langchain/src/agents/openai_tools/index.ts
+++ b/langchain/src/agents/openai_tools/index.ts
@@ -4,10 +4,7 @@ import type {
BaseChatModelCallOptions,
} from "@langchain/core/language_models/chat_models";
import { ChatPromptTemplate } from "@langchain/core/prompts";
-import {
- RunnablePassthrough,
- RunnableSequence,
-} from "@langchain/core/runnables";
+import { RunnablePassthrough } from "@langchain/core/runnables";
import { OpenAIClient } from "@langchain/openai";
import { convertToOpenAITool } from "@langchain/core/utils/function_calling";
import { formatToOpenAIToolMessages } from "../format_scratchpad/openai_tools.js";
@@ -15,7 +12,7 @@ import {
OpenAIToolsAgentOutputParser,
type ToolsAgentStep,
} from "../openai/output_parser.js";
-import { RunnableMultiActionAgent } from "../agent.js";
+import { AgentRunnableSequence } from "../agent.js";
export { OpenAIToolsAgentOutputParser, type ToolsAgentStep };
@@ -116,18 +113,21 @@ export async function createOpenAIToolsAgent({
);
}
const modelWithTools = llm.bind({ tools: tools.map(convertToOpenAITool) });
- const agent = RunnableSequence.from([
- RunnablePassthrough.assign({
- agent_scratchpad: (input: { steps: ToolsAgentStep[] }) =>
- formatToOpenAIToolMessages(input.steps),
- }),
- prompt,
- modelWithTools,
- new OpenAIToolsAgentOutputParser(),
- ]);
- return new RunnableMultiActionAgent({
- runnable: agent,
- defaultRunName: "OpenAIToolsAgent",
- streamRunnable,
- });
+ const agent = AgentRunnableSequence.fromRunnables(
+ [
+ RunnablePassthrough.assign({
+ agent_scratchpad: (input: { steps: ToolsAgentStep[] }) =>
+ formatToOpenAIToolMessages(input.steps),
+ }),
+ prompt,
+ modelWithTools,
+ new OpenAIToolsAgentOutputParser(),
+ ],
+ {
+ name: "OpenAIToolsAgent",
+ streamRunnable,
+ singleAction: false,
+ }
+ );
+ return agent;
}
diff --git a/langchain/src/agents/react/index.ts b/langchain/src/agents/react/index.ts
index 815a81919773..8e13066a3c54 100644
--- a/langchain/src/agents/react/index.ts
+++ b/langchain/src/agents/react/index.ts
@@ -4,15 +4,12 @@ import type {
BaseLanguageModel,
BaseLanguageModelInterface,
} from "@langchain/core/language_models/base";
-import {
- RunnablePassthrough,
- RunnableSequence,
-} from "@langchain/core/runnables";
+import { RunnablePassthrough } from "@langchain/core/runnables";
import { AgentStep } from "@langchain/core/agents";
import { renderTextDescription } from "../../tools/render.js";
import { formatLogToString } from "../format_scratchpad/log.js";
import { ReActSingleInputOutputParser } from "./output_parser.js";
-import { RunnableSingleActionAgent } from "../agent.js";
+import { AgentRunnableSequence } from "../agent.js";
/**
* Params used by the createXmlAgent function.
@@ -102,20 +99,23 @@ export async function createReactAgent({
const llmWithStop = (llm as BaseLanguageModel).bind({
stop: ["\nObservation:"],
});
- const agent = RunnableSequence.from([
- RunnablePassthrough.assign({
- agent_scratchpad: (input: { steps: AgentStep[] }) =>
- formatLogToString(input.steps),
- }),
- partialedPrompt,
- llmWithStop,
- new ReActSingleInputOutputParser({
- toolNames,
- }),
- ]);
- return new RunnableSingleActionAgent({
- runnable: agent,
- defaultRunName: "ReactAgent",
- streamRunnable,
- });
+ const agent = AgentRunnableSequence.fromRunnables(
+ [
+ RunnablePassthrough.assign({
+ agent_scratchpad: (input: { steps: AgentStep[] }) =>
+ formatLogToString(input.steps),
+ }),
+ partialedPrompt,
+ llmWithStop,
+ new ReActSingleInputOutputParser({
+ toolNames,
+ }),
+ ],
+ {
+ name: "ReactAgent",
+ streamRunnable,
+ singleAction: true,
+ }
+ );
+ return agent;
}
diff --git a/langchain/src/agents/structured_chat/index.ts b/langchain/src/agents/structured_chat/index.ts
index 4c63324417ce..fc284a9ccee2 100644
--- a/langchain/src/agents/structured_chat/index.ts
+++ b/langchain/src/agents/structured_chat/index.ts
@@ -4,10 +4,7 @@ import type {
BaseLanguageModel,
BaseLanguageModelInterface,
} from "@langchain/core/language_models/base";
-import {
- RunnablePassthrough,
- RunnableSequence,
-} from "@langchain/core/runnables";
+import { RunnablePassthrough } from "@langchain/core/runnables";
import type { BasePromptTemplate } from "@langchain/core/prompts";
import {
BaseMessagePromptTemplate,
@@ -22,8 +19,8 @@ import { Optional } from "../../types/type-utils.js";
import {
Agent,
AgentArgs,
+ AgentRunnableSequence,
OutputParserArgs,
- RunnableSingleActionAgent,
} from "../agent.js";
import { AgentInput } from "../types.js";
import { StructuredChatOutputParserWithRetries } from "./outputParser.js";
@@ -336,20 +333,23 @@ export async function createStructuredChatAgent({
const llmWithStop = (llm as BaseLanguageModel).bind({
stop: ["Observation"],
});
- const agent = RunnableSequence.from([
- RunnablePassthrough.assign({
- agent_scratchpad: (input: { steps: AgentStep[] }) =>
- formatLogToString(input.steps),
- }),
- partialedPrompt,
- llmWithStop,
- StructuredChatOutputParserWithRetries.fromLLM(llm, {
- toolNames,
- }),
- ]);
- return new RunnableSingleActionAgent({
- runnable: agent,
- defaultRunName: "StructuredChatAgent",
- streamRunnable,
- });
+ const agent = AgentRunnableSequence.fromRunnables(
+ [
+ RunnablePassthrough.assign({
+ agent_scratchpad: (input: { steps: AgentStep[] }) =>
+ formatLogToString(input.steps),
+ }),
+ partialedPrompt,
+ llmWithStop,
+ StructuredChatOutputParserWithRetries.fromLLM(llm, {
+ toolNames,
+ }),
+ ],
+ {
+ name: "StructuredChatAgent",
+ streamRunnable,
+ singleAction: true,
+ }
+ );
+ return agent;
}
diff --git a/langchain/src/agents/tests/create_xml_agent.int.test.ts b/langchain/src/agents/tests/create_xml_agent.int.test.ts
index 7a62e5a04cab..8f724440a8a7 100644
--- a/langchain/src/agents/tests/create_xml_agent.int.test.ts
+++ b/langchain/src/agents/tests/create_xml_agent.int.test.ts
@@ -10,7 +10,7 @@ const tools = [new TavilySearchResults({ maxResults: 1 })];
test("createXmlAgent works", async () => {
const prompt = await pull("hwchase17/xml-agent-convo");
const llm = new ChatAnthropic({
- modelName: "claude-3-opus-20240229",
+ modelName: "claude-2",
temperature: 0,
});
const agent = await createXmlAgent({
diff --git a/langchain/src/agents/xml/index.ts b/langchain/src/agents/xml/index.ts
index 9f209822a3e9..27a39fa7d4c9 100644
--- a/langchain/src/agents/xml/index.ts
+++ b/langchain/src/agents/xml/index.ts
@@ -3,10 +3,7 @@ import type {
BaseLanguageModelInterface,
} from "@langchain/core/language_models/base";
import type { ToolInterface } from "@langchain/core/tools";
-import {
- RunnablePassthrough,
- RunnableSequence,
-} from "@langchain/core/runnables";
+import { RunnablePassthrough } from "@langchain/core/runnables";
import type { BasePromptTemplate } from "@langchain/core/prompts";
import { AgentStep, AgentAction, AgentFinish } from "@langchain/core/agents";
import { ChainValues } from "@langchain/core/utils/types";
@@ -19,8 +16,8 @@ import { CallbackManager } from "@langchain/core/callbacks/manager";
import { LLMChain } from "../../chains/llm_chain.js";
import {
AgentArgs,
+ AgentRunnableSequence,
BaseSingleActionAgent,
- RunnableSingleActionAgent,
} from "../agent.js";
import { AGENT_INSTRUCTIONS } from "./prompt.js";
import { XMLAgentOutputParser } from "./output_parser.js";
@@ -223,18 +220,21 @@ export async function createXmlAgent({
const llmWithStop = (llm as BaseLanguageModel).bind({
stop: ["", ""],
});
- const agent = RunnableSequence.from([
- RunnablePassthrough.assign({
- agent_scratchpad: (input: { steps: AgentStep[] }) =>
- formatXml(input.steps),
- }),
- partialedPrompt,
- llmWithStop,
- new XMLAgentOutputParser(),
- ]);
- return new RunnableSingleActionAgent({
- runnable: agent,
- defaultRunName: "XMLAgent",
- streamRunnable,
- });
+ const agent = AgentRunnableSequence.fromRunnables(
+ [
+ RunnablePassthrough.assign({
+ agent_scratchpad: (input: { steps: AgentStep[] }) =>
+ formatXml(input.steps),
+ }),
+ partialedPrompt,
+ llmWithStop,
+ new XMLAgentOutputParser(),
+ ],
+ {
+ name: "XMLAgent",
+ streamRunnable,
+ singleAction: true,
+ }
+ );
+ return agent;
}
diff --git a/langchain/src/tests/text_splitter.test.ts b/langchain/src/tests/text_splitter.test.ts
index a15502f1368b..104efd090320 100644
--- a/langchain/src/tests/text_splitter.test.ts
+++ b/langchain/src/tests/text_splitter.test.ts
@@ -337,6 +337,19 @@ Bye!\n\n-H.`;
});
});
+test("Separator length is considered correctly for chunk size", async () => {
+ const text = "aa ab ac ba bb";
+ const splitter = new RecursiveCharacterTextSplitter({
+ keepSeparator: false,
+ chunkSize: 7,
+ chunkOverlap: 3,
+ });
+ const output = await splitter.splitText(text);
+ const expectedOutput = ["aa ab", "ab ac", "ac ba", "ba bb"];
+
+ expect(output).toEqual(expectedOutput);
+});
+
test("Token text splitter", async () => {
const text = "foo bar baz a a";
const splitter = new TokenTextSplitter({
diff --git a/langchain/src/text_splitter.ts b/langchain/src/text_splitter.ts
index 5b69e6716e12..095ea3e796ca 100644
--- a/langchain/src/text_splitter.ts
+++ b/langchain/src/text_splitter.ts
@@ -188,7 +188,7 @@ export abstract class TextSplitter
for (const d of splits) {
const _len = await this.lengthFunction(d);
if (
- total + _len + (currentDoc.length > 0 ? separator.length : 0) >
+ total + _len + currentDoc.length * separator.length >
this.chunkSize
) {
if (total > this.chunkSize) {
@@ -207,7 +207,9 @@ which is longer than the specified ${this.chunkSize}`
// - or if we still have any chunks and the length is long
while (
total > this.chunkOverlap ||
- (total + _len > this.chunkSize && total > 0)
+ (total + _len + currentDoc.length * separator.length >
+ this.chunkSize &&
+ total > 0)
) {
total -= await this.lengthFunction(currentDoc[0]);
currentDoc.shift();
diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore
index c335091b4018..a95b1058a7a4 100644
--- a/libs/langchain-community/.gitignore
+++ b/libs/langchain-community/.gitignore
@@ -162,6 +162,10 @@ embeddings/ollama.cjs
embeddings/ollama.js
embeddings/ollama.d.ts
embeddings/ollama.d.cts
+embeddings/premai.cjs
+embeddings/premai.js
+embeddings/premai.d.ts
+embeddings/premai.d.cts
embeddings/tensorflow.cjs
embeddings/tensorflow.js
embeddings/tensorflow.d.ts
@@ -498,6 +502,10 @@ chat_models/portkey.cjs
chat_models/portkey.js
chat_models/portkey.d.ts
chat_models/portkey.d.cts
+chat_models/premai.cjs
+chat_models/premai.js
+chat_models/premai.d.ts
+chat_models/premai.d.cts
chat_models/togetherai.cjs
chat_models/togetherai.js
chat_models/togetherai.d.ts
diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js
index 4601bc2dbaf3..e381b6626255 100644
--- a/libs/langchain-community/langchain.config.js
+++ b/libs/langchain-community/langchain.config.js
@@ -69,6 +69,7 @@ export const config = {
"embeddings/llama_cpp": "embeddings/llama_cpp",
"embeddings/minimax": "embeddings/minimax",
"embeddings/ollama": "embeddings/ollama",
+ "embeddings/premai": "embeddings/premai",
"embeddings/tensorflow": "embeddings/tensorflow",
"embeddings/togetherai": "embeddings/togetherai",
"embeddings/voyage": "embeddings/voyage",
@@ -156,6 +157,7 @@ export const config = {
"chat_models/minimax": "chat_models/minimax",
"chat_models/ollama": "chat_models/ollama",
"chat_models/portkey": "chat_models/portkey",
+ "chat_models/premai": "chat_models/premai",
"chat_models/togetherai": "chat_models/togetherai",
"chat_models/yandex": "chat_models/yandex",
"chat_models/zhipuai": "chat_models/zhipuai",
@@ -245,6 +247,7 @@ export const config = {
"embeddings/hf_transformers",
"embeddings/llama_cpp",
"embeddings/gradient_ai",
+ "embeddings/premai",
"embeddings/zhipuai",
"llms/load",
"llms/cohere",
@@ -309,6 +312,7 @@ export const config = {
"chat_models/googlepalm",
"chat_models/llama_cpp",
"chat_models/portkey",
+ "chat_models/premai",
"chat_models/iflytek_xinghuo",
"chat_models/iflytek_xinghuo/web",
"chat_models/zhipuai",
diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json
index 8f073f1114be..6e8a82217818 100644
--- a/libs/langchain-community/package.json
+++ b/libs/langchain-community/package.json
@@ -1,6 +1,6 @@
{
"name": "@langchain/community",
- "version": "0.0.41",
+ "version": "0.0.42",
"description": "Third-party integrations for LangChain.js",
"type": "module",
"engines": {
@@ -78,6 +78,7 @@
"@opensearch-project/opensearch": "^2.2.0",
"@pinecone-database/pinecone": "^1.1.0",
"@planetscale/database": "^1.8.0",
+ "@premai/prem-sdk": "^0.3.25",
"@qdrant/js-client-rest": "^1.2.0",
"@raycast/api": "^1.55.2",
"@rockset/client": "^0.9.1",
@@ -203,6 +204,7 @@
"@opensearch-project/opensearch": "*",
"@pinecone-database/pinecone": "*",
"@planetscale/database": "^1.8.0",
+ "@premai/prem-sdk": "^0.3.25",
"@qdrant/js-client-rest": "^1.2.0",
"@raycast/api": "^1.55.2",
"@rockset/client": "^0.9.1",
@@ -344,6 +346,9 @@
"@planetscale/database": {
"optional": true
},
+ "@premai/prem-sdk": {
+ "optional": true
+ },
"@qdrant/js-client-rest": {
"optional": true
},
@@ -910,6 +915,15 @@
"import": "./embeddings/ollama.js",
"require": "./embeddings/ollama.cjs"
},
+ "./embeddings/premai": {
+ "types": {
+ "import": "./embeddings/premai.d.ts",
+ "require": "./embeddings/premai.d.cts",
+ "default": "./embeddings/premai.d.ts"
+ },
+ "import": "./embeddings/premai.js",
+ "require": "./embeddings/premai.cjs"
+ },
"./embeddings/tensorflow": {
"types": {
"import": "./embeddings/tensorflow.d.ts",
@@ -1666,6 +1680,15 @@
"import": "./chat_models/portkey.js",
"require": "./chat_models/portkey.cjs"
},
+ "./chat_models/premai": {
+ "types": {
+ "import": "./chat_models/premai.d.ts",
+ "require": "./chat_models/premai.d.cts",
+ "default": "./chat_models/premai.d.ts"
+ },
+ "import": "./chat_models/premai.js",
+ "require": "./chat_models/premai.cjs"
+ },
"./chat_models/togetherai": {
"types": {
"import": "./chat_models/togetherai.d.ts",
@@ -2356,6 +2379,10 @@
"embeddings/ollama.js",
"embeddings/ollama.d.ts",
"embeddings/ollama.d.cts",
+ "embeddings/premai.cjs",
+ "embeddings/premai.js",
+ "embeddings/premai.d.ts",
+ "embeddings/premai.d.cts",
"embeddings/tensorflow.cjs",
"embeddings/tensorflow.js",
"embeddings/tensorflow.d.ts",
@@ -2692,6 +2719,10 @@
"chat_models/portkey.js",
"chat_models/portkey.d.ts",
"chat_models/portkey.d.cts",
+ "chat_models/premai.cjs",
+ "chat_models/premai.js",
+ "chat_models/premai.d.ts",
+ "chat_models/premai.d.cts",
"chat_models/togetherai.cjs",
"chat_models/togetherai.js",
"chat_models/togetherai.d.ts",
diff --git a/libs/langchain-community/src/chat_models/premai.ts b/libs/langchain-community/src/chat_models/premai.ts
new file mode 100644
index 000000000000..654b6bb4de9c
--- /dev/null
+++ b/libs/langchain-community/src/chat_models/premai.ts
@@ -0,0 +1,472 @@
+import {
+ AIMessage,
+ AIMessageChunk,
+ type BaseMessage,
+ ChatMessage,
+ ChatMessageChunk,
+ HumanMessageChunk,
+} from "@langchain/core/messages";
+import {
+ type BaseLanguageModelCallOptions,
+ TokenUsage,
+} from "@langchain/core/language_models/base";
+
+import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
+import {
+ type BaseChatModelParams,
+ BaseChatModel,
+} from "@langchain/core/language_models/chat_models";
+
+import Prem, {
+ ChatCompletionStreamingCompletionData,
+ CreateChatCompletionRequest,
+ CreateChatCompletionResponse,
+} from "@premai/prem-sdk";
+import { getEnvironmentVariable } from "@langchain/core/utils/env";
+import {
+ ChatGeneration,
+ ChatGenerationChunk,
+ ChatResult,
+} from "@langchain/core/outputs";
+
+import { NewTokenIndices } from "@langchain/core/callbacks/base";
+
+export type RoleEnum = "user" | "assistant";
+
+/**
+ * Input to chat model class.
+ */
+export interface ChatPremInput extends BaseChatModelParams {
+ project_id?: number | string;
+ session_id?: string;
+ messages?: {
+ role: "user" | "assistant";
+ content: string;
+ [k: string]: unknown;
+ }[];
+ model?: string;
+ system_prompt?: string;
+ frequency_penalty?: number;
+ logit_bias?: { [k: string]: unknown };
+ max_tokens?: number;
+ n?: number;
+ presence_penalty?: number;
+ response_format?: { [k: string]: unknown };
+ seed?: number;
+ stop?: string;
+ temperature?: number;
+ top_p?: number;
+ tools?: { [k: string]: unknown }[];
+ user?: string;
+ /**
+ * The Prem API key to use for requests.
+ * @default process.env.PREM_API_KEY
+ */
+ apiKey?: string;
+ streaming?: boolean;
+}
+
+export interface ChatCompletionCreateParamsNonStreaming
+ extends CreateChatCompletionRequest {
+ stream?: false;
+}
+
+export interface ChatCompletionCreateParamsStreaming
+ extends CreateChatCompletionRequest {
+ stream: true;
+}
+
+export type ChatCompletionCreateParams =
+ | ChatCompletionCreateParamsNonStreaming
+ | ChatCompletionCreateParamsStreaming;
+
+function extractGenericMessageCustomRole(message: ChatMessage) {
+ if (message.role !== "assistant" && message.role !== "user") {
+ console.warn(`Unknown message role: ${message.role}`);
+ }
+ return message.role as RoleEnum;
+}
+
+export function messageToPremRole(message: BaseMessage): RoleEnum {
+ const type = message._getType();
+ switch (type) {
+ case "ai":
+ return "assistant";
+ case "human":
+ return "user";
+ case "generic": {
+ if (!ChatMessage.isInstance(message))
+ throw new Error("Invalid generic chat message");
+ return extractGenericMessageCustomRole(message);
+ }
+ default:
+ throw new Error(`Unknown message type: ${type}`);
+ }
+}
+
+function convertMessagesToPremParams(
+ messages: BaseMessage[]
+): Array {
+ return messages.map((message) => {
+ if (typeof message.content !== "string") {
+ throw new Error("Non string message content not supported");
+ }
+ return {
+ role: messageToPremRole(message),
+ content: message.content,
+ name: message.name,
+ function_call: message.additional_kwargs.function_call,
+ };
+ });
+}
+
+function premResponseToChatMessage(
+ message: CreateChatCompletionResponse["choices"]["0"]["message"]
+): BaseMessage {
+ switch (message.role) {
+ case "assistant":
+ return new AIMessage(message.content || "");
+ default:
+ return new ChatMessage(message.content || "", message.role ?? "unknown");
+ }
+}
+
+function _convertDeltaToMessageChunk(
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ delta: Record
+) {
+ const { role } = delta;
+ const content = delta.content ?? "";
+ let additional_kwargs;
+ if (delta.function_call) {
+ additional_kwargs = {
+ function_call: delta.function_call,
+ };
+ } else {
+ additional_kwargs = {};
+ }
+ if (role === "user") {
+ return new HumanMessageChunk({ content });
+ } else if (role === "assistant") {
+ return new AIMessageChunk({ content, additional_kwargs });
+ } else {
+ return new ChatMessageChunk({ content, role });
+ }
+}
+
+/**
+ * Integration with a chat model.
+ */
+export class ChatPrem<
+ CallOptions extends BaseLanguageModelCallOptions = BaseLanguageModelCallOptions
+ >
+ extends BaseChatModel
+ implements ChatPremInput
+{
+ client: Prem;
+
+ apiKey?: string;
+
+ project_id: number;
+
+ session_id?: string;
+
+ messages: {
+ [k: string]: unknown;
+ role: "user" | "assistant";
+ content: string;
+ }[];
+
+ model?: string;
+
+ system_prompt?: string;
+
+ frequency_penalty?: number;
+
+ logit_bias?: { [k: string]: unknown };
+
+ max_tokens?: number;
+
+ n?: number;
+
+ presence_penalty?: number;
+
+ response_format?: { [k: string]: unknown };
+
+ seed?: number;
+
+ stop?: string;
+
+ temperature?: number;
+
+ top_p?: number;
+
+ tools?: { [k: string]: unknown }[];
+
+ user?: string;
+
+ streaming = false;
+
+ [k: string]: unknown;
+
+ // Used for tracing, replace with the same name as your class
+ static lc_name() {
+ return "ChatPrem";
+ }
+
+ lc_serializable = true;
+
+ /**
+ * Replace with any secrets this class passes to `super`.
+ * See {@link ../../langchain-cohere/src/chat_model.ts} for
+ * an example.
+ */
+ get lc_secrets(): { [key: string]: string } | undefined {
+ return {
+ apiKey: "PREM_API_KEY",
+ };
+ }
+
+ get lc_aliases(): { [key: string]: string } | undefined {
+ return {
+ apiKey: "PREM_API_KEY",
+ };
+ }
+
+ constructor(fields?: ChatPremInput) {
+ super(fields ?? {});
+ const apiKey = fields?.apiKey ?? getEnvironmentVariable("PREM_API_KEY");
+ if (!apiKey) {
+ throw new Error(
+ `Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`
+ );
+ }
+
+ const projectId =
+ fields?.project_id ??
+ parseInt(getEnvironmentVariable("PREM_PROJECT_ID") ?? "-1", 10);
+ if (!projectId || projectId === -1 || typeof projectId !== "number") {
+ throw new Error(
+ `Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`
+ );
+ }
+
+ this.client = new Prem({
+ apiKey,
+ });
+
+ this.project_id = projectId;
+ this.session_id = fields?.session_id ?? this.session_id;
+ this.messages = fields?.messages ?? this.messages;
+ this.model = fields?.model ?? this.model;
+ this.system_prompt = fields?.system_prompt ?? this.system_prompt;
+ this.frequency_penalty =
+ fields?.frequency_penalty ?? this.frequency_penalty;
+ this.logit_bias = fields?.logit_bias ?? this.logit_bias;
+ this.max_tokens = fields?.max_tokens ?? this.max_tokens;
+ this.n = fields?.n ?? this.n;
+ this.presence_penalty = fields?.presence_penalty ?? this.presence_penalty;
+ this.response_format = fields?.response_format ?? this.response_format;
+ this.seed = fields?.seed ?? this.seed;
+ this.stop = fields?.stop ?? this.stop;
+ this.temperature = fields?.temperature ?? this.temperature;
+ this.top_p = fields?.top_p ?? this.top_p;
+ this.tools = fields?.tools ?? this.tools;
+ this.user = fields?.user ?? this.user;
+ this.streaming = fields?.streaming ?? this.streaming;
+ }
+
+ // Replace
+ _llmType() {
+ return "prem";
+ }
+
+ async completionWithRetry(
+ request: ChatCompletionCreateParamsStreaming,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ options?: any
+ ): Promise>;
+
+ async completionWithRetry(
+ request: ChatCompletionCreateParams,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ options?: any
+ ): Promise;
+
+ async completionWithRetry(
+ request: ChatCompletionCreateParamsStreaming,
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ options?: any
+ ): Promise<
+ | AsyncIterable
+ | CreateChatCompletionResponse
+ > {
+ return this.caller.call(async () =>
+ this.client.chat.completions.create(request, options)
+ );
+ }
+
+ invocationParams(options: this["ParsedCallOptions"]) {
+ const params = super.invocationParams(options);
+ return {
+ ...params,
+ project_id: this.project_id,
+ session_id: this.session_id,
+ messages: this.messages,
+ model: this.model,
+ system_prompt: this.system_prompt,
+ frequency_penalty: this.frequency_penalty,
+ logit_bias: this.logit_bias,
+ max_tokens: this.max_tokens,
+ n: this.n,
+ presence_penalty: this.presence_penalty,
+ response_format: this.response_format,
+ seed: this.seed,
+ stop: this.stop,
+ temperature: this.temperature,
+ top_p: this.top_p,
+ tools: this.tools,
+ user: this.user,
+ streaming: this.streaming,
+ stream: this.streaming,
+ };
+ }
+
+ /**
+ * Implement to support streaming.
+ * Should yield chunks iteratively.
+ */
+ async *_streamResponseChunks(
+ messages: BaseMessage[],
+ options: this["ParsedCallOptions"],
+ runManager?: CallbackManagerForLLMRun
+ ): AsyncGenerator {
+ const params = this.invocationParams(options);
+ const messagesMapped = convertMessagesToPremParams(messages);
+
+ // All models have a built-in `this.caller` property for retries
+ const stream = await this.caller.call(async () =>
+ this.completionWithRetry(
+ {
+ ...params,
+ messages: messagesMapped,
+ stream: true,
+ },
+ params
+ )
+ );
+
+ for await (const data of stream) {
+ const choice = data?.choices[0];
+ if (!choice) {
+ continue;
+ }
+ const chunk = new ChatGenerationChunk({
+ message: _convertDeltaToMessageChunk(choice.delta ?? {}),
+ text: choice.delta.content ?? "",
+ generationInfo: {
+ finishReason: choice.finish_reason,
+ },
+ });
+ yield chunk;
+ void runManager?.handleLLMNewToken(chunk.text ?? "");
+ }
+ if (options.signal?.aborted) {
+ throw new Error("AbortError");
+ }
+ }
+
+ /** @ignore */
+ _combineLLMOutput() {
+ return [];
+ }
+
+ async _generate(
+ messages: BaseMessage[],
+ options: this["ParsedCallOptions"],
+ runManager?: CallbackManagerForLLMRun
+ ): Promise {
+ const tokenUsage: TokenUsage = {};
+ const params = this.invocationParams(options);
+ const messagesMapped = convertMessagesToPremParams(messages);
+
+ if (params.streaming) {
+ const stream = this._streamResponseChunks(messages, options, runManager);
+ const finalChunks: Record = {};
+ for await (const chunk of stream) {
+ const index =
+ (chunk.generationInfo as NewTokenIndices)?.completion ?? 0;
+ if (finalChunks[index] === undefined) {
+ finalChunks[index] = chunk;
+ } else {
+ finalChunks[index] = finalChunks[index].concat(chunk);
+ }
+ }
+ const generations = Object.entries(finalChunks)
+ .sort(([aKey], [bKey]) => parseInt(aKey, 10) - parseInt(bKey, 10))
+ .map(([_, value]) => value);
+
+ return { generations, llmOutput: { estimatedTokenUsage: tokenUsage } };
+ } else {
+ const data = await this.completionWithRetry(
+ {
+ ...params,
+ stream: false,
+ messages: messagesMapped,
+ },
+ {
+ signal: options?.signal,
+ }
+ );
+
+ if ("usage" in data && data.usage) {
+ const {
+ completion_tokens: completionTokens,
+ prompt_tokens: promptTokens,
+ total_tokens: totalTokens,
+ } = data.usage as CreateChatCompletionResponse["usage"];
+
+ if (completionTokens) {
+ tokenUsage.completionTokens =
+ (tokenUsage.completionTokens ?? 0) + completionTokens;
+ }
+
+ if (promptTokens) {
+ tokenUsage.promptTokens =
+ (tokenUsage.promptTokens ?? 0) + promptTokens;
+ }
+
+ if (totalTokens) {
+ tokenUsage.totalTokens = (tokenUsage.totalTokens ?? 0) + totalTokens;
+ }
+ }
+
+ const generations: ChatGeneration[] = [];
+
+ if ("choices" in data && data.choices) {
+ for (const part of (data as unknown as CreateChatCompletionResponse)
+ .choices) {
+ const text = part.message?.content ?? "";
+ const generation: ChatGeneration = {
+ text,
+ message: premResponseToChatMessage(
+ part.message ?? { role: "assistant" }
+ ),
+ };
+ generation.generationInfo = {
+ ...(part.finish_reason
+ ? { finish_reason: part.finish_reason }
+ : {}),
+ ...(part.logprobs ? { logprobs: part.logprobs } : {}),
+ };
+ generations.push(generation);
+ }
+ }
+
+ return {
+ generations,
+ llmOutput: { tokenUsage },
+ };
+ }
+ }
+}
diff --git a/libs/langchain-community/src/chat_models/tests/chatpremai.int.test.ts b/libs/langchain-community/src/chat_models/tests/chatpremai.int.test.ts
new file mode 100644
index 000000000000..8f77952cf2ec
--- /dev/null
+++ b/libs/langchain-community/src/chat_models/tests/chatpremai.int.test.ts
@@ -0,0 +1,48 @@
+import { describe, test, expect } from "@jest/globals";
+import { ChatMessage, HumanMessage } from "@langchain/core/messages";
+import { ChatPrem } from "../premai.js";
+
+describe.skip("ChatPrem", () => {
+ test("invoke", async () => {
+ const chat = new ChatPrem();
+ const message = new HumanMessage("What color is the sky?");
+ const res = await chat.invoke([message]);
+ expect(res.content.length).toBeGreaterThan(10);
+ });
+
+ test("generate", async () => {
+ const chat = new ChatPrem();
+ const message = new HumanMessage("Hello!");
+ const res = await chat.generate([[message]]);
+ // console.log(JSON.stringify(res, null, 2));
+ expect(res.generations[0][0].text.length).toBeGreaterThan(10);
+ });
+
+ test("custom messages", async () => {
+ const chat = new ChatPrem();
+ const res = await chat.invoke([new ChatMessage("Hello!", "user")]);
+ // console.log({ res });
+ expect(res.content.length).toBeGreaterThan(10);
+ });
+
+ test("custom messages in streaming mode", async () => {
+ const chat = new ChatPrem({ streaming: true });
+ const res = await chat.invoke([new ChatMessage("Hello!", "user")]);
+ // console.log({ res });
+ expect(res.content.length).toBeGreaterThan(10);
+ });
+
+ test("streaming", async () => {
+ const chat = new ChatPrem();
+ const message = new HumanMessage("What color is the sky?");
+ const stream = await chat.stream([message]);
+ let iters = 0;
+ let finalRes = "";
+ for await (const chunk of stream) {
+ iters += 1;
+ finalRes += chunk.content;
+ }
+ console.log({ finalRes, iters });
+ expect(iters).toBeGreaterThan(1);
+ });
+});
diff --git a/libs/langchain-community/src/embeddings/premai.ts b/libs/langchain-community/src/embeddings/premai.ts
new file mode 100644
index 000000000000..07bf4e75eb13
--- /dev/null
+++ b/libs/langchain-community/src/embeddings/premai.ts
@@ -0,0 +1,121 @@
+import { getEnvironmentVariable } from "@langchain/core/utils/env";
+import { Embeddings, EmbeddingsParams } from "@langchain/core/embeddings";
+import { chunkArray } from "@langchain/core/utils/chunk_array";
+import Prem from "@premai/prem-sdk";
+
+/**
+ * Interface for PremEmbeddings parameters. Extends EmbeddingsParams and
+ * defines additional parameters specific to the PremEmbeddings class.
+ */
+export interface PremEmbeddingsParams extends EmbeddingsParams {
+ /**
+ * The Prem API key to use for requests.
+ * @default process.env.PREM_API_KEY
+ */
+ apiKey?: string;
+
+ baseUrl?: string;
+
+ /**
+ * The ID of the project to use.
+ */
+ project_id?: number | string;
+ /**
+ * The model to generate the embeddings.
+ */
+ model: string;
+
+ encoding_format?: ("float" | "base64") & string;
+
+ batchSize?: number;
+}
+
+/**
+ * Class for generating embeddings using the Prem AI's API. Extends the
+ * Embeddings class and implements PremEmbeddingsParams and
+ */
+export class PremEmbeddings extends Embeddings implements PremEmbeddingsParams {
+ client: Prem;
+
+ batchSize = 128;
+
+ apiKey?: string;
+
+ project_id: number;
+
+ model: string;
+
+ encoding_format?: ("float" | "base64") & string;
+
+ constructor(fields: PremEmbeddingsParams) {
+ super(fields);
+ const apiKey = fields?.apiKey || getEnvironmentVariable("PREM_API_KEY");
+ if (!apiKey) {
+ throw new Error(
+ `Prem API key not found. Please set the PREM_API_KEY environment variable or provide the key into "apiKey"`
+ );
+ }
+
+ const projectId =
+ fields?.project_id ??
+ parseInt(getEnvironmentVariable("PREM_PROJECT_ID") ?? "-1", 10);
+ if (!projectId || projectId === -1 || typeof projectId !== "number") {
+ throw new Error(
+ `Prem project ID not found. Please set the PREM_PROJECT_ID environment variable or provide the key into "project_id"`
+ );
+ }
+
+ this.client = new Prem({
+ apiKey,
+ });
+ this.project_id = projectId;
+ this.model = fields.model ?? this.model;
+ this.encoding_format = fields.encoding_format ?? this.encoding_format;
+ }
+
+ /**
+ * Method to generate embeddings for an array of documents. Splits the
+ * documents into batches and makes requests to the Prem API to generate
+ * embeddings.
+ * @param texts Array of documents to generate embeddings for.
+ * @returns Promise that resolves to a 2D array of embeddings for each document.
+ */
+ async embedDocuments(texts: string[]): Promise {
+ const mappedTexts = texts.map((text) => text);
+
+ const batches = chunkArray(mappedTexts, this.batchSize);
+
+ const batchRequests = batches.map((batch) =>
+ this.caller.call(async () =>
+ this.client.embeddings.create({
+ input: batch,
+ model: this.model,
+ encoding_format: this.encoding_format,
+ project_id: this.project_id,
+ })
+ )
+ );
+ const batchResponses = await Promise.all(batchRequests);
+
+ const embeddings: number[][] = [];
+ for (let i = 0; i < batchResponses.length; i += 1) {
+ const batch = batches[i];
+ const { data: batchResponse } = batchResponses[i];
+ for (let j = 0; j < batch.length; j += 1) {
+ embeddings.push(batchResponse[j].embedding);
+ }
+ }
+ return embeddings;
+ }
+
+ /**
+ * Method to generate an embedding for a single document. Calls the
+ * embedDocuments method with the document as the input.
+ * @param text Document to generate an embedding for.
+ * @returns Promise that resolves to an embedding for the document.
+ */
+ async embedQuery(text: string): Promise {
+ const data = await this.embedDocuments([text]);
+ return data[0];
+ }
+}
diff --git a/libs/langchain-community/src/embeddings/tests/premai.int.test.ts b/libs/langchain-community/src/embeddings/tests/premai.int.test.ts
new file mode 100644
index 000000000000..1c8e2f34a127
--- /dev/null
+++ b/libs/langchain-community/src/embeddings/tests/premai.int.test.ts
@@ -0,0 +1,11 @@
+import { describe, test, expect } from "@jest/globals";
+import { PremEmbeddings } from "../premai.js";
+
+describe("EmbeddingsPrem", () => {
+ test("Test embedQuery", async () => {
+ const client = new PremEmbeddings({ model: "@cf/baai/bge-small-en-v1.5" });
+ const res = await client.embedQuery("Hello world");
+ // console.log(res);
+ expect(typeof res[0]).toBe("number");
+ });
+});
diff --git a/libs/langchain-community/src/vectorstores/elasticsearch.ts b/libs/langchain-community/src/vectorstores/elasticsearch.ts
index 1060041e4da8..e4199c43b51c 100644
--- a/libs/langchain-community/src/vectorstores/elasticsearch.ts
+++ b/libs/langchain-community/src/vectorstores/elasticsearch.ts
@@ -193,7 +193,8 @@ export class ElasticVectorSearch extends VectorStore {
_index: this.indexName,
},
}));
- await this.client.bulk({ refresh: true, operations });
+ if (operations.length > 0)
+ await this.client.bulk({ refresh: true, operations });
}
/**
diff --git a/libs/langchain-community/src/vectorstores/neo4j_vector.ts b/libs/langchain-community/src/vectorstores/neo4j_vector.ts
index 72f98f03039d..590c5180cef4 100644
--- a/libs/langchain-community/src/vectorstores/neo4j_vector.ts
+++ b/libs/langchain-community/src/vectorstores/neo4j_vector.ts
@@ -735,7 +735,11 @@ function getSearchIndexQuery(searchType: SearchType): string {
return typeToQueryMap[searchType];
}
-function removeLuceneChars(text: string): string {
+function removeLuceneChars(text: string | null) {
+ if (text === undefined || text === null) {
+ return null;
+ }
+
// Remove Lucene special characters
const specialChars = [
"+",
diff --git a/libs/langchain-mistralai/package.json b/libs/langchain-mistralai/package.json
index e697f3b70c80..66efd74d587a 100644
--- a/libs/langchain-mistralai/package.json
+++ b/libs/langchain-mistralai/package.json
@@ -1,6 +1,6 @@
{
"name": "@langchain/mistralai",
- "version": "0.0.15",
+ "version": "0.0.16",
"description": "MistralAI integration for LangChain.js",
"type": "module",
"engines": {
@@ -62,6 +62,7 @@
"eslint-plugin-prettier": "^4.2.1",
"jest": "^29.5.0",
"jest-environment-node": "^29.6.4",
+ "langchain": "workspace:^",
"prettier": "^2.8.3",
"release-it": "^15.10.1",
"rollup": "^4.5.2",
diff --git a/libs/langchain-mistralai/src/chat_models.ts b/libs/langchain-mistralai/src/chat_models.ts
index 5a210f6cdc6d..2521aa7542c3 100644
--- a/libs/langchain-mistralai/src/chat_models.ts
+++ b/libs/langchain-mistralai/src/chat_models.ts
@@ -5,6 +5,7 @@ import {
ToolChoice as MistralAIToolChoice,
ResponseFormat,
ChatCompletionResponseChunk,
+ ToolType,
} from "@mistralai/mistralai";
import {
MessageType,
@@ -16,6 +17,8 @@ import {
AIMessageChunk,
ToolMessageChunk,
ChatMessageChunk,
+ FunctionMessageChunk,
+ ToolCall,
} from "@langchain/core/messages";
import type {
BaseLanguageModelInput,
@@ -165,6 +168,10 @@ function convertMessagesToMistralMessages(
return "assistant";
case "system":
return "system";
+ case "tool":
+ return "tool";
+ case "function":
+ return "assistant";
default:
throw new Error(`Unknown message type: ${role}`);
}
@@ -183,9 +190,17 @@ function convertMessagesToMistralMessages(
);
};
+ const getTools = (toolCalls: ToolCall[] | undefined): MistralAIToolCalls[] =>
+ toolCalls?.map((toolCall) => ({
+ id: "null",
+ type: "function" as ToolType.function,
+ function: toolCall.function,
+ })) || [];
+
return messages.map((message) => ({
role: getRole(message._getType()),
content: getContent(message.content),
+ tool_calls: getTools(message.additional_kwargs.tool_calls),
}));
}
@@ -235,7 +250,7 @@ function _convertDeltaToMessageChunk(delta: {
if (delta.role) {
role = delta.role;
} else if (toolCallsWithIndex) {
- role = "tool";
+ role = "function";
}
const content = delta.content ?? "";
let additional_kwargs;
@@ -257,6 +272,11 @@ function _convertDeltaToMessageChunk(delta: {
additional_kwargs,
tool_call_id: toolCallsWithIndex?.[0].id ?? "",
});
+ } else if (role === "function") {
+ return new FunctionMessageChunk({
+ content,
+ additional_kwargs,
+ });
} else {
return new ChatMessageChunk({ content, role });
}
diff --git a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts
index 2e55e5b8ff41..f5685f8bbb0b 100644
--- a/libs/langchain-mistralai/src/tests/chat_models.int.test.ts
+++ b/libs/langchain-mistralai/src/tests/chat_models.int.test.ts
@@ -1,6 +1,13 @@
import { test } from "@jest/globals";
-import { ChatPromptTemplate } from "@langchain/core/prompts";
-import { StructuredTool } from "@langchain/core/tools";
+import {
+ ChatPromptTemplate,
+ HumanMessagePromptTemplate,
+ MessagesPlaceholder,
+ SystemMessagePromptTemplate,
+} from "@langchain/core/prompts";
+import { AgentExecutor, createOpenAIToolsAgent } from "langchain/agents";
+import { BaseChatModel } from "langchain/chat_models/base";
+import { DynamicStructuredTool, StructuredTool } from "@langchain/core/tools";
import { z } from "zod";
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { zodToJsonSchema } from "zod-to-json-schema";
@@ -562,6 +569,54 @@ describe("withStructuredOutput", () => {
)
).toBe(true);
});
+
+ test("Model is compatible with OpenAI tools agent and Agent Executor", async () => {
+ const llm: BaseChatModel = new ChatMistralAI({
+ temperature: 0,
+ modelName: "mistral-large-latest",
+ });
+
+ const systemMessage = SystemMessagePromptTemplate.fromTemplate(
+ "You are an agent capable of retrieving current weather information."
+ );
+ const humanMessage = HumanMessagePromptTemplate.fromTemplate("{input}");
+ const agentScratchpad = new MessagesPlaceholder("agent_scratchpad");
+
+ const prompt = ChatPromptTemplate.fromMessages([
+ systemMessage,
+ humanMessage,
+ agentScratchpad,
+ ]);
+
+ const currentWeatherTool = new DynamicStructuredTool({
+ name: "get_current_weather",
+ description: "Get the current weather in a given location",
+ schema: z.object({
+ location: z
+ .string()
+ .describe("The city and state, e.g. San Francisco, CA"),
+ }),
+ func: async () => Promise.resolve("28 °C"),
+ });
+
+ const agent = await createOpenAIToolsAgent({
+ llm,
+ tools: [currentWeatherTool],
+ prompt,
+ });
+
+ const agentExecutor = new AgentExecutor({
+ agent,
+ tools: [currentWeatherTool],
+ });
+
+ const input = "What's the weather like in Paris?";
+ const { output } = await agentExecutor.invoke({ input });
+
+ console.log(output);
+ expect(output).toBeDefined();
+ expect(output).toContain("The current temperature in Paris is 28 °C");
+ });
});
describe("ChatMistralAI aborting", () => {
diff --git a/yarn.lock b/yarn.lock
index d777ee53e7e5..f402d6d6237d 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -8970,6 +8970,7 @@ __metadata:
"@opensearch-project/opensearch": ^2.2.0
"@pinecone-database/pinecone": ^1.1.0
"@planetscale/database": ^1.8.0
+ "@premai/prem-sdk": ^0.3.25
"@qdrant/js-client-rest": ^1.2.0
"@raycast/api": ^1.55.2
"@rockset/client": ^0.9.1
@@ -9099,6 +9100,7 @@ __metadata:
"@opensearch-project/opensearch": "*"
"@pinecone-database/pinecone": "*"
"@planetscale/database": ^1.8.0
+ "@premai/prem-sdk": ^0.3.25
"@qdrant/js-client-rest": ^1.2.0
"@raycast/api": ^1.55.2
"@rockset/client": ^0.9.1
@@ -9214,6 +9216,8 @@ __metadata:
optional: true
"@planetscale/database":
optional: true
+ "@premai/prem-sdk":
+ optional: true
"@qdrant/js-client-rest":
optional: true
"@raycast/api":
@@ -9598,6 +9602,7 @@ __metadata:
eslint-plugin-prettier: ^4.2.1
jest: ^29.5.0
jest-environment-node: ^29.6.4
+ langchain: "workspace:^"
prettier: ^2.8.3
release-it: ^15.10.1
rollup: ^4.5.2
@@ -10738,6 +10743,15 @@ __metadata:
languageName: node
linkType: hard
+"@premai/prem-sdk@npm:^0.3.25":
+ version: 0.3.25
+ resolution: "@premai/prem-sdk@npm:0.3.25"
+ dependencies:
+ axios: ^1.6.2
+ checksum: e52bd2a1d44df773ca76c6267ef23dbbd94af4ad4f8cd3eb123772179c62358115c4e53bbc1373772529c08bc888df1fdcb994bfd9b57b8ee8206cd4808aa77d
+ languageName: node
+ linkType: hard
+
"@prisma/client@npm:^4.11.0":
version: 4.11.0
resolution: "@prisma/client@npm:4.11.0"
@@ -16791,6 +16805,17 @@ __metadata:
languageName: node
linkType: hard
+"axios@npm:^1.6.2":
+ version: 1.6.8
+ resolution: "axios@npm:1.6.8"
+ dependencies:
+ follow-redirects: ^1.15.6
+ form-data: ^4.0.0
+ proxy-from-env: ^1.1.0
+ checksum: bf007fa4b207d102459300698620b3b0873503c6d47bf5a8f6e43c0c64c90035a4f698b55027ca1958f61ab43723df2781c38a99711848d232cad7accbcdfcdd
+ languageName: node
+ linkType: hard
+
"axios@npm:^1.6.5, axios@npm:^1.6.7":
version: 1.6.7
resolution: "axios@npm:1.6.7"
@@ -22339,6 +22364,16 @@ __metadata:
languageName: node
linkType: hard
+"follow-redirects@npm:^1.15.6":
+ version: 1.15.6
+ resolution: "follow-redirects@npm:1.15.6"
+ peerDependenciesMeta:
+ debug:
+ optional: true
+ checksum: a62c378dfc8c00f60b9c80cab158ba54e99ba0239a5dd7c81245e5a5b39d10f0c35e249c3379eae719ff0285fff88c365dd446fab19dee771f1d76252df1bbf5
+ languageName: node
+ linkType: hard
+
"for-each@npm:^0.3.3":
version: 0.3.3
resolution: "for-each@npm:0.3.3"
@@ -26338,7 +26373,7 @@ __metadata:
languageName: node
linkType: hard
-"langchain@workspace:*, langchain@workspace:langchain":
+"langchain@workspace:*, langchain@workspace:^, langchain@workspace:langchain":
version: 0.0.0-use.local
resolution: "langchain@workspace:langchain"
dependencies: