Skip to content

Commit

Permalink
feat: rerank by fetching vectors in heap table
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi committed Feb 13, 2025
1 parent b0003f8 commit b637c0f
Show file tree
Hide file tree
Showing 14 changed files with 680 additions and 95 deletions.
14 changes: 11 additions & 3 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ jobs:
steps:
- name: Set up Environment
run: |
sudo apt-get remove -y '^postgres.*' '^libpq.*'
sudo apt-get purge -y '^postgres.*' '^libpq.*'
if [ "${{ matrix.runner }}" = "ubuntu-24.04" ]; then
sudo apt-get remove -y '^postgres.*' '^libpq.*'
sudo apt-get purge -y '^postgres.*' '^libpq.*'
fi
sudo update-alternatives --install /usr/bin/clang clang $(which clang-18) 255
Expand All @@ -118,7 +120,13 @@ jobs:
sudo systemctl stop postgresql
curl -fsSL https://github.com/tensorchord/pgrx/releases/download/v0.12.9/cargo-pgrx-v0.12.9-$(uname -m)-unknown-linux-musl.tar.gz | tar -xOzf - ./cargo-pgrx | install -m 755 /dev/stdin /usr/local/bin/cargo-pgrx
cargo pgrx init --pg${{ matrix.version }}=$(which pg_config)
if [ "${{ matrix.runner }}" = "ubuntu-24.04" ]; then
cargo pgrx init --pg${{ matrix.version }}=$(which pg_config)
fi
if [ "${{ matrix.runner }}" = "ubuntu-24.04-arm" ]; then
mkdir $HOME/.pgrx
echo "configs.pg${{ matrix.version }} = \"$(which pg_config)\"" > $HOME/.pgrx/config.toml
fi
curl -fsSL https://github.com/risinglightdb/sqllogictest-rs/releases/download/v0.26.4/sqllogictest-bin-v0.26.4-$(uname -m)-unknown-linux-musl.tar.gz | tar -xOzf - ./sqllogictest | install -m 755 /dev/stdin /usr/local/bin/sqllogictest
Expand Down
1 change: 1 addition & 0 deletions crates/algorithm/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub fn build<O: Operator>(
dims,
height_of_root: structures.len() as u32,
is_residual,
rerank_in_heap: vchordrq_options.rerank_in_table,
vectors_first: vectors.first(),
root_mean: pointer_of_means.last().unwrap()[0],
root_first: pointer_of_firsts.last().unwrap()[0],
Expand Down
7 changes: 6 additions & 1 deletion crates/algorithm/src/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub fn insert<O: Operator>(index: impl RelationWrite, payload: NonZeroU64, vecto
let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::<MetaTuple>);
let dims = meta_tuple.dims();
let is_residual = meta_tuple.is_residual();
let rerank_in_heap = meta_tuple.rerank_in_heap();
let height_of_root = meta_tuple.height_of_root();
assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions");
let root_mean = meta_tuple.root_mean();
Expand All @@ -31,7 +32,11 @@ pub fn insert<O: Operator>(index: impl RelationWrite, payload: NonZeroU64, vecto
None
};

let mean = vectors::append::<O>(index.clone(), vectors_first, vector.as_borrowed(), payload);
let mean = if !rerank_in_heap {
vectors::append::<O>(index.clone(), vectors_first, vector.as_borrowed(), payload)
} else {
IndexPointer::default()
};

type State<O> = (u32, Option<<O as Operator>::Vector>);
let mut state: State<O> = {
Expand Down
8 changes: 8 additions & 0 deletions crates/algorithm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod linked_vec;
mod maintain;
mod pipe;
mod prewarm;
mod rerank;
mod search;
mod select_heap;
mod tape;
Expand All @@ -27,6 +28,7 @@ pub use cache::cache;
pub use insert::insert;
pub use maintain::maintain;
pub use prewarm::prewarm;
pub use rerank::{rerank_heap, rerank_index};
pub use search::search;

use std::ops::{Deref, DerefMut};
Expand Down Expand Up @@ -72,3 +74,9 @@ pub trait RelationWrite: RelationRead {
fn extend(&self, tracking_freespace: bool) -> Self::WriteGuard<'_>;
fn search(&self, freespace: usize) -> Option<Self::WriteGuard<'_>>;
}

#[derive(Debug, Clone, Copy)]
pub enum RerankMethod {
Index,
Heap,
}
71 changes: 71 additions & 0 deletions crates/algorithm/src/rerank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::operator::*;
use crate::tuples::*;
use crate::{RelationRead, vectors};
use always_equal::AlwaysEqual;
use distance::Distance;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::num::NonZeroU64;
use vector::VectorOwned;

pub fn rerank_index<O: Operator>(
index: impl RelationRead,
vector: O::Vector,
results: Vec<(
Reverse<Distance>,
AlwaysEqual<IndexPointer>,
AlwaysEqual<NonZeroU64>,
)>,
) -> impl Iterator<Item = (Distance, NonZeroU64)> {
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(move || {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap();
if let Some(dis_u) = vectors::access_0::<O, _>(
index.clone(),
mean,
pay_u,
LAccess::new(
O::Vector::elements_and_metadata(vector.as_borrowed()),
O::DistanceAccessor::default(),
),
) {
cache.push((Reverse(dis_u), AlwaysEqual(pay_u)));
};
}
let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?;
Some((dis_u, pay_u))
})
}

pub fn rerank_heap<O: Operator, F>(
vector: O::Vector,
results: Vec<(
Reverse<Distance>,
AlwaysEqual<IndexPointer>,
AlwaysEqual<NonZeroU64>,
)>,
fetch: F,
) -> impl Iterator<Item = (Distance, NonZeroU64)>
where
F: Fn(NonZeroU64) -> Option<O::Vector>,
{
let mut heap = BinaryHeap::from(results);
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(move || {
let vector = O::Vector::elements_and_metadata(vector.as_borrowed());
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(_), AlwaysEqual(pay_u)) = heap.pop().unwrap();
if let Some(vec_u) = fetch(pay_u) {
let vec_u = O::Vector::elements_and_metadata(vec_u.as_borrowed());
let mut accessor = O::DistanceAccessor::default();
accessor.push(vector.0, vec_u.0);
let dis_u = accessor.finish(vector.1, vec_u.1);
cache.push((Reverse(dis_u), AlwaysEqual(pay_u)));
}
}
let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?;
Some((dis_u, pay_u))
})
}
48 changes: 25 additions & 23 deletions crates/algorithm/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::operator::*;
use crate::pipe::Pipe;
use crate::tape::{access_0, access_1};
use crate::tuples::*;
use crate::{Page, RelationRead, vectors};
use crate::{Page, RelationRead, RerankMethod, vectors};
use always_equal::AlwaysEqual;
use distance::Distance;
use std::cmp::Reverse;
Expand All @@ -16,14 +16,28 @@ pub fn search<O: Operator>(
vector: O::Vector,
probes: Vec<u32>,
epsilon: f32,
) -> impl Iterator<Item = (Distance, NonZeroU64)> {
) -> (
RerankMethod,
Vec<(
Reverse<Distance>,
AlwaysEqual<IndexPointer>,
AlwaysEqual<NonZeroU64>,
)>,
) {
let meta_guard = index.read(0);
let meta_tuple = meta_guard.get(1).unwrap().pipe(read_tuple::<MetaTuple>);
let dims = meta_tuple.dims();
let is_residual = meta_tuple.is_residual();
let rerank_in_heap = meta_tuple.rerank_in_heap();
let height_of_root = meta_tuple.height_of_root();
assert_eq!(dims, vector.as_borrowed().dims(), "unmatched dimensions");
assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes");
if height_of_root as usize != 1 + probes.len() {
panic!(
"need {} probes, but {} probes provided",
height_of_root - 1,
probes.len()
);
}
let root_mean = meta_tuple.root_mean();
let root_first = meta_tuple.root_first();
drop(meta_guard);
Expand Down Expand Up @@ -145,24 +159,12 @@ pub fn search<O: Operator>(
},
);
}
let mut heap = BinaryHeap::from(results.into_vec());
let mut cache = BinaryHeap::<(Reverse<Distance>, _)>::new();
std::iter::from_fn(move || {
while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) {
let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap();
if let Some(dis_u) = vectors::access_0::<O, _>(
index.clone(),
mean,
pay_u,
LAccess::new(
O::Vector::elements_and_metadata(vector.as_borrowed()),
O::DistanceAccessor::default(),
),
) {
cache.push((Reverse(dis_u), AlwaysEqual(pay_u)));
};
}
let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?;
Some((dis_u, pay_u))
})
(
if rerank_in_heap {
RerankMethod::Heap
} else {
RerankMethod::Index
},
results.into_vec(),
)
}
10 changes: 8 additions & 2 deletions crates/algorithm/src/tuples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use zerocopy_derive::{FromBytes, Immutable, IntoBytes, KnownLayout};
pub const ALIGN: usize = 8;
pub type Tag = u64;
const MAGIC: u64 = u64::from_ne_bytes(*b"vchordrq");
const VERSION: u64 = 1;
const VERSION: u64 = 2;

