diff --git a/app/api/anthropic/[...path]/route.ts b/app/api/anthropic/[...path]/route.ts index 4264893d93e..78106efa76c 100644 --- a/app/api/anthropic/[...path]/route.ts +++ b/app/api/anthropic/[...path]/route.ts @@ -4,12 +4,13 @@ import { Anthropic, ApiPath, DEFAULT_MODELS, + ServiceProvider, ModelProvider, } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; -import { collectModelTable } from "@/app/utils/model"; +import { isModelAvailableInServer } from "@/app/utils/model"; const ALLOWD_PATH = new Set([Anthropic.ChatPath, Anthropic.ChatPath1]); @@ -136,17 +137,19 @@ async function request(req: NextRequest) { // #1815 try to refuse some request to some models if (serverConfig.customModels && req.body) { try { - const modelTable = collectModelTable( - DEFAULT_MODELS, - serverConfig.customModels, - ); const clonedBody = await req.text(); fetchOptions.body = clonedBody; const jsonBody = JSON.parse(clonedBody) as { model?: string }; // not undefined and is false - if (modelTable[jsonBody?.model ?? ""].available === false) { + if ( + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.Anthropic as string, + ) + ) { return NextResponse.json( { error: true, diff --git a/app/api/common.ts b/app/api/common.ts index a75f2de5cfa..1454fde2ed1 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -1,7 +1,12 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "../config/server"; -import { DEFAULT_MODELS, OPENAI_BASE_URL, GEMINI_BASE_URL } from "../constant"; -import { collectModelTable } from "../utils/model"; +import { + DEFAULT_MODELS, + OPENAI_BASE_URL, + GEMINI_BASE_URL, + ServiceProvider, +} from "../constant"; +import { isModelAvailableInServer } from "../utils/model"; import { makeAzurePath } from "../azure"; const serverConfig = getServerSideConfig(); @@ -83,17 +88,24 @@ export async function requestOpenai(req: NextRequest) { // #1815 try to refuse gpt4 request if (serverConfig.customModels && req.body) { try { - const modelTable = collectModelTable( - DEFAULT_MODELS, - serverConfig.customModels, - ); const clonedBody = await req.text(); fetchOptions.body = clonedBody; const jsonBody = JSON.parse(clonedBody) as { model?: string }; // not undefined and is false - if (modelTable[jsonBody?.model ?? ""].available === false) { + if ( + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.OpenAI as string, + ) || + isModelAvailableInServer( + serverConfig.customModels, + jsonBody?.model as string, + ServiceProvider.Azure as string, + ) + ) { return NextResponse.json( { error: true, @@ -112,16 +124,16 @@ export async function requestOpenai(req: NextRequest) { try { const res = await fetch(fetchUrl, fetchOptions); - // Extract the OpenAI-Organization header from the response - const openaiOrganizationHeader = res.headers.get("OpenAI-Organization"); + // Extract the OpenAI-Organization header from the response + const openaiOrganizationHeader = res.headers.get("OpenAI-Organization"); - // Check if serverConfig.openaiOrgId is defined and not an empty string - if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") { - // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present - console.log("[Org ID]", openaiOrganizationHeader); - } else { - console.log("[Org ID] is not set up."); - } + // Check if serverConfig.openaiOrgId is defined and not an empty string + if (serverConfig.openaiOrgId && serverConfig.openaiOrgId.trim() !== "") { + // If openaiOrganizationHeader is present, log it; otherwise, log that the header is not present + console.log("[Org ID]", openaiOrganizationHeader); + } else { + console.log("[Org ID] is not set up."); + } // to prevent browser prompt for credentials const newHeaders = new Headers(res.headers); @@ -129,7 +141,6 @@ export async function requestOpenai(req: NextRequest) { // to disable nginx buffering newHeaders.set("X-Accel-Buffering", "no"); - // Conditionally delete the OpenAI-Organization header from the response if [Org ID] is undefined or empty (not setup in ENV) // Also, this is to prevent the header from being sent to the client if (!serverConfig.openaiOrgId || serverConfig.openaiOrgId.trim() === "") { @@ -142,7 +153,6 @@ export async function requestOpenai(req: NextRequest) { // The browser will try to decode the response with brotli and fail newHeaders.delete("content-encoding"); - return new Response(res.body, { status: res.status, statusText: res.statusText, diff --git a/app/store/config.ts b/app/store/config.ts index 94cfcd8ecaa..0e7f43ee6a6 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -116,12 +116,12 @@ export const useAppConfig = createPersistStore( for (const model of oldModels) { model.available = false; - modelMap[model.name] = model; + modelMap[`${model.name}@${model?.provider?.id}`] = model; } for (const model of newModels) { model.available = true; - modelMap[model.name] = model; + modelMap[`${model.name}@${model?.provider?.id}`] = model; } set(() => ({ diff --git a/app/utils/model.ts b/app/utils/model.ts index 056fff2e98d..249987726ad 100644 --- a/app/utils/model.ts +++ b/app/utils/model.ts @@ -1,8 +1,9 @@ +import { DEFAULT_MODELS } from "../constant"; import { LLMModel } from "../client/api"; const customProvider = (modelName: string) => ({ id: modelName, - providerName: "", + providerName: "Custom", providerType: "custom", }); @@ -23,7 +24,8 @@ export function collectModelTable( // default models models.forEach((m) => { - modelTable[m.name] = { + // using @ as fullName + modelTable[`${m.name}@${m?.provider?.id}`] = { ...m, displayName: m.name, // 'provider' is copied over if it exists }; @@ -45,12 +47,27 @@ export function collectModelTable( (model) => (model.available = available), ); } else { - modelTable[name] = { - name, - displayName: displayName || name, - available, - provider: modelTable[name]?.provider ?? customProvider(name), // Use optional chaining - }; + // 1. find model by name(), and set available value + let count = 0; + for (const fullName in modelTable) { + if (fullName.split("@").shift() == name) { + count += 1; + modelTable[fullName]["available"] = available; + if (displayName) { + modelTable[fullName]["displayName"] = displayName; + } + } + } + // 2. if model not exists, create new model with available value + if (count === 0) { + const provider = customProvider(name); + modelTable[`${name}@${provider?.id}`] = { + name, + displayName: displayName || name, + available, + provider, // Use optional chaining + }; + } } }); @@ -100,3 +117,13 @@ export function collectModelsWithDefaultModel( const allModels = Object.values(modelTable); return allModels; } + +export function isModelAvailableInServer( + customModels: string, + modelName: string, + providerName: string, +) { + const fullName = `${modelName}@${providerName}`; + const modelTable = collectModelTable(DEFAULT_MODELS, customModels); + return modelTable[fullName]?.available === false; +}