diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index fa9476fa..c99be91e 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -989,12 +989,14 @@ name = "scylla-cpp-driver-rust" version = "0.0.1" dependencies = [ "assert_matches", + "bigdecimal", "bindgen", "chrono", "lazy_static", "libc", "machine-uid", "ntest", + "num-bigint", "num-derive", "num-traits", "openssl", diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index c630237c..f941e0d6 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -18,6 +18,8 @@ machine-uid = "0.2.0" rand = "0.8.4" num-traits = "0.2" num-derive = "0.3" +bigdecimal = "0.2.0" +num-bigint = "0.3" libc = "0.2.108" openssl-sys = "0.9.75" openssl = "0.10.32" diff --git a/scylla-rust-wrapper/src/binding.rs b/scylla-rust-wrapper/src/binding.rs index baca9d6e..f2c7c627 100644 --- a/scylla-rust-wrapper/src/binding.rs +++ b/scylla-rust-wrapper/src/binding.rs @@ -139,8 +139,6 @@ macro_rules! make_appender { // TODO: Types for which binding is not implemented yet: // custom - Not implemented in Rust driver? -// decimal -// duration - DURATION not implemented in Rust Driver macro_rules! invoke_binder_maker_macro_with_type { (null, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { @@ -275,6 +273,32 @@ macro_rules! invoke_binder_maker_macro_with_type { [v @ crate::inet::CassInet] ); }; + (decimal, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { + $macro_name!( + $this, + $consume_v, + $fn, + |v, v_size, scale| { + // The value is copied, the memory pointed to by this parameter can be freed after this call. + let val = std::slice::from_raw_parts(v as *const u8, v_size as usize).to_vec(); + let int_value = num_bigint::BigInt::from_signed_bytes_be(val.as_slice()); + Ok(Some(Decimal(bigdecimal::BigDecimal::from((int_value, scale as i64))))) + }, + [v @ *const cass_byte_t, v_size @ size_t, scale @ cass_int32_t] + ); + }; + (duration, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { + $macro_name!( + $this, + $consume_v, + $fn, + |m, d, n| { + use scylla::frame::value::CqlDuration; + Ok(Some(Duration(CqlDuration {months: m, days: d, nanoseconds: n,}))) + }, + [m @ cass_int32_t, d @ cass_int32_t, n @ cass_int64_t] + ); + }; (collection, $macro_name:ident, $this:ty, $consume_v:expr, $fn:ident) => { $macro_name!( $this, diff --git a/scylla-rust-wrapper/src/cass_types.rs b/scylla-rust-wrapper/src/cass_types.rs index 0c7d6ffc..94e3e92d 100644 --- a/scylla-rust-wrapper/src/cass_types.rs +++ b/scylla-rust-wrapper/src/cass_types.rs @@ -285,6 +285,7 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { ColumnType::Blob => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_BLOB), ColumnType::Counter => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_COUNTER), ColumnType::Decimal => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DECIMAL), + ColumnType::Duration => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DURATION), ColumnType::Date => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DATE), ColumnType::Double => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_DOUBLE), ColumnType::Float => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_FLOAT), @@ -326,7 +327,6 @@ pub fn get_column_type(column_type: &ColumnType) -> CassDataType { ), ColumnType::Uuid => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_UUID), ColumnType::Varint => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_VARINT), - _ => CassDataType::Value(CassValueType::CASS_VALUE_TYPE_UNKNOWN), } } diff --git a/scylla-rust-wrapper/src/collection.rs b/scylla-rust-wrapper/src/collection.rs index dc3a45c6..2c440a91 100644 --- a/scylla-rust-wrapper/src/collection.rs +++ b/scylla-rust-wrapper/src/collection.rs @@ -87,6 +87,8 @@ make_binders!(string_n, cass_collection_append_string_n); make_binders!(bytes, cass_collection_append_bytes); make_binders!(uuid, cass_collection_append_uuid); make_binders!(inet, cass_collection_append_inet); +make_binders!(decimal, cass_collection_append_decimal); +make_binders!(duration, cass_collection_append_duration); make_binders!(collection, cass_collection_append_collection); make_binders!(tuple, cass_collection_append_tuple); make_binders!(user_type, cass_collection_append_user_type); diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index 6c2a8f2c..30153f65 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -12,7 +12,7 @@ use num_traits::Zero; use scylla::frame::frame_errors::ParseError; use scylla::frame::response::result::{ColumnSpec, ColumnType}; use scylla::frame::types; -use scylla::frame::value::Date; +use scylla::frame::value::{CqlDuration, Date}; use scylla::types::deserialize::row::ColumnIterator; use scylla::types::deserialize::value::{ DeserializeCql, MapIterator, SequenceIterator, UdtIterator, @@ -1233,7 +1233,40 @@ cass_value_get_numeric_type!( ); // other numeric types -// TODO: add decimal +#[no_mangle] +pub unsafe extern "C" fn cass_value_get_decimal( + value: *const CassValue, + varint: *mut *const cass_byte_t, + varint_size: *mut size_t, + scale: *mut cass_int32_t, +) -> CassError { + if !cass_value_is_null(value).is_zero() { + return CassError::CASS_ERROR_LIB_NULL_VALUE; + } + + match cass_value_type(value) { + CassValueType::CASS_VALUE_TYPE_DECIMAL => {} + _ => return CassError::CASS_ERROR_LIB_INVALID_VALUE_TYPE, + } + + let cass_value: &CassValue = ptr_to_ref(value); + if let Some(frame) = cass_value.frame_slice { + let mut val = frame.as_slice(); + let scale_res = types::read_int(&mut val); + + if let Ok(s) = scale_res { + let decimal_len = val.len(); + + *scale = s; + *varint_size = decimal_len as size_t; + *varint = val.as_ptr(); + + return CassError::CASS_OK; + } + } + + CassError::CASS_ERROR_LIB_NOT_ENOUGH_DATA +} // string cass_value_get_strict_type!( @@ -1269,7 +1302,26 @@ cass_value_get_strict_type!( ); // date and time types -// TODO: add duration +cass_value_get_strict_type!( + cass_value_get_duration, + CqlDuration, + cass_int32_t, + CassValueType::CASS_VALUE_TYPE_DURATION, + ColumnType::Duration, + |_value: *const CassValue, + months: *mut cass_int32_t, + days: *mut cass_int32_t, + nanos: *mut cass_int64_t, + val: CqlDuration| { + std::ptr::write(months, val.months); + std::ptr::write(days, val.days); + std::ptr::write(nanos, val.nanoseconds); + CassError::CASS_OK + }, + days: *mut cass_int32_t, // additional arguments + nanos: *mut cass_int64_t +); + cass_value_get_strict_type!( cass_value_get_uint32, Date, diff --git a/scylla-rust-wrapper/src/statement.rs b/scylla-rust-wrapper/src/statement.rs index a017f35e..6fe324d6 100644 --- a/scylla-rust-wrapper/src/statement.rs +++ b/scylla-rust-wrapper/src/statement.rs @@ -433,6 +433,18 @@ make_binders!( cass_statement_bind_inet_by_name, cass_statement_bind_inet_by_name_n ); +make_binders!( + decimal, + cass_statement_bind_decimal, + cass_statement_bind_decimal_by_name, + cass_statement_bind_decimal_by_name_n +); +make_binders!( + duration, + cass_statement_bind_duration, + cass_statement_bind_duration_by_name, + cass_statement_bind_duration_by_name_n +); make_binders!( collection, cass_statement_bind_collection, diff --git a/scylla-rust-wrapper/src/tuple.rs b/scylla-rust-wrapper/src/tuple.rs index 6f582284..528c5722 100644 --- a/scylla-rust-wrapper/src/tuple.rs +++ b/scylla-rust-wrapper/src/tuple.rs @@ -107,6 +107,8 @@ make_binders!(string_n, cass_tuple_set_string_n); make_binders!(bytes, cass_tuple_set_bytes); make_binders!(uuid, cass_tuple_set_uuid); make_binders!(inet, cass_tuple_set_inet); +make_binders!(decimal, cass_tuple_set_decimal); +make_binders!(duration, cass_tuple_set_duration); make_binders!(collection, cass_tuple_set_collection); make_binders!(tuple, cass_tuple_set_tuple); make_binders!(user_type, cass_tuple_set_user_type); diff --git a/scylla-rust-wrapper/src/user_type.rs b/scylla-rust-wrapper/src/user_type.rs index 5313445a..f2c32da2 100644 --- a/scylla-rust-wrapper/src/user_type.rs +++ b/scylla-rust-wrapper/src/user_type.rs @@ -182,6 +182,18 @@ make_binders!( cass_user_type_set_inet_by_name, cass_user_type_set_inet_by_name_n ); +make_binders!( + decimal, + cass_user_type_set_decimal, + cass_user_type_set_decimal_by_name, + cass_user_type_set_decimal_by_name_n +); +make_binders!( + duration, + cass_user_type_set_duration, + cass_user_type_set_duration_by_name, + cass_user_type_set_duration_by_name_n +); make_binders!( collection, cass_user_type_set_collection, diff --git a/src/testing_unimplemented.cpp b/src/testing_unimplemented.cpp index f66cc759..edea95fb 100644 --- a/src/testing_unimplemented.cpp +++ b/src/testing_unimplemented.cpp @@ -162,20 +162,6 @@ cass_collection_append_custom(CassCollection* collection, size_t value_size){ throw std::runtime_error("UNIMPLEMENTED cass_collection_append_custom\n"); } -CASS_EXPORT CassError -cass_collection_append_decimal(CassCollection* collection, - const cass_byte_t* varint, - size_t varint_size, - cass_int32_t scale){ - throw std::runtime_error("UNIMPLEMENTED cass_collection_append_decimal\n"); -} -CASS_EXPORT CassError -cass_collection_append_duration(CassCollection* collection, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_collection_append_duration\n"); -} CASS_EXPORT const CassValue* cass_column_meta_field_by_name(const CassColumnMeta* column_meta, const char* name){ @@ -357,38 +343,6 @@ cass_statement_bind_custom_by_name(CassStatement* statement, throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_custom_by_name\n"); } CASS_EXPORT CassError -cass_statement_bind_decimal(CassStatement* statement, - size_t index, - const cass_byte_t* varint, - size_t varint_size, - cass_int32_t scale){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_decimal\n"); -} -CASS_EXPORT CassError -cass_statement_bind_decimal_by_name(CassStatement* statement, - const char* name, - const cass_byte_t* varint, - size_t varint_size, - cass_int32_t scale){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_decimal_by_name\n"); -} -CASS_EXPORT CassError -cass_statement_bind_duration(CassStatement* statement, - size_t index, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_duration\n"); -} -CASS_EXPORT CassError -cass_statement_bind_duration_by_name(CassStatement* statement, - const char* name, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_statement_bind_duration_by_name\n"); -} -CASS_EXPORT CassError cass_statement_set_custom_payload(CassStatement* statement, const CassCustomPayload* payload){ throw std::runtime_error("UNIMPLEMENTED cass_statement_set_custom_payload\n"); @@ -447,22 +401,6 @@ cass_tuple_set_custom(CassTuple* tuple, throw std::runtime_error("UNIMPLEMENTED cass_tuple_set_custom\n"); } CASS_EXPORT CassError -cass_tuple_set_decimal(CassTuple* tuple, - size_t index, - const cass_byte_t* varint, - size_t varint_size, - cass_int32_t scale){ - throw std::runtime_error("UNIMPLEMENTED cass_tuple_set_decimal\n"); -} -CASS_EXPORT CassError -cass_tuple_set_duration(CassTuple* tuple, - size_t index, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_tuple_set_duration\n"); -} -CASS_EXPORT CassError cass_user_type_set_custom(CassUserType* user_type, size_t index, const char* class_name, @@ -477,34 +415,4 @@ cass_user_type_set_custom_by_name(CassUserType* user_type, const cass_byte_t* value, size_t value_size){ throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_custom_by_name\n"); -} -CASS_EXPORT CassError -cass_user_type_set_decimal_by_name(CassUserType* user_type, - const char* name, - const cass_byte_t* varint, - size_t varint_size, - int scale){ - throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_decimal_by_name\n"); -} -CASS_EXPORT CassError -cass_user_type_set_duration_by_name(CassUserType* user_type, - const char* name, - cass_int32_t months, - cass_int32_t days, - cass_int64_t nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_user_type_set_duration_by_name\n"); -} -CASS_EXPORT CassError -cass_value_get_decimal(const CassValue* value, - const cass_byte_t** varint, - size_t* varint_size, - cass_int32_t* scale){ - throw std::runtime_error("UNIMPLEMENTED cass_value_get_decimal\n"); -} -CASS_EXPORT CassError -cass_value_get_duration(const CassValue* value, - cass_int32_t* months, - cass_int32_t* days, - cass_int64_t* nanos){ - throw std::runtime_error("UNIMPLEMENTED cass_value_get_duration\n"); } \ No newline at end of file