Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update chat message to include image content as well as text #72

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 140 additions & 41 deletions Sources/OpenAIKit/Chat/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,68 @@ extension Chat {
extension Chat.Choice: Codable {}

extension Chat {
public enum Message {
case system(content: String)
case user(content: String)
case assistant(content: String)
public struct Message {
public let role: Role
public let content: [Content]

public init(role: Role, content: [Content]) {
self.role = role
self.content = content
}
}
}

extension Chat.Message {
public enum Role: String, Codable {
case system
case user
case assistant
}
}

extension Chat.Message {
public enum Content: Codable {
/// Text content.
case text(String)
/// A URL to an image.
case imageURL(URL)
/// Base 64 ecoded image and the associated content type.
case imageData(Data, String)
}
}

extension Chat.Message.Content {
public var text: String? {
switch self {
case .text(let text):
return text
default:
return nil
}
}
public var imageURL: URL? {
switch self {
case .imageURL(let url):
return url
default:
return nil
}
}
public var imageData: Data? {
switch self {
case .imageData(let data, _):
return data
default:
return nil
}
}
public var contentType: String? {
switch self {
case .imageData(_, let contentType):
return contentType
default:
return nil
}
}
}

Expand All @@ -38,51 +96,92 @@ extension Chat.Message: Codable {
case content
}

private enum ContentKeys: String, CodingKey {
case type
case text
case imageURL = "image_url"
}

private enum ImageURLKeys: String, CodingKey {
case url
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let role = try container.decode(String.self, forKey: .role)
let content = try container.decode(String.self, forKey: .content)
switch role {
case "system":
self = .system(content: content)
case "user":
self = .user(content: content)
case "assistant":
self = .assistant(content: content)
default:
throw DecodingError.dataCorruptedError(forKey: .role, in: container, debugDescription: "Invalid type")
role = try container.decode(Chat.Message.Role.self, forKey: .role)

if let singleTextContent = try? container.decode(String.self, forKey: .content) {
content = [.text(singleTextContent)]
} else {
var contentsArray = try container.nestedUnkeyedContainer(forKey: .content)
var contents = [Content]()

while !contentsArray.isAtEnd {
let contentContainer = try contentsArray.nestedContainer(keyedBy: ContentKeys.self)
let type = try contentContainer.decode(String.self, forKey: .type)

switch type {
case "text":
let text = try contentContainer.decode(String.self, forKey: .text)
contents.append(.text(text))
case "image_url":
let imageURLContainer = try contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)

if let url = try? imageURLContainer.decode(URL.self, forKey: .url) {
contents.append(.imageURL(url))
} else if let imageDataURL = try? imageURLContainer.decode(String.self, forKey: .url) {
let components = imageDataURL.components(separatedBy: ",")
if components.count == 2 {
let metadata = components[0]
let base64Component = components[1]

if
let base64Data = base64Component.data(using: .utf8),
let contentTypeRange = metadata.range(of: "data:")?.upperBound,
let base64IndicatorRange = metadata.range(of: ";base64")?.lowerBound
{
let contentType = String(metadata[contentTypeRange..<base64IndicatorRange])

contents.append(.imageData(base64Data, contentType))
}
}
} else {
throw DecodingError.dataCorruptedError(forKey: .type, in: contentContainer, debugDescription: "Unknown image URL")
}
default:
throw DecodingError.dataCorruptedError(forKey: .type, in: contentContainer, debugDescription: "Unknown content type: \(type)")
}
}
content = contents
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch self {
case .system(let content):
try container.encode("system", forKey: .role)
try container.encode(content, forKey: .content)
case .user(let content):
try container.encode("user", forKey: .role)
try container.encode(content, forKey: .content)
case .assistant(let content):
try container.encode("assistant", forKey: .role)
try container.encode(content, forKey: .content)
}
}
}

extension Chat.Message {
public var content: String {
get {
switch self {
case .system(let content), .user(let content), .assistant(let content):
return content
}
}
set {
switch self {
case .system: self = .system(content: newValue)
case .user: self = .user(content: newValue)
case .assistant: self = .assistant(content: newValue)
try container.encode(role, forKey: .role)

if content.count == 1, let text = content.first?.text {
try container.encode(text, forKey: .content)
} else {
var contentsArray = container.nestedUnkeyedContainer(forKey: .content)
for contentItem in content {
var contentContainer = contentsArray.nestedContainer(keyedBy: ContentKeys.self)
switch contentItem {
case .text(let text):
try contentContainer.encode("text", forKey: .type)
try contentContainer.encode(text, forKey: .text)
case .imageURL(let url):
try contentContainer.encode("image_url", forKey: .type)
var urlContainer = contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)
try urlContainer.encode(url, forKey: .url)
case .imageData(let data, let contentType):
try contentContainer.encode("image_url", forKey: .type)
var urlContainer = contentContainer.nestedContainer(keyedBy: ImageURLKeys.self, forKey: .imageURL)
let data = data.base64EncodedString()
let dataURL = "data:\(contentType);base64,\(data)"
try urlContainer.encode(dataURL, forKey: .url)
}
}
}
}
Expand Down