Skip to content

Commit

Permalink
refactor: replace [String: String] with struct ChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
tisfeng committed Nov 3, 2024
1 parent 246e44d commit 0d80b97
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 341 deletions.
4 changes: 4 additions & 0 deletions Easydict.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
0310C8272A94F5DF00B1D81E /* apple-dictionary.html in Resources */ = {isa = PBXBuildFile; fileRef = 0310C8262A94EFA100B1D81E /* apple-dictionary.html */; };
0313F8702AD5577400A5CFB0 /* EasydictTests.m in Sources */ = {isa = PBXBuildFile; fileRef = 0313F86F2AD5577400A5CFB0 /* EasydictTests.m */; };
0315D3E02C4E64A500AC0442 /* QueryService+Translate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0315D3DF2C4E64A500AC0442 /* QueryService+Translate.swift */; };
031CBA642CD76F1500364437 /* ChatMessage.swift in Sources */ = {isa = PBXBuildFile; fileRef = 031CBA632CD76F1500364437 /* ChatMessage.swift */; };
031DBD792AE01E130071CF85 /* easydict in Resources */ = {isa = PBXBuildFile; fileRef = 031DBD782AE01E130071CF85 /* easydict */; };
0320C5872B29F35700861B3D /* QueryServiceRecord.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0320C5862B29F35700861B3D /* QueryServiceRecord.swift */; };
0320DFF72C54A11300C516A7 /* LLMStreamService+Stream.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0320DFF62C54A11300C516A7 /* LLMStreamService+Stream.swift */; };
Expand Down Expand Up @@ -395,6 +396,7 @@
0313F86D2AD5577400A5CFB0 /* EasydictTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = EasydictTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
0313F86F2AD5577400A5CFB0 /* EasydictTests.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = EasydictTests.m; sourceTree = "<group>"; };
0315D3DF2C4E64A500AC0442 /* QueryService+Translate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "QueryService+Translate.swift"; sourceTree = "<group>"; };
031CBA632CD76F1500364437 /* ChatMessage.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ChatMessage.swift; sourceTree = "<group>"; };
031DBD782AE01E130071CF85 /* easydict */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.script.sh; path = easydict; sourceTree = "<group>"; };
0320C5862B29F35700861B3D /* QueryServiceRecord.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = QueryServiceRecord.swift; sourceTree = "<group>"; };
0320DFF62C54A11300C516A7 /* LLMStreamService+Stream.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = "LLMStreamService+Stream.swift"; sourceTree = "<group>"; };
Expand Down Expand Up @@ -1349,6 +1351,7 @@
0396DE542BB5844A009FD2A5 /* BaseOpenAIService.swift */,
03779F0B2BB256A7008D3C42 /* OpenAIService.swift */,
03779F0C2BB256A7008D3C42 /* Prompt.swift */,
031CBA632CD76F1500364437 /* ChatMessage.swift */,
);
path = OpenAI;
sourceTree = "<group>";
Expand Down Expand Up @@ -3032,6 +3035,7 @@
278540342B3DE04F004E9488 /* GeneralTab.swift in Sources */,
03BDA7BC2A26DA280079D04F /* XPMArgumentSignature.m in Sources */,
03B0230229231FA6001C7E63 /* EZWordResultView.m in Sources */,
031CBA642CD76F1500364437 /* ChatMessage.swift in Sources */,
0399C6A529A747E600B4AFCC /* EZDeepLTranslateResponse.m in Sources */,
0340D3912C8EEEE3004C9910 /* Data+Extension.swift in Sources */,
17BCAEF82B0DFF9000A7D372 /* EZNiuTransTranslate.m in Sources */,
Expand Down
2 changes: 1 addition & 1 deletion Easydict/Swift/Service/AITool/PolishingService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PolishingService: AIToolService {

// MARK: Internal

override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [[String: String]] {
override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [ChatMessage] {
polishingMessages(chatQuery)
}
}
2 changes: 1 addition & 1 deletion Easydict/Swift/Service/AITool/SummaryService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SummaryService: AIToolService {

// MARK: Internal

override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [[String: String]] {
override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [ChatMessage] {
summaryMessages(chatQuery)
}
}
8 changes: 4 additions & 4 deletions Easydict/Swift/Service/CustomOpenAI/CustomOpenAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ class CustomOpenAIService: BaseOpenAIService {
)
}

override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [[String: String]] {
override func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [ChatMessage] {
if enableCustomPrompt {
var chatMessages: [[String: String]] = []
var chatMessages: [ChatMessage] = []
let systemPrompt = replaceCustomPromptWithVariable(systemPrompt)
var userPrompt = replaceCustomPromptWithVariable(userPrompt)

if !systemPrompt.isEmpty {
chatMessages.append(chatMessage(role: .system, content: systemPrompt))
chatMessages.append(.init(role: .system, content: systemPrompt))
}

// If user prompt is empty, use query text as user prompt
if userPrompt.isEmpty {
userPrompt = queryModel.queryText
}
chatMessages.append(chatMessage(role: .user, content: userPrompt))
chatMessages.append(.init(role: .user, content: userPrompt))

return chatMessages
}
Expand Down
14 changes: 7 additions & 7 deletions Easydict/Swift/Service/Gemini/GeminiService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ public final class GeminiService: LLMStreamService {

override func serviceChatMessageModels(_ chatQuery: ChatQueryParam) -> [Any] {
var chatModels: [ModelContent] = []
for prompt in chatMessageDicts(chatQuery) {
if let openAIRole = prompt["role"],
let parts = prompt["content"] {
let role = getGeminiRole(from: openAIRole)
let chat = ModelContent(role: role, parts: parts)
chatModels.append(chat)
}
for message in chatMessageDicts(chatQuery) {
let openAIRole = message.role.rawValue
let parts = message.content

let role = getGeminiRole(from: openAIRole)
let chat = ModelContent(role: role, parts: parts)
chatModels.append(chat)
}
return chatModels
}
Expand Down
20 changes: 10 additions & 10 deletions Easydict/Swift/Service/OpenAI/BaseOpenAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ public class BaseOpenAIService: LLMStreamService {

// MARK: Internal

typealias ChatMessage = ChatQuery.ChatCompletionMessageParam
typealias OpenAIChatMessage = ChatQuery.ChatCompletionMessageParam

let control = StreamControl()

override func serviceChatMessageModels(_ chatQuery: ChatQueryParam) -> [Any] {
var chatModels: [ChatMessage] = []
var chatMessages: [OpenAIChatMessage] = []
for message in chatMessageDicts(chatQuery) {
if let roleRawValue = message["role"],
let role = ChatMessage.Role(rawValue: roleRawValue),
let content = message["content"] {
if let chat = ChatMessage(role: role, content: content) {
chatModels.append(chat)
}
let openAIRole = message.role.rawValue
let content = message.content

if let role = OpenAIChatMessage.Role(rawValue: openAIRole),
let chat = OpenAIChatMessage(role: role, content: content) {
chatMessages.append(chat)
}
}
return chatModels
return chatMessages
}

override func cancelStream() {
Expand Down Expand Up @@ -111,7 +111,7 @@ public class BaseOpenAIService: LLMStreamService {
)

let chatHistory = serviceChatMessageModels(chatQueryParam)
guard let chatHistory = chatHistory as? [ChatMessage] else {
guard let chatHistory = chatHistory as? [OpenAIChatMessage] else {
return AsyncThrowingStream { continuation in
continuation.finish(throwing: invalidURLError)
}
Expand Down
49 changes: 49 additions & 0 deletions Easydict/Swift/Service/OpenAI/ChatMessage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//
// ChatMessage.swift
// Easydict
//
// Created by tisfeng on 2024/11/3.
// Copyright © 2024 izual. All rights reserved.
//

import Foundation

func systemMessage(queryType: EZQueryTextType) -> ChatMessage {
switch queryType {
case .dictionary:
.init(role: .system, content: LLMStreamService.dictSystemPrompt)
default:
.init(role: .system, content: LLMStreamService.translationSystemPrompt)
}
}

func chatMessagePair(userContent: String, assistantContent: String) -> [ChatMessage] {
[
.init(role: .user, content: userContent),
.init(role: .assistant, content: assistantContent),
]
}

// MARK: - ChatMessage

struct ChatMessage {
// MARK: - ChatRole

enum ChatRole: String, Codable, Equatable, CaseIterable {
case system
case user
case assistant
case tool
case model // Gemini role, equal to OpenAI assistant role.
}

let role: ChatRole
let content: String
}

// MARK: - AIToolType

enum AIToolType {
case polishing
case summary
}
3 changes: 1 addition & 2 deletions Easydict/Swift/Service/OpenAI/LLMStreamService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ public class LLMStreamService: QueryService {
}
}

func chatMessageDicts(_ chatQuery: ChatQueryParam)
-> [[String: String]] {
func chatMessageDicts(_ chatQuery: ChatQueryParam) -> [ChatMessage] {
switch chatQuery.queryType {
case .dictionary:
dictMessages(chatQuery)
Expand Down
Loading

0 comments on commit 0d80b97

Please sign in to comment.