Skip to content

Commit

Permalink
fix: respect aliasing rule by not reading past of reference (#169)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Jan 20, 2025
1 parent 6af8bdd commit 345e3bf
Show file tree
Hide file tree
Showing 11 changed files with 531 additions and 514 deletions.
4 changes: 2 additions & 2 deletions src/algorithm/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl Structure {
let mut parents = BTreeMap::new();
let mut vectors = BTreeMap::new();
pgrx::spi::Spi::connect(|client| {
use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput;
use crate::datatype::memory_vector::VectorOutput;
use pgrx::pg_sys::panic::ErrorReportable;
use vector::VectorBorrowed;
let schema_query = "SELECT n.nspname::TEXT
Expand All @@ -237,7 +237,7 @@ impl Structure {
for row in centroids {
let id: Option<i32> = row.get_by_name("id").unwrap();
let parent: Option<i32> = row.get_by_name("parent").unwrap();
let vector: Option<PgvectorVectorOutput> = row.get_by_name("vector").unwrap();
let vector: Option<VectorOutput> = row.get_by_name("vector").unwrap();
let id = id.expect("external build: id could not be NULL");
let vector = vector.expect("external build: vector could not be NULL");
let pop = parents.insert(id, parent);
Expand Down
8 changes: 4 additions & 4 deletions src/datatype/functions_scalar8.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput;
use crate::datatype::memory_pgvector_vector::PgvectorVectorInput;
use crate::datatype::memory_halfvec::HalfvecInput;
use crate::datatype::memory_scalar8::Scalar8Output;
use crate::datatype::memory_vector::VectorInput;
use half::f16;
use simd::Floating;
use vector::scalar8::Scalar8Borrowed;

#[pgrx::pg_extern(sql = "")]
fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Output {
fn _vchord_vector_quantize_to_scalar8(vector: VectorInput) -> Scalar8Output {
let vector = vector.as_borrowed();
let sum_of_x2 = f32::reduce_sum_of_x2(vector.slice());
let (k, b, code) =
Expand All @@ -16,7 +16,7 @@ fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Out
}

#[pgrx::pg_extern(sql = "")]
fn _vchord_halfvec_quantize_to_scalar8(vector: PgvectorHalfvecInput) -> Scalar8Output {
fn _vchord_halfvec_quantize_to_scalar8(vector: HalfvecInput) -> Scalar8Output {
let vector = vector.as_borrowed();
let sum_of_x2 = f16::reduce_sum_of_x2(vector.slice());
let (k, b, code) =
Expand Down
210 changes: 210 additions & 0 deletions src/datatype/memory_halfvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use half::f16;
use pgrx::datum::FromDatum;
use pgrx::datum::IntoDatum;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use std::marker::PhantomData;
use std::ptr::NonNull;
use vector::VectorBorrowed;
use vector::vect::VectBorrowed;

#[repr(C, align(8))]
pub struct HalfvecHeader {
varlena: u32,
dims: u16,
unused: u16,
elements: [f16; 0],
}

impl HalfvecHeader {
fn size_of(len: usize) -> usize {
if len > 65535 {
panic!("vector is too large");
}
(size_of::<Self>() + size_of::<f16>() * len).next_multiple_of(8)
}
pub unsafe fn as_borrowed<'a>(this: NonNull<Self>) -> VectBorrowed<'a, f16> {
unsafe {
let this = this.as_ptr();
VectBorrowed::new_unchecked(std::slice::from_raw_parts(
(&raw const (*this).elements).cast(),
(&raw const (*this).dims).read() as usize,
))
}
}
}

pub struct HalfvecInput<'a>(NonNull<HalfvecHeader>, PhantomData<&'a ()>, bool);

impl HalfvecInput<'_> {
unsafe fn from_ptr(p: NonNull<HalfvecHeader>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
};
HalfvecInput(q, PhantomData, p != q)
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
unsafe { HalfvecHeader::as_borrowed(self.0) }
}
}

impl Drop for HalfvecInput<'_> {
fn drop(&mut self) {
if self.2 {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr().cast());
}
}
}
}

pub struct HalfvecOutput(NonNull<HalfvecHeader>);

impl HalfvecOutput {
unsafe fn from_ptr(p: NonNull<HalfvecHeader>) -> Self {
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum_copy(p.as_ptr().cast()).cast()).unwrap()
};
Self(q)
}
#[allow(dead_code)]
pub fn new(vector: VectBorrowed<'_, f16>) -> Self {
unsafe {
let slice = vector.slice();
let size = HalfvecHeader::size_of(slice.len());

let ptr = pgrx::pg_sys::palloc0(size) as *mut HalfvecHeader;
(&raw mut (*ptr).varlena).write((size << 2) as u32);
(&raw mut (*ptr).dims).write(vector.dims() as _);
(&raw mut (*ptr).unused).write(0);
std::ptr::copy_nonoverlapping(
slice.as_ptr(),
(&raw mut (*ptr).elements).cast(),
slice.len(),
);
Self(NonNull::new(ptr).unwrap())
}
}
pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> {
unsafe { HalfvecHeader::as_borrowed(self.0) }
}
pub fn into_raw(self) -> *mut HalfvecHeader {
let result = self.0.as_ptr();
std::mem::forget(self);
result
}
}

impl Drop for HalfvecOutput {
fn drop(&mut self) {
unsafe {
pgrx::pg_sys::pfree(self.0.as_ptr().cast());
}
}
}

// FromDatum

impl FromDatum for HalfvecInput<'_> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Some(Self::from_ptr(ptr)) }
}
}
}

impl FromDatum for HalfvecOutput {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Some(Self::from_ptr(ptr)) }
}
}
}

// IntoDatum

impl IntoDatum for HalfvecOutput {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw()))
}

fn type_oid() -> Oid {
Oid::INVALID
}

fn is_compatible_with(_: Oid) -> bool {
true
}
}

// UnboxDatum

unsafe impl pgrx::datum::UnboxDatum for HalfvecOutput {
type As<'src> = HalfvecOutput;
#[inline]
unsafe fn unbox<'src>(datum: pgrx::datum::Datum<'src>) -> Self::As<'src>
where
Self: 'src,
{
let datum = datum.sans_lifetime();
let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap();
unsafe { Self::from_ptr(ptr) }
}
}

// SqlTranslatable

unsafe impl SqlTranslatable for HalfvecInput<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("halfvec")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("halfvec"))))
}
}

unsafe impl SqlTranslatable for HalfvecOutput {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("halfvec")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from("halfvec"))))
}
}

// ArgAbi

unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for HalfvecInput<'fcx> {
unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self {
let index = arg.index();
unsafe {
arg.unbox_arg_using_from_datum()
.unwrap_or_else(|| panic!("argument {index} must not be null"))
}
}
}

// BoxRet

unsafe impl pgrx::callconv::BoxRet for HalfvecOutput {
unsafe fn box_into<'fcx>(
self,
fcinfo: &mut pgrx::callconv::FcInfo<'fcx>,
) -> pgrx::datum::Datum<'fcx> {
match self.into_datum() {
Some(datum) => unsafe { fcinfo.return_raw_datum(datum) },
None => fcinfo.return_null(),
}
}
}
Loading

0 comments on commit 345e3bf

Please sign in to comment.