Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start sketching out batch API #331

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Collate:
'BatchChat.R'
'utils-S7.R'
'types.R'
'content.R'
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# ellmer (development version)

* New `$chat_batch()` and `$extract_data_batch()` make it possible to use the
"batch" API provided by Claude, OpenAI, and Gemini (#143). Batch request are
typically 50% cheaper than regular requests but can take up to 24 hours to
complete.

* New `$chat_parallel()` and `$extract_data_parallel()` make it easier to
perform multiple actions in parallel (#143). For Claude, note that the number
of active connections is limited primarily by the output tokens per limit
Expand Down
35 changes: 35 additions & 0 deletions R/BatchChat.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
batch_wait <- function(provider, batch) {
info <- batch_info(provider, batch)
cli::cli_progress_bar(
format = paste0(
"{cli::pb_spin} Processing... {info$counts$processing} -> {cli::col_green({info$counts$succeeded})} / {cli::col_red({info$counts$failed})} ",
"[{cli::pb_elapsed}]"
),
clear = FALSE
)
tryCatch(
{
while (info$working) {
Sys.sleep(1)
cli::cli_progress_update()
batch <- batch_poll(provider, batch)
info <- batch_info(provider, batch)
}
},
interrupt = function(cnd) {}
)

batch
}


check_has_batch_support <- function(provider, call = caller_env()) {
if (has_batch_support(provider)) {
return()
}

cli::cli_abort(
"Batch requests are not currently supported by this provider.",
call = call
)
}
49 changes: 47 additions & 2 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,35 @@ Chat <- R6::R6Class("Chat",

map2(json, turns[ok], function(json, user_turn) {
chat <- self$clone()
turn <- value_turn(private$provider, json)
chat$add_turn(user_turn, turn)
ai_turn <- value_turn(private$provider, json)
chat$add_turn(user_turn, user_turn)
chat
})
},

#' @description Submit multiple prompts in parallel. Returns a list of
#' [Chat] objects, one for each prompt.
#' @param prompts A list of user prompts.
#' @param max_active The maximum number of simultaenous requests to send.
#' @param rpm Maximum number of requests per minute.
chat_batch = function(prompts) {
check_has_batch_support(private$provider)

turns <- as_user_turns(prompts)
new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn)))

batch <- batch_submit(private$provider, new_turns)
batch <- batch_wait(private$provider, batch)
results <- batch_retrieve(private$provider, batch)

ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x))

map2(results[ok], turns[ok], function(result, user_turn) {
ai_turn <- batch_result_turn(private$provider, result)
self$clone()$add_turn(user_turn, ai_turn)
})
},

#' @description Extract structured data
#' @param ... The input to send to the chatbot. Will typically include
#' the phrase "extract structured data".
Expand Down Expand Up @@ -254,6 +277,28 @@ Chat <- R6::R6Class("Chat",
})
},

extract_data_batch = function(prompts, type, convert = TRUE) {
check_has_batch_support(private$provider)
turns <- as_user_turns(prompts)
check_bool(convert)

needs_wrapper <- S7_inherits(private$provider, ProviderOpenAI)
if (needs_wrapper) {
type <- type_object(wrapper = type)
}

new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn)))
batch <- batch_submit(private$provider, new_turns, type = type)
batch <- batch_wait(provider, batch)
results <- batch_retrieve(provider, batch)

ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x))
map2(results[ok], turns[ok], function(result, user_turn) {
turn <- batch_result_turn(private$provider, result, has_type = TRUE)
extract_data(turn, type, convert = convert, needs_wrapper = needs_wrapper)
})
},

#' @description Extract structured data, asynchronously. Returns a promise
#' that resolves to an object matching the type specification.
#' @param ... The input to send to the chatbot. Will typically include
Expand Down
118 changes: 107 additions & 11 deletions R/provider-claude.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,8 @@ anthropic_key_exists <- function() {
key_exists("ANTHROPIC_API_KEY")
}

