Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cassandra 5.0 vector type CREATE/INSERT support #1020

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cassandra.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
run: cargo build --verbose --tests --features "full-serialization"
- name: Run tests on cassandra
run: |
CDC='disabled' RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
CDC='disabled' RUSTFLAGS="--cfg cassandra_tests" RUST_LOG=trace SCYLLA_URI=172.42.0.2:9042 SCYLLA_URI2=172.42.0.3:9042 SCYLLA_URI3=172.42.0.4:9042 cargo test --verbose --features "full-serialization" -- --skip test_views_in_schema_info --skip test_large_batch_statements
- name: Stop the cluster
if: ${{ always() }}
run: docker compose -f test/cluster/cassandra/docker-compose.yml stop
Expand Down
2 changes: 1 addition & 1 deletion scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ harness = false
[lints.rust]
unnameable_types = "warn"
unreachable_pub = "warn"
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)'] }
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(scylla_cloud_tests)', 'cfg(cassandra_tests)'] }
105 changes: 105 additions & 0 deletions scylla/src/transport/session_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3165,3 +3165,108 @@ async fn test_api_migration_session_sharing() {
assert!(matched);
}
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_metadata() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

session.refresh_metadata().await.unwrap();
let metadata = session.get_cluster_data();
let columns = &metadata.keyspaces[&ks].tables["t"].columns;
assert_eq!(
columns["b"].type_,
CqlType::Vector {
type_: Box::new(CqlType::Native(NativeType::Int)),
dimensions: 4,
},
);
assert_eq!(
columns["c"].type_,
CqlType::Vector {
type_: Box::new(CqlType::Native(NativeType::Text)),
dimensions: 2,
},
);
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_unprepared() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

session
.query_unpaged(
format!(
"INSERT INTO {}.t (a, b, c) VALUES (1, [1, 2, 3, 4], ['foo', 'bar'])",
ks
),
&[],
)
.await
.unwrap();

// TODO: Implement and test SELECT statements and bind values (`?`)
}

#[cfg(cassandra_tests)]
#[tokio::test]
async fn test_vector_type_prepared() {
setup_tracing();
let session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();

session.query_unpaged(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks), &[]).await.unwrap();
session
.query_unpaged(
format!(
"CREATE TABLE IF NOT EXISTS {}.t (a int PRIMARY KEY, b vector<int, 4>, c vector<text, 2>)",
ks
),
&[],
)
.await
.unwrap();

let prepared_statement = session
.prepare(format!(
"INSERT INTO {}.t (a, b, c) VALUES (?, [11, 12, 13, 14], ['afoo', 'abar'])",
ks
))
.await
.unwrap();
session
.execute_unpaged(&prepared_statement, &(2,))
.await
.unwrap();

// TODO: Implement and test SELECT statements and bind values (`?`)
}
wprzytula marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions scylla/src/transport/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ enum PreCqlType {
type_: PreCollectionType,
},
Tuple(Vec<PreCqlType>),
Vector {
type_: Box<PreCqlType>,
/// matches the datatype used by the java driver:
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
dimensions: i32,
},
UserDefinedType {
frozen: bool,
name: String,
Expand All @@ -211,6 +217,10 @@ impl PreCqlType {
.map(|t| t.into_cql_type(keyspace_name, udts))
.collect(),
),
PreCqlType::Vector { type_, dimensions } => CqlType::Vector {
type_: Box::new(type_.into_cql_type(keyspace_name, udts)),
dimensions,
},
PreCqlType::UserDefinedType { frozen, name } => {
let definition = match udts
.get(keyspace_name)
Expand All @@ -236,6 +246,12 @@ pub enum CqlType {
type_: CollectionType,
},
Tuple(Vec<CqlType>),
Vector {
type_: Box<CqlType>,
/// matches the datatype used by the java driver:
/// <https://github.com/apache/cassandra-java-driver/blob/85bb4065098b887d2dda26eb14423ce4fc687045/core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java#L77>
dimensions: i32,
},
UserDefinedType {
frozen: bool,
// Using Arc here in order not to have many copies of the same definition
Expand Down Expand Up @@ -1137,6 +1153,7 @@ fn topo_sort_udts(udts: &mut Vec<UdtRowWithParsedFieldTypes>) -> Result<(), Quer
PreCqlType::Tuple(types) => types
.iter()
.for_each(|type_| do_with_referenced_udts(what, type_)),
PreCqlType::Vector { type_, .. } => do_with_referenced_udts(what, type_),
PreCqlType::UserDefinedType { name, .. } => what(name),
}
}
Expand Down Expand Up @@ -1637,6 +1654,22 @@ fn parse_cql_type(p: ParserState<'_>) -> ParseResult<(PreCqlType, ParserState<'_
})?;

Ok((PreCqlType::Tuple(types), p))
} else if let Ok(p) = p.accept("vector<") {
let (inner_type, p) = parse_cql_type(p)?;

let p = p.skip_white();
let p = p.accept(",")?;
let p = p.skip_white();
let (size, p) = p.parse_i32()?;
let p = p.skip_white();
let p = p.accept(">")?;

let typ = PreCqlType::Vector {
type_: Box::new(inner_type),
dimensions: size,
};

Ok((typ, p))
} else if let Ok((typ, p)) = parse_native_type(p) {
Ok((PreCqlType::Native(typ), p))
} else if let Ok((name, p)) = parse_user_defined_type(p) {
Expand Down Expand Up @@ -1827,6 +1860,20 @@ mod tests {
PreCqlType::Native(NativeType::Varint),
]),
),
(
"vector<int, 5>",
PreCqlType::Vector {
type_: Box::new(PreCqlType::Native(NativeType::Int)),
dimensions: 5,
},
),
(
"vector<text, 1234>",
PreCqlType::Vector {
type_: Box::new(PreCqlType::Native(NativeType::Text)),
dimensions: 1234,
},
),
(
"com.scylladb.types.AwesomeType",
PreCqlType::UserDefinedType {
Expand Down
15 changes: 15 additions & 0 deletions scylla/src/utils/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ impl<'s> ParserState<'s> {
me
}

/// Parses a sequence of digits and '-' as an integer.
/// Consumes characters until it finds a character that is not a digit or '-'.
///
/// An error is returned if:
/// * The first character is not a digit or '-'
/// * The integer is larger than i32
pub(crate) fn parse_i32(self) -> ParseResult<(i32, Self)> {
let (digits, p) = self.take_while(|c| c.is_ascii_digit() || c == '-');
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has a minor flaw of accepting strings such as 00-21-37, but such strings will be rejected below anyway, so no problem here.

if let Ok(value) = digits.parse() {
Ok((value, p))
} else {
Err(p.error(ParseErrorCause::Other("Expected 32-bit signed integer")))
}
}

/// Skips characters from the beginning while they satisfy given predicate
/// and returns new parser state which
pub(crate) fn take_while(self, mut pred: impl FnMut(char) -> bool) -> (&'s str, Self) {
Expand Down
Loading