diff --git a/front/lib/api/assistant/generation.ts b/front/lib/api/assistant/generation.ts index 59f4450046ce..a79954f8dc9f 100644 --- a/front/lib/api/assistant/generation.ts +++ b/front/lib/api/assistant/generation.ts @@ -9,7 +9,11 @@ import { renderRetrievalActionForModel, retrievalMetaPrompt, } from "@app/lib/api/assistant/actions/retrieval"; -import { getSupportedModelConfig } from "@app/lib/assistant"; +import { + getSupportedModelConfig, + GPT_4_32K_MODEL_ID, + GPT_4_MODEL_CONFIG, +} from "@app/lib/assistant"; import { Authenticator } from "@app/lib/auth"; import { CoreAPI } from "@app/lib/core_api"; import { redisClient } from "@app/lib/redis"; @@ -328,13 +332,15 @@ export async function* runGeneration( return; } - const contextSize = getSupportedModelConfig(c.model).contextSize; + let model = c.model; + + const contextSize = getSupportedModelConfig(model).contextSize; const MIN_GENERATION_TOKENS = 2048; if (contextSize < MIN_GENERATION_TOKENS) { throw new Error( - `Model contextSize unexpectedly small for model: ${c.model.providerId} ${c.model.modelId}` + `Model contextSize unexpectedly small for model: ${model.providerId} ${model.modelId}` ); } @@ -343,7 +349,7 @@ export async function* runGeneration( // Turn the conversation into a digest that can be presented to the model. const modelConversationRes = await renderConversationForModel({ conversation, - model: c.model, + model, prompt, allowedTokenCount: contextSize - MIN_GENERATION_TOKENS, }); @@ -356,17 +362,30 @@ export async function* runGeneration( messageId: agentMessage.sId, error: { code: "internal_server_error", - message: `Failed tokenization for ${c.model.providerId} ${c.model.modelId}: ${modelConversationRes.error.message}`, + message: `Failed tokenization for ${model.providerId} ${model.modelId}: ${modelConversationRes.error.message}`, }, }; return; } + // If model is gpt4-32k but tokens used is less than GPT_4_CONTEXT_SIZE-MIN_GENERATION_TOKENS, + // then we override the model to gpt4 standard (8k context, cheaper). + if ( + model.modelId === GPT_4_32K_MODEL_ID && + modelConversationRes.value.tokensUsed < + GPT_4_MODEL_CONFIG.contextSize - MIN_GENERATION_TOKENS + ) { + model = { + modelId: GPT_4_MODEL_CONFIG.modelId, + providerId: GPT_4_MODEL_CONFIG.providerId, + }; + } + const config = cloneBaseConfig( DustProdActionRegistry["assistant-v2-generator"].config ); - config.MODEL.provider_id = c.model.providerId; - config.MODEL.model_id = c.model.modelId; + config.MODEL.provider_id = model.providerId; + config.MODEL.model_id = model.modelId; config.MODEL.temperature = c.temperature; // This is the console.log you want to uncomment to generate inputs for the generator app. @@ -381,7 +400,7 @@ export async function* runGeneration( { workspaceId: conversation.owner.sId, conversationId: conversation.sId, - model: c.model, + model: model, temperature: c.temperature, }, "[ASSISTANT_TRACE] Generation exection"