Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add reasoning model #750

Merged
merged 10 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Get your OpenAI API Key here: https://platform.openai.com/account/api-keys
# Get your OpenAI API Key here for chat models: https://platform.openai.com/account/api-keys
OPENAI_API_KEY=****

# Get your Fireworks AI API Key here for reasoning models: https://fireworks.ai/account/api-keys
FIREWORKS_API_KEY=****

# Generate a random secret: https://generate-secret.vercel.app/32 or `openssl rand -base64 32`
AUTH_SECRET=****

Expand Down
10 changes: 5 additions & 5 deletions app/(chat)/actions.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
'use server';

import { type CoreUserMessage, generateText, Message } from 'ai';
import { generateText, Message } from 'ai';
import { cookies } from 'next/headers';

import { customModel } from '@/lib/ai';
import {
deleteMessagesByChatIdAfterTimestamp,
getMessageById,
updateChatVisiblityById,
} from '@/lib/db/queries';
import { VisibilityType } from '@/components/visibility-selector';
import { myProvider } from '@/lib/ai/models';

export async function saveModelId(model: string) {
export async function saveChatModelAsCookie(model: string) {
const cookieStore = await cookies();
cookieStore.set('model-id', model);
cookieStore.set('chat-model', model);
}

export async function generateTitleFromUserMessage({
Expand All @@ -22,7 +22,7 @@ export async function generateTitleFromUserMessage({
message: Message;
}) {
const { text: title } = await generateText({
model: customModel('gpt-4o-mini'),
model: myProvider.languageModel('title-model'),
system: `\n
- you will generate a short title based on the first message a user begins a conversation with
- ensure it is not more than 80 characters long
Expand Down
56 changes: 27 additions & 29 deletions app/(chat)/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import {
createDataStreamResponse,
smoothStream,
streamText,
wrapLanguageModel,
} from 'ai';

import { auth } from '@/app/(auth)/auth';
import { customModel } from '@/lib/ai';
import { models } from '@/lib/ai/models';
import { myProvider } from '@/lib/ai/models';
import { systemPrompt } from '@/lib/ai/prompts';
import {
deleteChatById,
Expand Down Expand Up @@ -48,8 +48,8 @@ export async function POST(request: Request) {
const {
id,
messages,
modelId,
}: { id: string; messages: Array<Message>; modelId: string } =
selectedChatModel,
}: { id: string; messages: Array<Message>; selectedChatModel: string } =
await request.json();

const session = await auth();
Expand All @@ -58,12 +58,6 @@ export async function POST(request: Request) {
return new Response('Unauthorized', { status: 401 });
}

const model = models.find((model) => model.id === modelId);

if (!model) {
return new Response('Model not found', { status: 404 });
}

const userMessage = getMostRecentUserMessage(messages);

if (!userMessage) {
Expand All @@ -84,7 +78,7 @@ export async function POST(request: Request) {
return createDataStreamResponse({
execute: (dataStream) => {
const result = streamText({
model: customModel(model.apiIdentifier),
model: myProvider.languageModel(selectedChatModel),
system: systemPrompt,
messages,
maxSteps: 5,
Expand All @@ -93,32 +87,31 @@ export async function POST(request: Request) {
experimental_generateMessageId: generateUUID,
tools: {
getWeather,
createDocument: createDocument({ session, dataStream, model }),
updateDocument: updateDocument({ session, dataStream, model }),
createDocument: createDocument({ session, dataStream }),
updateDocument: updateDocument({ session, dataStream }),
requestSuggestions: requestSuggestions({
session,
dataStream,
model,
}),
},
onFinish: async ({ response }) => {
onFinish: async ({ response, reasoning }) => {
if (session.user?.id) {
try {
const responseMessagesWithoutIncompleteToolCalls =
sanitizeResponseMessages(response.messages);
const sanitizedResponseMessages = sanitizeResponseMessages({
messages: response.messages,
reasoning,
});

await saveMessages({
messages: responseMessagesWithoutIncompleteToolCalls.map(
(message) => {
return {
id: message.id,
chatId: id,
role: message.role,
content: message.content,
createdAt: new Date(),
};
},
),
messages: sanitizedResponseMessages.map((message) => {
return {
id: message.id,
chatId: id,
role: message.role,
content: message.content,
createdAt: new Date(),
};
}),
});
} catch (error) {
console.error('Failed to save chat');
Expand All @@ -131,7 +124,12 @@ export async function POST(request: Request) {
},
});

result.mergeIntoDataStream(dataStream);
result.mergeIntoDataStream(dataStream, {
sendReasoning: true,
});
},
onError: (error) => {
return 'Oops, an error occured!';
},
});
}
Expand Down
24 changes: 18 additions & 6 deletions app/(chat)/chat/[id]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { notFound } from 'next/navigation';

import { auth } from '@/app/(auth)/auth';
import { Chat } from '@/components/chat';
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
import { getChatById, getMessagesByChatId } from '@/lib/db/queries';
import { convertToUIMessages } from '@/lib/utils';
import { DataStreamHandler } from '@/components/data-stream-handler';
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';

export default async function Page(props: { params: Promise<{ id: string }> }) {
const params = await props.params;
Expand Down Expand Up @@ -34,17 +34,29 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
});

const cookieStore = await cookies();
const modelIdFromCookie = cookieStore.get('model-id')?.value;
const selectedModelId =
models.find((model) => model.id === modelIdFromCookie)?.id ||
DEFAULT_MODEL_NAME;
const chatModelFromCookie = cookieStore.get('chat-model');

if (!chatModelFromCookie) {
return (
<>
<Chat
id={chat.id}
initialMessages={convertToUIMessages(messagesFromDb)}
selectedChatModel={DEFAULT_CHAT_MODEL}
selectedVisibilityType={chat.visibility}
isReadonly={session?.user?.id !== chat.userId}
/>
<DataStreamHandler id={id} />
</>
);
}

return (
<>
<Chat
id={chat.id}
initialMessages={convertToUIMessages(messagesFromDb)}
selectedModelId={selectedModelId}
selectedChatModel={chatModelFromCookie.value}
selectedVisibilityType={chat.visibility}
isReadonly={session?.user?.id !== chat.userId}
/>
Expand Down
24 changes: 18 additions & 6 deletions app/(chat)/page.tsx
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
import { cookies } from 'next/headers';

import { Chat } from '@/components/chat';
import { DEFAULT_MODEL_NAME, models } from '@/lib/ai/models';
import { DEFAULT_CHAT_MODEL } from '@/lib/ai/models';
import { generateUUID } from '@/lib/utils';
import { DataStreamHandler } from '@/components/data-stream-handler';

export default async function Page() {
const id = generateUUID();

const cookieStore = await cookies();
const modelIdFromCookie = cookieStore.get('model-id')?.value;
const modelIdFromCookie = cookieStore.get('chat-model');

const selectedModelId =
models.find((model) => model.id === modelIdFromCookie)?.id ||
DEFAULT_MODEL_NAME;
if (!modelIdFromCookie) {
return (
<>
<Chat
key={id}
id={id}
initialMessages={[]}
selectedChatModel={DEFAULT_CHAT_MODEL}
selectedVisibilityType="private"
isReadonly={false}
/>
<DataStreamHandler id={id} />
</>
);
}

return (
<>
<Chat
key={id}
id={id}
initialMessages={[]}
selectedModelId={selectedModelId}
selectedChatModel={modelIdFromCookie.value}
selectedVisibilityType="private"
isReadonly={false}
/>
Expand Down
13 changes: 9 additions & 4 deletions components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ import { MultimodalInput } from './multimodal-input';
import { Messages } from './messages';
import { VisibilityType } from './visibility-selector';
import { useBlockSelector } from '@/hooks/use-block';
import { toast } from 'sonner';

export function Chat({
id,
initialMessages,
selectedModelId,
selectedChatModel,
selectedVisibilityType,
isReadonly,
}: {
id: string;
initialMessages: Array<Message>;
selectedModelId: string;
selectedChatModel: string;
selectedVisibilityType: VisibilityType;
isReadonly: boolean;
}) {
Expand All @@ -42,14 +43,18 @@ export function Chat({
reload,
} = useChat({
id,
body: { id, modelId: selectedModelId },
body: { id, selectedChatModel: selectedChatModel },
initialMessages,
experimental_throttle: 100,
sendExtraMessageFields: true,
generateId: generateUUID,
onFinish: () => {
mutate('/api/history');
},
onError: (error) => {
console.log(error);
toast.error('An error occured, please try again!');
},
});

const { data: votes } = useSWR<Array<Vote>>(
Expand All @@ -65,7 +70,7 @@ export function Chat({
<div className="flex flex-col min-w-0 h-dvh bg-background">
<ChatHeader
chatId={id}
selectedModelId={selectedModelId}
selectedModelId={selectedChatModel}
selectedVisibilityType={selectedVisibilityType}
isReadonly={isReadonly}
/>
Expand Down
2 changes: 1 addition & 1 deletion components/markdown.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Link from 'next/link';
import React, { memo, useMemo, useState } from 'react';
import React, { memo } from 'react';
import ReactMarkdown, { type Components } from 'react-markdown';
import remarkGfm from 'remark-gfm';
import { CodeBlock } from './code-block';
Expand Down
75 changes: 75 additions & 0 deletions components/message-reasoning.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
'use client';

import { useState } from 'react';
import { ChevronDownIcon, LoaderIcon } from './icons';
import { motion, AnimatePresence } from 'framer-motion';
import { Markdown } from './markdown';

interface MessageReasoningProps {
isLoading: boolean;
reasoning: string;
}

export function MessageReasoning({
isLoading,
reasoning,
}: MessageReasoningProps) {
const [isExpanded, setIsExpanded] = useState(true);

const variants = {
collapsed: {
height: 0,
opacity: 0,
marginTop: 0,
marginBottom: 0,
},
expanded: {
height: 'auto',
opacity: 1,
marginTop: '1rem',
marginBottom: '0.5rem',
},
};

return (
<div className="flex flex-col">
{isLoading ? (
<div className="flex flex-row gap-2 items-center">
<div className="font-medium">Reasoning</div>
<div className="animate-spin">
<LoaderIcon />
</div>
</div>
) : (
<div className="flex flex-row gap-2 items-center">
<div className="font-medium">Reasoned for a few seconds</div>
<div
className="cursor-pointer"
onClick={() => {
setIsExpanded(!isExpanded);
}}
>
<ChevronDownIcon />
</div>
</div>
)}

<AnimatePresence initial={false}>
{isExpanded && (
<motion.div
key="content"
initial="collapsed"
animate="expanded"
exit="collapsed"
variants={variants}
transition={{ duration: 0.2, ease: 'easeInOut' }}
style={{ overflow: 'hidden' }}
className="pl-4 text-zinc-600 dark:text-zinc-400 border-l flex flex-col gap-4"
>
<Markdown>{reasoning}</Markdown>
</motion.div>
)}
</AnimatePresence>
</div>
);
}
Loading