diff --git a/Cargo.lock b/Cargo.lock index 649c965..70a4d77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -78,6 +78,15 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -430,6 +439,7 @@ dependencies = [ "portable-atomic-util", "rawpointer", "rayon", + "serde", ] [[package]] @@ -824,9 +834,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.133" +version = "1.0.134" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" +checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" dependencies = [ "itoa", "memchr", @@ -904,6 +914,7 @@ name = "vector_quantizer" version = "0.0.2" dependencies = [ "anyhow", + "bincode", "criterion", "env_logger", "log", @@ -916,6 +927,7 @@ dependencies = [ "rand_distr", "rayon", "serde", + "serde_json", "thiserror", ] diff --git a/Cargo.toml b/Cargo.toml index 953c568..80b3859 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ categories = ["algorithms"] [dependencies] anyhow = "1.0.93" -ndarray = { version = "0.16.1", features = ["rayon"] } +ndarray = { version = "0.16.1", features = ["rayon", "serde"] } ndarray-stats = "0.6.0" rand = "0.9.0-alpha.2" ndarray-rand = "0.15.0" @@ -24,6 +24,8 @@ env_logger = "0.11.5" serde = { version = "1.0.215", features = ["derive"] } numpy = "0.23.0" num-traits = "0.2.19" +serde_json = "1.0.134" +bincode = "1.3.3" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/errors.rs b/src/errors.rs index ede0fdb..fa9d4f9 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -69,4 +69,13 @@ pub enum PQError { #[error("Unsupported initialization method")] InvalidInitMethod, + + #[error("Unable to create dump file on disk")] + IoError, + + #[error("Unable to serialize trained quantizer to disk")] + SerializationError, + + #[error("Unable to deserialize trained quantizer from disk")] + DeserializationError, } diff --git a/src/pq.rs b/src/pq.rs index de22c21..42e2d32 100644 --- a/src/pq.rs +++ b/src/pq.rs @@ -1,16 +1,21 @@ use crate::errors::PQError; use crate::utils::{determine_code_type, euclidean_distance, kmeans2}; +use bincode; use log::{info, trace, warn}; use ndarray::parallel::prelude::*; use ndarray::{s, Array2, Array3, Axis}; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::{BufReader, BufWriter}; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)] pub enum CodeType { U8, U16, U32, } +#[derive(Serialize, Deserialize)] pub struct PQ { pub m: usize, pub code_dtype: CodeType, @@ -199,13 +204,28 @@ impl PQ { let codes = self.encode(vecs)?; self.decode(&codes) } + + pub fn dump(&self, path: &str) -> Result<(), PQError> { + let file = File::create(path).map_err(|_| PQError::IoError)?; + let writer = BufWriter::new(file); + bincode::serialize_into(writer, &self).map_err(|_| PQError::SerializationError)?; + Ok(()) + } + + pub fn load(path: &str) -> Result { + let file = File::open(path).map_err(|_| PQError::IoError)?; + let reader = BufReader::new(file); + let pq = bincode::deserialize_from(reader).map_err(|_| PQError::DeserializationError)?; + Ok(pq) + } } #[cfg(test)] mod tests { use super::*; use crate::utils::create_random_vectors; - use ndarray::Array2; + use ndarray::{array, Array2}; + use rand::Rng; fn create_dummy_vectors(num_vectors: usize, dimension: usize) -> Array2 { Array2::::zeros((num_vectors, dimension)) @@ -387,4 +407,99 @@ mod tests { "Encode should succeed with valid u8 code values" ); } + + #[test] + fn test_dump_and_load() { + let mut pq = PQ::try_new(4, 4).unwrap(); + let vecs = array![ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0] + ]; + pq.fit(&vecs, 10).unwrap(); + + let path = "test_pq.json"; + pq.dump(path).unwrap(); + + let loaded_pq = PQ::load(path).unwrap(); + + assert_eq!(pq.m, loaded_pq.m); + assert_eq!(pq.ks, loaded_pq.ks); + assert_eq!(pq.code_dtype, loaded_pq.code_dtype); + assert_eq!(pq.ds, loaded_pq.ds); + assert_eq!(pq.dim, loaded_pq.dim); + + if let (Some(ref original), Some(ref loaded)) = (&pq.codewords, &loaded_pq.codewords) { + assert_eq!( + original.shape(), + loaded.shape(), + "Codewords shapes do not match" + ); + for (orig, load) in original.iter().zip(loaded.iter()) { + assert!( + (orig - load).abs() < 1e-6, + "Codewords do not match: original={} loaded={}", + orig, + load + ); + } + } else { + panic!("Codewords are missing after loading"); + } + + std::fs::remove_file(path).unwrap(); + } + + #[test] + fn test_dump_and_load_large() { + fn create_random_vectors(num_vectors: usize, dimension: usize) -> Array2 { + let mut rng = rand::thread_rng(); + let data: Vec = (0..num_vectors * dimension) + .map(|_| rng.gen_range(0.0..1.0)) + .collect(); + Array2::from_shape_vec((num_vectors, dimension), data).unwrap() + } + + let num_vectors = 100_000; + let dimension = 384; + let subspaces = 8; + let centroids = 256; + + let vecs = create_random_vectors(num_vectors, dimension); + + let mut pq = PQ::try_new(subspaces, centroids).unwrap(); + pq.fit(&vecs, 1).unwrap(); + + let path = "test_pq_large.bin"; + pq.dump(path).unwrap(); + + let loaded_pq = PQ::load(path).unwrap(); + + assert_eq!(pq.m, loaded_pq.m); + assert_eq!(pq.ks, loaded_pq.ks); + assert_eq!(pq.code_dtype, loaded_pq.code_dtype); + assert_eq!(pq.ds, loaded_pq.ds); + assert_eq!(pq.dim, loaded_pq.dim); + + if let (Some(ref original), Some(ref loaded)) = (&pq.codewords, &loaded_pq.codewords) { + assert_eq!( + original.shape(), + loaded.shape(), + "Codewords shapes do not match" + ); + for (orig, load) in original.iter().zip(loaded.iter()) { + assert!( + (orig - load).abs() < 1e-6, + "Codewords do not match: original={} loaded={}", + orig, + load + ); + } + } else { + panic!("Codewords are missing after loading"); + } + + std::fs::remove_file(path).unwrap(); + } }