Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1761 - Support for Conversation Specific Model #2007

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/renderer/components/Header.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import { useAtomValue, useSetAtom } from 'jotai'
import * as sessionActions from '../stores/sessionActions'
import Toolbar from './Toolbar'
import { cn } from '@/lib/utils'
import { getModelDisplayName } from '../packages/models'

interface Props { }

export default function Header(props: Props) {
const theme = useTheme()
const currentSession = useAtomValue(atoms.currentSessionAtom)
const settings = useAtomValue(atoms.settingsAtom)
const setChatConfigDialogSession = useSetAtom(atoms.chatConfigDialogAtom)

useEffect(() => {
Expand Down Expand Up @@ -52,11 +54,17 @@ export default function Header(props: Props) {
editCurrentSession()
}}
>
{
<Typography variant="h6" noWrap className={cn('max-w-56', 'ml-3')}>
{currentSession.name}
</Typography>
}
<Typography variant="h6" noWrap className={cn('max-w-56', 'ml-3')}>
{currentSession.name}
</Typography>
<Typography
variant="body2"
color="text.secondary"
noWrap
className="ml-2 self-center"
>
{getModelDisplayName({...settings, aiProvider: currentSession.aiProvider || settings.aiProvider}, currentSession.type || 'chat')}
</Typography>
</Typography>
<Toolbar />
</div>
Expand Down
26 changes: 25 additions & 1 deletion src/renderer/pages/ChatConfigWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,27 @@ import {
DialogTitle,
DialogContentText,
TextField,
Box,
} from '@mui/material'
import {
Session,
createMessage,
ModelSettings,
} from '../../shared/types'
import { useTranslation } from 'react-i18next'
import * as sessionActions from '../stores/sessionActions'
import * as atoms from '../stores/atoms'
import { useAtom } from 'jotai'
import { trackingEvent } from '@/packages/event'
import AIProviderSelect from '../components/AIProviderSelect'

interface Props {
}

export default function ChatConfigWindow(props: Props) {
const { t } = useTranslation()
const [chatConfigDialogSession, setChatConfigDialogSession] = useAtom(atoms.chatConfigDialogAtom)
const [settings] = useAtom(atoms.settingsAtom)

const [editingData, setEditingData] = React.useState<Session | null>(chatConfigDialogSession)
useEffect(() => {
Expand Down Expand Up @@ -78,9 +82,23 @@ export default function ChatConfigWindow(props: Props) {
setChatConfigDialogSession(null)
}

// Create a merged settings object that uses the session provider if set
const effectiveSettings = React.useMemo(() => ({
...settings,
aiProvider: editingData?.aiProvider || settings.aiProvider,
}), [settings, editingData?.aiProvider])

if (!chatConfigDialogSession || !editingData) {
return null
}

const handleProviderChange = (newSettings: ModelSettings) => {
setEditingData({
...editingData,
aiProvider: newSettings.aiProvider,
})
}

return (
<Dialog open={!!chatConfigDialogSession} onClose={onCancel} fullWidth>
<DialogTitle>{t('Conversation Settings')}</DialogTitle>
Expand All @@ -95,7 +113,13 @@ export default function ChatConfigWindow(props: Props) {
value={editingData.name}
onChange={(e) => setEditingData({ ...editingData, name: e.target.value })}
/>
<div className='mt-1'>
<Box className='mt-4'>
<AIProviderSelect
settings={effectiveSettings}
setSettings={handleProviderChange}
/>
</Box>
<div className='mt-4'>
<TextField
margin="dense"
label={t('Instruction (System Prompt)')}
Expand Down
18 changes: 13 additions & 5 deletions src/renderer/stores/sessionActions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
createMessage,
Message,
Session,
ModelProvider,
} from '../../shared/types'
import * as atoms from './atoms'
import * as promptFormat from '../packages/prompts'
Expand Down Expand Up @@ -180,12 +181,17 @@ export async function generate(sessionId: string, targetMsg: Message) {
return
}
const placeholder = '...'

// Use session aiProvider if specified, otherwise use global setting
const effectiveAIProvider: ModelProvider = session.aiProvider || settings.aiProvider
const effectiveSettings = { ...settings, aiProvider: effectiveAIProvider }

targetMsg = {
...targetMsg,
content: placeholder,
cancel: undefined,
aiProvider: settings.aiProvider,
model: getModelDisplayName(settings, session.type || 'chat'),
aiProvider: effectiveAIProvider,
model: getModelDisplayName(effectiveSettings, session.type || 'chat'),
generating: true,
errorCode: undefined,
error: undefined,
Expand All @@ -197,7 +203,7 @@ export async function generate(sessionId: string, targetMsg: Message) {
let targetMsgIx = messages.findIndex((m) => m.id === targetMsg.id)

try {
const model = getModel(settings, configs)
const model = getModel(effectiveSettings, configs)
switch (session.type) {
case 'chat':
case undefined:
Expand Down Expand Up @@ -237,7 +243,7 @@ export async function generate(sessionId: string, targetMsg: Message) {
errorCode,
error: `${err.message}`,
errorExtra: {
aiProvider: settings.aiProvider,
aiProvider: effectiveAIProvider,
host: err['host'],
},
}
Expand All @@ -254,7 +260,9 @@ async function _generateName(sessionId: string, modifyName: (sessionId: string,
}
const configs = await platform.getConfig()
try {
const model = getModel(settings, configs)
const effectiveAIProvider: ModelProvider = session.aiProvider || settings.aiProvider
const effectiveSettings = { ...settings, aiProvider: effectiveAIProvider }
const model = getModel(effectiveSettings, configs)
let name = await model.chat(promptFormat.nameConversation(
session.messages
.filter(m => m.role !== 'system')
Expand Down
7 changes: 7 additions & 0 deletions src/shared/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export interface Session {
type?: SessionType
name: string
picUrl?: string
aiProvider?: ModelProvider
messages: Message[]
copilotId?: string
}
Expand All @@ -69,6 +70,7 @@ export enum ModelProvider {
Ollama = 'ollama',
SiliconFlow = 'silicon-flow',
LMStudio = 'lm-studio',
PPIO = 'ppio',
}

export interface ModelSettings {
Expand Down Expand Up @@ -115,6 +117,11 @@ export interface ModelSettings {
siliconCloudKey: string
siliconCloudModel: siliconflow.Model | 'custom-model'

// ppio
ppioHost: string
ppioKey: string
ppioModel: string

temperature: number
topP: number
openaiMaxContextMessageCount: number
Expand Down