diff --git a/src/scripting/context.rs b/src/scripting/context.rs index aed7aa5..a6e4805 100644 --- a/src/scripting/context.rs +++ b/src/scripting/context.rs @@ -11,6 +11,8 @@ use rand::prelude::ThreadRng; use rand::random; use rune::runtime::{Object, Shared}; use rune::{Any, Value}; +use scylla::batch::{Batch, BatchType}; +use scylla::frame::response::result::CqlValue; use scylla::prepared_statement::PreparedStatement; use scylla::query::Query; use std::collections::HashMap; @@ -655,6 +657,67 @@ impl Context { } } + pub async fn batch_prepared( + &self, + keys: Vec<&str>, + params: Vec, + ) -> Result<(), CassError> { + let keys_len = keys.len(); + let params_len = params.len(); + if keys_len != params_len { + return Err(CassError(CassErrorKind::Error(format!( + "Number of prepared statements ({keys_len}) and values ({params_len}) must be equal" + )))); + } else if keys_len == 0 { + return Err(CassError(CassErrorKind::Error("Empty batch".to_string()))); + } + let mut batch: Batch = Batch::new(BatchType::Logged); + let mut batch_values: Vec> = vec![]; + for (i, key) in enumerate(keys) { + let statement = self.statements.get(key).ok_or_else(|| { + CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) + })?; + let statement_col_specs = statement.get_variable_col_specs(); + batch.append_statement((**statement).clone()); + batch_values.push(bind::to_scylla_query_params( + params.get(i).expect("REASON"), + statement_col_specs, + )?); + } + match &self.session { + Some(session) => { + let mut current_attempt_num = 0; + while current_attempt_num <= self.retry_number { + let start_time = self.stats.try_lock().unwrap().start_request(); + let rs = session.batch(&batch, batch_values.clone()).await; + let duration = Instant::now() - start_time; + match rs { + Ok(_) => { + self.stats.try_lock().unwrap().complete_request( + duration, + Some(batch_values.len() as u64), + &rs, + ); + return Ok(()); + } + Err(e) => { + let current_error = CassError(CassErrorKind::Error(format!( + "batch execution failed: {e}" + ))); + handle_retry_error(self, current_attempt_num, current_error).await; + current_attempt_num += 1; + continue; + } + } + } + Err(CassError::query_retries_exceeded(self.retry_number)) + } + None => Err(CassError(CassErrorKind::Error( + "'session' is not defined".to_string(), + ))), + } + } + pub fn elapsed_secs(&self) -> f64 { self.start_time.try_lock().unwrap().elapsed().as_secs_f64() } diff --git a/src/scripting/functions.rs b/src/scripting/functions.rs index 0ca6d95..2412b0e 100644 --- a/src/scripting/functions.rs +++ b/src/scripting/functions.rs @@ -259,6 +259,16 @@ pub async fn execute_prepared( ctx.execute_prepared(&key, params).await } +#[rune::function(instance)] +pub async fn batch_prepared( + ctx: Ref, + keys: Vec>, + params: Vec, +) -> Result<(), CassError> { + ctx.batch_prepared(keys.iter().map(|k| k.deref()).collect(), params) + .await +} + #[rune::function(instance)] pub async fn init_partition_row_distribution_preset( mut ctx: Mut, diff --git a/src/scripting/mod.rs b/src/scripting/mod.rs index d69296b..c513cdc 100644 --- a/src/scripting/mod.rs +++ b/src/scripting/mod.rs @@ -28,6 +28,7 @@ fn try_install( context_module.function_meta(functions::execute)?; context_module.function_meta(functions::prepare)?; context_module.function_meta(functions::execute_prepared)?; + context_module.function_meta(functions::batch_prepared)?; context_module.function_meta(functions::init_partition_row_distribution_preset)?; context_module.function_meta(functions::get_partition_idx)?; context_module.function_meta(functions::get_datacenters)?;