diff --git a/libs/langchain-azure-openai/src/chat_models.ts b/libs/langchain-azure-openai/src/chat_models.ts index 6acda2300b23..ffa7c8030adb 100644 --- a/libs/langchain-azure-openai/src/chat_models.ts +++ b/libs/langchain-azure-openai/src/chat_models.ts @@ -266,15 +266,27 @@ export class AzureChatOpenAI (getEnvironmentVariable("AZURE_OPENAI_API_KEY") || getEnvironmentVariable("OPENAI_API_KEY")); + const azureCredential = + fields?.credentials ?? + (fields?.azureOpenAIApiKey || + getEnvironmentVariable("AZURE_OPENAI_API_KEY") + ? new AzureKeyCredential(this.azureOpenAIApiKey ?? "") + : new OpenAIKeyCredential(this.azureOpenAIApiKey ?? "")); + + const isOpenAIApiKey = + fields?.azureOpenAIApiKey || + // eslint-disable-next-line no-instanceof/no-instanceof + azureCredential instanceof OpenAIKeyCredential; + if (!this.azureOpenAIApiKey && !fields?.credentials) { throw new Error("Azure OpenAI API key not found"); } - if (!this.azureOpenAIEndpoint) { + if (!this.azureOpenAIEndpoint && !isOpenAIApiKey) { throw new Error("Azure OpenAI Endpoint not found"); } - if (!this.azureOpenAIApiDeploymentName) { + if (!this.azureOpenAIApiDeploymentName && !isOpenAIApiKey) { throw new Error("Azure OpenAI Deployment name not found"); } @@ -294,28 +306,25 @@ export class AzureChatOpenAI this.streaming = fields?.streaming ?? false; - const azureCredential = - fields?.credentials ?? - (fields?.azureOpenAIApiKey || - getEnvironmentVariable("AZURE_OPENAI_API_KEY") - ? new AzureKeyCredential(this.azureOpenAIApiKey ?? "") - : new OpenAIKeyCredential(this.azureOpenAIApiKey ?? "")); + const options = { + userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX }, + }; - if (isTokenCredential(azureCredential)) { + if (isOpenAIApiKey) { + this.client = new AzureOpenAIClient( + azureCredential as OpenAIKeyCredential + ); + } else if (isTokenCredential(azureCredential)) { this.client = new AzureOpenAIClient( this.azureOpenAIEndpoint ?? "", azureCredential as TokenCredential, - { - userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX }, - } + options ); } else { this.client = new AzureOpenAIClient( this.azureOpenAIEndpoint ?? "", azureCredential as KeyCredential, - { - userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX }, - } + options ); } } @@ -339,11 +348,11 @@ export class AzureChatOpenAI options: this["ParsedCallOptions"] ): Promise> { return this.caller.call(async () => { - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI Deployment name not found"); - } + const deploymentName = + this.azureOpenAIApiDeploymentName || this.modelName; + const res = await this.client.streamChatCompletions( - this.azureOpenAIApiDeploymentName, + deploymentName, azureOpenAIMessages, { functions: options?.functions, @@ -434,10 +443,7 @@ export class AzureChatOpenAI options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): Promise { - if (!this.azureOpenAIApiDeploymentName) { - throw new Error("Azure OpenAI Deployment name not found"); - } - const deploymentName = this.azureOpenAIApiDeploymentName; + const deploymentName = this.azureOpenAIApiDeploymentName || this.modelName; const tokenUsage: TokenUsage = {}; const azureOpenAIMessages: ChatRequestMessage[] = this.formatMessages(messages); diff --git a/libs/langchain-azure-openai/src/tests/chat_models.int.test.ts b/libs/langchain-azure-openai/src/tests/chat_models.int.test.ts index 2a4d7e01d717..e5c48c1e13fc 100644 --- a/libs/langchain-azure-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-azure-openai/src/tests/chat_models.int.test.ts @@ -16,6 +16,8 @@ import { import { CallbackManager } from "@langchain/core/callbacks/manager"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; import { InMemoryCache } from "@langchain/core/caches"; +import { getEnvironmentVariable } from "@langchain/core/utils/env"; +import { OpenAIKeyCredential } from "@azure/openai"; import { AzureChatOpenAI } from "../chat_models.js"; test("Test ChatOpenAI", async () => { @@ -790,3 +792,19 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => { expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed); } }); + +test("Test ChatOpenAI with OpenAI API key credentials", async () => { + const openAiKey: string = getEnvironmentVariable("OPENAI_API_KEY") ?? ""; + const credentials = new OpenAIKeyCredential(openAiKey); + + const chat = new AzureChatOpenAI({ + modelName: "gpt-3.5-turbo", + maxTokens: 5, + credentials, + azureOpenAIEndpoint: "", + azureOpenAIApiDeploymentName: "", + }); + const message = new HumanMessage("Hello!"); + const res = await chat.invoke([["system", "Say hi"], message]); + console.log(res); +});