Skip to content

Commit

Permalink
Update chat message to include image content as well as text
Browse files Browse the repository at this point in the history
  • Loading branch information
mpdifran committed Nov 12, 2024
1 parent 1ce94bb commit e847e18
Showing 1 changed file with 140 additions and 41 deletions.
181 changes: 140 additions & 41 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,68 @@ extension Chat {
extension Chat.Choice: Codable {}

extension Chat {
public enum Message {
case system(content: String)
case user(content: String)
case assistant(content: String)
public struct Message {
public let role: Role
public let content: [Content]

public init(role: Role, content: [Content]) {
self.role = role
self.content = content
}
}
}

extension Chat.Message {
public enum Role: String, Codable {
case system
case user
case assistant
}
}

extension Chat.Message {
public enum Content: Codable {
/// Text content.
case text(String)
/// A URL to an image.
case imageURL(URL)
/// Base 64 ecoded image and the associated content type.
case imageData(Data, String)
}
}

extension Chat.Message.Content {
public var text: String? {
switch self {
case .text(let text):
return text
default:
return nil
}
}
public var imageURL: URL? {
switch self {
case .imageURL(let url):
return url
default:
return nil
}
}
public var imageData: Data? {
switch self {
case .imageData(let data, _):
return data
default:
return nil
}
}
public var contentType: String? {
switch self {
case .imageData(_, let contentType):
return contentType
default:
return nil
}
}
}

Expand All @@ -38,51 +96,92 @@ extension Chat.Message: Codable {
case content
}

private enum ContentKeys: String, CodingKey {
case type
case text
case imageURL = "image_url"
}

private enum ImageURLKeys: String, CodingKey {
case url
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let role = try container.decode(String.self, forKey: .role)
let content = try container.decode(String.self, forKey: .content)
switch role {
case "system":
self = .system(content: content)
case "user":
self = .user(content: content)
case "assistant":
self = .assistant(content: content)
default:
throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid type")
role = try container.decode(Chat.Message.Role.self, forKey: .role)

if let singleTextContent = try? container.decode(String.self, forKey: .content) {
content = [.text(singleTextContent)]
} else {
var contentsArray = try container.nestedUnkeyedContainer(forKey: .content)
var contents = [Content]()

while !contentsArray.isAtEnd {
let contentContainer = try contentsArray.nestedContainer(keyedBy: ContentKeys.self)
let type = try contentContainer.decode(String.self, forKey: .type)

switch type {
case "text":
let text = try contentContainer.decode(String.self, forKey: .text)
contents.append(.text(text))
case "image_url":
let imageURLContainer = try contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)

if let url = try? imageURLContainer.decode(URL.self, forKey: .url) {
contents.append(.imageURL(url))
} else if let imageDataURL = try? imageURLContainer.decode(String.self, forKey: .url) {
let components = imageDataURL.components(separatedBy: ",")
if components.count == 2 {
let metadata = components[0]
let base64Component = components[1]

if
let base64Data = base64Component.dropFirst().dropLast().data(using: .utf8),
let contentTypeRange = metadata.range(of: "data:")?.upperBound,
let base64IndicatorRange = metadata.range(of: ";base64")?.lowerBound
{
let contentType = String(metadata[contentTypeRange..<base64IndicatorRange])

contents.append(.imageData(base64Data, contentType))
}
}
} else {
throw DecodingError.dataCorruptedError(forKey: .type, in: contentContainer, debugDescription: "Unknown image URL")
}
default:
throw DecodingError.dataCorruptedError(forKey: .type, in: contentContainer, debugDescription: "Unknown content type: \(type)")
}
}
content = contents
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .system(let content):
try container.encode("system", forKey: .role)
try container.encode(content, forKey: .content)
case .user(let content):
try container.encode("user", forKey: .role)
try container.encode(content, forKey: .content)
case .assistant(let content):
try container.encode("assistant", forKey: .role)
try container.encode(content, forKey: .content)
}
}
}

extension Chat.Message {
public var content: String {
get {
switch self {
case .system(let content), .user(let content), .assistant(let content):
return content
}
}
set {
switch self {
case .system: self = .system(content: newValue)
case .user: self = .user(content: newValue)
case .assistant: self = .assistant(content: newValue)
try container.encode(role, forKey: .role)

if content.count == 1, let text = content.first?.text {
try container.encode(text, forKey: .content)
} else {
var contentsArray = container.nestedUnkeyedContainer(forKey: .content)
for contentItem in content {
var contentContainer = contentsArray.nestedContainer(keyedBy: ContentKeys.self)
switch contentItem {
case .text(let text):
try contentContainer.encode("text", forKey: .type)
try contentContainer.encode(text, forKey: .text)
case .imageURL(let url):
try contentContainer.encode("image_url", forKey: .type)
var urlContainer = contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)
try urlContainer.encode(url, forKey: .url)
case .imageData(let data, let contentType):
try contentContainer.encode("image_url", forKey: .type)
var urlContainer = contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)
let data = data.base64EncodedString()
let dataURL = "data:\(contentType);base64,{\(data)}"
try urlContainer.encode(dataURL, forKey: .url)
}
}
}
}
Expand Down

0 comments on commit e847e18

Please sign in to comment.