Skip to content

Commit

Permalink
fix: openai key usage with chat
Browse files Browse the repository at this point in the history
  • Loading branch information
sinedied committed Mar 27, 2024
1 parent 1133495 commit 64348ba
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 23 deletions.
52 changes: 29 additions & 23 deletions libs/langchain-azure-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,27 @@ export class AzureChatOpenAI
(getEnvironmentVariable("AZURE_OPENAI_API_KEY") ||
getEnvironmentVariable("OPENAI_API_KEY"));

const azureCredential =
fields?.credentials ??
(fields?.azureOpenAIApiKey ||
getEnvironmentVariable("AZURE_OPENAI_API_KEY")
? new AzureKeyCredential(this.azureOpenAIApiKey ?? "")
: new OpenAIKeyCredential(this.azureOpenAIApiKey ?? ""));

const isOpenAIApiKey =
fields?.azureOpenAIApiKey ||
// eslint-disable-next-line no-instanceof/no-instanceof
azureCredential instanceof OpenAIKeyCredential;

if (!this.azureOpenAIApiKey && !fields?.credentials) {
throw new Error("Azure OpenAI API key not found");
}

if (!this.azureOpenAIEndpoint) {
if (!this.azureOpenAIEndpoint && !isOpenAIApiKey) {
throw new Error("Azure OpenAI Endpoint not found");
}

if (!this.azureOpenAIApiDeploymentName) {
if (!this.azureOpenAIApiDeploymentName && !isOpenAIApiKey) {
throw new Error("Azure OpenAI Deployment name not found");
}

Expand All @@ -294,28 +306,25 @@ export class AzureChatOpenAI

this.streaming = fields?.streaming ?? false;

const azureCredential =
fields?.credentials ??
(fields?.azureOpenAIApiKey ||
getEnvironmentVariable("AZURE_OPENAI_API_KEY")
? new AzureKeyCredential(this.azureOpenAIApiKey ?? "")
: new OpenAIKeyCredential(this.azureOpenAIApiKey ?? ""));
const options = {
userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX },
};

if (isTokenCredential(azureCredential)) {
if (isOpenAIApiKey) {
this.client = new AzureOpenAIClient(
azureCredential as OpenAIKeyCredential
);
} else if (isTokenCredential(azureCredential)) {
this.client = new AzureOpenAIClient(
this.azureOpenAIEndpoint ?? "",
azureCredential as TokenCredential,
{
userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX },
}
options
);
} else {
this.client = new AzureOpenAIClient(
this.azureOpenAIEndpoint ?? "",
azureCredential as KeyCredential,
{
userAgentOptions: { userAgentPrefix: USER_AGENT_PREFIX },
}
options
);
}
}
Expand All @@ -339,11 +348,11 @@ export class AzureChatOpenAI
options: this["ParsedCallOptions"]
): Promise<EventStream<ChatCompletions>> {
return this.caller.call(async () => {
if (!this.azureOpenAIApiDeploymentName) {
throw new Error("Azure OpenAI Deployment name not found");
}
const deploymentName =
this.azureOpenAIApiDeploymentName || this.modelName;

const res = await this.client.streamChatCompletions(
this.azureOpenAIApiDeploymentName,
deploymentName,
azureOpenAIMessages,
{
functions: options?.functions,
Expand Down Expand Up @@ -434,10 +443,7 @@ export class AzureChatOpenAI
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!this.azureOpenAIApiDeploymentName) {
throw new Error("Azure OpenAI Deployment name not found");
}
const deploymentName = this.azureOpenAIApiDeploymentName;
const deploymentName = this.azureOpenAIApiDeploymentName || this.modelName;
const tokenUsage: TokenUsage = {};
const azureOpenAIMessages: ChatRequestMessage[] =
this.formatMessages(messages);
Expand Down
18 changes: 18 additions & 0 deletions libs/langchain-azure-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import {
import { CallbackManager } from "@langchain/core/callbacks/manager";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { InMemoryCache } from "@langchain/core/caches";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { OpenAIKeyCredential } from "@azure/openai";
import { AzureChatOpenAI } from "../chat_models.js";

test("Test ChatOpenAI", async () => {
Expand Down Expand Up @@ -790,3 +792,19 @@ test("Test ChatOpenAI token usage reporting for streaming calls", async () => {
expect(streamingTokenUsed).toEqual(nonStreamingTokenUsed);
}
});

test("Test ChatOpenAI with OpenAI API key credentials", async () => {
const openAiKey: string = getEnvironmentVariable("OPENAI_API_KEY") ?? "";
const credentials = new OpenAIKeyCredential(openAiKey);

const chat = new AzureChatOpenAI({
modelName: "gpt-3.5-turbo",
maxTokens: 5,
credentials,
azureOpenAIEndpoint: "",
azureOpenAIApiDeploymentName: "",
});
const message = new HumanMessage("Hello!");
const res = await chat.invoke([["system", "Say hi"], message]);
console.log(res);
});

0 comments on commit 64348ba

Please sign in to comment.