diff --git a/lib/llm/base_llm_client.dart b/lib/llm/base_llm_client.dart index 70ae896..0920290 100644 --- a/lib/llm/base_llm_client.dart +++ b/lib/llm/base_llm_client.dart @@ -15,7 +15,7 @@ abstract class BaseLLMClient { final response = await chatCompletion( CompletionRequest( - model: ProviderManager.chatProvider.currentModel, + model: ProviderManager.chatModelProvider.currentModel, messages: [ChatMessage(role: MessageRole.user, content: content)], tools: openaiTools, ), @@ -43,4 +43,6 @@ abstract class BaseLLMClient { } Future genTitle(List messages); + + Future> models(); } diff --git a/lib/llm/claude_client.dart b/lib/llm/claude_client.dart index 69bd45c..22ae29f 100644 --- a/lib/llm/claude_client.dart +++ b/lib/llm/claude_client.dart @@ -26,32 +26,34 @@ class ClaudeClient extends BaseLLMClient { final Dio _dio; ClaudeClient({ - required this.apiKey, + required this.apiKey, String? baseUrl, Dio? dio, - }) : baseUrl = (baseUrl == null || baseUrl.isEmpty) - ? 'https://api.anthropic.com/v1/messages' + }) : baseUrl = (baseUrl == null || baseUrl.isEmpty) + ? 'https://api.anthropic.com/v1' : baseUrl, - _dio = dio ?? Dio(BaseOptions( - headers: { - 'Content-Type': 'application/json', - 'x-api-key': apiKey, - 'anthropic-version': '2023-06-01', - }, - responseType: ResponseType.stream, - )); + _dio = dio ?? + Dio(BaseOptions( + headers: { + 'Content-Type': 'application/json', + 'x-api-key': apiKey, + 'anthropic-version': '2023-06-01', + }, + )); @override Future chatCompletion(CompletionRequest request) async { - final messages = request.messages.map((m) => { - 'role': m.role == MessageRole.user ? 'user' : 'assistant', - 'content': [ - { - 'type': 'text', - 'text': m.content ?? '', - } - ], - }).toList(); + final messages = request.messages + .map((m) => { + 'role': m.role == MessageRole.user ? 'user' : 'assistant', + 'content': [ + { + 'type': 'text', + 'text': m.content ?? '', + } + ], + }) + .toList(); final body = { 'model': request.model, @@ -69,46 +71,51 @@ class ClaudeClient extends BaseLLMClient { try { final response = await _dio.post( - baseUrl, + '$baseUrl/messages', data: jsonEncode(body), ); - final buffer = StringBuffer(); - await for (final chunk in response.data.stream) { - buffer.write(utf8.decode(chunk)); + var json; + if (response.data is ResponseBody) { + final responseBody = response.data as ResponseBody; + final responseStr = await utf8.decodeStream(responseBody.stream); + json = jsonDecode(responseStr); + } else { + json = response.data; } - final json = jsonDecode(buffer.toString()); final content = json['content'][0]['text']; - + // Parse tool calls if present - final toolCalls = json['tool_calls']?.map((t) => ToolCall( - id: t['id'], - type: t['type'], - function: FunctionCall( - name: t['function']['name'], - arguments: t['function']['arguments'], - ), - ))?.toList(); + final toolCalls = json['tool_calls'] + ?.map((t) => ToolCall( + id: t['id'], + type: t['type'], + function: FunctionCall( + name: t['function']['name'], + arguments: t['function']['arguments'], + ), + )) + ?.toList(); return LLMResponse( content: content, toolCalls: toolCalls, ); - } catch (e) { - final tips = "Claude API call failed: $baseUrl body: $body error: $e"; - Logger.root.severe(tips); - throw Exception(tips); + throw Exception( + "Claude API call failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e"); } } @override Stream chatStreamCompletion(CompletionRequest request) async* { - final messages = request.messages.map((m) => { - 'role': m.role == MessageRole.user ? 'user' : 'assistant', - 'content': m.content ?? '', - }).toList(); + final messages = request.messages + .map((m) => { + 'role': m.role == MessageRole.user ? 'user' : 'assistant', + 'content': m.content ?? '', + }) + .toList(); final body = { 'model': request.model, @@ -127,15 +134,16 @@ class ClaudeClient extends BaseLLMClient { } try { + _dio.options.responseType = ResponseType.stream; final response = await _dio.post( - baseUrl, + '$baseUrl/messages', data: jsonEncode(body), ); String buffer = ''; String currentContent = ''; List? currentToolCalls; - + await for (final chunk in response.data.stream) { final decodedChunk = utf8.decode(chunk); buffer += decodedChunk; @@ -145,9 +153,9 @@ class ClaudeClient extends BaseLLMClient { final line = buffer.substring(0, index).trim(); buffer = buffer.substring(index + 1); - if (!line.startsWith('data: ')) continue; - - final jsonStr = line.substring(6).trim(); + if (!line.startsWith('data:')) continue; + + final jsonStr = line.substring(5).trim(); if (jsonStr.isEmpty) continue; try { @@ -155,6 +163,7 @@ class ClaudeClient extends BaseLLMClient { final eventType = event['type']; switch (eventType) { + case 'content_block_start': case 'content_block_delta': final delta = event['delta']; if (delta['type'] == 'text_delta') { @@ -190,9 +199,8 @@ class ClaudeClient extends BaseLLMClient { } } } catch (e) { - final error = "Claude streaming API call failed: $baseUrl body: $body error: $e"; - Logger.root.severe(error); - throw Exception(error); + throw Exception( + "Claude streaming API call failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e"); } } @@ -205,7 +213,8 @@ class ClaudeClient extends BaseLLMClient { final prompt = ChatMessage( role: MessageRole.user, - content: """Generate a concise title (max 20 characters) for the following conversation. + content: + """Generate a concise title (max 20 characters) for the following conversation. The title should summarize the main topic. Return only the title without any explanation or extra punctuation. Conversation: @@ -213,10 +222,10 @@ $conversationText""", ); final response = await chatCompletion(CompletionRequest( - model: "claude-3-5-haiku-latest", + model: "claude-3-5-haiku-20241022", messages: [prompt], )); - + return response.content?.trim() ?? "New Chat"; } @@ -226,35 +235,38 @@ $conversationText""", Map>> toolsResponse, ) async { // Convert tools to Claude's format - final tools = toolsResponse.entries.map((entry) { - return entry.value.map((tool) { - final parameters = tool['parameters']; - if (parameters is! Map) { - return { - 'name': tool['name'], - 'description': tool['description'], - 'input_schema': { - 'type': 'object', - 'properties': {}, - 'required': [], - }, - }; - } + final tools = toolsResponse.entries + .map((entry) { + return entry.value.map((tool) { + final parameters = tool['parameters']; + if (parameters is! Map) { + return { + 'name': tool['name'], + 'description': tool['description'], + 'input_schema': { + 'type': 'object', + 'properties': {}, + 'required': [], + }, + }; + } - return { - 'name': tool['name'], - 'description': tool['description'], - 'input_schema': { - 'type': 'object', - 'properties': parameters['properties'] ?? {}, - 'required': parameters['required'] ?? [], - }, - }; - }).toList(); - }).expand((x) => x).toList(); + return { + 'name': tool['name'], + 'description': tool['description'], + 'input_schema': { + 'type': 'object', + 'properties': parameters['properties'] ?? {}, + 'required': parameters['required'] ?? [], + }, + }; + }).toList(); + }) + .expand((x) => x) + .toList(); final body = { - 'model': ProviderManager.chatProvider.currentModel, + 'model': ProviderManager.chatModelProvider.currentModel, 'messages': [ { 'role': 'user', @@ -267,19 +279,21 @@ $conversationText""", try { final response = await _dio.post( - baseUrl, + '$baseUrl/messages', data: jsonEncode(body), ); - final buffer = StringBuffer(); - await for (final chunk in response.data.stream) { - buffer.write(utf8.decode(chunk)); + var jsonData; + if (response.data is ResponseBody) { + final responseBody = response.data as ResponseBody; + final responseStr = await utf8.decodeStream(responseBody.stream); + jsonData = jsonDecode(responseStr); + } else { + jsonData = response.data; } - final json = jsonDecode(buffer.toString()); - // Check if response contains tool calls in the content array - final contentBlocks = json['content'] as List?; + final contentBlocks = jsonData['content'] as List?; if (contentBlocks == null || contentBlocks.isEmpty) { return { 'need_tool_call': false, @@ -288,9 +302,9 @@ $conversationText""", } // Look for tool_calls in the response - final toolUseBlocks = contentBlocks.where((block) => - block['type'] == 'tool_calls' || block['type'] == 'tool_use'); - + final toolUseBlocks = contentBlocks.where((block) => + block['type'] == 'tool_calls' || block['type'] == 'tool_use'); + if (toolUseBlocks.isEmpty) { // Get text content from the first text block final textBlock = contentBlocks.firstWhere( @@ -304,11 +318,13 @@ $conversationText""", } // Extract tool calls - final toolCalls = toolUseBlocks.map((block) => { - 'id': block['id'], - 'name': block['name'], - 'arguments': block['input'], - }).toList(); + final toolCalls = toolUseBlocks + .map((block) => { + 'id': block['id'], + 'name': block['name'], + 'arguments': block['input'], + }) + .toList(); // Get any accompanying text content final textBlock = contentBlocks.firstWhere( @@ -321,10 +337,29 @@ $conversationText""", 'content': textBlock['text'] ?? '', 'tool_calls': toolCalls, }; - } catch (e) { - Logger.root.severe('Claude tool call check failed: $baseUrl body: $body error: $e'); - throw Exception('Failed to check tool calls: $e'); + throw Exception( + 'Claude tool call check failed: $baseUrl/messages body: ${jsonEncode(body)} error: $e'); + } + } + + @override + Future> models() async { + try { + final response = await _dio.get("$baseUrl/models"); + + final data = response.data; + + final models = (data['data'] as List) + .map((m) => m['id'].toString()) + .where((id) => id.contains('claude')) + .toList(); + + return models; + } catch (e, trace) { + Logger.root.severe('获取模型列表失败: $e, trace: $trace'); + // 返回预定义的模型列表作为后备 + return []; } } } diff --git a/lib/llm/llm_factory.dart b/lib/llm/llm_factory.dart index d68a969..78a8fae 100644 --- a/lib/llm/llm_factory.dart +++ b/lib/llm/llm_factory.dart @@ -1,6 +1,8 @@ import 'openai_client.dart'; import 'claude_client.dart'; import 'base_llm_client.dart'; +import 'package:ChatMcp/provider/provider_manager.dart'; +import 'package:logging/logging.dart'; enum LLMProvider { openAI, claude, llama } @@ -17,3 +19,45 @@ class LLMFactory { } } } + +class LLMFactoryHelper { + static BaseLLMClient createFromModel(String currentModel) { + // 根据模型名称判断 provider + final provider = currentModel.startsWith('gpt') ? 'openai' : 'claude'; + + // 获取配置信息 + final apiKey = + ProviderManager.settingsProvider.apiSettings[provider]?.apiKey ?? ''; + final baseUrl = + ProviderManager.settingsProvider.apiSettings[provider]?.apiEndpoint ?? + ''; + + Logger.root.fine( + 'Using API Key: $apiKey for provider: $provider model: $currentModel'); + + // 创建 LLM 客户端 + return LLMFactory.create( + provider == 'openai' ? LLMProvider.openAI : LLMProvider.claude, + apiKey: apiKey, + baseUrl: baseUrl); + } + + static Future> getAvailableModels() async { + List providers = ["openai", "claude"]; + List models = []; + for (var provider in providers) { + final apiKey = + ProviderManager.settingsProvider.apiSettings[provider]?.apiKey ?? ''; + final baseUrl = + ProviderManager.settingsProvider.apiSettings[provider]?.apiEndpoint ?? + ''; + final client = LLMFactory.create( + provider == "openai" ? LLMProvider.openAI : LLMProvider.claude, + apiKey: apiKey, + baseUrl: baseUrl); + models.addAll(await client.models()); + } + + return models; + } +} diff --git a/lib/llm/model.dart b/lib/llm/model.dart index 981c324..761ecae 100644 --- a/lib/llm/model.dart +++ b/lib/llm/model.dart @@ -140,6 +140,9 @@ class Model { required this.name, required this.label, }); + + @override + String toString() => 'Model(name: $name, label: $label)'; } class CompletionRequest { diff --git a/lib/llm/openai_client.dart b/lib/llm/openai_client.dart index 36077b5..e35590c 100644 --- a/lib/llm/openai_client.dart +++ b/lib/llm/openai_client.dart @@ -41,13 +41,12 @@ class OpenAIClient extends BaseLLMClient { 'Content-Type': 'application/json', 'Authorization': 'Bearer $apiKey', }, - responseType: ResponseType.stream, )); @override Future chatCompletion(CompletionRequest request) async { final body = { - 'model': 'gpt-4o-mini', + 'model': request.model, 'messages': request.messages.map((m) => m.toJson()).toList(), }; @@ -64,17 +63,17 @@ class OpenAIClient extends BaseLLMClient { data: bodyStr, ); - // 处理流数据 - final buffer = StringBuffer(); - - await for (final chunk in response.data.stream) { - buffer.write(utf8.decode(chunk)); + // 处理 ResponseBody 类型的响应 + var jsonData; + if (response.data is ResponseBody) { + final responseBody = response.data as ResponseBody; + final responseStr = await utf8.decodeStream(responseBody.stream); + jsonData = jsonDecode(responseStr); + } else { + jsonData = response.data; } - final responseBody = buffer.toString(); - final json = jsonDecode(responseBody); - - final message = json['choices'][0]['message']; + final message = jsonData['choices'][0]['message']; // 解析工具调用 final toolCalls = message['tool_calls'] @@ -95,7 +94,6 @@ class OpenAIClient extends BaseLLMClient { } catch (e) { final tips = "call openai chatCompletion failed: endpoint: $baseUrl/chat/completions body: $body $e"; - Logger.root.severe(tips); throw Exception(tips); } } @@ -109,6 +107,7 @@ class OpenAIClient extends BaseLLMClient { }; try { + _dio.options.responseType = ResponseType.stream; final response = await _dio.post( "$baseUrl/chat/completions", data: jsonEncode(body), @@ -195,4 +194,24 @@ $conversationText""", )); return response.content?.trim() ?? "New Chat"; } + + @override + Future> models() async { + try { + final response = await _dio.get("$baseUrl/models"); + + final data = response.data; + + final models = (data['data'] as List) + .map((m) => m['id'].toString()) + .where((id) => id.contains('gpt') || id.contains('o1')) + .toList(); + + return models; + } catch (e, trace) { + Logger.root.severe('获取模型列表失败: $e, trace: $trace'); + // 返回预定义的模型列表作为后备 + return []; + } + } } diff --git a/lib/page/layout/chat_page/chat_page.dart b/lib/page/layout/chat_page/chat_page.dart index 71917f0..11d092e 100644 --- a/lib/page/layout/chat_page/chat_page.dart +++ b/lib/page/layout/chat_page/chat_page.dart @@ -1,6 +1,5 @@ import 'package:flutter/material.dart'; import 'package:ChatMcp/llm/model.dart'; -import 'package:ChatMcp/llm/prompt.dart'; import 'package:ChatMcp/llm/llm_factory.dart'; import 'package:ChatMcp/llm/base_llm_client.dart'; import 'package:logging/logging.dart'; @@ -32,34 +31,23 @@ class _ChatPageState extends State { super.initState(); _initializeLLMClient(); - // Add settings change listener + // Add settings change listener, when settings changed, we need to reinitialize LLM client ProviderManager.settingsProvider.addListener(_onSettingsChanged); - // Add chat change listener + // Add chat model change listener, when model changed, we need to reinitialize LLM client + ProviderManager.chatModelProvider.addListener(_initializeLLMClient); + // Add chat change listener, when chat changed, we need to reinitialize history messages ProviderManager.chatProvider.addListener(_onChatProviderChanged); _initializeHistoryMessages(); } void _onChatProviderChanged() { - _initializeLLMClient(); _initializeHistoryMessages(); } void _initializeLLMClient() { - final currentModel = ProviderManager.chatProvider.currentModel; - final provider = currentModel.startsWith('gpt') ? 'openai' : 'claude'; - - final apiKey = ProviderManager.settingsProvider.apiSettings[provider]?.apiKey ?? ''; - final baseUrl = ProviderManager.settingsProvider.apiSettings[provider]?.apiEndpoint ?? ''; - - Logger.root.fine('Using API Key: [HIDDEN] for provider: $provider'); - _llmClient = LLMFactory.create( - provider == 'openai' ? LLMProvider.openAI : LLMProvider.claude, - apiKey: apiKey, - baseUrl: baseUrl - ); - + _llmClient = LLMFactoryHelper.createFromModel( + ProviderManager.chatModelProvider.currentModel); setState(() {}); // Refresh UI after client change - Logger.root.fine('llmClient initialized for $provider with model $currentModel'); } void _onSettingsChanged() { @@ -252,7 +240,7 @@ class _ChatPageState extends State { ]; final stream = _llmClient!.chatStreamCompletion(CompletionRequest( - model: ProviderManager.chatProvider.currentModel, + model: ProviderManager.chatModelProvider.currentModel, messages: messageList, )); diff --git a/lib/page/layout/layout.dart b/lib/page/layout/layout.dart index 5a9b0e2..b18574f 100644 --- a/lib/page/layout/layout.dart +++ b/lib/page/layout/layout.dart @@ -1,8 +1,10 @@ import 'package:flutter/material.dart'; +import 'package:provider/provider.dart'; import './chat_page/chat_page.dart'; import './chat_history.dart'; import 'package:ChatMcp/provider/provider_manager.dart'; +import 'package:ChatMcp/provider/chat_model_provider.dart'; class LayoutPage extends StatefulWidget { const LayoutPage({super.key}); @@ -14,86 +16,99 @@ class LayoutPage extends StatefulWidget { class _LayoutPageState extends State { bool hideChatHistory = false; + @override + void initState() { + super.initState(); + WidgetsBinding.instance.addPostFrameCallback((_) { + ProviderManager.chatModelProvider.loadAvailableModels(); + }); + } + @override Widget build(BuildContext context) { - return Scaffold( - body: Row( - children: [ - if (!hideChatHistory) - Container( - width: 250, - color: Colors.grey[200], - child: ChatHistoryPanel( - onToggle: () => setState(() { - hideChatHistory = !hideChatHistory; - }), + return Consumer( + builder: (context, chatModelProvider, child) { + return Scaffold( + body: Row( + children: [ + if (!hideChatHistory) + Container( + width: 250, + color: Colors.grey[200], + child: ChatHistoryPanel( + onToggle: () => setState(() { + hideChatHistory = !hideChatHistory; + }), + ), ), - ), - Expanded( - child: Column( - children: [ - Container( - height: 50, - color: Colors.grey[100], - padding: - EdgeInsets.fromLTRB(hideChatHistory ? 70 : 0, 0, 16, 0), - child: Row( - children: [ - if (hideChatHistory) - IconButton( - icon: const Icon(Icons.menu), - onPressed: () => setState(() { - hideChatHistory = !hideChatHistory; - }), - ), - // model select - DropdownButtonHideUnderline( - child: ButtonTheme( - alignedDropdown: true, - child: DropdownButton( - value: ProviderManager.chatProvider.currentModel, - items: ProviderManager.chatProvider.availableModels - .map((model) => DropdownMenuItem( - value: model.name, - child: Text(model.label), - )) - .toList(), - onChanged: (String? value) { - if (value != null) { - setState(() { - ProviderManager.chatProvider.currentModel = value; - }); - } - }, - menuMaxHeight: 200, - elevation: 20, - isDense: true, - underline: Container( - height: 0, + Expanded( + child: Column( + children: [ + Container( + height: 50, + color: Colors.grey[100], + padding: + EdgeInsets.fromLTRB(hideChatHistory ? 70 : 0, 0, 16, 0), + child: Row( + children: [ + if (hideChatHistory) + IconButton( + icon: const Icon(Icons.menu), + onPressed: () => setState(() { + hideChatHistory = !hideChatHistory; + }), + ), + // model select + DropdownButtonHideUnderline( + child: ButtonTheme( + alignedDropdown: true, + child: DropdownButton( + value: chatModelProvider.currentModel, + items: ProviderManager + .chatModelProvider.availableModels + .map((model) => DropdownMenuItem( + value: model.name, + child: Text(model.label), + )) + .toList(), + onChanged: (String? value) { + if (value != null) { + setState(() { + ProviderManager + .chatModelProvider.currentModel = value; + }); + } + }, + menuMaxHeight: 200, + elevation: 20, + isDense: true, + underline: Container( + height: 0, + ), + isExpanded: false, + alignment: AlignmentDirectional.centerStart, ), - isExpanded: false, - alignment: AlignmentDirectional.centerStart, ), ), - ), - const Spacer(), - IconButton( - icon: const Icon(Icons.add), - onPressed: () { - ProviderManager.chatProvider.clearActiveChat(); - }, - ), - ], + const Spacer(), + IconButton( + icon: const Icon(Icons.add), + onPressed: () { + ProviderManager.chatProvider.clearActiveChat(); + }, + ), + ], + ), ), - ), - const Expanded( - child: ChatPage(), - ), - ], + const Expanded( + child: ChatPage(), + ), + ], + ), ), - ), - ], - ), - ); + ], + ), + ); + }); } } diff --git a/lib/provider/chat_model_provider.dart b/lib/provider/chat_model_provider.dart new file mode 100644 index 0000000..77f6af7 --- /dev/null +++ b/lib/provider/chat_model_provider.dart @@ -0,0 +1,63 @@ +import 'package:flutter/material.dart'; +import 'package:ChatMcp/llm/model.dart' as llmModel; +import 'package:shared_preferences/shared_preferences.dart'; +import 'package:logging/logging.dart'; +import 'package:ChatMcp/llm/llm_factory.dart'; + +class ChatModelProvider extends ChangeNotifier { + static final ChatModelProvider _instance = ChatModelProvider._internal(); + factory ChatModelProvider() => _instance; + ChatModelProvider._internal() { + _loadSavedModel(); + } + + bool _isInitialized = false; + + Future loadAvailableModels() async { + if (_isInitialized) return; + + final models = await LLMFactoryHelper.getAvailableModels(); + + _availableModels.clear(); // 先清空列表 + _availableModels.addAll(models + .map((model) => llmModel.Model(name: model, label: model)) + .toList()); + + // 确保当前选择的模型在可用列表中 + if (!_availableModels.any((model) => model.name == _currentModel)) { + _currentModel = _availableModels.first.name; // 如果不在,选择第一个可用的模型 + _saveSavedModel(); + } + + _isInitialized = true; + notifyListeners(); + } + + final List _availableModels = []; + + List get availableModels => _availableModels; + + static const String _modelKey = 'current_model'; + String _currentModel = "gpt-4o-mini"; + + String get currentModel => _currentModel; + + set currentModel(String model) { + _currentModel = model; + _saveSavedModel(); + notifyListeners(); + } + + Future _loadSavedModel() async { + final prefs = await SharedPreferences.getInstance(); + _currentModel = prefs.getString(_modelKey) ?? "gpt-4o-mini"; + Logger.root.info( + 'load model: ${prefs.getString(_modelKey)} currentModel: $_currentModel'); + notifyListeners(); + } + + Future _saveSavedModel() async { + final prefs = await SharedPreferences.getInstance(); + await prefs.setString(_modelKey, _currentModel); + } +} diff --git a/lib/provider/chat_provider.dart b/lib/provider/chat_provider.dart index 6c84d0d..c85eda6 100644 --- a/lib/provider/chat_provider.dart +++ b/lib/provider/chat_provider.dart @@ -3,16 +3,13 @@ import 'package:ChatMcp/dao/chat.dart'; import 'package:ChatMcp/dao/chat_message.dart'; import 'package:logging/logging.dart'; import 'package:ChatMcp/llm/model.dart' as llmModel; -import 'package:shared_preferences/shared_preferences.dart'; import 'package:ChatMcp/llm/openai_client.dart' as openai; import 'package:ChatMcp/llm/claude_client.dart' as claude; class ChatProvider extends ChangeNotifier { static final ChatProvider _instance = ChatProvider._internal(); factory ChatProvider() => _instance; - ChatProvider._internal() { - _loadSavedModel(); - } + ChatProvider._internal() {} Chat? _activeChat; List _chats = []; @@ -20,32 +17,10 @@ class ChatProvider extends ChangeNotifier { Chat? get activeChat => _activeChat; List get chats => _chats; - static const String _modelKey = 'current_model'; - String _currentModel = "gpt-4o-mini"; - - String get currentModel => _currentModel; - List get availableModels => [ - ...openai.models, - ...claude.models, - ]; - - set currentModel(String model) { - _currentModel = model; - _saveSavedModel(); - notifyListeners(); - } - - Future _loadSavedModel() async { - final prefs = await SharedPreferences.getInstance(); - _currentModel = prefs.getString(_modelKey) ?? "gpt-4o-mini"; - notifyListeners(); - } - - Future _saveSavedModel() async { - final prefs = await SharedPreferences.getInstance(); - await prefs.setString(_modelKey, _currentModel); - } + ...openai.models, + ...claude.models, + ]; Future loadChats() async { final chatDao = ChatDao(); diff --git a/lib/provider/provider_manager.dart b/lib/provider/provider_manager.dart index 5c483aa..cddab93 100644 --- a/lib/provider/provider_manager.dart +++ b/lib/provider/provider_manager.dart @@ -3,6 +3,7 @@ import 'package:provider/provider.dart'; import 'settings_provider.dart'; import 'mcp_server_provider.dart'; import 'chat_provider.dart'; +import 'chat_model_provider.dart'; class ProviderManager { static List providers = [ @@ -15,6 +16,9 @@ class ProviderManager { ChangeNotifierProvider( create: (_) => ChatProvider(), ), + ChangeNotifierProvider( + create: (_) => ChatModelProvider(), + ), // 在这里添加其他 Provider ]; @@ -39,6 +43,13 @@ class ProviderManager { return _chatProvider!; } + static ChatModelProvider? _chatModelProvider; + + static ChatModelProvider get chatModelProvider { + _chatModelProvider ??= ChatModelProvider(); + return _chatModelProvider!; + } + static Future init() async { await SettingsProvider().loadSettings(); _mcpServerProvider = McpServerProvider(); diff --git a/lib/widgets/markit.dart b/lib/widgets/markit.dart index ea3d552..764d590 100644 --- a/lib/widgets/markit.dart +++ b/lib/widgets/markit.dart @@ -133,6 +133,7 @@ class HighlightView extends StatelessWidget { HighlightView( String input, { + super.key, this.language, this.theme = const {}, this.padding, @@ -145,7 +146,7 @@ class HighlightView extends StatelessWidget { var currentSpans = spans; List> stack = []; - _traverse(Node node) { + traverse(Node node) { if (node.value != null) { currentSpans.add(node.className == null ? TextSpan(text: node.value) @@ -157,17 +158,17 @@ class HighlightView extends StatelessWidget { stack.add(currentSpans); currentSpans = tmp; - node.children!.forEach((n) { - _traverse(n); + for (var n in node.children!) { + traverse(n); if (n == node.children!.last) { currentSpans = stack.isEmpty ? spans : stack.removeLast(); } - }); + } } } for (var node in nodes) { - _traverse(node); + traverse(node); } return spans; @@ -184,12 +185,12 @@ class HighlightView extends StatelessWidget { @override Widget build(BuildContext context) { - var _textStyle = TextStyle( + var textStyle = TextStyle( fontFamily: _defaultFontFamily, color: theme[_rootKey]?.color ?? _defaultFontColor, ); if (textStyle != null) { - _textStyle = _textStyle.merge(textStyle); + textStyle = textStyle.merge(textStyle); } return Container( @@ -197,7 +198,7 @@ class HighlightView extends StatelessWidget { padding: padding, child: SelectableText.rich( TextSpan( - style: _textStyle, + style: textStyle, children: _convert(highlight.parse(source, language: language).nodes!), ), diff --git a/pubspec.lock b/pubspec.lock index a23b39d..2622dfe 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -182,6 +182,14 @@ packages: url: "https://pub.flutter-io.cn" source: hosted version: "1.0.8" + curl_logger_dio_interceptor: + dependency: "direct main" + description: + name: curl_logger_dio_interceptor + sha256: f20d89187a321d2150e1412bca30ebf4d89130bafc648ce21bd4f1ef4062b214 + url: "https://pub.flutter-io.cn" + source: hosted + version: "1.0.0" dart_style: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index a17a1ac..0864f3b 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -53,6 +53,7 @@ dependencies: highlighter: ^0.1.1 url_launcher: ^6.3.1 window_manager: ^0.4.3 + curl_logger_dio_interceptor: ^1.0.0 dev_dependencies: flutter_test: