Skip to content

Commit

Permalink
feat: support o1 and o1-high-reasoning for custom assistants
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Fontanier committed Dec 23, 2024
1 parent 20216a0 commit f26aa54
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 27 deletions.
12 changes: 8 additions & 4 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1835,7 +1835,8 @@ impl LLM for OpenAILLM {
Some(self.id.clone()),
prompt,
max_tokens,
temperature,
// [o1] O1 models do not support custom temperature.
if !model_is_o1 { temperature } else { 1.0 },
n,
match top_logprobs {
Some(l) => Some(l),
Expand Down Expand Up @@ -1879,7 +1880,8 @@ impl LLM for OpenAILLM {
Some(self.id.clone()),
prompt,
max_tokens,
temperature,
// [o1] O1 models do not support custom temperature.
if !model_is_o1 { temperature } else { 1.0 },
n,
match top_logprobs {
Some(l) => Some(l),
Expand Down Expand Up @@ -2060,7 +2062,8 @@ impl LLM for OpenAILLM {
&openai_messages,
tools,
tool_choice,
temperature,
// [o1] O1 models do not support custom temperature.
if !model_is_o1 { temperature } else { 1.0 },
match top_p {
Some(t) => t,
None => 1.0,
Expand Down Expand Up @@ -2091,7 +2094,8 @@ impl LLM for OpenAILLM {
&openai_messages,
tools,
tool_choice,
temperature,
// [o1] O1 models do not support custom temperature.
if !model_is_o1 { temperature } else { 1.0 },
match top_p {
Some(t) => t,
None => 1.0,
Expand Down
1 change: 0 additions & 1 deletion front/components/assistant_builder/AssistantBuilder.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ export default function AssistantBuilder({
return (
<InstructionScreen
owner={owner}
plan={plan}
builderState={builderState}
setBuilderState={setBuilderState}
setEdited={setEdited}
Expand Down
29 changes: 10 additions & 19 deletions front/components/assistant_builder/InstructionScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import type {
LightAgentConfigurationType,
ModelConfigurationType,
ModelIdType,
PlanType,
Result,
SupportedModel,
WorkspaceType,
Expand All @@ -29,7 +28,6 @@ import {
GPT_4O_MODEL_ID,
MISTRAL_LARGE_MODEL_ID,
} from "@dust-tt/types";
import { isProviderWhitelisted } from "@dust-tt/types";
import {
ASSISTANT_CREATIVITY_LEVEL_DISPLAY_NAMES,
ASSISTANT_CREATIVITY_LEVEL_TEMPERATURES,
Expand All @@ -53,18 +51,15 @@ import React, {
} from "react";

import type { AssistantBuilderState } from "@app/components/assistant_builder/types";
import {
MODEL_PROVIDER_LOGOS,
USED_MODEL_CONFIGS,
} from "@app/components/providers/types";
import { MODEL_PROVIDER_LOGOS } from "@app/components/providers/types";
import { ParagraphExtension } from "@app/components/text_editor/extensions";
import { getSupportedModelConfig } from "@app/lib/assistant";
import {
plainTextFromTipTapContent,
tipTapContentFromPlainText,
} from "@app/lib/client/assistant_builder/instructions";
import { isUpgraded } from "@app/lib/plans/plan_codes";
import { useAgentConfigurationHistory } from "@app/lib/swr/assistants";
import { useModels } from "@app/lib/swr/models";
import { classNames } from "@app/lib/utils";
import { debounce } from "@app/lib/utils/debounce";

Expand Down Expand Up @@ -111,7 +106,6 @@ const useInstructionEditorService = (editor: Editor | null) => {

export function InstructionScreen({
owner,
plan,
builderState,
setBuilderState,
setEdited,
Expand All @@ -123,7 +117,6 @@ export function InstructionScreen({
agentConfigurationId,
}: {
owner: WorkspaceType;
plan: PlanType;
builderState: AssistantBuilderState;
setBuilderState: (
statefn: (state: AssistantBuilderState) => AssistantBuilderState
Expand Down Expand Up @@ -325,7 +318,6 @@ export function InstructionScreen({
<div className="mt-2 self-end">
<AdvancedSettings
owner={owner}
plan={plan}
generationSettings={builderState.generationSettings}
setGenerationSettings={(generationSettings) => {
setEdited(true);
Expand Down Expand Up @@ -400,6 +392,7 @@ function ModelList({ modelConfigs, onClick }: ModelListProps) {
onClick({
modelId: modelConfig.modelId,
providerId: modelConfig.providerId,
reasoningEffort: modelConfig.reasoningEffort,
});
};

Expand All @@ -420,17 +413,21 @@ function ModelList({ modelConfigs, onClick }: ModelListProps) {

export function AdvancedSettings({
owner,
plan,
generationSettings,
setGenerationSettings,
}: {
owner: WorkspaceType;
plan: PlanType;
generationSettings: AssistantBuilderState["generationSettings"];
setGenerationSettings: (
generationSettingsSettings: AssistantBuilderState["generationSettings"]
) => void;
}) {
const { models, isModelsLoading } = useModels({ owner });

if (isModelsLoading) {
return null;
}

const supportedModelConfig = getSupportedModelConfig(
generationSettings.modelSettings
);
Expand All @@ -441,13 +438,7 @@ export function AdvancedSettings({

const bestPerformingModelConfigs: ModelConfigurationType[] = [];
const otherModelConfigs: ModelConfigurationType[] = [];
for (const modelConfig of USED_MODEL_CONFIGS) {
if (
!isProviderWhitelisted(owner, modelConfig.providerId) ||
(modelConfig.largeModel && !isUpgraded(plan))
) {
continue;
}
for (const modelConfig of models) {
if (isBestPerformingModel(modelConfig.modelId)) {
bestPerformingModelConfigs.push(modelConfig);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ export async function submitAssistantBuilderForm({
modelId: builderState.generationSettings.modelSettings.modelId,
providerId: builderState.generationSettings.modelSettings.providerId,
temperature: builderState.generationSettings.temperature,
reasoningEffort:
builderState.generationSettings.modelSettings.reasoningEffort,
},
maxStepsPerRun,
visualizationEnabled: builderState.visualizationEnabled,
Expand Down
6 changes: 6 additions & 0 deletions front/components/providers/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import {
MISTRAL_CODESTRAL_MODEL_CONFIG,
MISTRAL_LARGE_MODEL_CONFIG,
MISTRAL_SMALL_MODEL_CONFIG,
O1_HIGH_REASONING_MODEL_CONFIG,
O1_MINI_MODEL_CONFIG,
O1_MODEL_CONFIG,
TOGETHERAI_LLAMA_3_3_70B_INSTRUCT_TURBO_MODEL_CONFIG,
TOGETHERAI_QWEN_2_5_CODER_32B_INSTRUCT_MODEL_CONFIG,
TOGETHERAI_QWEN_32B_PREVIEW_MODEL_CONFIG,
Expand All @@ -38,6 +41,9 @@ export const USED_MODEL_CONFIGS: readonly ModelConfig[] = [
GPT_4O_MODEL_CONFIG,
GPT_4O_MINI_MODEL_CONFIG,
GPT_4_TURBO_MODEL_CONFIG,
O1_MODEL_CONFIG,
O1_MINI_MODEL_CONFIG,
O1_HIGH_REASONING_MODEL_CONFIG,
CLAUDE_3_5_SONNET_DEFAULT_MODEL_CONFIG,
CLAUDE_3_5_HAIKU_DEFAULT_MODEL_CONFIG,
MISTRAL_LARGE_MODEL_CONFIG,
Expand Down
1 change: 1 addition & 0 deletions front/lib/api/assistant/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,7 @@ export async function createAgentConfiguration(
providerId: model.providerId,
modelId: model.modelId,
temperature: model.temperature,
reasoningEffort: model.reasoningEffort,
maxStepsPerRun,
visualizationEnabled,
pictureUrl,
Expand Down
21 changes: 20 additions & 1 deletion front/lib/api/assistant/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ import {
makeMessageRateLimitKeyForWorkspace,
} from "@app/lib/api/assistant/rate_limits";
import { maybeUpsertFileAttachment } from "@app/lib/api/files/utils";
import { Authenticator } from "@app/lib/auth";
import { getSupportedModelConfig } from "@app/lib/assistant";
import { Authenticator, getFeatureFlags } from "@app/lib/auth";
import { AgentMessageContent } from "@app/lib/models/assistant/agent_message_content";
import {
AgentMessage,
Expand Down Expand Up @@ -711,6 +712,8 @@ export async function* postUserMessage(
return;
}

const featureFlags = await getFeatureFlags(owner);

if (!canAccessConversation(auth, conversation)) {
yield {
type: "user_message_error",
Expand Down Expand Up @@ -771,6 +774,22 @@ export async function* postUserMessage(
};
return; // Stop processing if any agent uses a disabled provider
}
const supportedModelConfig = getSupportedModelConfig(agentConfig.model);
if (
supportedModelConfig.featureFlag &&
!featureFlags.includes(supportedModelConfig.featureFlag)
) {
yield {
type: "agent_disabled_error",
created: Date.now(),
configurationId: agentConfig.sId,
error: {
code: "model_not_supported",
message: "The model is not supported.",
},
};
return;
}
}

// In one big transaction creante all Message, UserMessage, AgentMessage and Mention rows.
Expand Down
3 changes: 2 additions & 1 deletion front/lib/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export function getSupportedModelConfig(
return SUPPORTED_MODEL_CONFIGS.find(
(m) =>
m.modelId === supportedModel.modelId &&
m.providerId === supportedModel.providerId
m.providerId === supportedModel.providerId &&
m.reasoningEffort === supportedModel.reasoningEffort
) as (typeof SUPPORTED_MODEL_CONFIGS)[number];
}
20 changes: 20 additions & 0 deletions front/lib/swr/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import type { LightWorkspaceType } from "@dust-tt/types";
import type { Fetcher } from "swr";

import { fetcher, useSWRWithDefaults } from "@app/lib/swr/swr";
import type { GetAvailableModelsResponseType } from "@app/pages/api/w/[wId]/models";

export function useModels({ owner }: { owner: LightWorkspaceType }) {
const modelsFetcher: Fetcher<GetAvailableModelsResponseType> = fetcher;

const { data, error } = useSWRWithDefaults(
`/api/w/${owner.sId}/models`,
modelsFetcher
);

return {
models: data ? data.models : [],
isModelsLoading: !error && !data,
isModelsError: !!error,
};
}
67 changes: 67 additions & 0 deletions front/pages/api/w/[wId]/models.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import type {
ModelConfigurationType,
WithAPIErrorResponse,
} from "@dust-tt/types";
import { isProviderWhitelisted } from "@dust-tt/types";
import type { NextApiRequest, NextApiResponse } from "next";

import { USED_MODEL_CONFIGS } from "@app/components/providers/types";
import { withSessionAuthenticationForWorkspace } from "@app/lib/api/auth_wrappers";
import type { Authenticator } from "@app/lib/auth";
import { getFeatureFlags } from "@app/lib/auth";
import { isUpgraded } from "@app/lib/plans/plan_codes";
import { apiError } from "@app/logger/withlogging";

export type GetAvailableModelsResponseType = {
models: ModelConfigurationType[];
};

async function handler(
req: NextApiRequest,
res: NextApiResponse<WithAPIErrorResponse<GetAvailableModelsResponseType>>,
auth: Authenticator
): Promise<void> {
const owner = auth.getNonNullableWorkspace();
const plan = auth.plan();

switch (req.method) {
case "GET":
const featureFlags = await getFeatureFlags(owner);

const models: ModelConfigurationType[] = [];
for (const m of USED_MODEL_CONFIGS) {
if (
!isProviderWhitelisted(owner, m.providerId) ||
(m.largeModel && !isUpgraded(plan))
) {
continue;
}

if (m.featureFlag && !featureFlags.includes(m.featureFlag)) {
continue;
}

if (
m.customAssistantFeatureFlag &&
!featureFlags.includes(m.customAssistantFeatureFlag)
) {
continue;
}

models.push(m);
}

return res.status(200).json({ models });

default:
return apiError(req, res, {
status_code: 405,
api_error: {
type: "method_not_supported_error",
message: "The method passed is not supported, GET is expected.",
},
});
}
}

export default withSessionAuthenticationForWorkspace(handler);
7 changes: 7 additions & 0 deletions types/src/front/api_handlers/internal/agent_configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ const ModelConfigurationSchema = t.intersection([
}),
// TODO(2024-11-04 flav) Clean up this legacy type.
t.partial(multiActionsCommonFields),
t.partial({
reasoningEffort: t.union([
t.literal("low"),
t.literal("medium"),
t.literal("high"),
]),
}),
]);
const IsSupportedModelSchema = new t.Type<SupportedModel>(
"SupportedModel",
Expand Down
Loading

0 comments on commit f26aa54

Please sign in to comment.