diff --git a/src/context.rs b/src/context.rs deleted file mode 100644 index b60a3bc..0000000 --- a/src/context.rs +++ /dev/null @@ -1,1150 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::fmt::{Display, Formatter}; -use std::fs::File; -use std::future::Future; -use std::hash::{Hash, Hasher}; -use std::io; -use std::io::{BufRead, BufReader, ErrorKind, Read}; -use std::net::IpAddr; -use std::ops::Deref; -use std::str::FromStr; -use std::sync::Arc; - -use chrono::Utc; -use itertools::Itertools; -use metrohash::{MetroHash128, MetroHash64}; -use openssl::error::ErrorStack; -use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod}; -use rand::distributions::Distribution; -use rand::prelude::ThreadRng; -use rand::rngs::StdRng; -use rand::{random, Rng, SeedableRng}; -use rune::alloc::fmt::TryWrite; -use rune::macros::{quote, MacroContext, TokenStream}; -use rune::parse::Parser; -use rune::runtime::{Mut, Object, Ref, Shared, TypeInfo, VmError, VmResult}; -use rune::{ast, vm_try, vm_write}; -use rune::{Any, Value}; -use rust_embed::RustEmbed; -use scylla::_macro_internal::ColumnType; -use scylla::frame::response::result::CqlValue; -use scylla::frame::value::CqlTimeuuid; -use scylla::load_balancing::DefaultPolicy; -use scylla::prepared_statement::PreparedStatement; -use scylla::transport::errors::{DbError, NewSessionError, QueryError}; -use scylla::transport::session::PoolSize; -use scylla::{ExecutionProfile, QueryResult, SessionBuilder}; -use statrs::distribution::{Normal, Uniform}; -use tokio::time::{Duration, Instant}; -use tracing::error; -use try_lock::TryLock; -use uuid::{Variant, Version}; - -use crate::config::{ConnectionConf, RetryStrategy}; -use crate::latency::LatencyDistributionRecorder; -use crate::LatteError; - -fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> { - if conf.ssl { - let mut ssl = SslContextBuilder::new(SslMethod::tls())?; - if let Some(path) = &conf.ssl_ca_cert_file { - ssl.set_ca_file(path)?; - } - if let Some(path) = &conf.ssl_cert_file { - ssl.set_certificate_file(path, SslFiletype::PEM)?; - } - if let Some(path) = &conf.ssl_key_file { - ssl.set_private_key_file(path, SslFiletype::PEM)?; - } - Ok(Some(ssl.build())) - } else { - Ok(None) - } -} - -/// Configures connection to Cassandra. -pub async fn connect(conf: &ConnectionConf) -> Result { - let mut policy_builder = DefaultPolicy::builder().token_aware(true); - if let Some(dc) = &conf.datacenter { - policy_builder = policy_builder - .prefer_datacenter(dc.to_owned()) - .permit_dc_failover(true); - } - let profile = ExecutionProfile::builder() - .consistency(conf.consistency.scylla_consistency()) - .load_balancing_policy(policy_builder.build()) - .request_timeout(Some(conf.request_timeout)) - .build(); - - let scylla_session = SessionBuilder::new() - .known_nodes(&conf.addresses) - .pool_size(PoolSize::PerShard(conf.count)) - .user(&conf.user, &conf.password) - .ssl_context(ssl_context(&conf)?) - .default_execution_profile_handle(profile.into_handle()) - .build() - .await - .map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?; - Ok(Context::new(scylla_session, conf.retry_strategy)) -} - -pub struct ClusterInfo { - pub name: String, - pub cassandra_version: String, -} - -/// Transforms a CqlValue object to a string dedicated to be part of CassError message -pub fn cql_value_obj_to_string(v: &CqlValue) -> String { - let no_transformation_size_limit = 32; - match v { - // Replace big string- and bytes-alike object values with its size labels - CqlValue::Text(param) if param.len() > no_transformation_size_limit => { - format!("Text(={})", param.len()) - } - CqlValue::Ascii(param) if param.len() > no_transformation_size_limit => { - format!("Ascii(={})", param.len()) - } - CqlValue::Blob(param) if param.len() > no_transformation_size_limit => { - format!("Blob(={})", param.len()) - } - CqlValue::UserDefinedType { - keyspace, - type_name, - fields, - } => { - let mut result = format!( - "UDT {{ keyspace: \"{}\", type_name: \"{}\", fields: [", - keyspace, type_name, - ); - for (field_name, field_value) in fields { - let field_string = match field_value { - Some(field) => cql_value_obj_to_string(field), - None => String::from("None"), - }; - result.push_str(&format!("(\"{}\", {}), ", field_name, field_string)); - } - if result.len() >= 2 { - result.truncate(result.len() - 2); - } - result.push_str("] }"); - result - } - CqlValue::List(elements) => { - let mut result = String::from("List(["); - for element in elements { - let element_string = cql_value_obj_to_string(element); - result.push_str(&element_string); - result.push_str(", "); - } - if result.len() >= 2 { - result.truncate(result.len() - 2); - } - result.push_str("])"); - result - } - CqlValue::Set(elements) => { - let mut result = String::from("Set(["); - for element in elements { - let element_string = cql_value_obj_to_string(element); - result.push_str(&element_string); - result.push_str(", "); - } - if result.len() >= 2 { - result.truncate(result.len() - 2); - } - result.push_str("])"); - result - } - CqlValue::Map(pairs) => { - let mut result = String::from("Map({"); - for (key, value) in pairs { - let key_string = cql_value_obj_to_string(key); - let value_string = cql_value_obj_to_string(value); - result.push_str(&format!("({}: {}), ", key_string, value_string)); - } - if result.len() >= 2 { - result.truncate(result.len() - 2); - } - result.push_str("})"); - result - } - _ => format!("{v:?}"), - } -} - -#[derive(Any, Debug)] -pub struct CassError(pub CassErrorKind); - -impl CassError { - fn prepare_error(cql: &str, err: QueryError) -> CassError { - CassError(CassErrorKind::Prepare(cql.to_string(), err)) - } - - fn query_execution_error(cql: &str, params: &[CqlValue], err: QueryError) -> CassError { - let query = QueryInfo { - cql: cql.to_string(), - params: params.iter().map(cql_value_obj_to_string).collect(), - }; - let kind = match err { - QueryError::RequestTimeout(_) - | QueryError::TimeoutError - | QueryError::DbError( - DbError::Overloaded | DbError::ReadTimeout { .. } | DbError::WriteTimeout { .. }, - _, - ) => CassErrorKind::Overloaded(query, err), - _ => CassErrorKind::QueryExecution(query, err), - }; - CassError(kind) - } -} - -#[derive(Debug)] -pub struct QueryInfo { - cql: String, - params: Vec, -} - -impl Display for QueryInfo { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "\"{}\" with params [{}]", - self.cql, - self.params.join(", ") - ) - } -} - -#[derive(Debug)] -pub enum CassErrorKind { - SslConfiguration(ErrorStack), - FailedToConnect(Vec, NewSessionError), - PreparedStatementNotFound(String), - QueryRetriesExceeded(String), - QueryParamConversion(String, ColumnType, Option), - ValueOutOfRange(String, ColumnType), - InvalidNumberOfQueryParams, - InvalidQueryParamsObject(TypeInfo), - Prepare(String, QueryError), - Overloaded(QueryInfo, QueryError), - QueryExecution(QueryInfo, QueryError), -} - -impl CassError { - #[rune::function(protocol = STRING_DISPLAY)] - pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.to_string()); - VmResult::Ok(()) - } - - pub fn display(&self, buf: &mut String) -> std::fmt::Result { - use std::fmt::Write; - match &self.0 { - CassErrorKind::SslConfiguration(e) => { - write!(buf, "SSL configuration error: {e}") - } - CassErrorKind::FailedToConnect(hosts, e) => { - write!(buf, "Could not connect to {}: {}", hosts.join(","), e) - } - CassErrorKind::PreparedStatementNotFound(s) => { - write!(buf, "Prepared statement not found: {s}") - } - CassErrorKind::QueryRetriesExceeded(s) => { - write!(buf, "QueryRetriesExceeded: {s}") - } - CassErrorKind::ValueOutOfRange(v, t) => { - write!(buf, "Value {v} out of range for Cassandra type {t:?}") - } - CassErrorKind::QueryParamConversion(v, t, None) => { - write!(buf, "Cannot convert value {v} to Cassandra type {t:?}") - } - CassErrorKind::QueryParamConversion(v, t, Some(e)) => { - write!(buf, "Cannot convert value {v} to Cassandra type {t:?}: {e}") - } - CassErrorKind::InvalidNumberOfQueryParams => { - write!(buf, "Incorrect number of query parameters") - } - CassErrorKind::InvalidQueryParamsObject(t) => { - write!(buf, "Value of type {t} cannot by used as query parameters; expected a list or object") - } - CassErrorKind::Prepare(q, e) => { - write!(buf, "Failed to prepare query \"{q}\": {e}") - } - CassErrorKind::Overloaded(q, e) => { - write!(buf, "Overloaded when executing query {q}: {e}") - } - CassErrorKind::QueryExecution(q, e) => { - write!(buf, "Failed to execute query {q}: {e}") - } - } - } -} - -impl Display for CassError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - let mut buf = String::new(); - self.display(&mut buf).unwrap(); - write!(f, "{buf}") - } -} - -impl From for CassError { - fn from(e: ErrorStack) -> CassError { - CassError(CassErrorKind::SslConfiguration(e)) - } -} - -impl std::error::Error for CassError {} - -#[derive(Clone, Debug)] -pub struct SessionStats { - pub req_count: u64, - pub req_errors: HashSet, - pub req_error_count: u64, - pub req_retry_count: u64, - pub row_count: u64, - pub queue_length: u64, - pub mean_queue_length: f32, - pub resp_times_ns: LatencyDistributionRecorder, -} - -impl SessionStats { - pub fn new() -> SessionStats { - Default::default() - } - - pub fn start_request(&mut self) -> Instant { - if self.req_count > 0 { - self.mean_queue_length += - (self.queue_length as f32 - self.mean_queue_length) / self.req_count as f32; - } - self.queue_length += 1; - Instant::now() - } - - pub fn complete_request( - &mut self, - duration: Duration, - rs: &Result, - retries: u64, - ) { - self.queue_length -= 1; - self.resp_times_ns.record(duration); - self.req_count += 1; - self.req_retry_count += retries; - match rs { - Ok(rs) => self.row_count += rs.rows.as_ref().map(|r| r.len()).unwrap_or(0) as u64, - Err(e) => { - self.req_error_count += 1; - self.req_errors.insert(format!("{e}")); - } - } - } - - /// Resets all accumulators - pub fn reset(&mut self) { - self.req_error_count = 0; - self.row_count = 0; - self.req_count = 0; - self.req_retry_count = 0; - self.mean_queue_length = 0.0; - self.req_errors.clear(); - self.resp_times_ns.clear(); - - // note that current queue_length is *not* reset to zero because there - // might be pending requests and if we set it to zero, that would underflow - } -} - -impl Default for SessionStats { - fn default() -> Self { - SessionStats { - req_count: 0, - req_errors: HashSet::new(), - req_error_count: 0, - req_retry_count: 0, - row_count: 0, - queue_length: 0, - mean_queue_length: 0.0, - resp_times_ns: LatencyDistributionRecorder::default(), - } - } -} - -pub fn get_exponential_retry_interval( - min_interval: Duration, - max_interval: Duration, - current_attempt_num: u64, -) -> Duration { - let min_interval_float: f64 = min_interval.as_secs_f64(); - let mut current_interval: f64 = - min_interval_float * (2u64.pow(current_attempt_num.try_into().unwrap_or(0)) as f64); - - // Add jitter - current_interval += random::() * min_interval_float; - current_interval -= min_interval_float / 2.0; - - Duration::from_secs_f64(current_interval.min(max_interval.as_secs_f64())) -} - -/// This is the main object that a workload script uses to interface with the outside world. -/// It also tracks query execution metrics such as number of requests, rows, response times etc. -#[derive(Any)] -pub struct Context { - start_time: TryLock, - session: Arc, - statements: HashMap>, - stats: TryLock, - retry_strategy: RetryStrategy, - #[rune(get, set, add_assign, copy)] - pub load_cycle_count: u64, - #[rune(get)] - pub data: Value, - pub rng: ThreadRng, -} - -// Needed, because Rune `Value` is !Send, as it may contain some internal pointers. -// Therefore, it is not safe to pass a `Value` to another thread by cloning it, because -// both objects could accidentally share some unprotected, `!Sync` data. -// To make it safe, the same `Context` is never used by more than one thread at once, and -// we make sure in `clone` to make a deep copy of the `data` field by serializing -// and deserializing it, so no pointers could get through. -unsafe impl Send for Context {} -unsafe impl Sync for Context {} - -impl Context { - pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Context { - Context { - start_time: TryLock::new(Instant::now()), - session: Arc::new(session), - statements: HashMap::new(), - stats: TryLock::new(SessionStats::new()), - retry_strategy, - load_cycle_count: 0, - data: Value::Object(Shared::new(Object::new()).unwrap()), - rng: rand::thread_rng(), - } - } - - /// Clones the context for use by another thread. - /// The new clone gets fresh statistics. - /// The user data gets passed through serialization and deserialization to avoid - /// accidental data sharing. - pub fn clone(&self) -> Result { - let serialized = rmp_serde::to_vec(&self.data)?; - let deserialized: Value = rmp_serde::from_slice(&serialized)?; - Ok(Context { - session: self.session.clone(), - statements: self.statements.clone(), - stats: TryLock::new(SessionStats::default()), - data: deserialized, - start_time: TryLock::new(*self.start_time.try_lock().unwrap()), - rng: rand::thread_rng(), - ..*self - }) - } - - /// Returns cluster metadata such as cluster name and cassandra version. - pub async fn cluster_info(&self) -> Result, CassError> { - let cql = "SELECT cluster_name, release_version FROM system.local"; - let rs = self - .session - .query(cql, ()) - .await - .map_err(|e| CassError::query_execution_error(cql, &[], e)); - match rs { - Ok(rs) => { - if let Some(rows) = rs.rows { - if let Some(row) = rows.into_iter().next() { - if let Ok((name, cassandra_version)) = row.into_typed() { - return Ok(Some(ClusterInfo { - name, - cassandra_version, - })); - } - } - } - Ok(None) - } - Err(e) => { - eprintln!("WARNING: {e}", e=e); - Ok(None) - } - } - } - - /// Prepares a statement and stores it in an internal statement map for future use. - pub async fn prepare(&mut self, key: &str, cql: &str) -> Result<(), CassError> { - let statement = self - .session - .prepare(cql) - .await - .map_err(|e| CassError::prepare_error(cql, e))?; - self.statements.insert(key.to_string(), Arc::new(statement)); - Ok(()) - } - - /// Executes an ad-hoc CQL statement with no parameters. Does not prepare. - pub async fn execute(&self, cql: &str) -> Result<(), CassError> { - if let Err(err) = self.execute_inner(|| self.session.query(cql, ())).await { - let err = CassError::query_execution_error(cql, &[], err); - error!("{}", err); - return Err(err); - } - Ok(()) - } - - /// Executes a statement prepared and registered earlier by a call to `prepare`. - pub async fn execute_prepared(&self, key: &str, params: Value) -> Result<(), CassError> { - let statement = self - .statements - .get(key) - .ok_or_else(|| CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())))?; - - let params = bind::to_scylla_query_params(¶ms, statement.get_variable_col_specs())?; - let rs = self - .execute_inner(|| self.session.execute(statement, params.clone())) - .await; - - if let Err(err) = rs { - let err = CassError::query_execution_error(statement.get_statement(), ¶ms, err); - error!("{}", err); - return Err(err); - } - - Ok(()) - } - - async fn execute_inner(&self, f: impl Fn() -> R) -> Result - where - R: Future>, - { - let start_time = self.stats.try_lock().unwrap().start_request(); - - let mut rs: Result = Err(QueryError::TimeoutError); - let mut attempts = 0; - let retry_strategy = &self.retry_strategy; - while attempts <= retry_strategy.retries && Self::should_retry(&rs, retry_strategy) { - if attempts > 0 { - let current_retry_interval = get_exponential_retry_interval( - retry_strategy.retry_delay.min, - retry_strategy.retry_delay.max, - attempts, - ); - tokio::time::sleep(current_retry_interval).await; - } - rs = f().await; - attempts += 1; - } - - let duration = Instant::now() - start_time; - self.stats - .try_lock() - .unwrap() - .complete_request(duration, &rs, attempts - 1); - rs - } - - pub fn elapsed_secs(&self) -> f64 { - self.start_time.try_lock().unwrap().elapsed().as_secs_f64() - } - - fn should_retry(result: &Result, retry_strategy: &RetryStrategy) -> bool { - if !result.is_err() { - return false; - } - if retry_strategy.retry_on_all_errors { - return true; - } - matches!( - result, - Err(QueryError::RequestTimeout(_)) - | Err(QueryError::TimeoutError) - | Err(QueryError::DbError( - DbError::ReadTimeout { .. } - | DbError::WriteTimeout { .. } - | DbError::Overloaded, - _ - )) - ) - } - - /// Returns the current accumulated request stats snapshot and resets the stats. - pub fn take_session_stats(&self) -> SessionStats { - let mut stats = self.stats.try_lock().unwrap(); - let result = stats.clone(); - stats.reset(); - result - } - - /// Resets query and request counters - pub fn reset(&self) { - self.stats.try_lock().unwrap().reset(); - *self.start_time.try_lock().unwrap() = Instant::now(); - } -} - -/// Functions for binding rune values to CQL parameters -mod bind { - use crate::CassErrorKind; - use rune::ToValue; - use scylla::_macro_internal::ColumnType; - use scylla::frame::response::result::{ColumnSpec, CqlValue}; - - use super::*; - - fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result { - // TODO: add support for the following native CQL types: - // 'counter', 'date', 'decimal', 'duration', 'inet', 'time', - // 'timestamp', 'timeuuid' and 'variant'. - // Also, for the 'tuple'. - match (v, typ) { - (Value::Bool(v), ColumnType::Boolean) => Ok(CqlValue::Boolean(*v)), - - (Value::Byte(v), ColumnType::TinyInt) => Ok(CqlValue::TinyInt(*v as i8)), - (Value::Byte(v), ColumnType::SmallInt) => Ok(CqlValue::SmallInt(*v as i16)), - (Value::Byte(v), ColumnType::Int) => Ok(CqlValue::Int(*v as i32)), - (Value::Byte(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v as i64)), - - (Value::Integer(v), ColumnType::TinyInt) => { - convert_int(*v, ColumnType::TinyInt, CqlValue::TinyInt) - } - (Value::Integer(v), ColumnType::SmallInt) => { - convert_int(*v, ColumnType::SmallInt, CqlValue::SmallInt) - } - (Value::Integer(v), ColumnType::Int) => convert_int(*v, ColumnType::Int, CqlValue::Int), - (Value::Integer(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v)), - (Value::Integer(v), ColumnType::Timestamp) => { - Ok(CqlValue::Timestamp(scylla::frame::value::CqlTimestamp(*v))) - } - - (Value::Float(v), ColumnType::Float) => Ok(CqlValue::Float(*v as f32)), - (Value::Float(v), ColumnType::Double) => Ok(CqlValue::Double(*v)), - - (Value::String(s), ColumnType::Timeuuid) => { - let timeuuid_str = s.borrow_ref().unwrap(); - let timeuuid = CqlTimeuuid::from_str(timeuuid_str.as_str()); - match timeuuid { - Ok(timeuuid) => Ok(CqlValue::Timeuuid(timeuuid)), - Err(e) => Err(CassError(CassErrorKind::QueryParamConversion( - format!("{:?}", v), - ColumnType::Timeuuid, - Some(format!("{}", e)), - ))), - } - } - (Value::String(v), ColumnType::Text | ColumnType::Ascii) => { - Ok(CqlValue::Text(v.borrow_ref().unwrap().as_str().to_string())) - } - (Value::String(s), ColumnType::Inet) => { - let ipaddr_str = s.borrow_ref().unwrap(); - let ipaddr = IpAddr::from_str(ipaddr_str.as_str()); - match ipaddr { - Ok(ipaddr) => Ok(CqlValue::Inet(ipaddr)), - Err(e) => Err(CassError(CassErrorKind::QueryParamConversion( - format!("{:?}", v), - ColumnType::Inet, - Some(format!("{}", e)), - ))), - } - } - (Value::Bytes(v), ColumnType::Blob) => { - Ok(CqlValue::Blob(v.borrow_ref().unwrap().to_vec())) - } - (Value::Option(v), typ) => match v.borrow_ref().unwrap().as_ref() { - Some(v) => to_scylla_value(v, typ), - None => Ok(CqlValue::Empty), - }, - (Value::Vec(v), ColumnType::List(elt)) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| to_scylla_value(v, elt)) - .try_collect()?; - Ok(CqlValue::List(elements)) - } - (Value::Vec(v), ColumnType::Set(elt)) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| to_scylla_value(v, elt)) - .try_collect()?; - Ok(CqlValue::Set(elements)) - } - (Value::Vec(v), ColumnType::Map(key_elt, value_elt)) => { - let v = v.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(v.len()); - for tuple in v.iter() { - match tuple { - Value::Tuple(tuple) if tuple.borrow_ref().unwrap().len() == 2 => { - let tuple = tuple.borrow_ref().unwrap(); - let key = to_scylla_value(tuple.first().unwrap(), key_elt)?; - let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?; - map_vec.push((key, value)); - } - _ => { - return Err(CassError(CassErrorKind::QueryParamConversion( - format!("{:?}", tuple), - ColumnType::Tuple(vec![ - key_elt.as_ref().clone(), - value_elt.as_ref().clone(), - ]), - None, - ))); - } - } - } - Ok(CqlValue::Map(map_vec)) - } - (Value::Object(obj), ColumnType::Map(key_elt, value_elt)) => { - let obj = obj.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(obj.keys().len()); - for (k, v) in obj.iter() { - let key = String::from(k.as_str()); - let key = to_scylla_value(&(key.to_value().unwrap()), key_elt)?; - let value = to_scylla_value(v, value_elt)?; - map_vec.push((key, value)); - } - Ok(CqlValue::Map(map_vec)) - } - ( - Value::Object(v), - ColumnType::UserDefinedType { - keyspace, - type_name, - field_types, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let fields = read_fields(|s| obj.get(s), field_types)?; - Ok(CqlValue::UserDefinedType { - keyspace: keyspace.to_string(), - type_name: type_name.to_string(), - fields, - }) - } - ( - Value::Struct(v), - ColumnType::UserDefinedType { - keyspace, - type_name, - field_types, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let fields = read_fields(|s| obj.get(s), field_types)?; - Ok(CqlValue::UserDefinedType { - keyspace: keyspace.to_string(), - type_name: type_name.to_string(), - fields, - }) - } - - (Value::Any(obj), ColumnType::Uuid) => { - let obj = obj.borrow_ref().unwrap(); - let h = obj.type_hash(); - if h == Uuid::type_hash() { - let uuid: &Uuid = obj.downcast_borrow_ref().unwrap(); - Ok(CqlValue::Uuid(uuid.0)) - } else { - Err(CassError(CassErrorKind::QueryParamConversion( - format!("{:?}", v), - ColumnType::Uuid, - None, - ))) - } - } - (value, typ) => Err(CassError(CassErrorKind::QueryParamConversion( - format!("{:?}", value), - typ.clone(), - None, - ))), - } - } - - fn convert_int, R>( - value: i64, - typ: ColumnType, - f: impl Fn(T) -> R, - ) -> Result { - let converted = value.try_into().map_err(|_| { - CassError(CassErrorKind::ValueOutOfRange( - value.to_string(), - typ.clone(), - )) - })?; - Ok(f(converted)) - } - - /// Binds parameters passed as a single rune value to the arguments of the statement. - /// The `params` value can be a tuple, a vector, a struct or an object. - pub fn to_scylla_query_params( - params: &Value, - types: &[ColumnSpec], - ) -> Result, CassError> { - Ok(match params { - Value::Tuple(tuple) => { - let mut values = Vec::new(); - let tuple = tuple.borrow_ref().unwrap(); - if tuple.len() != types.len() { - return Err(CassError(CassErrorKind::InvalidNumberOfQueryParams)); - } - for (v, t) in tuple.iter().zip(types) { - values.push(to_scylla_value(v, &t.typ)?); - } - values - } - Value::Vec(vec) => { - let mut values = Vec::new(); - - let vec = vec.borrow_ref().unwrap(); - for (v, t) in vec.iter().zip(types) { - values.push(to_scylla_value(v, &t.typ)?); - } - values - } - Value::Object(obj) => { - let obj = obj.borrow_ref().unwrap(); - read_params(|f| obj.get(f), types)? - } - Value::Struct(obj) => { - let obj = obj.borrow_ref().unwrap(); - read_params(|f| obj.get(f), types)? - } - other => { - return Err(CassError(CassErrorKind::InvalidQueryParamsObject( - other.type_info().unwrap(), - ))); - } - }) - } - - fn read_params<'a, 'b>( - get_value: impl Fn(&str) -> Option<&'a Value>, - params: &[ColumnSpec], - ) -> Result, CassError> { - let mut values = Vec::with_capacity(params.len()); - for column in params { - let value = match get_value(&column.name) { - Some(value) => to_scylla_value(value, &column.typ)?, - None => CqlValue::Empty, - }; - values.push(value) - } - Ok(values) - } - - fn read_fields<'a, 'b>( - get_value: impl Fn(&str) -> Option<&'a Value>, - fields: &[(String, ColumnType)], - ) -> Result)>, CassError> { - let mut values = Vec::with_capacity(fields.len()); - for (field_name, field_type) in fields { - if let Some(value) = get_value(field_name) { - let value = Some(to_scylla_value(value, field_type)?); - values.push((field_name.to_string(), value)) - }; - } - Ok(values) - } -} - -#[derive(RustEmbed)] -#[folder = "resources/"] -struct Resources; - -#[derive(Clone, Debug, Any)] -pub struct Uuid(pub uuid::Uuid); - -impl Uuid { - pub fn new(i: i64) -> Uuid { - let mut hash = MetroHash128::new(); - i.hash(&mut hash); - let (h1, h2) = hash.finish128(); - let h = ((h1 as u128) << 64) | (h2 as u128); - let mut builder = uuid::Builder::from_u128(h); - builder.set_variant(Variant::RFC4122); - builder.set_version(Version::Random); - Uuid(builder.into_uuid()) - } - - #[rune::function(protocol = STRING_DISPLAY)] - pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.0); - VmResult::Ok(()) - } -} - -#[derive(Clone, Debug, Any)] -pub struct Int8(pub i8); - -#[derive(Clone, Debug, Any)] -pub struct Int16(pub i16); - -#[derive(Clone, Debug, Any)] -pub struct Int32(pub i32); - -#[derive(Clone, Debug, Any)] -pub struct Float32(pub f32); - -/// Returns the literal value stored in the `params` map under the key given as the first -/// macro arg, and if not found, returns the expression from the second arg. -pub fn param( - ctx: &mut MacroContext, - params: &HashMap, - ts: &TokenStream, -) -> rune::compile::Result { - let mut parser = Parser::from_token_stream(ts, ctx.macro_span()); - let name = parser.parse::()?; - let name = ctx.resolve(name)?.to_string(); - let _ = parser.parse::()?; - let expr = parser.parse::()?; - let rhs = match params.get(&name) { - Some(value) => { - let src_id = ctx.insert_source(&name, value)?; - let value = ctx.parse_source::(src_id)?; - quote!(#value) - } - None => quote!(#expr), - }; - Ok(rhs.into_token_stream(ctx)?) -} - -/// Creates a new UUID for current iteration -#[rune::function] -pub fn uuid(i: i64) -> Uuid { - Uuid::new(i) -} - -#[rune::function] -pub fn float_to_i8(value: f64) -> Option { - Some(Int8((value as i64).try_into().ok()?)) -} - -/// Computes a hash of an integer value `i`. -/// Returns a value in range `0..i64::MAX`. -fn hash_inner(i: i64) -> i64 { - let mut hash = MetroHash64::new(); - i.hash(&mut hash); - (hash.finish() & 0x7FFFFFFFFFFFFFFF) as i64 -} - -/// Computes a hash of an integer value `i`. -/// Returns a value in range `0..i64::MAX`. -#[rune::function] -pub fn hash(i: i64) -> i64 { - hash_inner(i) -} - -/// Computes hash of two integer values. -#[rune::function] -pub fn hash2(a: i64, b: i64) -> i64 { - let mut hash = MetroHash64::new(); - a.hash(&mut hash); - b.hash(&mut hash); - (hash.finish() & 0x7FFFFFFFFFFFFFFF) as i64 -} - -/// Computes a hash of an integer value `i`. -/// Returns a value in range `0..max`. -#[rune::function] -pub fn hash_range(i: i64, max: i64) -> i64 { - hash_inner(i) % max -} - -/// Generates a floating point value with normal distribution -#[rune::function] -pub fn normal(i: i64, mean: f64, std_dev: f64) -> VmResult { - let mut rng = StdRng::seed_from_u64(i as u64); - let distribution = - vm_try!(Normal::new(mean, std_dev).map_err(|e| VmError::panic(format!("{e}")))); - VmResult::Ok(distribution.sample(&mut rng)) -} - -#[rune::function] -pub fn uniform(i: i64, min: f64, max: f64) -> VmResult { - let mut rng = StdRng::seed_from_u64(i as u64); - let distribution = vm_try!(Uniform::new(min, max).map_err(|e| VmError::panic(format!("{e}")))); - VmResult::Ok(distribution.sample(&mut rng)) -} - -/// Generates random blob of data of given length. -/// Parameter `seed` is used to seed the RNG. -#[rune::function] -pub fn blob(seed: i64, len: usize) -> Vec { - let mut rng = StdRng::seed_from_u64(seed as u64); - (0..len).map(|_| rng.gen::()).collect() -} - -/// Generates random string of given length. -/// Parameter `seed` is used to seed -/// the RNG. -#[rune::function] -pub fn text(seed: i64, len: usize) -> String { - let mut rng = StdRng::seed_from_u64(seed as u64); - (0..len) - .map(|_| { - let code_point = rng.gen_range(0x0061u32..=0x007Au32); // Unicode range for 'a-z' - std::char::from_u32(code_point).unwrap() - }) - .collect() -} - -/// Generates 'now' timestamp -#[rune::function] -pub fn now_timestamp() -> i64 { - Utc::now().timestamp() -} - -/// Selects one item from the collection based on the hash of the given value. -#[rune::function] -pub fn hash_select(i: i64, collection: &[Value]) -> Value { - collection[(hash_inner(i) % collection.len() as i64) as usize].clone() -} - -/// Reads a file into a string. -#[rune::function] -pub fn read_to_string(filename: &str) -> io::Result { - let mut file = File::open(filename).expect("no such file"); - - let mut buffer = String::new(); - file.read_to_string(&mut buffer)?; - - Ok(buffer) -} - -/// Reads a file into a vector of lines. -#[rune::function] -pub fn read_lines(filename: &str) -> io::Result> { - let file = File::open(filename).expect("no such file"); - let buf = BufReader::new(file); - let result = buf - .lines() - .map(|l| l.expect("Could not parse line")) - .collect(); - Ok(result) -} - -/// Reads a resource file as a string. -fn read_resource_to_string_inner(path: &str) -> io::Result { - let resource = Resources::get(path).ok_or_else(|| { - io::Error::new(ErrorKind::NotFound, format!("Resource not found: {path}")) - })?; - let contents = std::str::from_utf8(resource.data.as_ref()) - .map_err(|e| io::Error::new(ErrorKind::InvalidData, format!("Invalid UTF8 string: {e}")))?; - Ok(contents.to_string()) -} - -#[rune::function] -pub fn read_resource_to_string(path: &str) -> io::Result { - read_resource_to_string_inner(path) -} - -#[rune::function] -pub fn read_resource_lines(path: &str) -> io::Result> { - Ok(read_resource_to_string_inner(path)? - .split('\n') - .map(|s| s.to_string()) - .collect_vec()) -} - -#[rune::function(instance)] -pub async fn prepare(mut ctx: Mut, key: Ref, cql: Ref) -> Result<(), CassError> { - ctx.prepare(&key, &cql).await -} - -#[rune::function(instance)] -pub async fn execute(ctx: Ref, cql: Ref) -> Result<(), CassError> { - ctx.execute(cql.deref()).await -} - -#[rune::function(instance)] -pub async fn execute_prepared( - ctx: Ref, - key: Ref, - params: Value, -) -> Result<(), CassError> { - ctx.execute_prepared(&key, params).await -} - -#[rune::function(instance)] -pub fn elapsed_secs(ctx: &Context) -> f64 { - ctx.elapsed_secs() -} - -pub mod i64 { - use crate::context::{Float32, Int16, Int32, Int8}; - - /// Converts a Rune integer to i8 (Cassandra tinyint) - #[rune::function(instance)] - pub fn to_i8(value: i64) -> Option { - Some(Int8(value.try_into().ok()?)) - } - - /// Converts a Rune integer to i16 (Cassandra smallint) - #[rune::function(instance)] - pub fn to_i16(value: i64) -> Option { - Some(Int16(value.try_into().ok()?)) - } - - /// Converts a Rune integer to i32 (Cassandra int) - #[rune::function(instance)] - pub fn to_i32(value: i64) -> Option { - Some(Int32(value.try_into().ok()?)) - } - - /// Converts a Rune integer to f32 (Cassandra float) - #[rune::function(instance)] - pub fn to_f32(value: i64) -> Float32 { - Float32(value as f32) - } - - /// Converts a Rune integer to a String - #[rune::function(instance)] - pub fn to_string(value: i64) -> String { - value.to_string() - } - - /// Restricts a value to a certain interval. - #[rune::function(instance)] - pub fn clamp(value: i64, min: i64, max: i64) -> i64 { - value.clamp(min, max) - } -} - -pub mod f64 { - use crate::context::{Float32, Int16, Int32, Int8}; - - #[rune::function(instance)] - pub fn to_i8(value: f64) -> Int8 { - Int8(value as i8) - } - - #[rune::function(instance)] - pub fn to_i16(value: f64) -> Int16 { - Int16(value as i16) - } - - #[rune::function(instance)] - pub fn to_i32(value: f64) -> Int32 { - Int32(value as i32) - } - - #[rune::function(instance)] - pub fn to_f32(value: f64) -> Float32 { - Float32(value as f32) - } - - #[rune::function(instance)] - pub fn to_string(value: f64) -> String { - value.to_string() - } - - /// Restricts a value to a certain interval unless it is NaN. - #[rune::function(instance)] - pub fn clamp(value: f64, min: f64, max: f64) -> f64 { - value.clamp(min, max) - } -} diff --git a/src/context/bind.rs b/src/context/bind.rs new file mode 100644 index 0000000..2e06059 --- /dev/null +++ b/src/context/bind.rs @@ -0,0 +1,267 @@ +//! Functions for binding rune values to CQL parameters + +use crate::context::cass_error::{CassError, CassErrorKind}; +use crate::context::cql_types::Uuid; +use rune::{Any, ToValue, Value}; +use scylla::_macro_internal::ColumnType; +use scylla::frame::response::result::{ColumnSpec, CqlValue}; +use scylla::frame::value::CqlTimeuuid; +use std::net::IpAddr; +use std::str::FromStr; + +use itertools::*; + +fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result { + // TODO: add support for the following native CQL types: + // 'counter', 'date', 'decimal', 'duration', 'inet', 'time', + // 'timestamp', 'timeuuid' and 'variant'. + // Also, for the 'tuple'. + match (v, typ) { + (Value::Bool(v), ColumnType::Boolean) => Ok(CqlValue::Boolean(*v)), + + (Value::Byte(v), ColumnType::TinyInt) => Ok(CqlValue::TinyInt(*v as i8)), + (Value::Byte(v), ColumnType::SmallInt) => Ok(CqlValue::SmallInt(*v as i16)), + (Value::Byte(v), ColumnType::Int) => Ok(CqlValue::Int(*v as i32)), + (Value::Byte(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v as i64)), + + (Value::Integer(v), ColumnType::TinyInt) => { + convert_int(*v, ColumnType::TinyInt, CqlValue::TinyInt) + } + (Value::Integer(v), ColumnType::SmallInt) => { + convert_int(*v, ColumnType::SmallInt, CqlValue::SmallInt) + } + (Value::Integer(v), ColumnType::Int) => convert_int(*v, ColumnType::Int, CqlValue::Int), + (Value::Integer(v), ColumnType::BigInt) => Ok(CqlValue::BigInt(*v)), + (Value::Integer(v), ColumnType::Timestamp) => { + Ok(CqlValue::Timestamp(scylla::frame::value::CqlTimestamp(*v))) + } + + (Value::Float(v), ColumnType::Float) => Ok(CqlValue::Float(*v as f32)), + (Value::Float(v), ColumnType::Double) => Ok(CqlValue::Double(*v)), + + (Value::String(s), ColumnType::Timeuuid) => { + let timeuuid_str = s.borrow_ref().unwrap(); + let timeuuid = CqlTimeuuid::from_str(timeuuid_str.as_str()); + match timeuuid { + Ok(timeuuid) => Ok(CqlValue::Timeuuid(timeuuid)), + Err(e) => Err(CassError(CassErrorKind::QueryParamConversion( + format!("{:?}", v), + ColumnType::Timeuuid, + Some(format!("{}", e)), + ))), + } + } + (Value::String(v), ColumnType::Text | ColumnType::Ascii) => { + Ok(CqlValue::Text(v.borrow_ref().unwrap().as_str().to_string())) + } + (Value::String(s), ColumnType::Inet) => { + let ipaddr_str = s.borrow_ref().unwrap(); + let ipaddr = IpAddr::from_str(ipaddr_str.as_str()); + match ipaddr { + Ok(ipaddr) => Ok(CqlValue::Inet(ipaddr)), + Err(e) => Err(CassError(CassErrorKind::QueryParamConversion( + format!("{:?}", v), + ColumnType::Inet, + Some(format!("{}", e)), + ))), + } + } + (Value::Bytes(v), ColumnType::Blob) => Ok(CqlValue::Blob(v.borrow_ref().unwrap().to_vec())), + (Value::Option(v), typ) => match v.borrow_ref().unwrap().as_ref() { + Some(v) => to_scylla_value(v, typ), + None => Ok(CqlValue::Empty), + }, + (Value::Vec(v), ColumnType::List(elt)) => { + let v = v.borrow_ref().unwrap(); + let elements = v + .as_ref() + .iter() + .map(|v| to_scylla_value(v, elt)) + .try_collect()?; + Ok(CqlValue::List(elements)) + } + (Value::Vec(v), ColumnType::Set(elt)) => { + let v = v.borrow_ref().unwrap(); + let elements = v + .as_ref() + .iter() + .map(|v| to_scylla_value(v, elt)) + .try_collect()?; + Ok(CqlValue::Set(elements)) + } + (Value::Vec(v), ColumnType::Map(key_elt, value_elt)) => { + let v = v.borrow_ref().unwrap(); + let mut map_vec = Vec::with_capacity(v.len()); + for tuple in v.iter() { + match tuple { + Value::Tuple(tuple) if tuple.borrow_ref().unwrap().len() == 2 => { + let tuple = tuple.borrow_ref().unwrap(); + let key = to_scylla_value(tuple.first().unwrap(), key_elt)?; + let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?; + map_vec.push((key, value)); + } + _ => { + return Err(CassError(CassErrorKind::QueryParamConversion( + format!("{:?}", tuple), + ColumnType::Tuple(vec![ + key_elt.as_ref().clone(), + value_elt.as_ref().clone(), + ]), + None, + ))); + } + } + } + Ok(CqlValue::Map(map_vec)) + } + (Value::Object(obj), ColumnType::Map(key_elt, value_elt)) => { + let obj = obj.borrow_ref().unwrap(); + let mut map_vec = Vec::with_capacity(obj.keys().len()); + for (k, v) in obj.iter() { + let key = String::from(k.as_str()); + let key = to_scylla_value(&(key.to_value().unwrap()), key_elt)?; + let value = to_scylla_value(v, value_elt)?; + map_vec.push((key, value)); + } + Ok(CqlValue::Map(map_vec)) + } + ( + Value::Object(v), + ColumnType::UserDefinedType { + keyspace, + type_name, + field_types, + }, + ) => { + let obj = v.borrow_ref().unwrap(); + let fields = read_fields(|s| obj.get(s), field_types)?; + Ok(CqlValue::UserDefinedType { + keyspace: keyspace.to_string(), + type_name: type_name.to_string(), + fields, + }) + } + ( + Value::Struct(v), + ColumnType::UserDefinedType { + keyspace, + type_name, + field_types, + }, + ) => { + let obj = v.borrow_ref().unwrap(); + let fields = read_fields(|s| obj.get(s), field_types)?; + Ok(CqlValue::UserDefinedType { + keyspace: keyspace.to_string(), + type_name: type_name.to_string(), + fields, + }) + } + + (Value::Any(obj), ColumnType::Uuid) => { + let obj = obj.borrow_ref().unwrap(); + let h = obj.type_hash(); + if h == Uuid::type_hash() { + let uuid: &Uuid = obj.downcast_borrow_ref().unwrap(); + Ok(CqlValue::Uuid(uuid.0)) + } else { + Err(CassError(CassErrorKind::QueryParamConversion( + format!("{:?}", v), + ColumnType::Uuid, + None, + ))) + } + } + (value, typ) => Err(CassError(CassErrorKind::QueryParamConversion( + format!("{:?}", value), + typ.clone(), + None, + ))), + } +} + +fn convert_int, R>( + value: i64, + typ: ColumnType, + f: impl Fn(T) -> R, +) -> Result { + let converted = value.try_into().map_err(|_| { + CassError(CassErrorKind::ValueOutOfRange( + value.to_string(), + typ.clone(), + )) + })?; + Ok(f(converted)) +} + +/// Binds parameters passed as a single rune value to the arguments of the statement. +/// The `params` value can be a tuple, a vector, a struct or an object. +pub fn to_scylla_query_params( + params: &Value, + types: &[ColumnSpec], +) -> Result, CassError> { + Ok(match params { + Value::Tuple(tuple) => { + let mut values = Vec::new(); + let tuple = tuple.borrow_ref().unwrap(); + if tuple.len() != types.len() { + return Err(CassError(CassErrorKind::InvalidNumberOfQueryParams)); + } + for (v, t) in tuple.iter().zip(types) { + values.push(to_scylla_value(v, &t.typ)?); + } + values + } + Value::Vec(vec) => { + let mut values = Vec::new(); + + let vec = vec.borrow_ref().unwrap(); + for (v, t) in vec.iter().zip(types) { + values.push(to_scylla_value(v, &t.typ)?); + } + values + } + Value::Object(obj) => { + let obj = obj.borrow_ref().unwrap(); + read_params(|f| obj.get(f), types)? + } + Value::Struct(obj) => { + let obj = obj.borrow_ref().unwrap(); + read_params(|f| obj.get(f), types)? + } + other => { + return Err(CassError(CassErrorKind::InvalidQueryParamsObject( + other.type_info().unwrap(), + ))); + } + }) +} + +fn read_params<'a, 'b>( + get_value: impl Fn(&str) -> Option<&'a Value>, + params: &[ColumnSpec], +) -> Result, CassError> { + let mut values = Vec::with_capacity(params.len()); + for column in params { + let value = match get_value(&column.name) { + Some(value) => to_scylla_value(value, &column.typ)?, + None => CqlValue::Empty, + }; + values.push(value) + } + Ok(values) +} + +fn read_fields<'a, 'b>( + get_value: impl Fn(&str) -> Option<&'a Value>, + fields: &[(String, ColumnType)], +) -> Result)>, CassError> { + let mut values = Vec::with_capacity(fields.len()); + for (field_name, field_type) in fields { + if let Some(value) = get_value(field_name) { + let value = Some(to_scylla_value(value, field_type)?); + values.push((field_name.to_string(), value)) + }; + } + Ok(values) +} diff --git a/src/context/cass_error.rs b/src/context/cass_error.rs new file mode 100644 index 0000000..c826fbe --- /dev/null +++ b/src/context/cass_error.rs @@ -0,0 +1,210 @@ +use openssl::error::ErrorStack; +use rune::alloc::fmt::TryWrite; +use rune::runtime::{TypeInfo, VmResult}; +use rune::{vm_write, Any}; +use scylla::_macro_internal::{ColumnType, CqlValue}; +use scylla::transport::errors::{DbError, NewSessionError, QueryError}; +use std::fmt::{Display, Formatter}; + +#[derive(Any, Debug)] +pub struct CassError(pub CassErrorKind); + +impl CassError { + pub fn prepare_error(cql: &str, err: QueryError) -> CassError { + CassError(CassErrorKind::Prepare(cql.to_string(), err)) + } + + pub fn query_execution_error(cql: &str, params: &[CqlValue], err: QueryError) -> CassError { + let query = QueryInfo { + cql: cql.to_string(), + params: params.iter().map(cql_value_obj_to_string).collect(), + }; + let kind = match err { + QueryError::RequestTimeout(_) + | QueryError::TimeoutError + | QueryError::DbError( + DbError::Overloaded | DbError::ReadTimeout { .. } | DbError::WriteTimeout { .. }, + _, + ) => CassErrorKind::Overloaded(query, err), + _ => CassErrorKind::QueryExecution(query, err), + }; + CassError(kind) + } +} + +#[derive(Debug)] +pub enum CassErrorKind { + SslConfiguration(ErrorStack), + FailedToConnect(Vec, NewSessionError), + PreparedStatementNotFound(String), + QueryRetriesExceeded(String), + QueryParamConversion(String, ColumnType, Option), + ValueOutOfRange(String, ColumnType), + InvalidNumberOfQueryParams, + InvalidQueryParamsObject(TypeInfo), + Prepare(String, QueryError), + Overloaded(QueryInfo, QueryError), + QueryExecution(QueryInfo, QueryError), +} + +#[derive(Debug)] +pub struct QueryInfo { + cql: String, + params: Vec, +} + +impl CassError { + #[rune::function(protocol = STRING_DISPLAY)] + pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { + vm_write!(f, "{}", self.to_string()); + VmResult::Ok(()) + } + + pub fn display(&self, buf: &mut String) -> std::fmt::Result { + use std::fmt::Write; + match &self.0 { + CassErrorKind::SslConfiguration(e) => { + write!(buf, "SSL configuration error: {e}") + } + CassErrorKind::FailedToConnect(hosts, e) => { + write!(buf, "Could not connect to {}: {}", hosts.join(","), e) + } + CassErrorKind::PreparedStatementNotFound(s) => { + write!(buf, "Prepared statement not found: {s}") + } + CassErrorKind::QueryRetriesExceeded(s) => { + write!(buf, "QueryRetriesExceeded: {s}") + } + CassErrorKind::ValueOutOfRange(v, t) => { + write!(buf, "Value {v} out of range for Cassandra type {t:?}") + } + CassErrorKind::QueryParamConversion(v, t, None) => { + write!(buf, "Cannot convert value {v} to Cassandra type {t:?}") + } + CassErrorKind::QueryParamConversion(v, t, Some(e)) => { + write!(buf, "Cannot convert value {v} to Cassandra type {t:?}: {e}") + } + CassErrorKind::InvalidNumberOfQueryParams => { + write!(buf, "Incorrect number of query parameters") + } + CassErrorKind::InvalidQueryParamsObject(t) => { + write!(buf, "Value of type {t} cannot by used as query parameters; expected a list or object") + } + CassErrorKind::Prepare(q, e) => { + write!(buf, "Failed to prepare query \"{q}\": {e}") + } + CassErrorKind::Overloaded(q, e) => { + write!(buf, "Overloaded when executing query {q}: {e}") + } + CassErrorKind::QueryExecution(q, e) => { + write!(buf, "Failed to execute query {q}: {e}") + } + } + } +} + +impl Display for CassError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut buf = String::new(); + self.display(&mut buf).unwrap(); + write!(f, "{buf}") + } +} + +impl From for CassError { + fn from(e: ErrorStack) -> CassError { + CassError(CassErrorKind::SslConfiguration(e)) + } +} + +impl std::error::Error for CassError {} + +/// Transforms a CqlValue object to a string dedicated to be part of CassError message +pub fn cql_value_obj_to_string(v: &CqlValue) -> String { + let no_transformation_size_limit = 32; + match v { + // Replace big string- and bytes-alike object values with its size labels + CqlValue::Text(param) if param.len() > no_transformation_size_limit => { + format!("Text(={})", param.len()) + } + CqlValue::Ascii(param) if param.len() > no_transformation_size_limit => { + format!("Ascii(={})", param.len()) + } + CqlValue::Blob(param) if param.len() > no_transformation_size_limit => { + format!("Blob(={})", param.len()) + } + CqlValue::UserDefinedType { + keyspace, + type_name, + fields, + } => { + let mut result = format!( + "UDT {{ keyspace: \"{}\", type_name: \"{}\", fields: [", + keyspace, type_name, + ); + for (field_name, field_value) in fields { + let field_string = match field_value { + Some(field) => cql_value_obj_to_string(field), + None => String::from("None"), + }; + result.push_str(&format!("(\"{}\", {}), ", field_name, field_string)); + } + if result.len() >= 2 { + result.truncate(result.len() - 2); + } + result.push_str("] }"); + result + } + CqlValue::List(elements) => { + let mut result = String::from("List(["); + for element in elements { + let element_string = cql_value_obj_to_string(element); + result.push_str(&element_string); + result.push_str(", "); + } + if result.len() >= 2 { + result.truncate(result.len() - 2); + } + result.push_str("])"); + result + } + CqlValue::Set(elements) => { + let mut result = String::from("Set(["); + for element in elements { + let element_string = cql_value_obj_to_string(element); + result.push_str(&element_string); + result.push_str(", "); + } + if result.len() >= 2 { + result.truncate(result.len() - 2); + } + result.push_str("])"); + result + } + CqlValue::Map(pairs) => { + let mut result = String::from("Map({"); + for (key, value) in pairs { + let key_string = cql_value_obj_to_string(key); + let value_string = cql_value_obj_to_string(value); + result.push_str(&format!("({}: {}), ", key_string, value_string)); + } + if result.len() >= 2 { + result.truncate(result.len() - 2); + } + result.push_str("})"); + result + } + _ => format!("{v:?}"), + } +} + +impl Display for QueryInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "\"{}\" with params [{}]", + self.cql, + self.params.join(", ") + ) + } +} diff --git a/src/context/connect.rs b/src/context/connect.rs new file mode 100644 index 0000000..1c6077c --- /dev/null +++ b/src/context/connect.rs @@ -0,0 +1,56 @@ +use crate::config::ConnectionConf; +use crate::context::cass_error::{CassError, CassErrorKind}; +use crate::context::context::Context; +use openssl::ssl::{SslContext, SslContextBuilder, SslFiletype, SslMethod}; +use scylla::load_balancing::DefaultPolicy; +use scylla::transport::session::PoolSize; +use scylla::{ExecutionProfile, SessionBuilder}; + +fn ssl_context(conf: &&ConnectionConf) -> Result, CassError> { + if conf.ssl { + let mut ssl = SslContextBuilder::new(SslMethod::tls())?; + if let Some(path) = &conf.ssl_ca_cert_file { + ssl.set_ca_file(path)?; + } + if let Some(path) = &conf.ssl_cert_file { + ssl.set_certificate_file(path, SslFiletype::PEM)?; + } + if let Some(path) = &conf.ssl_key_file { + ssl.set_private_key_file(path, SslFiletype::PEM)?; + } + Ok(Some(ssl.build())) + } else { + Ok(None) + } +} + +/// Configures connection to Cassandra. +pub async fn connect(conf: &ConnectionConf) -> Result { + let mut policy_builder = DefaultPolicy::builder().token_aware(true); + if let Some(dc) = &conf.datacenter { + policy_builder = policy_builder + .prefer_datacenter(dc.to_owned()) + .permit_dc_failover(true); + } + let profile = ExecutionProfile::builder() + .consistency(conf.consistency.scylla_consistency()) + .load_balancing_policy(policy_builder.build()) + .request_timeout(Some(conf.request_timeout)) + .build(); + + let scylla_session = SessionBuilder::new() + .known_nodes(&conf.addresses) + .pool_size(PoolSize::PerShard(conf.count)) + .user(&conf.user, &conf.password) + .ssl_context(ssl_context(&conf)?) + .default_execution_profile_handle(profile.into_handle()) + .build() + .await + .map_err(|e| CassError(CassErrorKind::FailedToConnect(conf.addresses.clone(), e)))?; + Ok(Context::new(scylla_session, conf.retry_strategy)) +} + +pub struct ClusterInfo { + pub name: String, + pub cassandra_version: String, +} diff --git a/src/context/context.rs b/src/context/context.rs new file mode 100644 index 0000000..93ce2e8 --- /dev/null +++ b/src/context/context.rs @@ -0,0 +1,223 @@ +use crate::config::RetryStrategy; +use crate::context::bind; +use crate::context::cass_error::{CassError, CassErrorKind}; +use crate::context::connect::ClusterInfo; +use crate::error::LatteError; +use crate::stats::session::SessionStats; +use rand::prelude::ThreadRng; +use rand::random; +use rune::runtime::{Object, Shared}; +use rune::{Any, Value}; +use scylla::prepared_statement::PreparedStatement; +use scylla::transport::errors::{DbError, QueryError}; +use scylla::QueryResult; +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; +use std::time::Duration; +use tokio::time::Instant; +use tracing::error; +use try_lock::TryLock; + +/// This is the main object that a workload script uses to interface with the outside world. +/// It also tracks query execution metrics such as number of requests, rows, response times etc. +#[derive(Any)] +pub struct Context { + start_time: TryLock, + session: Arc, + statements: HashMap>, + stats: TryLock, + retry_strategy: RetryStrategy, + #[rune(get, set, add_assign, copy)] + pub load_cycle_count: u64, + #[rune(get)] + pub data: Value, + pub rng: ThreadRng, +} + +// Needed, because Rune `Value` is !Send, as it may contain some internal pointers. +// Therefore, it is not safe to pass a `Value` to another thread by cloning it, because +// both objects could accidentally share some unprotected, `!Sync` data. +// To make it safe, the same `Context` is never used by more than one thread at once, and +// we make sure in `clone` to make a deep copy of the `data` field by serializing +// and deserializing it, so no pointers could get through. +unsafe impl Send for Context {} +unsafe impl Sync for Context {} + +impl Context { + pub fn new(session: scylla::Session, retry_strategy: RetryStrategy) -> Context { + Context { + start_time: TryLock::new(Instant::now()), + session: Arc::new(session), + statements: HashMap::new(), + stats: TryLock::new(SessionStats::new()), + retry_strategy, + load_cycle_count: 0, + data: Value::Object(Shared::new(Object::new()).unwrap()), + rng: rand::thread_rng(), + } + } + + /// Clones the context for use by another thread. + /// The new clone gets fresh statistics. + /// The user data gets passed through serialization and deserialization to avoid + /// accidental data sharing. + pub fn clone(&self) -> Result { + let serialized = rmp_serde::to_vec(&self.data)?; + let deserialized: Value = rmp_serde::from_slice(&serialized)?; + Ok(Context { + session: self.session.clone(), + statements: self.statements.clone(), + stats: TryLock::new(SessionStats::default()), + data: deserialized, + start_time: TryLock::new(*self.start_time.try_lock().unwrap()), + rng: rand::thread_rng(), + ..*self + }) + } + + /// Returns cluster metadata such as cluster name and cassandra version. + pub async fn cluster_info(&self) -> Result, CassError> { + let cql = "SELECT cluster_name, release_version FROM system.local"; + let rs = self + .session + .query(cql, ()) + .await + .map_err(|e| CassError::query_execution_error(cql, &[], e))?; + if let Some(rows) = rs.rows { + if let Some(row) = rows.into_iter().next() { + if let Ok((name, cassandra_version)) = row.into_typed() { + return Ok(Some(ClusterInfo { + name, + cassandra_version, + })); + } + } + } + Ok(None) + } + + /// Prepares a statement and stores it in an internal statement map for future use. + pub async fn prepare(&mut self, key: &str, cql: &str) -> Result<(), CassError> { + let statement = self + .session + .prepare(cql) + .await + .map_err(|e| CassError::prepare_error(cql, e))?; + self.statements.insert(key.to_string(), Arc::new(statement)); + Ok(()) + } + + /// Executes an ad-hoc CQL statement with no parameters. Does not prepare. + pub async fn execute(&self, cql: &str) -> Result<(), CassError> { + if let Err(err) = self.execute_inner(|| self.session.query(cql, ())).await { + let err = CassError::query_execution_error(cql, &[], err); + error!("{}", err); + return Err(err); + } + Ok(()) + } + + /// Executes a statement prepared and registered earlier by a call to `prepare`. + pub async fn execute_prepared(&self, key: &str, params: Value) -> Result<(), CassError> { + let statement = self + .statements + .get(key) + .ok_or_else(|| CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())))?; + + let params = bind::to_scylla_query_params(¶ms, statement.get_variable_col_specs())?; + let rs = self + .execute_inner(|| self.session.execute(statement, params.clone())) + .await; + + if let Err(err) = rs { + let err = CassError::query_execution_error(statement.get_statement(), ¶ms, err); + error!("{}", err); + return Err(err); + } + + Ok(()) + } + + async fn execute_inner(&self, f: impl Fn() -> R) -> Result + where + R: Future>, + { + let start_time = self.stats.try_lock().unwrap().start_request(); + + let mut rs: Result = Err(QueryError::TimeoutError); + let mut attempts = 0; + let retry_strategy = &self.retry_strategy; + while attempts <= retry_strategy.retries && should_retry(&rs, retry_strategy) { + if attempts > 0 { + let current_retry_interval = get_exponential_retry_interval( + retry_strategy.retry_delay.min, + retry_strategy.retry_delay.max, + attempts, + ); + tokio::time::sleep(current_retry_interval).await; + } + rs = f().await; + attempts += 1; + } + + let duration = Instant::now() - start_time; + self.stats + .try_lock() + .unwrap() + .complete_request(duration, &rs, attempts - 1); + rs + } + + pub fn elapsed_secs(&self) -> f64 { + self.start_time.try_lock().unwrap().elapsed().as_secs_f64() + } + + /// Returns the current accumulated request stats snapshot and resets the stats. + pub fn take_session_stats(&self) -> SessionStats { + let mut stats = self.stats.try_lock().unwrap(); + let result = stats.clone(); + stats.reset(); + result + } + + /// Resets query and request counters + pub fn reset(&self) { + self.stats.try_lock().unwrap().reset(); + *self.start_time.try_lock().unwrap() = Instant::now(); + } +} + +pub fn get_exponential_retry_interval( + min_interval: Duration, + max_interval: Duration, + current_attempt_num: u64, +) -> Duration { + let min_interval_float: f64 = min_interval.as_secs_f64(); + let mut current_interval: f64 = + min_interval_float * (2u64.pow(current_attempt_num.try_into().unwrap_or(0)) as f64); + + // Add jitter + current_interval += random::() * min_interval_float; + current_interval -= min_interval_float / 2.0; + + Duration::from_secs_f64(current_interval.min(max_interval.as_secs_f64())) +} + +fn should_retry(result: &Result, retry_strategy: &RetryStrategy) -> bool { + if !result.is_err() { + return false; + } + if retry_strategy.retry_on_all_errors { + return true; + } + matches!( + result, + Err(QueryError::RequestTimeout(_)) + | Err(QueryError::TimeoutError) + | Err(QueryError::DbError( + DbError::ReadTimeout { .. } | DbError::WriteTimeout { .. } | DbError::Overloaded, + _ + )) + ) +} diff --git a/src/context/cql_types.rs b/src/context/cql_types.rs new file mode 100644 index 0000000..55e856e --- /dev/null +++ b/src/context/cql_types.rs @@ -0,0 +1,115 @@ +use metrohash::MetroHash128; +use rune::alloc::fmt::TryWrite; +use rune::runtime::VmResult; +use rune::{vm_write, Any}; +use std::hash::Hash; +use uuid::{Variant, Version}; + +#[derive(Clone, Debug, Any)] +pub struct Int8(pub i8); + +#[derive(Clone, Debug, Any)] +pub struct Int16(pub i16); + +#[derive(Clone, Debug, Any)] +pub struct Int32(pub i32); + +#[derive(Clone, Debug, Any)] +pub struct Float32(pub f32); + +#[derive(Clone, Debug, Any)] +pub struct Uuid(pub uuid::Uuid); + +impl Uuid { + pub fn new(i: i64) -> Uuid { + let mut hash = MetroHash128::new(); + i.hash(&mut hash); + let (h1, h2) = hash.finish128(); + let h = ((h1 as u128) << 64) | (h2 as u128); + let mut builder = uuid::Builder::from_u128(h); + builder.set_variant(Variant::RFC4122); + builder.set_version(Version::Random); + Uuid(builder.into_uuid()) + } + + #[rune::function(protocol = STRING_DISPLAY)] + pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { + vm_write!(f, "{}", self.0); + VmResult::Ok(()) + } +} + +pub mod i64 { + use crate::context::cql_types::{Float32, Int16, Int32, Int8}; + + /// Converts a Rune integer to i8 (Cassandra tinyint) + #[rune::function(instance)] + pub fn to_i8(value: i64) -> Option { + Some(Int8(value.try_into().ok()?)) + } + + /// Converts a Rune integer to i16 (Cassandra smallint) + #[rune::function(instance)] + pub fn to_i16(value: i64) -> Option { + Some(Int16(value.try_into().ok()?)) + } + + /// Converts a Rune integer to i32 (Cassandra int) + #[rune::function(instance)] + pub fn to_i32(value: i64) -> Option { + Some(Int32(value.try_into().ok()?)) + } + + /// Converts a Rune integer to f32 (Cassandra float) + #[rune::function(instance)] + pub fn to_f32(value: i64) -> Float32 { + Float32(value as f32) + } + + /// Converts a Rune integer to a String + #[rune::function(instance)] + pub fn to_string(value: i64) -> String { + value.to_string() + } + + /// Restricts a value to a certain interval. + #[rune::function(instance)] + pub fn clamp(value: i64, min: i64, max: i64) -> i64 { + value.clamp(min, max) + } +} + +pub mod f64 { + use crate::context::cql_types::{Float32, Int16, Int32, Int8}; + + #[rune::function(instance)] + pub fn to_i8(value: f64) -> Int8 { + Int8(value as i8) + } + + #[rune::function(instance)] + pub fn to_i16(value: f64) -> Int16 { + Int16(value as i16) + } + + #[rune::function(instance)] + pub fn to_i32(value: f64) -> Int32 { + Int32(value as i32) + } + + #[rune::function(instance)] + pub fn to_f32(value: f64) -> Float32 { + Float32(value as f32) + } + + #[rune::function(instance)] + pub fn to_string(value: f64) -> String { + value.to_string() + } + + /// Restricts a value to a certain interval unless it is NaN. + #[rune::function(instance)] + pub fn clamp(value: f64, min: f64, max: f64) -> f64 { + value.clamp(min, max) + } +} diff --git a/src/context/functions.rs b/src/context/functions.rs new file mode 100644 index 0000000..7e26e60 --- /dev/null +++ b/src/context/functions.rs @@ -0,0 +1,205 @@ +use crate::context::cass_error::CassError; +use crate::context::context::Context; +use crate::context::cql_types::{Int8, Uuid}; +use crate::context::Resources; +use chrono::Utc; +use metrohash::MetroHash64; +use rand::distributions::Distribution; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use rune::macros::{quote, MacroContext, TokenStream}; +use rune::parse::Parser; +use rune::runtime::{Mut, Ref, VmError, VmResult}; +use rune::{ast, vm_try, Value}; +use statrs::distribution::{Normal, Uniform}; +use std::collections::HashMap; +use std::fs::File; +use std::hash::{Hash, Hasher}; +use std::io; +use std::io::{BufRead, BufReader, ErrorKind, Read}; +use std::ops::Deref; + +/// Returns the literal value stored in the `params` map under the key given as the first +/// macro arg, and if not found, returns the expression from the second arg. +pub fn param( + ctx: &mut MacroContext, + params: &HashMap, + ts: &TokenStream, +) -> rune::compile::Result { + let mut parser = Parser::from_token_stream(ts, ctx.macro_span()); + let name = parser.parse::()?; + let name = ctx.resolve(name)?.to_string(); + let _ = parser.parse::()?; + let expr = parser.parse::()?; + let rhs = match params.get(&name) { + Some(value) => { + let src_id = ctx.insert_source(&name, value)?; + let value = ctx.parse_source::(src_id)?; + quote!(#value) + } + None => quote!(#expr), + }; + Ok(rhs.into_token_stream(ctx)?) +} + +/// Creates a new UUID for current iteration +#[rune::function] +pub fn uuid(i: i64) -> Uuid { + Uuid::new(i) +} + +#[rune::function] +pub fn float_to_i8(value: f64) -> Option { + Some(Int8((value as i64).try_into().ok()?)) +} + +/// Computes a hash of an integer value `i`. +/// Returns a value in range `0..i64::MAX`. +fn hash_inner(i: i64) -> i64 { + let mut hash = MetroHash64::new(); + i.hash(&mut hash); + (hash.finish() & 0x7FFFFFFFFFFFFFFF) as i64 +} + +/// Computes a hash of an integer value `i`. +/// Returns a value in range `0..i64::MAX`. +#[rune::function] +pub fn hash(i: i64) -> i64 { + hash_inner(i) +} + +/// Computes hash of two integer values. +#[rune::function] +pub fn hash2(a: i64, b: i64) -> i64 { + let mut hash = MetroHash64::new(); + a.hash(&mut hash); + b.hash(&mut hash); + (hash.finish() & 0x7FFFFFFFFFFFFFFF) as i64 +} + +/// Computes a hash of an integer value `i`. +/// Returns a value in range `0..max`. +#[rune::function] +pub fn hash_range(i: i64, max: i64) -> i64 { + hash_inner(i) % max +} + +/// Generates a floating point value with normal distribution +#[rune::function] +pub fn normal(i: i64, mean: f64, std_dev: f64) -> VmResult { + let mut rng = StdRng::seed_from_u64(i as u64); + let distribution = + vm_try!(Normal::new(mean, std_dev).map_err(|e| VmError::panic(format!("{e}")))); + VmResult::Ok(distribution.sample(&mut rng)) +} + +#[rune::function] +pub fn uniform(i: i64, min: f64, max: f64) -> VmResult { + let mut rng = StdRng::seed_from_u64(i as u64); + let distribution = vm_try!(Uniform::new(min, max).map_err(|e| VmError::panic(format!("{e}")))); + VmResult::Ok(distribution.sample(&mut rng)) +} + +/// Generates random blob of data of given length. +/// Parameter `seed` is used to seed the RNG. +#[rune::function] +pub fn blob(seed: i64, len: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(seed as u64); + (0..len).map(|_| rng.gen::()).collect() +} + +/// Generates random string of given length. +/// Parameter `seed` is used to seed +/// the RNG. +#[rune::function] +pub fn text(seed: i64, len: usize) -> String { + let mut rng = StdRng::seed_from_u64(seed as u64); + (0..len) + .map(|_| { + let code_point = rng.gen_range(0x0061u32..=0x007Au32); // Unicode range for 'a-z' + std::char::from_u32(code_point).unwrap() + }) + .collect() +} + +/// Generates 'now' timestamp +#[rune::function] +pub fn now_timestamp() -> i64 { + Utc::now().timestamp() +} + +/// Selects one item from the collection based on the hash of the given value. +#[rune::function] +pub fn hash_select(i: i64, collection: &[Value]) -> Value { + collection[(hash_inner(i) % collection.len() as i64) as usize].clone() +} + +/// Reads a file into a string. +#[rune::function] +pub fn read_to_string(filename: &str) -> io::Result { + let mut file = File::open(filename).expect("no such file"); + + let mut buffer = String::new(); + file.read_to_string(&mut buffer)?; + + Ok(buffer) +} + +/// Reads a file into a vector of lines. +#[rune::function] +pub fn read_lines(filename: &str) -> io::Result> { + let file = File::open(filename).expect("no such file"); + let buf = BufReader::new(file); + let result = buf + .lines() + .map(|l| l.expect("Could not parse line")) + .collect(); + Ok(result) +} + +/// Reads a resource file as a string. +fn read_resource_to_string_inner(path: &str) -> io::Result { + let resource = Resources::get(path).ok_or_else(|| { + io::Error::new(ErrorKind::NotFound, format!("Resource not found: {path}")) + })?; + let contents = std::str::from_utf8(resource.data.as_ref()) + .map_err(|e| io::Error::new(ErrorKind::InvalidData, format!("Invalid UTF8 string: {e}")))?; + Ok(contents.to_string()) +} + +#[rune::function] +pub fn read_resource_to_string(path: &str) -> io::Result { + read_resource_to_string_inner(path) +} + +#[rune::function] +pub fn read_resource_lines(path: &str) -> io::Result> { + Ok(read_resource_to_string_inner(path)? + .split('\n') + .map(|s| s.to_string()) + .collect()) +} + +#[rune::function(instance)] +pub async fn prepare(mut ctx: Mut, key: Ref, cql: Ref) -> Result<(), CassError> { + ctx.prepare(&key, &cql).await +} + +#[rune::function(instance)] +pub async fn execute(ctx: Ref, cql: Ref) -> Result<(), CassError> { + ctx.execute(cql.deref()).await +} + +#[rune::function(instance)] +pub async fn execute_prepared( + ctx: Ref, + key: Ref, + params: Value, +) -> Result<(), CassError> { + ctx.execute_prepared(&key, params).await +} + +#[rune::function(instance)] +pub fn elapsed_secs(ctx: &Context) -> f64 { + ctx.elapsed_secs() +} diff --git a/src/context/mod.rs b/src/context/mod.rs new file mode 100644 index 0000000..1ffafa7 --- /dev/null +++ b/src/context/mod.rs @@ -0,0 +1,80 @@ +use crate::context::cass_error::CassError; +use crate::context::context::Context; +use rune::{ContextError, Module}; +use rust_embed::RustEmbed; +use std::collections::HashMap; + +mod bind; +pub mod cass_error; +pub mod connect; +pub mod context; +mod cql_types; +mod functions; + +#[derive(RustEmbed)] +#[folder = "resources/"] +struct Resources; + +pub fn install(rune_ctx: &mut rune::Context, params: HashMap) { + try_install(rune_ctx, params).unwrap() +} + +fn try_install( + rune_ctx: &mut rune::Context, + params: HashMap, +) -> Result<(), ContextError> { + let mut context_module = Module::default(); + context_module.ty::()?; + context_module.function_meta(functions::execute)?; + context_module.function_meta(functions::prepare)?; + context_module.function_meta(functions::execute_prepared)?; + context_module.function_meta(functions::elapsed_secs)?; + + let mut err_module = Module::default(); + err_module.ty::()?; + err_module.function_meta(CassError::string_display)?; + + let mut uuid_module = Module::default(); + uuid_module.ty::()?; + uuid_module.function_meta(cql_types::Uuid::string_display)?; + + let mut latte_module = Module::with_crate("latte")?; + latte_module.macro_("param", move |ctx, ts| functions::param(ctx, ¶ms, ts))?; + + latte_module.function_meta(functions::blob)?; + latte_module.function_meta(functions::text)?; + latte_module.function_meta(functions::now_timestamp)?; + latte_module.function_meta(functions::hash)?; + latte_module.function_meta(functions::hash2)?; + latte_module.function_meta(functions::hash_range)?; + latte_module.function_meta(functions::hash_select)?; + latte_module.function_meta(functions::uuid)?; + latte_module.function_meta(functions::normal)?; + latte_module.function_meta(functions::uniform)?; + + latte_module.function_meta(cql_types::i64::to_i32)?; + latte_module.function_meta(cql_types::i64::to_i16)?; + latte_module.function_meta(cql_types::i64::to_i8)?; + latte_module.function_meta(cql_types::i64::to_f32)?; + latte_module.function_meta(cql_types::i64::clamp)?; + + latte_module.function_meta(cql_types::f64::to_i8)?; + latte_module.function_meta(cql_types::f64::to_i16)?; + latte_module.function_meta(cql_types::f64::to_i32)?; + latte_module.function_meta(cql_types::f64::to_f32)?; + latte_module.function_meta(cql_types::f64::clamp)?; + + let mut fs_module = Module::with_crate("fs")?; + fs_module.function_meta(functions::read_to_string)?; + fs_module.function_meta(functions::read_lines)?; + fs_module.function_meta(functions::read_resource_to_string)?; + fs_module.function_meta(functions::read_resource_lines)?; + + rune_ctx.install(&context_module)?; + rune_ctx.install(&err_module)?; + rune_ctx.install(&uuid_module)?; + rune_ctx.install(&latte_module)?; + rune_ctx.install(&fs_module)?; + + Ok(()) +} diff --git a/src/error.rs b/src/error.rs index 0927151..358f390 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::context::CassError; +use crate::context::cass_error::CassError; use err_derive::*; use hdrhistogram::serialization::interval_log::IntervalLogWriterError; use hdrhistogram::serialization::V2DeflateSerializeError; diff --git a/src/chunks.rs b/src/exec/chunks.rs similarity index 99% rename from src/chunks.rs rename to src/exec/chunks.rs index 9736ce7..290682b 100644 --- a/src/chunks.rs +++ b/src/exec/chunks.rs @@ -171,7 +171,7 @@ where #[cfg(test)] mod test { - use crate::chunks::{ChunksAggregated, ChunksExt}; + use crate::exec::chunks::{ChunksAggregated, ChunksExt}; use futures::{stream, FutureExt, StreamExt}; use std::time::Duration; use tokio::time::interval; diff --git a/src/cycle.rs b/src/exec/cycle.rs similarity index 98% rename from src/cycle.rs rename to src/exec/cycle.rs index 559bc83..343191b 100644 --- a/src/cycle.rs +++ b/src/exec/cycle.rs @@ -117,7 +117,7 @@ impl BoundedCycleCounter { #[cfg(test)] mod test { - use crate::cycle::{CycleCounter, BATCH_SIZE}; + use crate::exec::cycle::{CycleCounter, BATCH_SIZE}; use itertools::Itertools; use std::collections::BTreeSet; diff --git a/src/exec.rs b/src/exec/mod.rs similarity index 99% rename from src/exec.rs rename to src/exec/mod.rs index 6f3f287..0a54357 100644 --- a/src/exec.rs +++ b/src/exec/mod.rs @@ -16,11 +16,16 @@ use tokio::signal::ctrl_c; use tokio::time::MissedTickBehavior; use tokio_stream::wrappers::IntervalStream; -use crate::chunks::ChunksExt; use crate::error::{LatteError, Result}; use crate::{ BenchmarkStats, BoundedCycleCounter, Interval, Progress, Recorder, Workload, WorkloadStats, }; +use chunks::ChunksExt; + +mod chunks; +pub mod cycle; +pub mod progress; +pub mod workload; /// Returns a stream emitting `rate` events per second. fn interval_stream(rate: f64) -> IntervalStream { diff --git a/src/progress.rs b/src/exec/progress.rs similarity index 100% rename from src/progress.rs rename to src/exec/progress.rs diff --git a/src/workload.rs b/src/exec/workload.rs similarity index 84% rename from src/workload.rs rename to src/exec/workload.rs index 6f4b434..e560ae9 100644 --- a/src/workload.rs +++ b/src/exec/workload.rs @@ -7,6 +7,11 @@ use std::sync::Arc; use std::time::Duration; use std::time::Instant; +use crate::context::cass_error::{CassError, CassErrorKind}; +use crate::context::context::Context; +use crate::error::LatteError; +use crate::stats::latency::LatencyDistributionRecorder; +use crate::stats::session::SessionStats; use rand::distributions::{Distribution, WeightedIndex}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -15,14 +20,10 @@ use rune::compile::meta::Kind; use rune::compile::{CompileVisitor, MetaError, MetaRef}; use rune::runtime::{AnyObj, Args, RuntimeContext, Shared, VmError, VmResult}; use rune::termcolor::{ColorChoice, StandardStream}; -use rune::{vm_try, Any, Diagnostics, Module, Source, Sources, ToValue, Unit, Value, Vm}; +use rune::{vm_try, Any, Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; use serde::{Deserialize, Serialize}; use try_lock::TryLock; -use crate::error::LatteError; -use crate::latency::LatencyDistributionRecorder; -use crate::{context, CassError, CassErrorKind, Context, SessionStats}; - /// Wraps a reference to Session that can be converted to a Rune `Value` /// and passed as one of `Args` arguments to a function. struct SessionRef<'a> { @@ -115,69 +116,8 @@ impl Program { /// - `script`: source code in Rune language /// - `params`: parameter values that will be exposed to the script by the `params!` macro pub fn new(source: Source, params: HashMap) -> Result { - let mut context_module = Module::default(); - context_module.ty::().unwrap(); - context_module.function_meta(context::execute).unwrap(); - context_module.function_meta(context::prepare).unwrap(); - context_module - .function_meta(context::execute_prepared) - .unwrap(); - context_module.function_meta(context::elapsed_secs).unwrap(); - - let mut err_module = Module::default(); - err_module.ty::().unwrap(); - err_module.function_meta(CassError::string_display).unwrap(); - - let mut uuid_module = Module::default(); - uuid_module.ty::().unwrap(); - uuid_module - .function_meta(context::Uuid::string_display) - .unwrap(); - - let mut latte_module = Module::with_crate("latte").unwrap(); - latte_module - .macro_("param", move |ctx, ts| context::param(ctx, ¶ms, ts)) - .unwrap(); - - latte_module.function_meta(context::blob).unwrap(); - latte_module.function_meta(context::text).unwrap(); - latte_module.function_meta(context::now_timestamp).unwrap(); - latte_module.function_meta(context::hash).unwrap(); - latte_module.function_meta(context::hash2).unwrap(); - latte_module.function_meta(context::hash_range).unwrap(); - latte_module.function_meta(context::hash_select).unwrap(); - latte_module.function_meta(context::uuid).unwrap(); - latte_module.function_meta(context::normal).unwrap(); - latte_module.function_meta(context::uniform).unwrap(); - - latte_module.function_meta(context::i64::to_i32).unwrap(); - latte_module.function_meta(context::i64::to_i16).unwrap(); - latte_module.function_meta(context::i64::to_i8).unwrap(); - latte_module.function_meta(context::i64::to_f32).unwrap(); - latte_module.function_meta(context::i64::clamp).unwrap(); - - latte_module.function_meta(context::f64::to_i8).unwrap(); - latte_module.function_meta(context::f64::to_i16).unwrap(); - latte_module.function_meta(context::f64::to_i32).unwrap(); - latte_module.function_meta(context::f64::to_f32).unwrap(); - latte_module.function_meta(context::f64::clamp).unwrap(); - - let mut fs_module = Module::with_crate("fs").unwrap(); - fs_module.function_meta(context::read_to_string).unwrap(); - fs_module.function_meta(context::read_lines).unwrap(); - fs_module - .function_meta(context::read_resource_to_string) - .unwrap(); - fs_module - .function_meta(context::read_resource_lines) - .unwrap(); - let mut context = rune::Context::with_default_modules().unwrap(); - context.install(&context_module).unwrap(); - context.install(&err_module).unwrap(); - context.install(&uuid_module).unwrap(); - context.install(&latte_module).unwrap(); - context.install(&fs_module).unwrap(); + crate::context::install(&mut context, params); let mut options = rune::Options::default(); options.debug_info(true); diff --git a/src/main.rs b/src/main.rs index c5b4e34..b2c32dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,35 +27,24 @@ use crate::config::{ AppConfig, Command, ConnectionConf, EditCommand, HdrCommand, Interval, ListCommand, LoadCommand, SchemaCommand, ShowCommand, }; -use crate::context::*; -use crate::context::{CassError, CassErrorKind, Context, SessionStats}; -use crate::cycle::BoundedCycleCounter; +use crate::context::connect::ClusterInfo; +use crate::context::context::Context; use crate::error::{LatteError, Result}; use crate::exec::{par_execute, ExecutionOptions}; -use crate::plot::plot_graph; -use crate::progress::Progress; use crate::report::{PathAndSummary, Report, RunConfigCmp}; use crate::stats::{BenchmarkCmp, BenchmarkStats, Recorder}; -use crate::table::{Alignment, Table}; -use crate::workload::{FnRef, Program, Workload, WorkloadStats, LOAD_FN}; +use exec::cycle::BoundedCycleCounter; +use exec::progress::Progress; +use exec::workload::{FnRef, Program, Workload, WorkloadStats, LOAD_FN}; +use report::plot::plot_graph; +use report::table::{Alignment, Table}; -mod chunks; mod config; mod context; -mod cycle; mod error; mod exec; -mod histogram; -mod latency; -mod percentiles; -mod plot; -mod progress; mod report; mod stats; -mod table; -mod throughput; -mod timeseries; -mod workload; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -122,7 +111,7 @@ fn find_workload(workload: &Path) -> PathBuf { /// Connects to the server and returns the session async fn connect(conf: &ConnectionConf) -> Result<(Context, Option)> { eprintln!("info: Connecting to {:?}... ", conf.addresses); - let session = context::connect(conf).await?; + let session = context::connect::connect(conf).await?; let cluster_info = session.cluster_info().await?; eprintln!( "info: Connected to {} running Cassandra version {}", diff --git a/src/report.rs b/src/report/mod.rs similarity index 99% rename from src/report.rs rename to src/report/mod.rs index 48c13f5..a698d73 100644 --- a/src/report.rs +++ b/src/report/mod.rs @@ -1,7 +1,6 @@ use crate::config::{RunCommand, WeightedFunction}; -use crate::percentiles::Percentile; +use crate::stats::percentiles::Percentile; use crate::stats::{BenchmarkCmp, BenchmarkStats, Mean, Sample, Significance}; -use crate::table::Row; use chrono::{DateTime, Local, TimeZone}; use console::{pad_str, style, Alignment}; use core::fmt; @@ -15,6 +14,10 @@ use std::num::NonZeroUsize; use std::path::{Path, PathBuf}; use std::{fs, io}; use strum::IntoEnumIterator; +use table::Row; + +pub mod plot; +pub mod table; /// A standard error is multiplied by this factor to get the error margin. /// For a normally distributed random variable, diff --git a/src/plot.rs b/src/report/plot.rs similarity index 99% rename from src/plot.rs rename to src/report/plot.rs index 9252480..e35bf12 100644 --- a/src/plot.rs +++ b/src/report/plot.rs @@ -1,6 +1,6 @@ use crate::config::PlotCommand; use crate::load_report_or_abort; -use crate::plot::SeriesKind::{ResponseTime, Throughput}; +use crate::report::plot::SeriesKind::{ResponseTime, Throughput}; use crate::report::Report; use crate::Result; use itertools::Itertools; diff --git a/src/table.rs b/src/report/table.rs similarity index 98% rename from src/table.rs rename to src/report/table.rs index ad5e6a9..a073fd5 100644 --- a/src/table.rs +++ b/src/report/table.rs @@ -121,7 +121,7 @@ impl Display for Table { #[cfg(test)] mod test { - use crate::table::{Alignment, Row, Table}; + use crate::report::table::{Alignment, Row, Table}; #[test] fn render_table() { diff --git a/src/histogram.rs b/src/stats/histogram.rs similarity index 100% rename from src/histogram.rs rename to src/stats/histogram.rs diff --git a/src/latency.rs b/src/stats/latency.rs similarity index 94% rename from src/latency.rs rename to src/stats/latency.rs index fce1ce3..fcbd9e4 100644 --- a/src/latency.rs +++ b/src/stats/latency.rs @@ -1,7 +1,7 @@ -use crate::histogram::SerializableHistogram; -use crate::percentiles::Percentiles; +use crate::stats::histogram::SerializableHistogram; +use crate::stats::percentiles::Percentiles; +use crate::stats::timeseries::TimeSeriesStats; use crate::stats::Mean; -use crate::timeseries::TimeSeriesStats; use hdrhistogram::Histogram; use serde::{Deserialize, Serialize}; use std::time::Duration; diff --git a/src/stats.rs b/src/stats/mod.rs similarity index 98% rename from src/stats.rs rename to src/stats/mod.rs index 2c13716..f5d6a34 100644 --- a/src/stats.rs +++ b/src/stats/mod.rs @@ -4,14 +4,21 @@ use std::num::NonZeroUsize; use std::ops::Mul; use std::time::{Instant, SystemTime}; -use crate::latency::{LatencyDistribution, LatencyDistributionRecorder}; -use crate::percentiles::Percentile; -use crate::throughput::ThroughputMeter; -use crate::timeseries::TimeSeriesStats; -use crate::workload::WorkloadStats; +use crate::exec::workload::WorkloadStats; +use crate::stats::latency::{LatencyDistribution, LatencyDistributionRecorder}; use cpu_time::ProcessTime; +use percentiles::Percentile; use serde::{Deserialize, Serialize}; use statrs::distribution::{ContinuousCDF, StudentsT}; +use throughput::ThroughputMeter; +use timeseries::TimeSeriesStats; + +pub mod histogram; +pub mod latency; +pub mod percentiles; +pub mod session; +pub mod throughput; +pub mod timeseries; /// Holds a mean and its error together. /// Makes it more convenient to compare means, and it also reduces the number diff --git a/src/percentiles.rs b/src/stats/percentiles.rs similarity index 99% rename from src/percentiles.rs rename to src/stats/percentiles.rs index 571da30..35306a8 100644 --- a/src/percentiles.rs +++ b/src/stats/percentiles.rs @@ -156,7 +156,7 @@ fn percentiles(hist: &Histogram, scale: f64) -> [f64; Percentile::COUNT] { #[cfg(test)] mod test { - use crate::percentiles::{Percentile, Percentiles}; + use crate::stats::percentiles::{Percentile, Percentiles}; use assert_approx_eq::assert_approx_eq; use hdrhistogram::Histogram; use rand::{thread_rng, Rng}; diff --git a/src/stats/session.rs b/src/stats/session.rs new file mode 100644 index 0000000..be0d626 --- /dev/null +++ b/src/stats/session.rs @@ -0,0 +1,81 @@ +use crate::stats::latency::LatencyDistributionRecorder; +use scylla::transport::errors::QueryError; +use scylla::QueryResult; +use std::collections::HashSet; +use std::time::Duration; +use tokio::time::Instant; + +#[derive(Clone, Debug)] +pub struct SessionStats { + pub req_count: u64, + pub req_errors: HashSet, + pub req_error_count: u64, + pub req_retry_count: u64, + pub row_count: u64, + pub queue_length: u64, + pub mean_queue_length: f32, + pub resp_times_ns: LatencyDistributionRecorder, +} + +impl SessionStats { + pub fn new() -> SessionStats { + Default::default() + } + + pub fn start_request(&mut self) -> Instant { + if self.req_count > 0 { + self.mean_queue_length += + (self.queue_length as f32 - self.mean_queue_length) / self.req_count as f32; + } + self.queue_length += 1; + Instant::now() + } + + pub fn complete_request( + &mut self, + duration: Duration, + rs: &Result, + retries: u64, + ) { + self.queue_length -= 1; + self.resp_times_ns.record(duration); + self.req_count += 1; + self.req_retry_count += retries; + match rs { + Ok(rs) => self.row_count += rs.rows.as_ref().map(|r| r.len()).unwrap_or(0) as u64, + Err(e) => { + self.req_error_count += 1; + self.req_errors.insert(format!("{e}")); + } + } + } + + /// Resets all accumulators + pub fn reset(&mut self) { + self.req_error_count = 0; + self.row_count = 0; + self.req_count = 0; + self.req_retry_count = 0; + self.mean_queue_length = 0.0; + self.req_errors.clear(); + self.resp_times_ns.clear(); + + // note that current queue_length is *not* reset to zero because there + // might be pending requests and if we set it to zero, that would underflow + } +} + +impl Default for SessionStats { + fn default() -> Self { + SessionStats { + req_count: 0, + req_errors: HashSet::new(), + req_error_count: 0, + req_retry_count: 0, + row_count: 0, + queue_length: 0, + mean_queue_length: 0.0, + resp_times_ns: LatencyDistributionRecorder::default(), + } + } +} diff --git a/src/throughput.rs b/src/stats/throughput.rs similarity index 94% rename from src/throughput.rs rename to src/stats/throughput.rs index e819679..928eba7 100644 --- a/src/throughput.rs +++ b/src/stats/throughput.rs @@ -1,5 +1,5 @@ +use crate::stats::timeseries::TimeSeriesStats; use crate::stats::Mean; -use crate::timeseries::TimeSeriesStats; use std::time::Instant; pub struct ThroughputMeter { diff --git a/src/timeseries.rs b/src/stats/timeseries.rs similarity index 99% rename from src/timeseries.rs rename to src/stats/timeseries.rs index 2abbddf..4f481b2 100644 --- a/src/timeseries.rs +++ b/src/stats/timeseries.rs @@ -214,7 +214,7 @@ impl Stats { #[cfg(test)] mod test { - use crate::timeseries::{Stats, TimeSeriesStats}; + use crate::stats::timeseries::{Stats, TimeSeriesStats}; use assert_approx_eq::assert_approx_eq; use more_asserts::{assert_gt, assert_le}; use rand::rngs::SmallRng;