diff --git a/Cargo.lock b/Cargo.lock index 4e882da95e7..7e3cde827c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,9 +75,9 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] name = "arrow" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +checksum = "7ae9728f104939be6d8d9b368a354b4929b0569160ea1641f0721b55a861ce38" dependencies = [ "arrow-arith", "arrow-array", @@ -96,9 +96,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "753abd0a5290c1bcade7c6623a556f7d1659c5f4148b140b5b63ce7bd1a45705" +checksum = "a7029a5b3efbeafbf4a12d12dc16b8f9e9bff20a410b8c25c5d28acc089e1043" dependencies = [ "arrow-array", "arrow-buffer", @@ -111,9 +111,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" +checksum = "d33238427c60271710695f17742f45b1a5dc5bcfc5c15331c25ddfe7abf70d97" dependencies = [ "ahash", "arrow-buffer", @@ -127,9 +127,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" +checksum = "fe9b95e825ae838efaf77e366c00d3fc8cca78134c9db497d6bda425f2e7b7c1" dependencies = [ "bytes", "half 2.4.1", @@ -138,27 +138,29 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e448e5dd2f4113bf5b74a1f26531708f5edcacc77335b7066f9398f4bcf4cdef" +checksum = "87cf8385a9d5b5fcde771661dd07652b79b9139fea66193eda6a88664400ccab" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "base64 0.21.7", + "atoi", + "base64 0.22.1", "chrono", "half 2.4.1", "lexical-core", "num", + "ryu", ] [[package]] name = "arrow-csv" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46af72211f0712612f5b18325530b9ad1bfbdc87290d5fbfd32a7da128983781" +checksum = "cea5068bef430a86690059665e40034625ec323ffa4dd21972048eebb0127adc" dependencies = [ "arrow-array", "arrow-buffer", @@ -175,9 +177,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" +checksum = "cb29be98f987bcf217b070512bb7afba2f65180858bca462edf4a39d84a23e10" dependencies = [ "arrow-buffer", "arrow-schema", @@ -187,9 +189,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03dea5e79b48de6c2e04f03f62b0afea7105be7b77d134f6c5414868feefb80d" +checksum = "ffc68f6523970aa6f7ce1dc9a33a7d9284cfb9af77d4ad3e617dbe5d79cc6ec8" dependencies = [ "arrow-array", "arrow-buffer", @@ -201,9 +203,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8950719280397a47d37ac01492e3506a8a724b3fb81001900b866637a829ee0f" +checksum = "2041380f94bd6437ab648e6c2085a045e45a0c44f91a1b9a4fe3fed3d379bfb1" dependencies = [ "arrow-array", "arrow-buffer", @@ -221,9 +223,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed9630979034077982d8e74a942b7ac228f33dd93a93b615b4d02ad60c260be" +checksum = "fcb56ed1547004e12203652f12fe12e824161ff9d1e5cf2a7dc4ff02ba94f413" dependencies = [ "arrow-array", "arrow-buffer", @@ -236,9 +238,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007035e17ae09c4e8993e4cb8b5b96edf0afb927cd38e2dff27189b274d83dcf" +checksum = "575b42f1fc588f2da6977b94a5ca565459f5ab07b60545e17243fb9a7ed6d43e" dependencies = [ "ahash", "arrow-array", @@ -251,15 +253,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" +checksum = "32aae6a60458a2389c0da89c9de0b7932427776127da1a738e2efc21d32f3393" [[package]] name = "arrow-select" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce20973c1912de6514348e064829e50947e35977bb9d7fb637dc99ea9ffd78c" +checksum = "de36abaef8767b4220d7b4a8c2fe5ffc78b47db81b03d77e2136091c3ba39102" dependencies = [ "ahash", "arrow-array", @@ -271,15 +273,16 @@ dependencies = [ [[package]] name = "arrow-string" -version = "50.0.0" +version = "52.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +checksum = "e435ada8409bcafc910bc3e0077f532a4daa20e99060a496685c0e3e53cc2597" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "memchr", "num", "regex", "regex-syntax 0.8.2", @@ -324,6 +327,15 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atomic" version = "0.6.0" @@ -807,6 +819,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "base64-simd" version = "0.8.0" @@ -1366,9 +1384,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flatbuffers" -version = "23.5.26" +version = "24.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -4659,6 +4677,7 @@ dependencies = [ "cc", "criterion", "figment", + "flatbuffers", "futures", "k8s-openapi", "kube", diff --git a/rust/worker/Cargo.toml b/rust/worker/Cargo.toml index f1862a1e7ff..ece7c9c1646 100644 --- a/rust/worker/Cargo.toml +++ b/rust/worker/Cargo.toml @@ -41,7 +41,7 @@ parking_lot = "0.12.1" aws-sdk-s3 = "1.5.0" aws-smithy-types = "1.1.0" aws-config = { version = "1.1.2", features = ["behavior-version-latest"] } -arrow = "50.0.0" +arrow = "52.0.0" roaring = "0.10.3" tantivy = "0.21.1" tracing = "0.1" @@ -55,6 +55,7 @@ opentelemetry = { version = "0.19.0", default-features = false, features = [ opentelemetry-otlp = "0.12.0" shuttle = "0.7.1" regex = "1.10.5" +flatbuffers = "24.3.25" [dev-dependencies] proptest = "1.4.0" diff --git a/rust/worker/src/blockstore/arrow/block/delta.rs b/rust/worker/src/blockstore/arrow/block/delta.rs index 2df5fe2629d..f1061b313de 100644 --- a/rust/worker/src/blockstore/arrow/block/delta.rs +++ b/rust/worker/src/blockstore/arrow/block/delta.rs @@ -203,10 +203,26 @@ mod test { use roaring::RoaringBitmap; use std::collections::HashMap; + /// Saves a block to a random file under the given path, then loads the block + /// and validates that the loaded block has the same size as the original block. + /// ### Returns + /// - The loaded block + /// ### Notes + /// - Assumes that path will be cleaned up by the caller + fn test_save_load_size(path: &str, block: &Block) -> Block { + let save_path = format!("{}/{}", path, random::()); + block.save(&save_path).unwrap(); + let loaded = Block::load_with_validation(&save_path, block.id).unwrap(); + assert_eq!(loaded.id, block.id); + assert_eq!(block.get_size(), loaded.get_size()); + loaded + } + #[tokio::test] async fn test_sizing_int_arr_val() { let tmp_dir = tempfile::tempdir().unwrap(); - let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let path = tmp_dir.path().to_str().unwrap(); + let storage = Storage::Local(LocalStorage::new(path)); let block_manager = BlockManager::new(storage); let delta = block_manager.create::<&str, &Int32Array>(); @@ -227,12 +243,16 @@ mod test { // Semantically, that makes sense, since a delta is unsuable after commit block_manager.commit::<&str, &Int32Array>(&delta); let block = block_manager.get(&delta.id).await.unwrap(); + // Ensure the deltas estimated size matches the actual size of the block assert_eq!(size, block.get_size()); + + test_save_load_size(path, &block); } #[tokio::test] async fn test_sizing_string_val() { let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().to_str().unwrap(); let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); let block_manager = BlockManager::new(storage); let delta = block_manager.create::<&str, &str>(); @@ -256,11 +276,7 @@ mod test { } // test save/load - block.save("test.arrow").unwrap(); - let loaded = Block::load("test.arrow", delta_id).unwrap(); - assert_eq!(loaded.id, delta_id); - // TODO: make this sizing work - // assert_eq!(block.get_size(), loaded.get_size()); + let loaded = test_save_load_size(path, &block); for i in 0..n { let key = format!("key{}", i); let read = loaded.get::<&str, &str>("prefix", &key); @@ -282,7 +298,8 @@ mod test { #[tokio::test] async fn test_sizing_float_key() { let tmp_dir = tempfile::tempdir().unwrap(); - let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let path = tmp_dir.path().to_str().unwrap(); + let storage = Storage::Local(LocalStorage::new(path)); let block_manager = BlockManager::new(storage); let delta = block_manager.create::(); @@ -298,12 +315,16 @@ mod test { block_manager.commit::(&delta); let block = block_manager.get(&delta.id).await.unwrap(); assert_eq!(size, block.get_size()); + + // test save/load + test_save_load_size(path, &block); } #[tokio::test] async fn test_sizing_roaring_bitmap_val() { let tmp_dir = tempfile::tempdir().unwrap(); - let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let path = tmp_dir.path().to_str().unwrap(); + let storage = Storage::Local(LocalStorage::new(path)); let block_manager = BlockManager::new(storage); let delta = block_manager.create::<&str, &RoaringBitmap>(); @@ -326,12 +347,16 @@ mod test { let expected = RoaringBitmap::from_iter((0..i).map(|x| x as u32)); assert_eq!(read, Some(expected)); } + + // test save/load + test_save_load_size(path, &block); } #[tokio::test] async fn test_data_record() { let tmp_dir = tempfile::tempdir().unwrap(); - let storage = Storage::Local(LocalStorage::new(tmp_dir.path().to_str().unwrap())); + let path = tmp_dir.path().to_str().unwrap(); + let storage = Storage::Local(LocalStorage::new(path)); let block_manager = BlockManager::new(storage); let ids = vec!["embedding_id_2", "embedding_id_0", "embedding_id_1"]; let embeddings = vec![ @@ -383,23 +408,33 @@ mod test { assert_eq!(read.document, documents[i]); } assert_eq!(size, block.get_size()); + + // test save/load + test_save_load_size(path, &block); } - // #[test] - // fn test_sizing_uint_key_val() { - // let block_provider = ArrowBlockProvider::new(); - // let block = block_provider.create_block(KeyType::Uint, ValueType::Uint); - // let delta = BlockDelta::from(block.clone()); - - // let n = 2000; - // for i in 0..n { - // let key = BlockfileKey::new("prefix".to_string(), Key::Uint(i as u32)); - // let value = Value::UintValue(i as u32); - // delta.add(key, value); - // } - - // let size = delta.get_size(); - // let block_data = BlockData::try_from(&delta).unwrap(); - // assert_eq!(size, block_data.get_size()); - // } + #[tokio::test] + async fn test_sizing_uint_key_val() { + let tmp_dir = tempfile::tempdir().unwrap(); + let path = tmp_dir.path().to_str().unwrap(); + let storage = Storage::Local(LocalStorage::new(path)); + let block_manager = BlockManager::new(storage); + let delta = block_manager.create::(); + + let n = 2000; + for i in 0..n { + let prefix = "prefix"; + let key = i as u32; + let value = format!("value{}", i); + delta.add(prefix, key, value.as_str()); + } + + let size = delta.get_size::(); + block_manager.commit::(&delta); + let block = block_manager.get(&delta.id).await.unwrap(); + assert_eq!(size, block.get_size()); + + // test save/load + test_save_load_size(path, &block); + } } diff --git a/rust/worker/src/blockstore/arrow/block/types.rs b/rust/worker/src/blockstore/arrow/block/types.rs index 702f1525220..c87b34471a4 100644 --- a/rust/worker/src/blockstore/arrow/block/types.rs +++ b/rust/worker/src/blockstore/arrow/block/types.rs @@ -1,12 +1,21 @@ use super::delta::BlockDelta; use crate::blockstore::arrow::types::{ArrowReadableKey, ArrowReadableValue}; -use crate::errors::ChromaError; +use crate::errors::{ChromaError, ErrorCodes}; +use arrow::array::ArrayData; +use arrow::buffer::Buffer; +use arrow::ipc::reader::read_footer_length; +use arrow::ipc::{root_as_footer, root_as_message, MessageHeader, MetadataVersion}; +use arrow::util::bit_util; use arrow::{ array::{Array, StringArray}, record_batch::RecordBatch, }; +use std::io::SeekFrom; +use thiserror::Error; use uuid::Uuid; +const ARROW_ALIGNMENT: usize = 64; + /// A block in a blockfile. A block is a sorted collection of data that is immutable once it has been committed. /// Blocks are the fundamental unit of storage in the blockstore and are used to store data in the form of (key, value) pairs. /// These pairs are stored in an Arrow record batch with the schema (prefix, key, value). @@ -27,10 +36,12 @@ pub struct Block { } impl Block { + /// Create a concrete block from an id and the underlying record batch of data pub fn from_record_batch(id: Uuid, data: RecordBatch) -> Self { Self { id, data } } + /// Converts the block to a block delta for writing to a new block pub fn to_block_delta<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, mut delta: BlockDelta, @@ -51,6 +62,13 @@ impl Block { delta } + /* + ===== Block Queries ===== + */ + + /// Get the value for a given key in the block + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -72,6 +90,9 @@ impl Block { None } + /// Get all the values for a given prefix in the block + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_prefix<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -96,6 +117,9 @@ impl Block { return Some(res); } + /// Get all the values for a given prefix in the block where the key is greater than the given key + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_gt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -118,6 +142,9 @@ impl Block { return Some(res); } + /// Get all the values for a given prefix in the block where the key is less than the given key + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_lt<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -140,6 +167,9 @@ impl Block { return Some(res); } + /// Get all the values for a given prefix in the block where the key is less than or equal to the given key + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_lte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -162,6 +192,9 @@ impl Block { return Some(res); } + /// Get all the values for a given prefix in the block where the key is greater than or equal to the given key + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_gte<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, prefix: &str, @@ -184,6 +217,12 @@ impl Block { return Some(res); } + /// Get all the values for a given prefix in the block where the key is between the given keys + /// ### Notes + /// - Returns a tuple of (prefix, key, value) + /// - Returns None if the requested index is out of bounds + /// ### Panics + /// - If the underlying data types are not the same as the types specified in the function signature pub fn get_at_index<'me, K: ArrowReadableKey<'me>, V: ArrowReadableValue<'me>>( &'me self, index: usize, @@ -203,13 +242,18 @@ impl Block { Some((prefix, key, value)) } + /* + ===== Block Metadata ===== + */ + /// Returns the size of the block in bytes pub(crate) fn get_size(&self) -> usize { let mut total_size = 0; for column in self.data.columns() { - total_size += column.get_buffer_memory_size(); + let array_data = column.to_data(); + total_size += get_size_of_array_data(&array_data); } - total_size + return total_size; } /// Returns the number of items in the block @@ -217,182 +261,423 @@ impl Block { self.data.num_rows() } - pub fn save(&self, path: &str) -> Result<(), Box> { - let file = std::fs::File::create(path); - let mut file = match file { + /* + ===== Block Serialization ===== + */ + + /// Save the block in Arrow IPC format to the given path + pub fn save(&self, path: &str) -> Result<(), BlockSaveError> { + let file = match std::fs::File::create(path) { Ok(file) => file, Err(e) => { - // TODO: Return a proper error - panic!("Error creating file: {:?}", e) + return Err(BlockSaveError::IOError(e)); } }; + + // We force the block to be written with 64 byte alignment + // this is the default, but we are just being defensive let mut writer = std::io::BufWriter::new(file); - let writer = arrow::ipc::writer::FileWriter::try_new(&mut writer, &self.data.schema()); + let options = match arrow::ipc::writer::IpcWriteOptions::try_new( + ARROW_ALIGNMENT, + false, + MetadataVersion::V5, + ) { + Ok(options) => options, + Err(e) => { + return Err(BlockSaveError::ArrowError(e)); + } + }; + + let writer = arrow::ipc::writer::FileWriter::try_new_with_options( + &mut writer, + &self.data.schema(), + options, + ); let mut writer = match writer { Ok(writer) => writer, Err(e) => { - // TODO: Return a proper error - panic!("Error creating writer: {:?}", e) + return Err(BlockSaveError::ArrowError(e)); } }; match writer.write(&self.data) { Ok(_) => match writer.finish() { Ok(_) => return Ok(()), Err(e) => { - panic!("Error finishing writer: {:?}", e); + return Err(BlockSaveError::ArrowError(e)); } }, Err(e) => { - panic!("Error writing data: {:?}", e); + return Err(BlockSaveError::ArrowError(e)); } } } - pub fn to_bytes(&self) -> Vec { + /// Convert the block to bytes in Arrow IPC format + pub fn to_bytes(&self) -> Result, BlockToBytesError> { let mut bytes = Vec::new(); // Scope the writer so that it is dropped before we return the bytes { let mut writer = - arrow::ipc::writer::FileWriter::try_new(&mut bytes, &self.data.schema()) - .expect("Error creating writer"); - writer.write(&self.data).expect("Error writing data"); - writer.finish().expect("Error finishing writer"); + match arrow::ipc::writer::FileWriter::try_new(&mut bytes, &self.data.schema()) { + Ok(writer) => writer, + Err(e) => { + return Err(BlockToBytesError::ArrowError(e)); + } + }; + match writer.write(&self.data) { + Ok(_) => {} + Err(e) => { + return Err(BlockToBytesError::ArrowError(e)); + } + } + match writer.finish() { + Ok(_) => {} + Err(e) => { + return Err(BlockToBytesError::ArrowError(e)); + } + } } - bytes + Ok(bytes) + } + + /// Load a block from bytes in Arrow IPC format with the given id + pub fn from_bytes(bytes: &[u8], id: Uuid) -> Result { + return Self::from_bytes_internal(bytes, id, false); } - pub fn from_bytes(bytes: &[u8], id: Uuid) -> Result> { + /// Load a block from bytes in Arrow IPC format with the given id and validate the layout + /// ### Notes + /// - This method should be used in tests to ensure that the layout of the IPC file is as expected + /// - The validation is not performant and should not be used in production code + pub fn from_bytes_with_validation(bytes: &[u8], id: Uuid) -> Result { + return Self::from_bytes_internal(bytes, id, true); + } + + fn from_bytes_internal(bytes: &[u8], id: Uuid, validate: bool) -> Result { let cursor = std::io::Cursor::new(bytes); - let mut reader = - arrow::ipc::reader::FileReader::try_new(cursor, None).expect("Error creating reader"); - return Self::load_with_reader(reader, id); + return Self::load_with_reader(cursor, id, validate); } - pub fn load(path: &str, id: Uuid) -> Result> { + /// Load a block from the given path with the given id and validate the layout + /// ### Notes + /// - This method should be used in tests to ensure that the layout of the IPC file is as expected + /// - The validation is not performant and should not be used in production code + pub fn load_with_validation(path: &str, id: Uuid) -> Result { + return Self::load_internal(path, id, true); + } + + /// Load a block from the given path with the given id + pub fn load(path: &str, id: Uuid) -> Result { + return Self::load_internal(path, id, false); + } + + fn load_internal(path: &str, id: Uuid, validate: bool) -> Result { let file = std::fs::File::open(path); let file = match file { Ok(file) => file, Err(e) => { - // TODO: Return a proper error - panic!("Error opening file: {:?}", e) - } - }; - let mut reader = std::io::BufReader::new(file); - let reader = arrow::ipc::reader::FileReader::try_new(&mut reader, None); - let mut reader = match reader { - Ok(reader) => reader, - Err(e) => { - // TODO: Return a proper error - panic!("Error creating reader: {:?}", e) + return Err(BlockLoadError::IOError(e)); } }; - return Self::load_with_reader(reader, id); + let reader = std::io::BufReader::new(file); + return Self::load_with_reader(reader, id, validate); } - fn load_with_reader( - mut reader: arrow::ipc::reader::FileReader, - id: Uuid, - ) -> Result> + fn load_with_reader(mut reader: R, id: Uuid, validate: bool) -> Result where R: std::io::Read + std::io::Seek, { - let batch = reader.next().unwrap(); - // TODO: how to store / hydrate id? - match batch { - Ok(batch) => Ok(Self::from_record_batch(id, batch)), + if validate { + let res = verify_buffers_layout(&mut reader); + match res { + Ok(_) => {} + Err(e) => { + return Err(BlockLoadError::ArrowLayoutVerificationError(e)); + } + } + } + + let mut arrow_reader = match arrow::ipc::reader::FileReader::try_new(&mut reader, None) { + Ok(arrow_reader) => arrow_reader, Err(e) => { - panic!("Error reading batch: {:?}", e); + return Err(BlockLoadError::ArrowError(e)); } + }; + + let batch = match arrow_reader.next() { + Some(Ok(batch)) => batch, + Some(Err(e)) => { + return Err(BlockLoadError::ArrowError(e)); + } + None => { + return Err(BlockLoadError::NoRecordBatches); + } + }; + + // TODO: how to store / hydrate id? + Ok(Self::from_record_batch(id, batch)) + } +} + +fn get_size_of_array_data(array_data: &ArrayData) -> usize { + let mut total_size = 0; + for buffer in array_data.buffers() { + // SYSTEM ASSUMPTION: ALL BUFFERS ARE PADDED TO 64 bytes + // We maintain this invariant in three places + // 1. In the to_arrow methods of delta storage, we allocate + // padded buffers + // 2. In calls to load() in tests we validate that the buffers are of size 64 + // 3. In writing to the IPC block file we use an option ensure 64 byte alignment + // which makes the arrow writer add padding to the buffers + // Why do we do this instead of using get_buffer_memory_size() + // or using the buffers capacity? + // The reason is that arrow can dramatically overreport the size of buffers + // if the underlying buffers are shared. If we use something like get_buffer_memory_size() + // or capacity. This is because the buffer may be shared with other arrays. + // In the case of Arrow IPC data, all the data is one buffer + // so get_buffer_memory_size() would overreport the size of the buffer + // by the number of columns and also by the number of validity, and offset buffers. + // This is why we use the buffer.len() method which gives us the actual size of the buffer + // however len() excludes the capacity of the buffer which is why we round up to the nearest + // multiple of 64 bytes. We ensure, both when we construct the buffer and when we write it to disk + // that the buffer is also block.len() + padding of 64 bytes exactly. + // (As an added note, arrow throws away explicit knowledge of this padding, + // see verify_buffers_layout() for how we infer the padding based on + // the offsets of each buffer) + let size = bit_util::round_upto_multiple_of_64(buffer.len()); + total_size += size; + } + // List and Struct arrays have child arrays + for child in array_data.child_data() { + total_size += get_size_of_array_data(child); + } + // Some data types (like our data record) have null buffers + if let Some(buffer) = array_data.nulls() { + let size = bit_util::round_upto_multiple_of_64(buffer.len()); + total_size += size; + } + return total_size; +} + +/* +===== ErrorTypes ===== +*/ + +#[derive(Error, Debug)] +pub enum BlockSaveError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + ArrowError(#[from] arrow::error::ArrowError), +} + +impl ChromaError for BlockSaveError { + fn code(&self) -> ErrorCodes { + match self { + BlockSaveError::IOError(_) => ErrorCodes::Internal, + BlockSaveError::ArrowError(_) => ErrorCodes::Internal, } } } -// #[derive(Error, Debug)] -// pub enum FinishError { -// #[error("Arrow error")] -// ArrowError(#[from] arrow::error::ArrowError), -// } - -// impl ChromaError for FinishError { -// fn code(&self) -> ErrorCodes { -// match self { -// FinishError::ArrowError(_) => ErrorCodes::Internal, -// } -// } -// } - -// // #[cfg(test)] -// // mod test { -// // use super::*; -// // use crate::blockstore::types::Key; -// // use arrow::array::Int32Array; - -// // #[test] -// // fn test_block_builder_can_add() { -// // let num_entries = 1000; - -// // let mut keys = Vec::new(); -// // let mut key_bytes = 0; -// // for i in 0..num_entries { -// // keys.push(Key::String(format!("{:04}", i))); -// // key_bytes += i.to_string().len(); -// // } - -// // let prefix = "key".to_string(); -// // let prefix_bytes = prefix.len() * num_entries; -// // let mut block_builder = BlockDataBuilder::new( -// // KeyType::String, -// // ValueType::Int32Array, -// // Some(BlockBuilderOptions::new( -// // num_entries, -// // prefix_bytes, -// // key_bytes, -// // num_entries, // 2 int32s per entry -// // num_entries * 2 * 4, // 2 int32s per entry -// // )), -// // ); - -// // for i in 0..num_entries { -// // block_builder -// // .add( -// // BlockfileKey::new(prefix.clone(), keys[i].clone()), -// // Value::Int32ArrayValue(Int32Array::from(vec![i as i32, (i + 1) as i32])), -// // ) -// // .unwrap(); -// // } - -// // // Basic sanity check -// // let block_data = block_builder.build().unwrap(); -// // assert_eq!(block_data.data.column(0).len(), num_entries); -// // assert_eq!(block_data.data.column(1).len(), num_entries); -// // assert_eq!(block_data.data.column(2).len(), num_entries); -// // } - -// // #[test] -// // fn test_out_of_order_key_fails() { -// // let mut block_builder = BlockDataBuilder::new( -// // KeyType::String, -// // ValueType::Int32Array, -// // Some(BlockBuilderOptions::default()), -// // ); - -// // block_builder -// // .add( -// // BlockfileKey::new("key".to_string(), Key::String("b".to_string())), -// // Value::Int32ArrayValue(Int32Array::from(vec![1, 2])), -// // ) -// // .unwrap(); - -// // let result = block_builder.add( -// // BlockfileKey::new("key".to_string(), Key::String("a".to_string())), -// // Value::Int32ArrayValue(Int32Array::from(vec![1, 2])), -// // ); - -// // match result { -// // Ok(_) => panic!("Expected error"), -// // Err(e) => { -// // assert_eq!(e.code(), ErrorCodes::InvalidArgument); -// // } -// // } -// // } -// // } +#[derive(Error, Debug)] +pub enum BlockToBytesError { + #[error(transparent)] + ArrowError(#[from] arrow::error::ArrowError), +} + +impl ChromaError for BlockToBytesError { + fn code(&self) -> ErrorCodes { + match self { + BlockToBytesError::ArrowError(_) => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum BlockLoadError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + ArrowError(#[from] arrow::error::ArrowError), + #[error(transparent)] + ArrowLayoutVerificationError(#[from] ArrowLayoutVerificationError), + #[error("No record batches in IPC file")] + NoRecordBatches, +} + +/* +===== Layout Verification ===== +*/ + +#[derive(Error, Debug)] +pub enum ArrowLayoutVerificationError { + #[error("Buffer length is not 64 byte aligned")] + BufferLengthNotAligned, + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + ArrowError(#[from] arrow::error::ArrowError), + #[error(transparent)] + InvalidFlatbuffer(#[from] flatbuffers::InvalidFlatbuffer), + #[error("No record batches in footer")] + NoRecordBatches, + #[error("More than one record batch in IPC file")] + MultipleRecordBatches, + #[error("Invalid message type")] + InvalidMessageType, + #[error("Error decoding record batch message as record batch")] + RecordBatchDecodeError, +} + +impl ChromaError for ArrowLayoutVerificationError { + fn code(&self) -> ErrorCodes { + match self { + // All errors are internal for this error type + _ => ErrorCodes::Internal, + } + } +} + +/// Verifies that the buffers in the IPC file are 64 byte aligned +/// and stored in Arrow in the way we expect. +/// All non-benchmark test code should use this by loading the block +/// with verification enabled. +fn verify_buffers_layout(mut reader: R) -> Result<(), ArrowLayoutVerificationError> +where + R: std::io::Read + std::io::Seek, +{ + // Read the IPC file and verify that the buffers are 64 byte aligned + // by inspecting the offsets, this is required since our + // size calculation assumes that the buffers are 64 byte aligned + // Space for ARROW_MAGIC (6 bytes) and length (4 bytes) + let mut footer_buffer = [0; 10]; + match reader.seek(SeekFrom::End(-10)) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + + match reader.read_exact(&mut footer_buffer) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + + let footer_len = read_footer_length(footer_buffer); + let footer_len = match footer_len { + Ok(footer_len) => footer_len, + Err(e) => { + return Err(ArrowLayoutVerificationError::ArrowError(e)); + } + }; + + // read footer + let mut footer_data = vec![0; footer_len]; + match reader.seek(SeekFrom::End(-10 - footer_len as i64)) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + match reader.read_exact(&mut footer_data) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + + let footer = match root_as_footer(&footer_data) { + Ok(footer) => footer, + Err(e) => { + return Err(ArrowLayoutVerificationError::InvalidFlatbuffer(e)); + } + }; + + // Read the record batch + let record_batch_definitions = match footer.recordBatches() { + Some(record_batch_definitions) => record_batch_definitions, + None => { + return Err(ArrowLayoutVerificationError::NoRecordBatches); + } + }; + + // Ensure there is only ONE record batch, which is how we store data + if record_batch_definitions.len() != 1 { + return Err(ArrowLayoutVerificationError::MultipleRecordBatches); + } + + let record_batch_definition = record_batch_definitions.get(0); + let record_batch_len = record_batch_definition.bodyLength() as usize + + record_batch_definition.metaDataLength() as usize; + let record_batch_body_len = record_batch_definition.bodyLength() as usize; + + // Read the actual record batch + let mut file_buffer = vec![0; record_batch_len]; + match reader.seek(SeekFrom::Start(record_batch_definition.offset() as u64)) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + match reader.read_exact(&mut file_buffer) { + Ok(_) => {} + Err(e) => { + return Err(ArrowLayoutVerificationError::IOError(e)); + } + } + let buffer = Buffer::from(file_buffer); + + // This is borrowed from arrow-ipc parse_message.rs + // https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format + let buf = match buffer[..4] == [0xff; 4] { + true => &buffer[8..], + false => &buffer[4..], + }; + let message = match root_as_message(buf) { + Ok(message) => message, + Err(e) => { + return Err(ArrowLayoutVerificationError::InvalidFlatbuffer(e)); + } + }; + + match message.header_type() { + MessageHeader::RecordBatch => { + let record_batch = match message.header_as_record_batch() { + Some(record_batch) => record_batch, + None => { + return Err(ArrowLayoutVerificationError::RecordBatchDecodeError); + } + }; + // Loop over offsets and ensure the lengths of each buffer are 64 byte aligned + let blocks = match record_batch.buffers() { + Some(blocks) => blocks, + None => { + return Err(ArrowLayoutVerificationError::RecordBatchDecodeError); + } + }; + + let mut prev_offset = blocks.get(0).offset(); + for block in blocks.iter().skip(1) { + let curr_offset = block.offset(); + let len = (curr_offset - prev_offset) as usize; + if len % ARROW_ALIGNMENT != 0 { + return Err(ArrowLayoutVerificationError::BufferLengthNotAligned); + } + prev_offset = curr_offset; + } + // Check the remaining buffer length based on the body length + let last_buffer_len = record_batch_body_len - prev_offset as usize; + if last_buffer_len % ARROW_ALIGNMENT != 0 { + return Err(ArrowLayoutVerificationError::BufferLengthNotAligned); + } + } + _ => { + return Err(ArrowLayoutVerificationError::InvalidMessageType); + } + } + + Ok(()) +} diff --git a/rust/worker/src/blockstore/arrow/provider.rs b/rust/worker/src/blockstore/arrow/provider.rs index 2ee8e951688..373fd728eb6 100644 --- a/rust/worker/src/blockstore/arrow/provider.rs +++ b/rust/worker/src/blockstore/arrow/provider.rs @@ -200,7 +200,13 @@ impl BlockManager { match block { Some(block) => { - let bytes = block.to_bytes(); + let bytes = match block.to_bytes() { + Ok(bytes) => bytes, + Err(e) => { + return Err(Box::new(e)); + } + }; + let key = format!("block/{}", id); let res = self.storage.put_bytes(&key, bytes).await; match res { @@ -334,7 +340,13 @@ impl SparseIndexManager { let as_block = index.to_block::(); match as_block { Ok(block) => { - let bytes = block.to_bytes(); + let bytes = match block.to_bytes() { + Ok(bytes) => bytes, + Err(e) => { + return Err(Box::new(e)); + } + }; + let key = format!("sparse_index/{}", id); let res = self.storage.put_bytes(&key, bytes).await; match res {