diff --git a/Sources/OpenAIKit/Chat/Chat.swift b/Sources/OpenAIKit/Chat/Chat.swift index f0dde58..af7b636 100644 --- a/Sources/OpenAIKit/Chat/Chat.swift +++ b/Sources/OpenAIKit/Chat/Chat.swift @@ -25,10 +25,68 @@ extension Chat { extension Chat.Choice: Codable {} extension Chat { - public enum Message { - case system(content: String) - case user(content: String) - case assistant(content: String) + public struct Message { + public let role: Role + public let content: [Content] + + public init(role: Role, content: [Content]) { + self.role = role + self.content = content + } + } +} + +extension Chat.Message { + public enum Role: String, Codable { + case system + case user + case assistant + } +} + +extension Chat.Message { + public enum Content: Codable { + /// Text content. + case text(String) + /// A URL to an image. + case imageURL(URL) + /// Base 64 ecoded image and the associated content type. + case imageData(Data, String) + } +} + +extension Chat.Message.Content { + public var text: String? { + switch self { + case .text(let text): + return text + default: + return nil + } + } + public var imageURL: URL? { + switch self { + case .imageURL(let url): + return url + default: + return nil + } + } + public var imageData: Data? { + switch self { + case .imageData(let data, _): + return data + default: + return nil + } + } + public var contentType: String? { + switch self { + case .imageData(_, let contentType): + return contentType + default: + return nil + } } } @@ -38,51 +96,92 @@ extension Chat.Message: Codable { case content } + private enum ContentKeys: String, CodingKey { + case type + case text + case imageURL = "image_url" + } + + private enum ImageURLKeys: String, CodingKey { + case url + } + public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) - let role = try container.decode(String.self, forKey: .role) - let content = try container.decode(String.self, forKey: .content) - switch role { - case "system": - self = .system(content: content) - case "user": - self = .user(content: content) - case "assistant": - self = .assistant(content: content) - default: - throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid type") + role = try container.decode(Chat.Message.Role.self, forKey: .role) + + if let singleTextContent = try? container.decode(String.self, forKey: .content) { + content = [.text(singleTextContent)] + } else { + var contentsArray = try container.nestedUnkeyedContainer(forKey: .content) + var contents = [Content]() + + while !contentsArray.isAtEnd { + let contentContainer = try contentsArray.nestedContainer(keyedBy: ContentKeys.self) + let type = try contentContainer.decode(String.self, forKey: .type) + + switch type { + case "text": + let text = try contentContainer.decode(String.self, forKey: .text) + contents.append(.text(text)) + case "image_url": + let imageURLContainer = try contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL) + + if let url = try? imageURLContainer.decode(URL.self, forKey: .url) { + contents.append(.imageURL(url)) + } else if let imageDataURL = try? imageURLContainer.decode(String.self, forKey: .url) { + let components = imageDataURL.components(separatedBy: ",") + if components.count == 2 { + let metadata = components[0] + let base64Component = components[1] + + if + let base64Data = base64Component.data(using: .utf8), + let contentTypeRange = metadata.range(of: "data:")?.upperBound, + let base64IndicatorRange = metadata.range(of: ";base64")?.lowerBound + { + let contentType = String(metadata[contentTypeRange..