From 88435570e9984787b64e81fed4105d13bcc87b99 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Tue, 18 Feb 2025 16:58:51 -0600 Subject: [PATCH 1/8] Start sketching out batch API Part of #143 --- R/BatchChat.R | 70 +++++++++++++++++++++++++++++ R/provider-claude.R | 105 +++++++++++++++++++++++++++++++++++++++----- R/provider.R | 41 +++++++++++++++++ 3 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 R/BatchChat.R diff --git a/R/BatchChat.R b/R/BatchChat.R new file mode 100644 index 00000000..394d4a4d --- /dev/null +++ b/R/BatchChat.R @@ -0,0 +1,70 @@ +# Make print method call `$run()`? + +BatchChat <- R6::R6Class( + "BatchChat", + public = list( + + # Need to take existing turns and then add user turns to them + # TODO: optionally automatically cache all previous the turns + initialize = function(id, provider, turns, type = NULL) { + private$provider <- provider + private$turns <- turns + private$type <- type + }, + + run = function() { + if (status == "new") { + private$submitted <- batch_submit( + private$provider, + turns = private$turns, + type = private$type + ) + private$poll() + } + + while (status != "completed") { + # TODO: update progress spinner + tryCatch( + { + result <- batch_poll(private$provider, private$submitted) + private$polled <- result$body + }, + interrupt = function(cnd) { + cli::cli_inform(c( + x = "Interrupted by user.", + i = "Use `$run()` to resume." + )) + break + } + ) + + if (result$done) { + break + } + } + + if (status == "completed") { + cli::cli_inform(c( + v = "Completed", + i = "Use `$results()` to get results." + )) + } + }, + + results = function() { + batch_results(private$provider, private$polled) + } + ), + private = list( + id = NULL, + provider = NULL, + turns = NULL, + type = NULL, + + # "new" | "submitted" | "completed" + status = NULL, + + submitted = NULL, + polled = NULL + ) +) diff --git a/R/provider-claude.R b/R/provider-claude.R index 59aa692c..3987aa99 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -72,15 +72,8 @@ anthropic_key_exists <- function() { key_exists("ANTHROPIC_API_KEY") } -method(chat_request, ProviderClaude) <- function(provider, - stream = TRUE, - turns = list(), - tools = list(), - type = NULL) { - +method(base_request, ProviderClaude) <- function(provider) { req <- request(provider@base_url) - # https://docs.anthropic.com/en/api/messages - req <- req_url_path_append(req, "/messages") # req <- req_headers(req, `anthropic-version` = "2023-06-01") # @@ -102,6 +95,37 @@ method(chat_request, ProviderClaude) <- function(provider, } }) + req +} + +# Chat ------------------------------------------------------------------------ + +method(chat_request, ProviderClaude) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + req <- base_request(provider) + # https://docs.anthropic.com/en/api/messages + req <- req_url_path_append(req, "/messages") + + body <- chat_body( + provider, + stream = stream, + turns = turns, + tools = tools, + type = type + ) + req <- req_body_json(req, body) + req +} + +method(chat_body, ProviderClaude) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + if (length(turns) >= 1 && is_system_prompt(turns[[1]])) { system <- turns[[1]]@text } else { @@ -134,10 +158,69 @@ method(chat_request, ProviderClaude) <- function(provider, tools = tools, tool_choice = tool_choice, )) - body <- modify_list(body, provider@extra_args) - req <- req_body_json(req, body) + modify_list(body, provider@extra_args) +} - req +# Batch chat ------------------------------------------------------------------- + +# https://docs.anthropic.com/en/api/creating-message-batches +method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { + req <- provider_request(provider) + req <- req_url_path_append(req, "/messages/batches") + + requests <- map(seq_along(turns), function(i) { + params <- chat_body( + provider, + stream = FALSE, + turns = turns[[i]], + type = type + ) + list( + custom_id = paste0("chat-", i), + params = body + ) + }) + req <- req_body_json(req, list(requests = requests)) + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batches +method(batch_poll, ProviderClaude) <- function(provider, batch) { + req <- provider_request(provider) + req <- req_url_path_append(req, "/messages/batches/", batch$id) + resp <- req_perform(req) + body <- resp_body_json(resp) + + list( + done = body$processing_status == "ended", + status = list( + processing = body$request_counts$processing, + succeeded = body$request_counts$completed, + failed = body$request_counts$errored + + body$request_counts$cancelled + + body$request_counts$expired + ), + body = body + ) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batch-results +method(batch_retrieve, ProviderClaude) <- function(provider, batch) { + req <- provider_request(provider) + req <- req_url_path(batch$results_url) + req <- req_progress(req, "down") + + path <- withr::local_tempfile() + req <- req_perform(req, path = path) + + # + jsonlite::stream_in(path, function(page) { + # Parse json a page at a time + }) + + # Re-align to match inputs } # Claude -> ellmer -------------------------------------------------------------- diff --git a/R/provider.R b/R/provider.R index cf53adba..bc012dcd 100644 --- a/R/provider.R +++ b/R/provider.R @@ -28,12 +28,27 @@ Provider <- new_class( # Create a request------------------------------------ +base_request <- new_generic("base_request", "provider", + function(provider) { + S7_dispatch() + } +) + chat_request <- new_generic("chat_request", "provider", function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL) { S7_dispatch() } ) +chat_body <- new_generic( + "chat_body", + "provider", + function(provider, stream = TRUE, turns = list(), tools = list(), type = NULL) { + S7_dispatch() + } +) + + chat_resp_stream <- new_generic("chat_resp_stream", "provider", function(provider, resp) { S7_dispatch() @@ -75,3 +90,29 @@ method(as_json, list(Provider, class_list)) <- function(provider, x) { method(as_json, list(Provider, ContentJson)) <- function(provider, x) { as_json(provider, ContentText("")) } + +# Batch API --------------------------------------------------------------- + +batch_submit <- new_generic( + "batch_submit", + "provider", + function(provider, turns, type = NULL) { + S7_dispatch() + } +) + +batch_poll <- new_generic( + "batch_poll", + "provider", + function(provider, batch) { + S7_dispatch() + } +) + +batch_retrieve <- new_generic( + "batch_retrieve", + "provider", + function(provider, batch) { + S7_dispatch() + } +) From 30a73b300b5c776a49e1c9070cdb80303e916c40 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 09:32:15 -0600 Subject: [PATCH 2/8] Sketch out basic batch interface --- R/BatchChat.R | 87 +++++++++++---------------------------------- R/chat.R | 27 ++++++++++++-- R/provider-claude.R | 41 ++++++++++----------- R/provider.R | 8 +++++ 4 files changed, 74 insertions(+), 89 deletions(-) diff --git a/R/BatchChat.R b/R/BatchChat.R index 394d4a4d..bda4f7ed 100644 --- a/R/BatchChat.R +++ b/R/BatchChat.R @@ -1,70 +1,23 @@ -# Make print method call `$run()`? - -BatchChat <- R6::R6Class( - "BatchChat", - public = list( - - # Need to take existing turns and then add user turns to them - # TODO: optionally automatically cache all previous the turns - initialize = function(id, provider, turns, type = NULL) { - private$provider <- provider - private$turns <- turns - private$type <- type - }, - - run = function() { - if (status == "new") { - private$submitted <- batch_submit( - private$provider, - turns = private$turns, - type = private$type - ) - private$poll() - } - - while (status != "completed") { - # TODO: update progress spinner - tryCatch( - { - result <- batch_poll(private$provider, private$submitted) - private$polled <- result$body - }, - interrupt = function(cnd) { - cli::cli_inform(c( - x = "Interrupted by user.", - i = "Use `$run()` to resume." - )) - break - } - ) - - if (result$done) { - break - } - } - - if (status == "completed") { - cli::cli_inform(c( - v = "Completed", - i = "Use `$results()` to get results." - )) +batch_wait <- function(provider, batch) { + info <- batch_info(provider, batch) + cli::cli_progress_bar( + format = paste0( + "{cli::pb_spin} ", + "{info$counts$processing} -> {cli::col_green({info$counts$succeeded})} / {cli::col_red({info$counts$failed})} ", + "[{cli::pb_elapsed}]" + ), + clear = FALSE + ) + tryCatch({ + while (info$working) { + for (i in 1:10) { + Sys.sleep(0.1) + cli::cli_progress_update() } - }, - - results = function() { - batch_results(private$provider, private$polled) + batch <- batch_poll(provider, batch) + info <- batch_info(provider, batch) } - ), - private = list( - id = NULL, - provider = NULL, - turns = NULL, - type = NULL, - - # "new" | "submitted" | "completed" - status = NULL, + }, interrupt = function(cnd) {}) - submitted = NULL, - polled = NULL - ) -) + batch +} diff --git a/R/chat.R b/R/chat.R index 2db3089b..8427de3b 100644 --- a/R/chat.R +++ b/R/chat.R @@ -173,8 +173,31 @@ Chat <- R6::R6Class("Chat", map2(json, turns[ok], function(json, user_turn) { chat <- self$clone() - turn <- value_turn(private$provider, json) - chat$add_turn(user_turn, turn) + ai_turn <- value_turn(private$provider, json) + chat$add_turn(user_turn, user_turn) + chat + }) + }, + + #' @description Submit multiple prompts in parallel. Returns a list of + #' [Chat] objects, one for each prompt. + #' @param prompts A list of user prompts. + #' @param max_active The maximum number of simultaenous requests to send. + #' @param rpm Maximum number of requests per minute. + chat_batch = function(prompts) { + turns <- as_user_turns(prompts) + new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) + + batch <- batch_submit(private$provider, new_turns) + batch <- batch_wait(provider, batch) + results <- batch_retrieve(provider, batch) + + ok <- map_lgl(results, function(x) x$type == "succeeded") + + map2(results[ok], turns[ok], function(result, user_turn) { + chat <- self$clone() + ai_turn <- value_turn(private$provider, result$message) + chat$add_turn(user_turn, ai_turn) chat }) }, diff --git a/R/provider-claude.R b/R/provider-claude.R index 3987aa99..955986c6 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -165,7 +165,7 @@ method(chat_body, ProviderClaude) <- function(provider, # https://docs.anthropic.com/en/api/creating-message-batches method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { - req <- provider_request(provider) + req <- base_request(provider) req <- req_url_path_append(req, "/messages/batches") requests <- map(seq_along(turns), function(i) { @@ -177,7 +177,7 @@ method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { ) list( custom_id = paste0("chat-", i), - params = body + params = params ) }) req <- req_body_json(req, list(requests = requests)) @@ -188,39 +188,40 @@ method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { # https://docs.anthropic.com/en/api/retrieving-message-batches method(batch_poll, ProviderClaude) <- function(provider, batch) { - req <- provider_request(provider) + req <- base_request(provider) req <- req_url_path_append(req, "/messages/batches/", batch$id) resp <- req_perform(req) - body <- resp_body_json(resp) + resp_body_json(resp) +} + +method(batch_info, ProviderClaude) <- function(provider, batch) { + counts <- batch$request_counts list( - done = body$processing_status == "ended", - status = list( - processing = body$request_counts$processing, - succeeded = body$request_counts$completed, - failed = body$request_counts$errored + - body$request_counts$cancelled + - body$request_counts$expired - ), - body = body + working = batch$processing_status != "ended", + counts = list( + processing = counts$processing, + succeeded = counts$succeeded, + failed = counts$errored + counts$canceled + counts$expired + ) ) } # https://docs.anthropic.com/en/api/retrieving-message-batch-results method(batch_retrieve, ProviderClaude) <- function(provider, batch) { - req <- provider_request(provider) - req <- req_url_path(batch$results_url) + req <- base_request(provider) + req <- req_url(req, batch$results_url) req <- req_progress(req, "down") path <- withr::local_tempfile() req <- req_perform(req, path = path) - # - jsonlite::stream_in(path, function(page) { - # Parse json a page at a time - }) + lines <- readLines(path, warn = FALSE) + json <- lapply(lines, jsonlite::fromJSON, simplifyVector = FALSE) - # Re-align to match inputs + ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id"))) + results <- lapply(json, "[[", "result") + results[order(ids)] } # Claude -> ellmer -------------------------------------------------------------- diff --git a/R/provider.R b/R/provider.R index bc012dcd..e115e74a 100644 --- a/R/provider.R +++ b/R/provider.R @@ -116,3 +116,11 @@ batch_retrieve <- new_generic( S7_dispatch() } ) + +batch_info <- new_generic( + "batch_info", + "provider", + function(provider, batch) { + S7_dispatch() + } +) From 12ef33fdfbd3832bd5472c8c71f6b1239f64f53f Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 11:42:49 -0600 Subject: [PATCH 3/8] Redocument --- DESCRIPTION | 1 + man/Chat.Rd | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/DESCRIPTION b/DESCRIPTION index f2935cd4..d07f00f2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -54,6 +54,7 @@ Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 Collate: + 'BatchChat.R' 'utils-S7.R' 'types.R' 'content.R' diff --git a/man/Chat.Rd b/man/Chat.Rd index 41c6097e..e304601f 100644 --- a/man/Chat.Rd +++ b/man/Chat.Rd @@ -37,6 +37,7 @@ chat$chat("Tell me a funny joke") \item \href{#method-Chat-last_turn}{\code{Chat$last_turn()}} \item \href{#method-Chat-chat}{\code{Chat$chat()}} \item \href{#method-Chat-chat_parallel}{\code{Chat$chat_parallel()}} +\item \href{#method-Chat-chat_batch}{\code{Chat$chat_batch()}} \item \href{#method-Chat-extract_data}{\code{Chat$extract_data()}} \item \href{#method-Chat-extract_data_parallel}{\code{Chat$extract_data_parallel()}} \item \href{#method-Chat-extract_data_async}{\code{Chat$extract_data_async()}} @@ -246,6 +247,28 @@ Submit multiple prompts in parallel. Returns a list of \item{\code{max_active}}{The maximum number of simultaenous requests to send.} +\item{\code{rpm}}{Maximum number of requests per minute.} +} +\if{html}{\out{}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Chat-chat_batch}{}}} +\subsection{Method \code{chat_batch()}}{ +Submit multiple prompts in parallel. Returns a list of +\link{Chat} objects, one for each prompt. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Chat$chat_batch(prompts)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{prompts}}{A list of user prompts.} + +\item{\code{max_active}}{The maximum number of simultaenous requests to send.} + \item{\code{rpm}}{Maximum number of requests per minute.} } \if{html}{\out{
}} From 5fa5f95a358995ddf58f669657f7338c17197f34 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 11:44:00 -0600 Subject: [PATCH 4/8] Nice error if provider doesn't support batch requests --- R/BatchChat.R | 12 ++++++++++++ R/chat.R | 2 ++ R/provider-claude.R | 4 ++++ R/provider.R | 11 +++++++++++ 4 files changed, 29 insertions(+) diff --git a/R/BatchChat.R b/R/BatchChat.R index bda4f7ed..f1740291 100644 --- a/R/BatchChat.R +++ b/R/BatchChat.R @@ -21,3 +21,15 @@ batch_wait <- function(provider, batch) { batch } + + +check_has_batch_support <- function(provider, call = caller_env()) { + if (has_batch_support(provider)) { + return() + } + + cli::cli_abort( + "Batch requests are not currently supported by this provider.", + call = call + ) +} diff --git a/R/chat.R b/R/chat.R index 8427de3b..02ee24a3 100644 --- a/R/chat.R +++ b/R/chat.R @@ -185,6 +185,8 @@ Chat <- R6::R6Class("Chat", #' @param max_active The maximum number of simultaenous requests to send. #' @param rpm Maximum number of requests per minute. chat_batch = function(prompts) { + check_has_batch_support(private$provider) + turns <- as_user_turns(prompts) new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) diff --git a/R/provider-claude.R b/R/provider-claude.R index 955986c6..c932da6c 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -163,6 +163,10 @@ method(chat_body, ProviderClaude) <- function(provider, # Batch chat ------------------------------------------------------------------- +method(has_batch_support, ProviderClaude) <- function(provider) { + TRUE +} + # https://docs.anthropic.com/en/api/creating-message-batches method(batch_submit, ProviderClaude) <- function(provider, turns, type = NULL) { req <- base_request(provider) diff --git a/R/provider.R b/R/provider.R index e115e74a..f92629f9 100644 --- a/R/provider.R +++ b/R/provider.R @@ -93,6 +93,17 @@ method(as_json, list(Provider, ContentJson)) <- function(provider, x) { # Batch API --------------------------------------------------------------- +has_batch_support <- new_generic( + "has_batch_support", + "provider", + function(provider) { + S7_dispatch() + } +) +method(has_batch_support, class_any) <- function(provider) { + FALSE +} + batch_submit <- new_generic( "batch_submit", "provider", From 1b3c751570b339952f9f678212de376e084737e0 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 11:48:36 -0600 Subject: [PATCH 5/8] Polishing progress bars --- R/BatchChat.R | 20 ++++++++++---------- R/provider-gemini-upload.R | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/R/BatchChat.R b/R/BatchChat.R index f1740291..c00130a0 100644 --- a/R/BatchChat.R +++ b/R/BatchChat.R @@ -2,22 +2,22 @@ batch_wait <- function(provider, batch) { info <- batch_info(provider, batch) cli::cli_progress_bar( format = paste0( - "{cli::pb_spin} ", - "{info$counts$processing} -> {cli::col_green({info$counts$succeeded})} / {cli::col_red({info$counts$failed})} ", + "{cli::pb_spin} Processing... {info$counts$processing} -> {cli::col_green({info$counts$succeeded})} / {cli::col_red({info$counts$failed})} ", "[{cli::pb_elapsed}]" ), clear = FALSE ) - tryCatch({ - while (info$working) { - for (i in 1:10) { - Sys.sleep(0.1) + tryCatch( + { + while (info$working) { + Sys.sleep(1) cli::cli_progress_update() + batch <- batch_poll(provider, batch) + info <- batch_info(provider, batch) } - batch <- batch_poll(provider, batch) - info <- batch_info(provider, batch) - } - }, interrupt = function(cnd) {}) + }, + interrupt = function(cnd) {} + ) batch } diff --git a/R/provider-gemini-upload.R b/R/provider-gemini-upload.R index 1badfed1..20706e3d 100644 --- a/R/provider-gemini-upload.R +++ b/R/provider-gemini-upload.R @@ -100,12 +100,12 @@ gemini_upload_status <- function(uri, credentials) { } gemini_upload_wait <- function(status, credentials) { - cli::cli_progress_bar(format = "{cli::pb_spin} Processing [{cli::pb_elapsed}] ") + cli::cli_progress_bar(format = "{cli::pb_spin} Processing... [{cli::pb_elapsed}] ") while (status$state == "PROCESSING") { cli::cli_progress_update() status <- gemini_upload_status(status$uri, credentials) - Sys.sleep(0.5) + Sys.sleep(1) } if (status$state == "FAILED") { cli::cli_abort("Upload failed: {status$error$message}") From 78b1101a5994e543d42bc3e8f8569f2718d42db6 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 11:55:30 -0600 Subject: [PATCH 6/8] Slap in a data extraction method --- R/chat.R | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/R/chat.R b/R/chat.R index 02ee24a3..f6d48077 100644 --- a/R/chat.R +++ b/R/chat.R @@ -279,6 +279,28 @@ Chat <- R6::R6Class("Chat", }) }, + extract_data_batch = function(prompts, type, convert = TRUE) { + check_has_batch_support(private$provider) + turns <- as_user_turns(prompts) + check_bool(convert) + + needs_wrapper <- S7_inherits(private$provider, ProviderOpenAI) + if (needs_wrapper) { + type <- type_object(wrapper = type) + } + + new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) + batch <- batch_submit(private$provider, new_turns, type = type) + batch <- batch_wait(provider, batch) + results <- batch_retrieve(provider, batch) + + ok <- map_lgl(results, function(x) x$type == "succeeded") + map2(results[ok], turns[ok], function(result, user_turn) { + turn <- value_turn(private$provider, result$message, has_type = TRUE) + extract_data(turn, type, convert = convert, needs_wrapper = needs_wrapper) + }) + }, + #' @description Extract structured data, asynchronously. Returns a promise #' that resolves to an object matching the type specification. #' @param ... The input to send to the chatbot. Will typically include From 30df8b7eed3ee6bde90c0edeef2098a78ed43af3 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 17:04:11 -0600 Subject: [PATCH 7/8] Add OpenAI batching --- R/chat.R | 16 +++-- R/provider-claude.R | 8 +++ R/provider-openai.R | 140 ++++++++++++++++++++++++++++++++++++++++---- R/provider.R | 15 +++++ 4 files changed, 160 insertions(+), 19 deletions(-) diff --git a/R/chat.R b/R/chat.R index f6d48077..729378aa 100644 --- a/R/chat.R +++ b/R/chat.R @@ -191,16 +191,14 @@ Chat <- R6::R6Class("Chat", new_turns <- map(turns, function(new_turn) c(private$.turns, list(new_turn))) batch <- batch_submit(private$provider, new_turns) - batch <- batch_wait(provider, batch) - results <- batch_retrieve(provider, batch) + batch <- batch_wait(private$provider, batch) + results <- batch_retrieve(private$provider, batch) - ok <- map_lgl(results, function(x) x$type == "succeeded") + ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x)) map2(results[ok], turns[ok], function(result, user_turn) { - chat <- self$clone() - ai_turn <- value_turn(private$provider, result$message) - chat$add_turn(user_turn, ai_turn) - chat + ai_turn <- batch_result_turn(private$provider, result) + self$clone()$add_turn(user_turn, ai_turn) }) }, @@ -294,9 +292,9 @@ Chat <- R6::R6Class("Chat", batch <- batch_wait(provider, batch) results <- batch_retrieve(provider, batch) - ok <- map_lgl(results, function(x) x$type == "succeeded") + ok <- map_lgl(results, function(x) batch_result_ok(private$provider, x)) map2(results[ok], turns[ok], function(result, user_turn) { - turn <- value_turn(private$provider, result$message, has_type = TRUE) + turn <- batch_result_turn(private$provider, result, has_type = TRUE) extract_data(turn, type, convert = convert, needs_wrapper = needs_wrapper) }) }, diff --git a/R/provider-claude.R b/R/provider-claude.R index c932da6c..e27f8395 100644 --- a/R/provider-claude.R +++ b/R/provider-claude.R @@ -228,6 +228,14 @@ method(batch_retrieve, ProviderClaude) <- function(provider, batch) { results[order(ids)] } +method(batch_result_ok, ProviderClaude) <- function(provider, result) { + result$type == "succeeded" +} + +method(batch_result_turn, ProviderClaude) <- function(provider, result, has_type = FALSE) { + value_turn(provider, result$message, has_type = has_type) +} + # Claude -> ellmer -------------------------------------------------------------- method(stream_parse, ProviderClaude) <- function(provider, event) { diff --git a/R/provider-openai.R b/R/provider-openai.R index c703779d..cdd23446 100644 --- a/R/provider-openai.R +++ b/R/provider-openai.R @@ -95,15 +95,8 @@ openai_key <- function() { key_get("OPENAI_API_KEY") } -# https://platform.openai.com/docs/api-reference/chat/create -method(chat_request, ProviderOpenAI) <- function(provider, - stream = TRUE, - turns = list(), - tools = list(), - type = NULL) { - +method(base_request, ProviderOpenAI) <- function(provider) { req <- request(provider@base_url) - req <- req_url_path_append(req, "/chat/completions") req <- req_auth_bearer_token(req, provider@api_key) req <- req_retry(req, max_tries = 2) req <- ellmer_req_timeout(req, stream) @@ -116,6 +109,36 @@ method(chat_request, ProviderOpenAI) <- function(provider, } }) + req +} + +# https://platform.openai.com/docs/api-reference/chat/create +method(chat_request, ProviderOpenAI) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + + req <- base_request(provider) + req <- req_url_path_append(req, "/chat/completions") + + body <- chat_body(provider, + stream = stream, + turns = turns, + tools = tools, + type = type + ) + req <- req_body_json(req, body) + + req +} + +method(chat_body, ProviderOpenAI) <- function(provider, + stream = TRUE, + turns = list(), + tools = list(), + type = NULL) { + messages <- compact(unlist(as_json(provider, turns), recursive = FALSE)) tools <- as_json(provider, unname(tools)) @@ -142,11 +165,108 @@ method(chat_request, ProviderOpenAI) <- function(provider, response_format = response_format )) body <- utils::modifyList(body, provider@extra_args) - req <- req_body_json(req, body) - req + body +} + +# Batched requests ------------------------------------------------------------- + +method(has_batch_support, ProviderOpenAI) <- function(provider) { + TRUE +} + +# https://platform.openai.com/docs/api-reference/batch +method(batch_submit, ProviderOpenAI) <- function(provider, turns, type = NULL) { + path <- withr::local_tempfile() + + # First put the requests in a file + # https://platform.openai.com/docs/api-reference/batch/request-input + requests <- map(seq_along(turns), function(i) { + body <- chat_body(provider, stream = FALSE, turns = turns[[i]], type = type) + + list( + custom_id = paste0("chat-", i), + method = "POST", + url = "/v1/chat/completions", + body = body + ) + }) + json <- map_chr(requests, jsonlite::toJSON, auto_unbox = TRUE) + writeLines(json, path) + # Then upload it + uploaded <- openai_upload(provider, path) + + # Now we can submit the + req <- base_request(provider) + req <- req_url_path_append(req, "/batches") + req <- req_body_json(req, list( + input_file_id = uploaded$id, + endpoint = "/v1/chat/completions", + completion_window = "24h" + )) + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://platform.openai.com/docs/api-reference/batch/retrieve +openai_upload <- function(provider, path, purpose = "batch") { + req <- base_request(provider) + req <- req_url_path_append(req, "/files") + req <- req_body_multipart(req, purpose = purpose, file = curl::form_file(path)) + req <- req_progress(req, "up") + + resp <- req_perform(req) + resp_body_json(resp) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batches +method(batch_poll, ProviderOpenAI) <- function(provider, batch) { + req <- base_request(provider) + req <- req_url_path_append(req, "/batches/", batch$id) + + resp <- req_perform(req) + resp_body_json(resp) +} + +method(batch_info, ProviderOpenAI) <- function(provider, batch) { + counts <- batch$request_counts + + list( + working = batch$status != "completed", + counts = list( + processing = counts$total - counts$completed, + succeeded = counts$completed, + failed = counts$failed + ) + ) +} + +# https://docs.anthropic.com/en/api/retrieving-message-batch-results +method(batch_retrieve, ProviderOpenAI) <- function(provider, batch) { + path <- withr::local_tempfile() + + req <- base_request(provider) + req <- req_url_path_append(req, "/files/", batch$output_file_id, "/content") + req <- req_progress(req, "down") + resp <- req_perform(req, path = path) + + lines <- readLines(path, warn = FALSE) + json <- lapply(lines, jsonlite::fromJSON, simplifyVector = FALSE) + + ids <- as.numeric(gsub("chat-", "", map_chr(json, "[[", "custom_id"))) + results <- lapply(json, "[[", "response") + results[order(ids)] } + +method(batch_result_ok, ProviderOpenAI) <- function(provider, result) { + result$status_code == 200 +} + +method(batch_result_turn, ProviderOpenAI) <- function(provider, result, has_type = FALSE) { + value_turn(provider, result$body, has_type = has_type) +} # OpenAI -> ellmer -------------------------------------------------------------- method(stream_parse, ProviderOpenAI) <- function(provider, event) { diff --git a/R/provider.R b/R/provider.R index f92629f9..55abfd80 100644 --- a/R/provider.R +++ b/R/provider.R @@ -135,3 +135,18 @@ batch_info <- new_generic( S7_dispatch() } ) + +batch_result_ok <- new_generic( + "batch_result_ok", + "provider", + function(provider, result) { + S7_dispatch() + } +) +batch_result_turn <- new_generic( + "batch_result_turn", + "provider", + function(provider, result, has_type = FALSE) { + S7_dispatch() + } +) From c9bb8140dcf1d79da785828f752f36ef59120885 Mon Sep 17 00:00:00 2001 From: Hadley Wickham Date: Wed, 26 Feb 2025 17:05:40 -0600 Subject: [PATCH 8/8] Initial news bullet --- NEWS.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/NEWS.md b/NEWS.md index ed38b01d..e0145d02 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,10 @@ # ellmer (development version) +* New `$chat_batch()` and `$extract_data_batch()` make it possible to use the + "batch" API provided by Claude, OpenAI, and Gemini (#143). Batch request are + typically 50% cheaper than regular requests but can take up to 24 hours to + complete. + * New `$chat_parallel()` and `$extract_data_parallel()` make it easier to perform multiple actions in parallel (#143). For Claude, note that the number of active connections is limited primarily by the output tokens per limit