From 9a0675c07d41bdda8f364e8dba6f226cff67dc58 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Tue, 11 Jun 2024 16:18:59 -0700 Subject: [PATCH] anthropic[patch]: Stream tokens (#5730) * anthropic[patch]: Stream tokens * chore: lint files * added normal test * add streamUsage field to control streaming token counts --- libs/langchain-anthropic/src/chat_models.ts | 43 ++++++++++++++++++- .../src/tests/chat_models.int.test.ts | 29 ++++++++++++- .../tests/chat_models.standard.int.test.ts | 6 --- libs/langchain-cloudflare/src/chat_models.ts | 2 +- 4 files changed, 70 insertions(+), 10 deletions(-) diff --git a/libs/langchain-anthropic/src/chat_models.ts b/libs/langchain-anthropic/src/chat_models.ts index aba95cee3a00..036b21515671 100644 --- a/libs/langchain-anthropic/src/chat_models.ts +++ b/libs/langchain-anthropic/src/chat_models.ts @@ -11,6 +11,7 @@ import { ToolMessage, isAIMessage, MessageContent, + UsageMetadata, } from "@langchain/core/messages"; import { ChatGeneration, @@ -67,7 +68,9 @@ type AnthropicToolChoice = } | "any" | "auto"; -export interface ChatAnthropicCallOptions extends BaseLanguageModelCallOptions { +export interface ChatAnthropicCallOptions + extends BaseLanguageModelCallOptions, + Pick { tools?: (StructuredToolInterface | AnthropicTool)[]; /** * Whether or not to specify what tool the model should use @@ -211,6 +214,12 @@ export interface AnthropicInput { * `anthropic.messages`} that are not explicitly specified on this class. */ invocationKwargs?: Kwargs; + + /** + * Whether or not to include token usage data in streamed chunks. + * @default true + */ + streamUsage?: boolean; } /** @@ -485,6 +494,8 @@ export class ChatAnthropicMessages< // Used for streaming requests protected streamingClient: Anthropic; + streamUsage = true; + constructor(fields?: Partial & BaseChatModelParams) { super(fields ?? {}); @@ -516,12 +527,13 @@ export class ChatAnthropicMessages< this.streaming = fields?.streaming ?? false; this.clientOptions = fields?.clientOptions ?? {}; + this.streamUsage = fields?.streamUsage ?? this.streamUsage; } getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { const params = this.invocationParams(options); return { - ls_provider: "openai", + ls_provider: "anthropic", ls_model_name: this.model, ls_model_type: "chat", ls_temperature: params.temperature ?? undefined, @@ -691,18 +703,36 @@ export class ChatAnthropicMessages< } } usageData = usage; + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + total_tokens: usage.input_tokens + usage.output_tokens, + }; + } yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: "", additional_kwargs: filteredAdditionalKwargs, + usage_metadata: usageMetadata, }), text: "", }); } else if (data.type === "message_delta") { + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: data.usage.output_tokens, + output_tokens: 0, + total_tokens: data.usage.output_tokens, + }; + } yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: "", additional_kwargs: { ...data.delta }, + usage_metadata: usageMetadata, }), text: "", }); @@ -723,10 +753,19 @@ export class ChatAnthropicMessages< } } } + let usageMetadata: UsageMetadata | undefined; + if (this.streamUsage || options.streamUsage) { + usageMetadata = { + input_tokens: usageData.input_tokens, + output_tokens: usageData.output_tokens, + total_tokens: usageData.input_tokens + usageData.output_tokens, + }; + } yield new ChatGenerationChunk({ message: new AIMessageChunk({ content: "", additional_kwargs: { usage: usageData }, + usage_metadata: usageMetadata, }), text: "", }); diff --git a/libs/langchain-anthropic/src/tests/chat_models.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models.int.test.ts index b56aa2149076..9cc706ec6502 100644 --- a/libs/langchain-anthropic/src/tests/chat_models.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models.int.test.ts @@ -1,7 +1,7 @@ /* eslint-disable no-process-env */ import { expect, test } from "@jest/globals"; -import { HumanMessage } from "@langchain/core/messages"; +import { AIMessageChunk, HumanMessage } from "@langchain/core/messages"; import { ChatPromptValue } from "@langchain/core/prompt_values"; import { PromptTemplate, @@ -318,3 +318,30 @@ test("Test ChatAnthropic multimodal", async () => { ]); console.log(res); }); + +test("Stream tokens", async () => { + const model = new ChatAnthropic({ + model: "claude-3-haiku-20240307", + temperature: 0, + }); + let res: AIMessageChunk | null = null; + for await (const chunk of await model.stream( + "Why is the sky blue? Be concise." + )) { + if (!res) { + res = chunk; + } else { + res = res.concat(chunk); + } + } + console.log(res); + expect(res?.usage_metadata).toBeDefined(); + if (!res?.usage_metadata) { + return; + } + expect(res.usage_metadata.input_tokens).toBe(34); + expect(res.usage_metadata.output_tokens).toBeGreaterThan(10); + expect(res.usage_metadata.total_tokens).toBe( + res.usage_metadata.input_tokens + res.usage_metadata.output_tokens + ); +}); diff --git a/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts b/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts index 45eff821ce9c..1980680bf019 100644 --- a/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-anthropic/src/tests/chat_models.standard.int.test.ts @@ -23,12 +23,6 @@ class ChatAnthropicStandardIntegrationTests extends ChatModelIntegrationTests< }, }); } - - async testUsageMetadataStreaming() { - console.warn( - "Skipping testUsageMetadataStreaming, not implemented in ChatAnthropic." - ); - } } const testClass = new ChatAnthropicStandardIntegrationTests(); diff --git a/libs/langchain-cloudflare/src/chat_models.ts b/libs/langchain-cloudflare/src/chat_models.ts index bdae5020d6ba..cc7f7f174f3e 100644 --- a/libs/langchain-cloudflare/src/chat_models.ts +++ b/libs/langchain-cloudflare/src/chat_models.ts @@ -84,7 +84,7 @@ export class ChatCloudflareWorkersAI getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { return { - ls_provider: "openai", + ls_provider: "cloudflare", ls_model_name: this.model, ls_model_type: "chat", ls_stop: options.stop,