pub trait Tuple: 'static {
type Reader<'a>: TupleReader<'a, Tuple = Self>;
Expand Down Expand Up @@ -50,7 +50,8 @@ struct MetaTupleHeader {
dims: u32,
height_of_root: u32,
is_residual: Bool,
_padding_0: [ZeroU8; 3],
rerank_in_heap: Bool,
_padding_0: [ZeroU8; 2],
vectors_first: u32,
// raw vector
root_mean: IndexPointer,
Expand All @@ -63,6 +64,7 @@ pub struct MetaTuple {
pub dims: u32,
pub height_of_root: u32,
pub is_residual: bool,
pub rerank_in_heap: bool,
pub vectors_first: u32,
pub root_mean: IndexPointer,
pub root_first: u32,
Expand All @@ -79,6 +81,7 @@ impl Tuple for MetaTuple {
dims: self.dims,
height_of_root: self.height_of_root,
is_residual: self.is_residual.into(),
rerank_in_heap: self.rerank_in_heap.into(),
_padding_0: Default::default(),
vectors_first: self.vectors_first,
root_mean: self.root_mean,
Expand Down Expand Up @@ -125,6 +128,9 @@ impl MetaTupleReader<'_> {
pub fn is_residual(self) -> bool {
self.header.is_residual.into()
}
pub fn rerank_in_heap(self) -> bool {
self.header.rerank_in_heap.into()
}
pub fn vectors_first(self) -> u32 {
self.header.vectors_first
}
Expand Down
5 changes: 5 additions & 0 deletions crates/algorithm/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ use vector::vect::{VectBorrowed, VectOwned};
pub struct VchordrqIndexOptions {
#[serde(default = "VchordrqIndexOptions::default_residual_quantization")]
pub residual_quantization: bool,
#[serde(default = "VchordrqIndexOptions::default_rerank_in_table")]
pub rerank_in_table: bool,
}

impl VchordrqIndexOptions {
fn default_residual_quantization() -> bool {
false
}
fn default_rerank_in_table() -> bool {
false
}
}

#[derive(Debug, Clone)]
Expand Down
3 changes: 3 additions & 0 deletions crates/k_means/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ fn quick_centers(
) -> Vec<Vec<f32>> {
let n = samples.len();
assert!(c >= n);
if c == 1 && n == 0 {
return vec![vec![0.0; dims]];
}
let mut rng = rand::thread_rng();
let mut centroids = samples;
for _ in n..c {
Expand Down
10 changes: 6 additions & 4 deletions src/index/am/am_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ pub unsafe extern "C" fn ambuild(
}
VchordrqBuildSourceOptions::Internal(internal_build) => {
let mut tuples_total = 0_u64;
let samples = {
let samples = 'a: {
let mut rand = rand::thread_rng();
let max_number_of_samples = internal_build
let Some(max_number_of_samples) = internal_build
.lists
.last()
.unwrap()
.saturating_mul(internal_build.sampling_factor);
.map(|x| x.saturating_mul(internal_build.sampling_factor))
else {
break 'a Vec::new();
};
let mut samples = Vec::new();
let mut number_of_samples = 0_u32;
match opfamily.vector_kind() {
Expand Down
Loading

0 comments on commit b637c0f

Please sign in to comment.