Skip to content

Commit

Permalink
Merge pull request #9 from oramasearch/feat/dump
Browse files Browse the repository at this point in the history
feat: save and restore trained quantizer
  • Loading branch information
micheleriva authored Dec 26, 2024
2 parents ab0abc5 + 7f27c0d commit 0447325
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 6 deletions.
18 changes: 15 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ categories = ["algorithms"]

[dependencies]
anyhow = "1.0.93"
ndarray = { version = "0.16.1", features = ["rayon"] }
ndarray = { version = "0.16.1", features = ["rayon", "serde"] }
ndarray-stats = "0.6.0"
rand = "0.9.0-alpha.2"
ndarray-rand = "0.15.0"
Expand All @@ -24,6 +24,8 @@ env_logger = "0.11.5"
serde = { version = "1.0.215", features = ["derive"] }
numpy = "0.23.0"
num-traits = "0.2.19"
serde_json = "1.0.134"
bincode = "1.3.3"

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
Expand Down
9 changes: 9 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,13 @@ pub enum PQError {

#[error("Unsupported initialization method")]
InvalidInitMethod,

#[error("Unable to create dump file on disk")]
IoError,

#[error("Unable to serialize trained quantizer to disk")]
SerializationError,

#[error("Unable to deserialize trained quantizer from disk")]
DeserializationError,
}
119 changes: 117 additions & 2 deletions src/pq.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use crate::errors::PQError;
use crate::utils::{determine_code_type, euclidean_distance, kmeans2};
use bincode;
use log::{info, trace, warn};
use ndarray::parallel::prelude::*;
use ndarray::{s, Array2, Array3, Axis};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter};

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub enum CodeType {
U8,
U16,
U32,
}

#[derive(Serialize, Deserialize)]
pub struct PQ {
pub m: usize,
pub code_dtype: CodeType,
Expand Down Expand Up @@ -199,13 +204,28 @@ impl PQ {
let codes = self.encode(vecs)?;
self.decode(&codes)
}

pub fn dump(&self, path: &str) -> Result<(), PQError> {
let file = File::create(path).map_err(|_| PQError::IoError)?;
let writer = BufWriter::new(file);
bincode::serialize_into(writer, &self).map_err(|_| PQError::SerializationError)?;
Ok(())
}

pub fn load(path: &str) -> Result<Self, PQError> {
let file = File::open(path).map_err(|_| PQError::IoError)?;
let reader = BufReader::new(file);
let pq = bincode::deserialize_from(reader).map_err(|_| PQError::DeserializationError)?;
Ok(pq)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::utils::create_random_vectors;
use ndarray::Array2;
use ndarray::{array, Array2};
use rand::Rng;

fn create_dummy_vectors(num_vectors: usize, dimension: usize) -> Array2<f32> {
Array2::<f32>::zeros((num_vectors, dimension))
Expand Down Expand Up @@ -387,4 +407,99 @@ mod tests {
"Encode should succeed with valid u8 code values"
);
}

#[test]
fn test_dump_and_load() {
let mut pq = PQ::try_new(4, 4).unwrap();
let vecs = array![
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0]
];
pq.fit(&vecs, 10).unwrap();

let path = "test_pq.json";
pq.dump(path).unwrap();

let loaded_pq = PQ::load(path).unwrap();

assert_eq!(pq.m, loaded_pq.m);
assert_eq!(pq.ks, loaded_pq.ks);
assert_eq!(pq.code_dtype, loaded_pq.code_dtype);
assert_eq!(pq.ds, loaded_pq.ds);
assert_eq!(pq.dim, loaded_pq.dim);

if let (Some(ref original), Some(ref loaded)) = (&pq.codewords, &loaded_pq.codewords) {
assert_eq!(
original.shape(),
loaded.shape(),
"Codewords shapes do not match"
);
for (orig, load) in original.iter().zip(loaded.iter()) {
assert!(
(orig - load).abs() < 1e-6,
"Codewords do not match: original={} loaded={}",
orig,
load
);
}
} else {
panic!("Codewords are missing after loading");
}

std::fs::remove_file(path).unwrap();
}

#[test]
fn test_dump_and_load_large() {
fn create_random_vectors(num_vectors: usize, dimension: usize) -> Array2<f32> {
let mut rng = rand::thread_rng();
let data: Vec<f32> = (0..num_vectors * dimension)
.map(|_| rng.gen_range(0.0..1.0))
.collect();
Array2::from_shape_vec((num_vectors, dimension), data).unwrap()
}

let num_vectors = 100_000;
let dimension = 384;
let subspaces = 8;
let centroids = 256;

let vecs = create_random_vectors(num_vectors, dimension);

let mut pq = PQ::try_new(subspaces, centroids).unwrap();
pq.fit(&vecs, 1).unwrap();

let path = "test_pq_large.bin";
pq.dump(path).unwrap();

let loaded_pq = PQ::load(path).unwrap();

assert_eq!(pq.m, loaded_pq.m);
assert_eq!(pq.ks, loaded_pq.ks);
assert_eq!(pq.code_dtype, loaded_pq.code_dtype);
assert_eq!(pq.ds, loaded_pq.ds);
assert_eq!(pq.dim, loaded_pq.dim);

if let (Some(ref original), Some(ref loaded)) = (&pq.codewords, &loaded_pq.codewords) {
assert_eq!(
original.shape(),
loaded.shape(),
"Codewords shapes do not match"
);
for (orig, load) in original.iter().zip(loaded.iter()) {
assert!(
(orig - load).abs() < 1e-6,
"Codewords do not match: original={} loaded={}",
orig,
load
);
}
} else {
panic!("Codewords are missing after loading");
}

std::fs::remove_file(path).unwrap();
}
}

0 comments on commit 0447325

Please sign in to comment.