method(chat_request, ProviderClaude) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL) {

method(base_request, ProviderClaude) <- function(provider) {
req <- request(provider@base_url)
# https://docs.anthropic.com/en/api/messages
req <- req_url_path_append(req, "/messages")
# <https://docs.anthropic.com/en/api/versioning>
req <- req_headers(req, `anthropic-version` = "2023-06-01")
# <https://docs.anthropic.com/en/api/getting-started#authentication>
Expand All @@ -102,6 +95,37 @@ method(chat_request, ProviderClaude) <- function(provider,
}
})

req
}

# Chat ------------------------------------------------------------------------

method(chat_request, ProviderClaude) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL) {
req <- base_request(provider)
# https://docs.anthropic.com/en/api/messages
req <- req_url_path_append(req, "/messages")

body <- chat_body(
provider,
stream = stream,
turns = turns,
tools = tools,
type = type
)
req <- req_body_json(req, body)
req
}

method(chat_body, ProviderClaude) <- function(provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL) {

if (length(turns) >= 1 && is_system_prompt(turns[[1]])) {
system <- turns[[1]]@text
} else {
Expand Down Expand Up @@ -134,10 +158,82 @@ method(chat_request, ProviderClaude) <- function(provider,
tools = tools,
tool_choice = tool_choice,
))
body <- modify_list(body, provider@extra_args)
req <- req_body_json(req, body)
modify_list(body, provider@extra_args)
}

req
# Batch chat -------------------------------------------------------------------

method(has_batch_support, ProviderClaude) <- function(provider) {
TRUE
}

# https://docs.anthropic.com/en/api/creating-message-batches
method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) {
req <- base_request(provider)
req <- req_url_path_append(req, "/messages/batches")

requests <- map(seq_along(turns), function(i) {
params <- chat_body(
provider,
stream = FALSE,
turns = turns[[i]],
type = type
)
list(
custom_id = paste0("chat-", i),
params = params
)
})
req <- req_body_json(req, list(requests = requests))

resp <- req_perform(req)
resp_body_json(resp)
}

# https://docs.anthropic.com/en/api/retrieving-message-batches
method(batch_poll, ProviderClaude) <- function(provider, batch) {
req <- base_request(provider)
req <- req_url_path_append(req, "/messages/batches/", batch$id)
resp <- req_perform(req)
resp_body_json(resp)
}

method(batch_info, ProviderClaude) <- function(provider, batch) {
counts <- batch$request_counts

list(
working = batch$processing_status != "ended",
counts = list(
processing = counts$processing,
succeeded = counts$succeeded,
failed = counts$errored + counts$canceled + counts$expired
)
)
}

# https://docs.anthropic.com/en/api/retrieving-message-batch-results
method(batch_retrieve, ProviderClaude) <- function(provider, batch) {
req <- base_request(provider)
req <- req_url(req, batch$results_url)
req <- req_progress(req, "down")

path <- withr::local_tempfile()
req <- req_perform(req, path = path)

lines <- readLines(path, warn = FALSE)
json <- lapply(lines, jsonlite::fromJSON, simplifyVector = FALSE)

ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id")))
results <- lapply(json, "[[", "result")
results[order(ids)]
}

method(batch_result_ok, ProviderClaude) <- function(provider, result) {
result$type == "succeeded"
}

method(batch_result_turn, ProviderClaude) <- function(provider, result, has_type = FALSE) {
value_turn(provider, result$message, has_type = has_type)
}

# Claude -> ellmer --------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions R/provider-gemini-upload.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ gemini_upload_status <- function(uri, credentials) {
}

gemini_upload_wait <- function(status, credentials) {
cli::cli_progress_bar(format = "{cli::pb_spin} Processing [{cli::pb_elapsed}] ")
cli::cli_progress_bar(format = "{cli::pb_spin} Processing... [{cli::pb_elapsed}] ")

while (status$state == "PROCESSING") {
cli::cli_progress_update()
status <- gemini_upload_status(status$uri, credentials)
Sys.sleep(0.5)
Sys.sleep(1)
}
if (status$state == "FAILED") {
cli::cli_abort("Upload failed: {status$error$message}")
Expand Down
Loading
Loading