From 0b5f332ee27c08b22c6e6e5af8dddbb64244d127 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Mon, 27 Jan 2025 16:44:17 -0800 Subject: [PATCH] [ENH]: add Rust SQLite impl of log --- Cargo.lock | 454 +++++++++++++++++- Cargo.toml | 2 + rust/log/Cargo.toml | 6 +- rust/log/src/grpc_log.rs | 87 +++- rust/log/src/in_memory_log.rs | 17 +- rust/log/src/lib.rs | 1 + rust/log/src/log.rs | 70 ++- rust/log/src/sqlite_log.rs | 448 +++++++++++++++++ rust/log/src/types.rs | 74 +-- rust/types/src/metadata.rs | 87 +++- rust/types/src/scalar_encoding.rs | 21 + .../src/execution/operators/fetch_log.rs | 4 +- .../src/execution/operators/register.rs | 4 +- 13 files changed, 1150 insertions(+), 125 deletions(-) create mode 100644 rust/log/src/sqlite_log.rs diff --git a/Cargo.lock b/Cargo.lock index 9a3cec370ca7..1a3169fab03a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1016,6 +1016,9 @@ name = "bitflags" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +dependencies = [ + "serde", +] [[package]] name = "bitpacking" @@ -1064,9 +1067,9 @@ checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytemuck" -version = "1.20.0" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" [[package]] name = "byteorder" @@ -1332,14 +1335,19 @@ name = "chroma-log" version = "0.1.0" dependencies = [ "async-trait", + "bytemuck", "chroma-config", "chroma-error", "chroma-segment", "chroma-types", + "futures", "opentelemetry", "rand", "serde", + "serde_json", + "sqlx", "thiserror 1.0.69", + "tokio", "tonic", "tracing", "tracing-opentelemetry", @@ -1668,6 +1676,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32c" version = "0.6.8" @@ -1752,6 +1775,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1907,6 +1939,17 @@ dependencies = [ "zeroize", ] +[[package]] +name = "der" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +dependencies = [ + "const-oid", + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.3.11" @@ -1946,6 +1989,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -1982,6 +2026,12 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "downcast-rs" version = "1.2.1" @@ -2000,10 +2050,10 @@ version = "0.14.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "413301934810f597c1d19ca71c8710e99a3f1ba28a0d2ebc01551a2daeea3c5c" dependencies = [ - "der", + "der 0.6.1", "elliptic-curve", "rfc6979", - "signature", + "signature 1.6.4", ] [[package]] @@ -2011,6 +2061,9 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +dependencies = [ + "serde", +] [[package]] name = "elliptic-curve" @@ -2020,12 +2073,12 @@ checksum = "e7bb888ab5300a19b8e5bceef25ac745ad065f3c9f7efc6de1b91958110891d3" dependencies = [ "base16ct", "crypto-bigint 0.4.9", - "der", + "der 0.6.1", "digest", "ff", "generic-array", "group", - "pkcs8", + "pkcs8 0.9.0", "rand_core", "sec1", "subtle", @@ -2063,6 +2116,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "5.3.1" @@ -2439,6 +2503,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -2690,6 +2765,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.2", +] + [[package]] name = "heck" version = "0.4.1" @@ -2720,6 +2804,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -3424,6 +3517,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "levenshtein_automata" @@ -3517,6 +3613,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -3899,6 +4006,23 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -4292,6 +4416,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -4385,14 +4518,35 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der 0.7.9", + "pkcs8 0.10.2", + "spki 0.7.3", +] + [[package]] name = "pkcs8" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba" dependencies = [ - "der", - "spki", + "der 0.6.1", + "spki 0.6.0", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der 0.7.9", + "spki 0.7.3", ] [[package]] @@ -5041,6 +5195,26 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rsa" +version = "0.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47c75d7c5c6b673e58bf54d8544a9f432e3a925b0e80f7cd3602ab5c50c55519" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8 0.10.2", + "rand_core", + "signature 2.2.0", + "spki 0.7.3", + "subtle", + "zeroize", +] + [[package]] name = "rtrb" version = "0.3.1" @@ -5312,9 +5486,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3be24c1842290c45df0a7bf069e0c268a747ad05a192f2fd7dcfdbc1cba40928" dependencies = [ "base16ct", - "der", + "der 0.6.1", "generic-array", - "pkcs8", + "pkcs8 0.9.0", "subtle", "zeroize", ] @@ -5568,6 +5742,16 @@ dependencies = [ "rand_core", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + [[package]] name = "siphasher" version = "1.0.1" @@ -5597,6 +5781,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "snafu" @@ -5645,7 +5832,204 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67cf02bbac7a337dc36e4f5a693db6c21e7863f45070f7064577eb4367a3212b" dependencies = [ "base64ct", - "der", + "der 0.6.1", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der 0.7.9", +] + +[[package]] +name = "sqlx" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4410e73b3c0d8442c5f99b425d7a435b5ee0ae4167b3196771dd3f7a01be745f" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a007b6936676aa9ab40207cde35daab0a04b823be8ae004368c0793b96a61e0" +dependencies = [ + "bytes", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.15.2", + "hashlink", + "indexmap 2.6.0", + "log", + "memchr", + "once_cell", + "percent-encoding", + "serde", + "serde_json", + "sha2", + "smallvec", + "thiserror 2.0.4", + "tokio", + "tokio-stream", + "tracing", + "url", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3112e2ad78643fef903618d78cf0aec1cb3134b019730edb039b69eaf531f310" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.89", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e9f90acc5ab146a99bf5061a7eb4976b573f560bc898ef3bf8435448dd5e7ad" +dependencies = [ + "dotenvy", + "either", + "heck 0.5.0", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 2.0.89", + "tempfile", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "bytes", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.4", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613" +dependencies = [ + "atoi", + "base64 0.22.1", + "bitflags 2.6.0", + "byteorder", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.4", + "tracing", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f85ca71d3a5b24e64e1d08dd8fe36c6c95c339a896cc33068148906784620540" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "serde_urlencoded", + "sqlx-core", + "tracing", + "url", ] [[package]] @@ -5660,6 +6044,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.10.0" @@ -6487,12 +6882,33 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" + [[package]] name = "unicode-width" version = "0.1.14" @@ -6645,6 +7061,12 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.92" @@ -6762,6 +7184,16 @@ dependencies = [ "rustix", ] +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall", + "wasite", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 2ca0fcdcf512..f3c806914588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,8 @@ tracing-bunyan-formatter = "0.3" tracing-opentelemetry = "0.28.0" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.11.0", features = ["v4", "fast-rng", "macro-diagnostics", "serde"] } +sqlx = { version = "0.8.3", features = ["runtime-tokio", "sqlite"] } +bytemuck = "1.21.0" chroma-benchmark = { path = "rust/benchmark" } chroma-blockstore = { path = "rust/blockstore" } diff --git a/rust/log/Cargo.toml b/rust/log/Cargo.toml index 1c428f87054e..3c0ad6c30771 100644 --- a/rust/log/Cargo.toml +++ b/rust/log/Cargo.toml @@ -15,9 +15,13 @@ tracing = { workspace = true } # Used by tracing tracing-opentelemetry = { workspace = true } uuid = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +serde_json = { workspace = true } +bytemuck = { workspace = true } +futures = { workspace = true } chroma-config = { workspace = true } chroma-error = { workspace = true } chroma-segment = { workspace = true } chroma-types = { workspace = true } - diff --git a/rust/log/src/grpc_log.rs b/rust/log/src/grpc_log.rs index b87e51aa11b0..87347002f3b1 100644 --- a/rust/log/src/grpc_log.rs +++ b/rust/log/src/grpc_log.rs @@ -1,9 +1,6 @@ use super::config::LogConfig; use crate::tracing::client_interceptor; -use crate::types::{ - CollectionInfo, GetCollectionsWithNewDataError, PullLogsError, UpdateCollectionLogOffsetError, -}; -use crate::PushLogsError; +use crate::types::CollectionInfo; use async_trait::async_trait; use chroma_config::Configurable; use chroma_error::{ChromaError, ErrorCodes}; @@ -18,6 +15,72 @@ use tonic::transport::Endpoint; use tonic::{Request, Status}; use uuid::Uuid; +#[derive(Error, Debug)] +pub enum GrpcPullLogsError { + #[error("Failed to fetch")] + FailedToPullLogs(#[from] tonic::Status), + #[error("Failed to convert proto embedding record into EmbeddingRecord")] + ConversionError(#[from] RecordConversionError), +} + +impl ChromaError for GrpcPullLogsError { + fn code(&self) -> ErrorCodes { + match self { + GrpcPullLogsError::FailedToPullLogs(_) => ErrorCodes::Internal, + GrpcPullLogsError::ConversionError(_) => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum GrpcPushLogsError { + #[error("Failed to push logs")] + FailedToPushLogs(#[from] tonic::Status), + #[error("Failed to convert records to proto")] + ConversionError(#[from] RecordConversionError), +} + +impl ChromaError for GrpcPushLogsError { + fn code(&self) -> ErrorCodes { + match self { + GrpcPushLogsError::FailedToPushLogs(_) => ErrorCodes::Internal, + GrpcPushLogsError::ConversionError(_) => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum GrpcGetCollectionsWithNewDataError { + #[error("Failed to fetch")] + FailedGetCollectionsWithNewData(#[from] tonic::Status), +} + +impl ChromaError for GrpcGetCollectionsWithNewDataError { + fn code(&self) -> ErrorCodes { + match self { + GrpcGetCollectionsWithNewDataError::FailedGetCollectionsWithNewData(_) => { + ErrorCodes::Internal + } + } + } +} + +#[derive(Error, Debug)] +pub enum GrpcUpdateCollectionLogOffsetError { + #[error("Failed to update collection log offset")] + FailedToUpdateCollectionLogOffset(#[from] tonic::Status), +} + +impl ChromaError for GrpcUpdateCollectionLogOffsetError { + fn code(&self) -> ErrorCodes { + match self { + GrpcUpdateCollectionLogOffsetError::FailedToUpdateCollectionLogOffset(_) => { + ErrorCodes::Internal + } + } + } +} + #[derive(Clone, Debug)] pub struct GrpcLog { #[allow(clippy::type_complexity)] @@ -100,7 +163,7 @@ impl GrpcLog { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result, PullLogsError> { + ) -> Result, GrpcPullLogsError> { let end_timestamp = match end_timestamp { Some(end_timestamp) => end_timestamp, None => i64::MAX, @@ -124,7 +187,7 @@ impl GrpcLog { result.push(log_record); } Err(err) => { - return Err(PullLogsError::ConversionError(err)); + return Err(GrpcPullLogsError::ConversionError(err)); } } } @@ -132,7 +195,7 @@ impl GrpcLog { } Err(e) => { tracing::error!("Failed to pull logs: {}", e); - Err(PullLogsError::FailedToPullLogs(e)) + Err(GrpcPullLogsError::FailedToPullLogs(e)) } } } @@ -141,7 +204,7 @@ impl GrpcLog { &mut self, collection_id: CollectionUuid, records: Vec, - ) -> Result<(), PushLogsError> { + ) -> Result<(), GrpcPushLogsError> { let request = chroma_proto::PushLogsRequest { collection_id: collection_id.0.to_string(), @@ -160,7 +223,7 @@ impl GrpcLog { pub(super) async fn get_collections_with_new_data( &mut self, min_compaction_size: u64, - ) -> Result, GetCollectionsWithNewDataError> { + ) -> Result, GrpcGetCollectionsWithNewDataError> { let response = self .client .get_all_collection_info_to_compact( @@ -196,7 +259,7 @@ impl GrpcLog { } Err(e) => { tracing::error!("Failed to get collections: {}", e); - Err(GetCollectionsWithNewDataError::FailedGetCollectionsWithNewData(e)) + Err(GrpcGetCollectionsWithNewDataError::FailedGetCollectionsWithNewData(e)) } } } @@ -205,7 +268,7 @@ impl GrpcLog { &mut self, collection_id: CollectionUuid, new_offset: i64, - ) -> Result<(), UpdateCollectionLogOffsetError> { + ) -> Result<(), GrpcUpdateCollectionLogOffsetError> { let request = self.client.update_collection_log_offset( chroma_proto::UpdateCollectionLogOffsetRequest { // NOTE(rescrv): Use the untyped string representation of the collection ID. @@ -216,7 +279,7 @@ impl GrpcLog { let response = request.await; match response { Ok(_) => Ok(()), - Err(e) => Err(UpdateCollectionLogOffsetError::FailedToUpdateCollectionLogOffset(e)), + Err(e) => Err(GrpcUpdateCollectionLogOffsetError::FailedToUpdateCollectionLogOffset(e)), } } } diff --git a/rust/log/src/in_memory_log.rs b/rust/log/src/in_memory_log.rs index 9c32e2774b59..ea27bac60b9a 100644 --- a/rust/log/src/in_memory_log.rs +++ b/rust/log/src/in_memory_log.rs @@ -1,6 +1,4 @@ -use crate::types::{ - CollectionInfo, GetCollectionsWithNewDataError, PullLogsError, UpdateCollectionLogOffsetError, -}; +use crate::types::CollectionInfo; use chroma_types::{CollectionUuid, LogRecord}; use std::collections::HashMap; use std::fmt::Debug; @@ -63,7 +61,7 @@ impl InMemoryLog { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result, PullLogsError> { + ) -> Vec { let end_timestamp = match end_timestamp { Some(end_timestamp) => end_timestamp, None => i64::MAX, @@ -71,7 +69,7 @@ impl InMemoryLog { let logs = match self.collection_to_log.get(&collection_id) { Some(logs) => logs, - None => return Ok(Vec::new()), + None => return Vec::new(), }; let mut result = Vec::new(); for i in offset..(offset + batch_size as i64) { @@ -79,13 +77,13 @@ impl InMemoryLog { result.push(logs[i as usize].record.clone()); } } - Ok(result) + result } pub(super) async fn get_collections_with_new_data( &mut self, min_compaction_size: u64, - ) -> Result, GetCollectionsWithNewDataError> { + ) -> Vec { let mut collections = Vec::new(); for (collection_id, log_records) in self.collection_to_log.iter() { if log_records.is_empty() { @@ -115,16 +113,15 @@ impl InMemoryLog { first_log_ts: logs[0].log_ts, }); } - Ok(collections) + collections } pub(super) async fn update_collection_log_offset( &mut self, collection_id: CollectionUuid, new_offset: i64, - ) -> Result<(), UpdateCollectionLogOffsetError> { + ) { self.offsets.insert(collection_id, new_offset); - Ok(()) } } diff --git a/rust/log/src/lib.rs b/rust/log/src/lib.rs index 77cf2621c399..a7ea4621beac 100644 --- a/rust/log/src/lib.rs +++ b/rust/log/src/lib.rs @@ -3,6 +3,7 @@ pub mod grpc_log; pub mod in_memory_log; #[allow(clippy::module_inception)] mod log; +pub mod sqlite_log; pub mod test; pub mod tracing; pub mod types; diff --git a/rust/log/src/log.rs b/rust/log/src/log.rs index 6b5cdf04d4fd..e2e6a8a07f64 100644 --- a/rust/log/src/log.rs +++ b/rust/log/src/log.rs @@ -1,9 +1,8 @@ use crate::grpc_log::GrpcLog; use crate::in_memory_log::InMemoryLog; -use crate::types::{ - CollectionInfo, GetCollectionsWithNewDataError, PullLogsError, UpdateCollectionLogOffsetError, -}; -use crate::PushLogsError; +use crate::sqlite_log::SqliteLog; +use crate::types::CollectionInfo; +use chroma_error::ChromaError; use chroma_types::{CollectionUuid, LogRecord, OperationRecord}; use std::fmt::Debug; @@ -20,6 +19,7 @@ pub struct CollectionRecord { #[derive(Clone, Debug)] pub enum Log { + Sqlite(SqliteLog), Grpc(GrpcLog), #[allow(dead_code)] InMemory(InMemoryLog), @@ -32,16 +32,19 @@ impl Log { offset: i64, batch_size: i32, end_timestamp: Option, - ) -> Result, PullLogsError> { + ) -> Result, Box> { match self { - Log::Grpc(log) => { - log.read(collection_id, offset, batch_size, end_timestamp) - .await - } - Log::InMemory(log) => { - log.read(collection_id, offset, batch_size, end_timestamp) - .await - } + Log::Sqlite(log) => log + .read(collection_id, offset, batch_size, end_timestamp) + .await + .map_err(|e| Box::new(e) as Box), + Log::Grpc(log) => log + .read(collection_id, offset, batch_size, end_timestamp) + .await + .map_err(|e| Box::new(e) as Box), + Log::InMemory(log) => Ok(log + .read(collection_id, offset, batch_size, end_timestamp) + .await), } } @@ -49,9 +52,16 @@ impl Log { &mut self, collection_id: CollectionUuid, records: Vec, - ) -> Result<(), PushLogsError> { + ) -> Result<(), Box> { match self { - Log::Grpc(log) => log.push_logs(collection_id, records).await, + Log::Sqlite(log) => log + .push_logs(collection_id, records) + .await + .map_err(|e| Box::new(e) as Box), + Log::Grpc(log) => log + .push_logs(collection_id, records) + .await + .map_err(|e| Box::new(e) as Box), Log::InMemory(_) => unimplemented!(), } } @@ -59,10 +69,17 @@ impl Log { pub async fn get_collections_with_new_data( &mut self, min_compaction_size: u64, - ) -> Result, GetCollectionsWithNewDataError> { + ) -> Result, Box> { match self { - Log::Grpc(log) => log.get_collections_with_new_data(min_compaction_size).await, - Log::InMemory(log) => log.get_collections_with_new_data(min_compaction_size).await, + Log::Sqlite(log) => log + .get_collections_with_new_data(min_compaction_size) + .await + .map_err(|e| Box::new(e) as Box), + Log::Grpc(log) => log + .get_collections_with_new_data(min_compaction_size) + .await + .map_err(|e| Box::new(e) as Box), + Log::InMemory(log) => Ok(log.get_collections_with_new_data(min_compaction_size).await), } } @@ -70,15 +87,20 @@ impl Log { &mut self, collection_id: CollectionUuid, new_offset: i64, - ) -> Result<(), UpdateCollectionLogOffsetError> { + ) -> Result<(), Box> { match self { - Log::Grpc(log) => { - log.update_collection_log_offset(collection_id, new_offset) - .await - } + Log::Sqlite(log) => log + .update_collection_log_offset(collection_id, new_offset) + .await + .map_err(|e| Box::new(e) as Box), + Log::Grpc(log) => log + .update_collection_log_offset(collection_id, new_offset) + .await + .map_err(|e| Box::new(e) as Box), Log::InMemory(log) => { log.update_collection_log_offset(collection_id, new_offset) - .await + .await; + Ok(()) } } } diff --git a/rust/log/src/sqlite_log.rs b/rust/log/src/sqlite_log.rs new file mode 100644 index 000000000000..19772188ce7b --- /dev/null +++ b/rust/log/src/sqlite_log.rs @@ -0,0 +1,448 @@ +use crate::{CollectionInfo, WrappedSqlxError}; +use chroma_error::{ChromaError, ErrorCodes}; +use chroma_types::{ + CollectionUuid, LogRecord, Operation, OperationRecord, ScalarEncoding, + ScalarEncodingConversionError, UpdateMetadata, UpdateMetadataValue, +}; +use futures::TryStreamExt; +use sqlx::sqlite::SqlitePool; +use sqlx::{QueryBuilder, Row}; +use std::str::FromStr; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum SqlitePullLogsError { + #[error("Query error: {0}")] + QueryError(#[from] WrappedSqlxError), + #[error("Failed to parse embedding encoding")] + InvalidEncoding(#[from] ScalarEncodingConversionError), + #[error("Failed to parse embedding: {0}")] + InvalidEmbedding(bytemuck::PodCastError), + #[error("Failed to parse metadata: {0}")] + InvalidMetadata(#[from] serde_json::Error), +} + +impl ChromaError for SqlitePullLogsError { + fn code(&self) -> ErrorCodes { + match self { + SqlitePullLogsError::QueryError(err) => err.code(), + SqlitePullLogsError::InvalidEncoding(_) => ErrorCodes::InvalidArgument, + SqlitePullLogsError::InvalidEmbedding(_) => ErrorCodes::InvalidArgument, + SqlitePullLogsError::InvalidMetadata(_) => ErrorCodes::InvalidArgument, + } + } +} + +#[derive(Error, Debug)] +pub enum SqlitePushLogsError { + #[error("Query error: {0}")] + QueryError(#[from] WrappedSqlxError), + #[error("Failed to serialize metadata: {0}")] + InvalidMetadata(#[from] serde_json::Error), +} + +impl ChromaError for SqlitePushLogsError { + fn code(&self) -> ErrorCodes { + match self { + SqlitePushLogsError::QueryError(err) => err.code(), + SqlitePushLogsError::InvalidMetadata(_) => ErrorCodes::Internal, + } + } +} + +#[derive(Error, Debug)] +pub enum SqliteGetCollectionsWithNewDataError { + #[error("Query error: {0}")] + QueryError(#[from] WrappedSqlxError), +} + +impl ChromaError for SqliteGetCollectionsWithNewDataError { + fn code(&self) -> ErrorCodes { + match self { + SqliteGetCollectionsWithNewDataError::QueryError(err) => err.code(), + } + } +} + +#[derive(Error, Debug)] +pub enum SqliteUpdateCollectionLogOffsetError { + #[error("Query error: {0}")] + QueryError(#[from] WrappedSqlxError), +} + +impl ChromaError for SqliteUpdateCollectionLogOffsetError { + fn code(&self) -> ErrorCodes { + match self { + SqliteUpdateCollectionLogOffsetError::QueryError(err) => err.code(), + } + } +} + +#[derive(Clone, Debug)] +pub struct SqliteLog { + pool: SqlitePool, + tenant_id: String, + topic_namespace: String, +} + +impl SqliteLog { + pub(super) async fn read( + &mut self, + collection_id: CollectionUuid, + offset: i64, + batch_size: i32, + end_timestamp_ns: Option, + ) -> Result, SqlitePullLogsError> { + let topic = get_topic_name(&self.tenant_id, &self.topic_namespace, collection_id); + + let end_timestamp_ns = end_timestamp_ns.unwrap_or(i64::MAX); + + let mut logs = sqlx::query( + r#" + SELECT + seq_id, + id, + operation, + vector, + encoding, + metadata + FROM embeddings_queue + WHERE topic = ? + AND seq_id >= ? + AND CAST(strftime('%s', created_at) AS INTEGER) <= (? / 1000000000) + ORDER BY seq_id ASC + LIMIT ? + "#, + ) + .bind(topic) + .bind(offset) + .bind(end_timestamp_ns) + .bind(batch_size) + .fetch(&self.pool); + + let mut records = Vec::new(); + while let Some(row) = logs.try_next().await.map_err(WrappedSqlxError)? { + let log_offset: i64 = row.get("seq_id"); + let id: String = row.get("id"); + let embedding_bytes = row.get::, _>("vector"); + let encoding = row + .get::, _>("encoding") + .map(ScalarEncoding::try_from) + .transpose()?; + let metadata_str = row.get::, _>("metadata"); + + // Parse embedding + let embedding = embedding_bytes + .map( + |embedding_bytes| -> Result, SqlitePullLogsError> { + match encoding { + Some(ScalarEncoding::FLOAT32) => { + let slice: &[f32] = bytemuck::try_cast_slice(embedding_bytes) + .map_err(SqlitePullLogsError::InvalidEmbedding)?; + Ok(Some(slice.to_vec())) + } + Some(ScalarEncoding::INT32) => { + unimplemented!() + } + None => Ok(None), + } + }, + ) + .transpose()? + .flatten(); + + // Parse metadata + let parsed_metadata_and_document: Option<(UpdateMetadata, Option)> = + metadata_str + .map(|metadata_str| { + let mut parsed: UpdateMetadata = serde_json::from_str(metadata_str)?; + + let document = match parsed.remove("chroma:document") { + Some(UpdateMetadataValue::Str(document)) => Some(document), + None => None, + _ => panic!("Document not found in metadata"), + }; + + Ok::<_, SqlitePullLogsError>((parsed, document)) + }) + .transpose()?; + let document = parsed_metadata_and_document + .as_ref() + .and_then(|(_, document)| document.clone()); + let metadata = parsed_metadata_and_document.map(|(metadata, _)| metadata); + + let operation = operation_from_code(row.get("operation")); + + records.push(LogRecord { + log_offset, + record: OperationRecord { + id, + embedding, + encoding, + metadata, + document, + operation, + }, + }); + } + + Ok(records) + } + + pub(super) async fn push_logs( + &mut self, + collection_id: CollectionUuid, + records: Vec, + ) -> Result<(), SqlitePushLogsError> { + let topic = get_topic_name(&self.tenant_id, &self.topic_namespace, collection_id); + + let records_and_serialized_metadatas = records + .into_iter() + .map(|mut record| { + let mut empty_metadata = UpdateMetadata::new(); + + let metadata = record.metadata.as_mut().unwrap_or(&mut empty_metadata); + if let Some(ref document) = record.document { + metadata.insert( + "chroma:document".to_string(), + UpdateMetadataValue::Str(document.clone()), + ); + } + + let serialized = serde_json::to_string(&metadata)?; + Ok::<_, SqlitePushLogsError>((record, serialized)) + }) + .collect::, SqlitePushLogsError>>()?; + + let mut query_builder = QueryBuilder::new( + "INSERT INTO embeddings_queue (topic, id, operation, vector, encoding, metadata) ", + ); + query_builder.push_values( + records_and_serialized_metadatas, + |mut builder, (record, serialized_metadata)| { + builder.push_bind(&topic); + builder.push_bind(record.id); + builder.push_bind(operation_to_code(record.operation)); + builder.push_bind::>>( + record + .embedding + .map(|e| bytemuck::cast_slice(e.as_slice()).to_vec()), + ); + builder.push_bind(record.encoding.map(String::from)); + builder.push_bind::(serialized_metadata); + }, + ); + let query = query_builder.build(); + query.execute(&self.pool).await.map_err(WrappedSqlxError)?; + + Ok(()) + } + + pub(super) async fn get_collections_with_new_data( + &mut self, + min_compaction_size: u64, + ) -> Result, SqliteGetCollectionsWithNewDataError> { + let mut results = sqlx::query( + r#" + SELECT + collections.id AS collection_id, + MIN(COALESCE(CAST(max_seq_id.seq_id AS INTEGER), 0)) AS first_log_offset, + CAST(strftime('%s', MIN(created_at)) AS INTEGER) * 1000000000 AS first_log_ts + FROM collections + INNER JOIN segments ON segments.collection = collections.id + INNER JOIN embeddings_queue ON embeddings_queue.topic = CONCAT('persistent://', ?, '/', ?, '/', collections.id) + LEFT JOIN max_seq_id ON max_seq_id.segment_id = segments.id + WHERE embeddings_queue.seq_id > COALESCE(CAST(max_seq_id.seq_id AS INTEGER), 0) + GROUP BY + collections.id + HAVING + COUNT(*) >= ? + ORDER BY first_log_ts ASC + "#, + ) + .bind(&self.tenant_id) + .bind(&self.topic_namespace) + .bind(min_compaction_size as i64) // (SQLite doesn't support u64) + .fetch(&self.pool); + + let mut infos = Vec::new(); + while let Some(row) = results.try_next().await.map_err(WrappedSqlxError)? { + infos.push(CollectionInfo { + collection_id: CollectionUuid::from_str(row.get::<&str, _>("collection_id")) + .unwrap(), + first_log_offset: row.get("first_log_offset"), + first_log_ts: row.get("first_log_ts"), + }); + } + + Ok(infos) + } + + pub async fn update_collection_log_offset( + &mut self, + collection_id: CollectionUuid, + new_offset: i64, + ) -> Result<(), SqliteUpdateCollectionLogOffsetError> { + sqlx::query( + r#" + INSERT OR REPLACE INTO max_seq_id (seq_id, segment_id) + SELECT ?, id + FROM segments + WHERE + collection = ? + "#, + ) + .bind(new_offset) + .bind(collection_id.0.to_string()) + .execute(&self.pool) + .await + .map_err(WrappedSqlxError)?; + + Ok(()) + } +} + +fn get_topic_name(tenant: &str, namespace: &str, collection_id: CollectionUuid) -> String { + format!("persistent://{}/{}/{}", tenant, namespace, collection_id) +} + +fn operation_from_code(code: u32) -> Operation { + // chromadb/db/mixins/embeddings_queue.py + match code { + 0 => Operation::Add, + 1 => Operation::Update, + 2 => Operation::Upsert, + 3 => Operation::Delete, + _ => panic!("Invalid operation code"), + } +} + +fn operation_to_code(operation: Operation) -> u32 { + match operation { + Operation::Add => 0, + Operation::Update => 1, + Operation::Upsert => 2, + Operation::Delete => 3, + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::*; + use chroma_types::CollectionUuid; + use sqlx::sqlite::SqlitePoolOptions; + + #[tokio::test] + async fn test_pull_logs() { + let pool = SqlitePoolOptions::new() + .connect("sqlite:///Users/maxisom/git/chroma/chroma_data/chroma.sqlite3") + .await + .unwrap(); + + let mut log = SqliteLog { + pool, + tenant_id: "default".to_string(), + topic_namespace: "default".to_string(), + }; + + let collection_id = + CollectionUuid::from_str("f54138fd-4f8a-4fb1-afec-830f0684c3fb").unwrap(); + let offset = 0; + let batch_size = 10; + let end_timestamp_ns = None; + + let logs = log + .read(collection_id, offset, batch_size, end_timestamp_ns) + .await + .unwrap(); + + for log in logs { + println!("{:?}", log); + } + } + + #[tokio::test] + async fn test_push_logs() { + let pool = SqlitePoolOptions::new() + .connect("sqlite:///Users/maxisom/git/chroma/chroma_data/chroma_mut.sqlite3") + .await + .unwrap(); + + let mut log = SqliteLog { + pool, + tenant_id: "default".to_string(), + topic_namespace: "default".to_string(), + }; + + let collection_id = + CollectionUuid::from_str("f54138fd-4f8a-4fb1-afec-830f0684c3fb").unwrap(); + let mut metadata = UpdateMetadata::new(); + metadata.insert( + "foo".to_string(), + UpdateMetadataValue::Str("bar".to_string()), + ); + + let record_to_add = OperationRecord { + id: "foo".to_string(), + embedding: Some(vec![1.0, 2.0, 3.0]), + encoding: Some(ScalarEncoding::FLOAT32), + metadata: Some(metadata), + document: Some("bar".to_string()), + operation: Operation::Add, + }; + + log.push_logs(collection_id, vec![record_to_add.clone()]) + .await + .unwrap(); + + let logs = log.read(collection_id, 0, 100, None).await.unwrap(); + let added_log = logs.iter().find(|log| log.record.id == "foo").unwrap(); + + assert_eq!(added_log.record.id, record_to_add.id); + assert_eq!(added_log.record.embedding, record_to_add.embedding); + assert_eq!(added_log.record.encoding, record_to_add.encoding); + assert_eq!(added_log.record.metadata, record_to_add.metadata); + assert_eq!(added_log.record.document, record_to_add.document); + assert_eq!(added_log.record.operation, record_to_add.operation); + } + + #[tokio::test] + async fn test_foo() { + let pool = SqlitePoolOptions::new() + .connect("sqlite:///Users/maxisom/git/chroma/chroma_data/chroma.sqlite3") + .await + .unwrap(); + + let mut log = SqliteLog { + pool, + tenant_id: "default".to_string(), + topic_namespace: "default".to_string(), + }; + + println!("{:?}", log.get_collections_with_new_data(0).await.unwrap()); + } + + #[tokio::test] + async fn test_bar() { + let pool = SqlitePoolOptions::new() + .connect("sqlite:///Users/maxisom/git/chroma/chroma_data/chroma_mut.sqlite3") + .await + .unwrap(); + + let mut log = SqliteLog { + pool, + tenant_id: "default".to_string(), + topic_namespace: "default".to_string(), + }; + let collection_id = + CollectionUuid::from_str("f54138fd-4f8a-4fb1-afec-830f0684c3fb").unwrap(); + + log.update_collection_log_offset(collection_id, 10) + .await + .unwrap(); + + // println!("{:?}", log.get_collections_with_new_data(100).await); + } +} diff --git a/rust/log/src/types.rs b/rust/log/src/types.rs index 948c5349b9d1..a2eb9456363c 100644 --- a/rust/log/src/types.rs +++ b/rust/log/src/types.rs @@ -1,5 +1,5 @@ use chroma_error::{ChromaError, ErrorCodes}; -use chroma_types::{CollectionUuid, RecordConversionError}; +use chroma_types::CollectionUuid; use thiserror::Error; /// CollectionInfo is a struct that contains information about a collection for the @@ -15,68 +15,18 @@ pub struct CollectionInfo { pub first_log_ts: i64, } -#[derive(Error, Debug)] -pub enum PullLogsError { - #[error("Failed to fetch")] - FailedToPullLogs(#[from] tonic::Status), - #[error("Failed to convert proto embedding record into EmbeddingRecord")] - ConversionError(#[from] RecordConversionError), -} - -impl ChromaError for PullLogsError { - fn code(&self) -> ErrorCodes { - match self { - PullLogsError::FailedToPullLogs(_) => ErrorCodes::Internal, - PullLogsError::ConversionError(_) => ErrorCodes::Internal, - } - } -} - -#[derive(Error, Debug)] -pub enum PushLogsError { - #[error("Failed to push logs")] - FailedToPushLogs(#[from] tonic::Status), - #[error("Failed to convert records to proto")] - ConversionError(#[from] RecordConversionError), -} - -impl ChromaError for PushLogsError { - fn code(&self) -> ErrorCodes { - match self { - PushLogsError::FailedToPushLogs(_) => ErrorCodes::Internal, - PushLogsError::ConversionError(_) => ErrorCodes::Internal, - } - } -} - -#[derive(Error, Debug)] -pub enum GetCollectionsWithNewDataError { - #[error("Failed to fetch")] - FailedGetCollectionsWithNewData(#[from] tonic::Status), -} - -impl ChromaError for GetCollectionsWithNewDataError { - fn code(&self) -> ErrorCodes { - match self { - GetCollectionsWithNewDataError::FailedGetCollectionsWithNewData(_) => { - ErrorCodes::Internal - } - } - } -} - -#[derive(Error, Debug)] -pub enum UpdateCollectionLogOffsetError { - #[error("Failed to update collection log offset")] - FailedToUpdateCollectionLogOffset(#[from] tonic::Status), -} +/// Implements `ChromaError` for `sqlx::Error`. +#[derive(Debug, Error)] +#[error("Database error: {0}")] +pub struct WrappedSqlxError(pub sqlx::Error); -impl ChromaError for UpdateCollectionLogOffsetError { - fn code(&self) -> ErrorCodes { - match self { - UpdateCollectionLogOffsetError::FailedToUpdateCollectionLogOffset(_) => { - ErrorCodes::Internal - } +impl ChromaError for WrappedSqlxError { + fn code(&self) -> chroma_error::ErrorCodes { + match self.0 { + sqlx::Error::RowNotFound => ErrorCodes::NotFound, + sqlx::Error::PoolTimedOut => ErrorCodes::ResourceExhausted, + sqlx::Error::PoolClosed => ErrorCodes::Unavailable, + _ => ErrorCodes::Internal, } } } diff --git a/rust/types/src/metadata.rs b/rust/types/src/metadata.rs index 190feec3804d..34d43efaa837 100644 --- a/rust/types/src/metadata.rs +++ b/rust/types/src/metadata.rs @@ -1,5 +1,5 @@ use chroma_error::{ChromaError, ErrorCodes}; -use serde::{Deserialize, Serialize}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use std::{ cmp::Ordering, collections::{HashMap, HashSet}, @@ -92,6 +92,91 @@ impl TryFrom<&UpdateMetadataValue> for MetadataValue { } } +impl Serialize for UpdateMetadataValue { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + UpdateMetadataValue::Bool(b) => serializer.serialize_bool(*b), + UpdateMetadataValue::Int(i) => serializer.serialize_i64(*i), + UpdateMetadataValue::Float(f) => serializer.serialize_f64(*f), + UpdateMetadataValue::Str(s) => serializer.serialize_str(s), + UpdateMetadataValue::None => serializer.serialize_unit(), + } + } +} + +impl<'de> Deserialize<'de> for UpdateMetadataValue { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct UpdateMetadataValueVisitor; + + impl Visitor<'_> for UpdateMetadataValueVisitor { + type Value = UpdateMetadataValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a bool, an integer, a float, a string, or null") + } + + fn visit_bool(self, value: bool) -> Result { + Ok(UpdateMetadataValue::Bool(value)) + } + + fn visit_i64(self, value: i64) -> Result { + Ok(UpdateMetadataValue::Int(value)) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + // Because Serde may parse some integers as u64, + // convert to i64 if it fits, or fail otherwise: + if value <= i64::MAX as u64 { + Ok(UpdateMetadataValue::Int(value as i64)) + } else { + Err(E::invalid_value( + serde::de::Unexpected::Unsigned(value), + &self, + )) + } + } + + fn visit_f64(self, value: f64) -> Result { + Ok(UpdateMetadataValue::Float(value)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + Ok(UpdateMetadataValue::Str(value.to_owned())) + } + + fn visit_string(self, value: String) -> Result + where + E: serde::de::Error, + { + Ok(UpdateMetadataValue::Str(value)) + } + + fn visit_none(self) -> Result { + Ok(UpdateMetadataValue::None) + } + + fn visit_unit(self) -> Result { + // null in JSON is represented as unit in Serde + Ok(UpdateMetadataValue::None) + } + } + + deserializer.deserialize_any(UpdateMetadataValueVisitor) + } +} + /* =========================================== MetadataValue diff --git a/rust/types/src/scalar_encoding.rs b/rust/types/src/scalar_encoding.rs index 7c1a45f6eee3..49d1a2fea03e 100644 --- a/rust/types/src/scalar_encoding.rs +++ b/rust/types/src/scalar_encoding.rs @@ -49,6 +49,27 @@ impl TryFrom for ScalarEncoding { } } +impl TryFrom<&str> for ScalarEncoding { + type Error = ScalarEncodingConversionError; + + fn try_from(encoding: &str) -> Result { + match encoding { + "FLOAT32" => Ok(ScalarEncoding::FLOAT32), + "INT32" => Ok(ScalarEncoding::INT32), + _ => Err(ScalarEncodingConversionError::InvalidEncoding), + } + } +} + +impl From for String { + fn from(encoding: ScalarEncoding) -> String { + match encoding { + ScalarEncoding::FLOAT32 => "FLOAT32".to_string(), + ScalarEncoding::INT32 => "INT32".to_string(), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/worker/src/execution/operators/fetch_log.rs b/rust/worker/src/execution/operators/fetch_log.rs index 1518bbc6c5c0..d984858cb497 100644 --- a/rust/worker/src/execution/operators/fetch_log.rs +++ b/rust/worker/src/execution/operators/fetch_log.rs @@ -2,7 +2,7 @@ use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH}; use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_log::{Log, PullLogsError}; +use chroma_log::Log; use chroma_system::{Operator, OperatorType}; use chroma_types::{Chunk, CollectionUuid, LogRecord}; use thiserror::Error; @@ -43,7 +43,7 @@ pub type FetchLogOutput = Chunk; #[derive(Error, Debug)] pub enum FetchLogError { #[error("Error when pulling log: {0}")] - PullLog(#[from] PullLogsError), + PullLog(#[from] Box), #[error("Error when capturing system time: {0}")] SystemTime(#[from] SystemTimeError), } diff --git a/rust/worker/src/execution/operators/register.rs b/rust/worker/src/execution/operators/register.rs index 94f64dcaf37b..7531cbd36c09 100644 --- a/rust/worker/src/execution/operators/register.rs +++ b/rust/worker/src/execution/operators/register.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use chroma_error::{ChromaError, ErrorCodes}; -use chroma_log::{Log, UpdateCollectionLogOffsetError}; +use chroma_log::Log; use chroma_sysdb::FlushCompactionError; use chroma_sysdb::SysDb; use chroma_system::Operator; @@ -86,7 +86,7 @@ pub enum RegisterError { #[error("Flush compaction error: {0}")] FlushCompactionError(#[from] FlushCompactionError), #[error("Update log offset error: {0}")] - UpdateLogOffsetError(#[from] UpdateCollectionLogOffsetError), + UpdateLogOffsetError(#[from] Box), } impl ChromaError for RegisterError {