Skip to content

Commit

Permalink
Preserve the routing from 32k to 4
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu committed Nov 13, 2023
1 parent 9c530bc commit 015c602
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions front/lib/api/assistant/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import {
renderRetrievalActionForModel,
retrievalMetaPrompt,
} from "@app/lib/api/assistant/actions/retrieval";
import { getSupportedModelConfig } from "@app/lib/assistant";
import {
getSupportedModelConfig,
GPT_4_32K_MODEL_ID,
GPT_4_MODEL_CONFIG,
} from "@app/lib/assistant";
import { Authenticator } from "@app/lib/auth";
import { CoreAPI } from "@app/lib/core_api";
import { redisClient } from "@app/lib/redis";
Expand Down Expand Up @@ -328,13 +332,15 @@ export async function* runGeneration(
return;
}

const contextSize = getSupportedModelConfig(c.model).contextSize;
let model = c.model;

const contextSize = getSupportedModelConfig(model).contextSize;

const MIN_GENERATION_TOKENS = 2048;

if (contextSize < MIN_GENERATION_TOKENS) {
throw new Error(
`Model contextSize unexpectedly small for model: ${c.model.providerId} ${c.model.modelId}`
`Model contextSize unexpectedly small for model: ${model.providerId} ${model.modelId}`
);
}

Expand All @@ -343,7 +349,7 @@ export async function* runGeneration(
// Turn the conversation into a digest that can be presented to the model.
const modelConversationRes = await renderConversationForModel({
conversation,
model: c.model,
model,
prompt,
allowedTokenCount: contextSize - MIN_GENERATION_TOKENS,
});
Expand All @@ -356,17 +362,30 @@ export async function* runGeneration(
messageId: agentMessage.sId,
error: {
code: "internal_server_error",
message: `Failed tokenization for ${c.model.providerId} ${c.model.modelId}: ${modelConversationRes.error.message}`,
message: `Failed tokenization for ${model.providerId} ${model.modelId}: ${modelConversationRes.error.message}`,
},
};
return;
}

// If model is gpt4-32k but tokens used is less than GPT_4_CONTEXT_SIZE-MIN_GENERATION_TOKENS,
// then we override the model to gpt4 standard (8k context, cheaper).
if (
model.modelId === GPT_4_32K_MODEL_ID &&
modelConversationRes.value.tokensUsed <
GPT_4_MODEL_CONFIG.contextSize - MIN_GENERATION_TOKENS
) {
model = {
modelId: GPT_4_MODEL_CONFIG.modelId,
providerId: GPT_4_MODEL_CONFIG.providerId,
};
}

const config = cloneBaseConfig(
DustProdActionRegistry["assistant-v2-generator"].config
);
config.MODEL.provider_id = c.model.providerId;
config.MODEL.model_id = c.model.modelId;
config.MODEL.provider_id = model.providerId;
config.MODEL.model_id = model.modelId;
config.MODEL.temperature = c.temperature;

// This is the console.log you want to uncomment to generate inputs for the generator app.
Expand All @@ -381,7 +400,7 @@ export async function* runGeneration(
{
workspaceId: conversation.owner.sId,
conversationId: conversation.sId,
model: c.model,
model: model,
temperature: c.temperature,
},
"[ASSISTANT_TRACE] Generation exection"
Expand Down

0 comments on commit 015c602

Please sign in to comment.