diff --git a/src/algorithm/build.rs b/src/algorithm/build.rs index a0810a7..4c893b7 100644 --- a/src/algorithm/build.rs +++ b/src/algorithm/build.rs @@ -217,7 +217,7 @@ impl Structure { let mut parents = BTreeMap::new(); let mut vectors = BTreeMap::new(); pgrx::spi::Spi::connect(|client| { - use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; + use crate::datatype::memory_vector::VectorOutput; use pgrx::pg_sys::panic::ErrorReportable; use vector::VectorBorrowed; let schema_query = "SELECT n.nspname::TEXT @@ -237,7 +237,7 @@ impl Structure { for row in centroids { let id: Option = row.get_by_name("id").unwrap(); let parent: Option = row.get_by_name("parent").unwrap(); - let vector: Option = row.get_by_name("vector").unwrap(); + let vector: Option = row.get_by_name("vector").unwrap(); let id = id.expect("external build: id could not be NULL"); let vector = vector.expect("external build: vector could not be NULL"); let pop = parents.insert(id, parent); diff --git a/src/datatype/functions_scalar8.rs b/src/datatype/functions_scalar8.rs index adcdb30..bc73bd8 100644 --- a/src/datatype/functions_scalar8.rs +++ b/src/datatype/functions_scalar8.rs @@ -1,12 +1,12 @@ -use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; -use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; +use crate::datatype::memory_halfvec::HalfvecInput; use crate::datatype::memory_scalar8::Scalar8Output; +use crate::datatype::memory_vector::VectorInput; use half::f16; use simd::Floating; use vector::scalar8::Scalar8Borrowed; #[pgrx::pg_extern(sql = "")] -fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Output { +fn _vchord_vector_quantize_to_scalar8(vector: VectorInput) -> Scalar8Output { let vector = vector.as_borrowed(); let sum_of_x2 = f32::reduce_sum_of_x2(vector.slice()); let (k, b, code) = @@ -16,7 +16,7 @@ fn _vchord_vector_quantize_to_scalar8(vector: PgvectorVectorInput) -> Scalar8Out } #[pgrx::pg_extern(sql = "")] -fn _vchord_halfvec_quantize_to_scalar8(vector: PgvectorHalfvecInput) -> Scalar8Output { +fn _vchord_halfvec_quantize_to_scalar8(vector: HalfvecInput) -> Scalar8Output { let vector = vector.as_borrowed(); let sum_of_x2 = f16::reduce_sum_of_x2(vector.slice()); let (k, b, code) = diff --git a/src/datatype/memory_halfvec.rs b/src/datatype/memory_halfvec.rs new file mode 100644 index 0000000..b60f6c5 --- /dev/null +++ b/src/datatype/memory_halfvec.rs @@ -0,0 +1,210 @@ +use half::f16; +use pgrx::datum::FromDatum; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use std::marker::PhantomData; +use std::ptr::NonNull; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; + +#[repr(C, align(8))] +pub struct HalfvecHeader { + varlena: u32, + dims: u16, + unused: u16, + elements: [f16; 0], +} + +impl HalfvecHeader { + fn size_of(len: usize) -> usize { + if len > 65535 { + panic!("vector is too large"); + } + (size_of::() + size_of::() * len).next_multiple_of(8) + } + pub unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f16> { + unsafe { + let this = this.as_ptr(); + VectBorrowed::new_unchecked(std::slice::from_raw_parts( + (&raw const (*this).elements).cast(), + (&raw const (*this).dims).read() as usize, + )) + } + } +} + +pub struct HalfvecInput<'a>(NonNull, PhantomData<&'a ()>, bool); + +impl HalfvecInput<'_> { + unsafe fn from_ptr(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap() + }; + HalfvecInput(q, PhantomData, p != q) + } + pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> { + unsafe { HalfvecHeader::as_borrowed(self.0) } + } +} + +impl Drop for HalfvecInput<'_> { + fn drop(&mut self) { + if self.2 { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); + } + } + } +} + +pub struct HalfvecOutput(NonNull); + +impl HalfvecOutput { + unsafe fn from_ptr(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum_copy(p.as_ptr().cast()).cast()).unwrap() + }; + Self(q) + } + #[allow(dead_code)] + pub fn new(vector: VectBorrowed<'_, f16>) -> Self { + unsafe { + let slice = vector.slice(); + let size = HalfvecHeader::size_of(slice.len()); + + let ptr = pgrx::pg_sys::palloc0(size) as *mut HalfvecHeader; + (&raw mut (*ptr).varlena).write((size << 2) as u32); + (&raw mut (*ptr).dims).write(vector.dims() as _); + (&raw mut (*ptr).unused).write(0); + std::ptr::copy_nonoverlapping( + slice.as_ptr(), + (&raw mut (*ptr).elements).cast(), + slice.len(), + ); + Self(NonNull::new(ptr).unwrap()) + } + } + pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> { + unsafe { HalfvecHeader::as_borrowed(self.0) } + } + pub fn into_raw(self) -> *mut HalfvecHeader { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Drop for HalfvecOutput { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); + } + } +} + +// FromDatum + +impl FromDatum for HalfvecInput<'_> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } + } + } +} + +impl FromDatum for HalfvecOutput { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } + } + } +} + +// IntoDatum + +impl IntoDatum for HalfvecOutput { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw())) + } + + fn type_oid() -> Oid { + Oid::INVALID + } + + fn is_compatible_with(_: Oid) -> bool { + true + } +} + +// UnboxDatum + +unsafe impl pgrx::datum::UnboxDatum for HalfvecOutput { + type As<'src> = HalfvecOutput; + #[inline] + unsafe fn unbox<'src>(datum: pgrx::datum::Datum<'src>) -> Self::As<'src> + where + Self: 'src, + { + let datum = datum.sans_lifetime(); + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Self::from_ptr(ptr) } + } +} + +// SqlTranslatable + +unsafe impl SqlTranslatable for HalfvecInput<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("halfvec"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) + } +} + +unsafe impl SqlTranslatable for HalfvecOutput { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("halfvec"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) + } +} + +// ArgAbi + +unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for HalfvecInput<'fcx> { + unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { + let index = arg.index(); + unsafe { + arg.unbox_arg_using_from_datum() + .unwrap_or_else(|| panic!("argument {index} must not be null")) + } + } +} + +// BoxRet + +unsafe impl pgrx::callconv::BoxRet for HalfvecOutput { + unsafe fn box_into<'fcx>( + self, + fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, + ) -> pgrx::datum::Datum<'fcx> { + match self.into_datum() { + Some(datum) => unsafe { fcinfo.return_raw_datum(datum) }, + None => fcinfo.return_null(), + } + } +} diff --git a/src/datatype/memory_pgvector_halfvec.rs b/src/datatype/memory_pgvector_halfvec.rs deleted file mode 100644 index 1c065d9..0000000 --- a/src/datatype/memory_pgvector_halfvec.rs +++ /dev/null @@ -1,205 +0,0 @@ -use half::f16; -use pgrx::datum::FromDatum; -use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use std::ops::Deref; -use std::ptr::NonNull; -use vector::VectorBorrowed; -use vector::vect::VectBorrowed; - -#[repr(C, align(8))] -pub struct PgvectorHalfvecHeader { - varlena: u32, - dims: u16, - unused: u16, - phantom: [f16; 0], -} - -impl PgvectorHalfvecHeader { - fn size_of(len: usize) -> usize { - if len > 65535 { - panic!("vector is too large"); - } - (size_of::() + size_of::() * len).next_multiple_of(8) - } - pub fn as_borrowed(&self) -> VectBorrowed<'_, f16> { - unsafe { - VectBorrowed::new_unchecked(std::slice::from_raw_parts( - self.phantom.as_ptr(), - self.dims as usize, - )) - } - } -} - -pub enum PgvectorHalfvecInput<'a> { - Owned(PgvectorHalfvecOutput), - Borrowed(&'a PgvectorHalfvecHeader), -} - -impl PgvectorHalfvecInput<'_> { - unsafe fn new(p: NonNull) -> Self { - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - PgvectorHalfvecInput::Owned(PgvectorHalfvecOutput(q)) - } else { - unsafe { PgvectorHalfvecInput::Borrowed(p.as_ref()) } - } - } -} - -impl Deref for PgvectorHalfvecInput<'_> { - type Target = PgvectorHalfvecHeader; - - fn deref(&self) -> &Self::Target { - match self { - PgvectorHalfvecInput::Owned(x) => x, - PgvectorHalfvecInput::Borrowed(x) => x, - } - } -} - -pub struct PgvectorHalfvecOutput(NonNull); - -impl PgvectorHalfvecOutput { - pub fn new(vector: VectBorrowed<'_, f16>) -> PgvectorHalfvecOutput { - unsafe { - let slice = vector.slice(); - let size = PgvectorHalfvecHeader::size_of(slice.len()); - - let ptr = pgrx::pg_sys::palloc0(size) as *mut PgvectorHalfvecHeader; - (&raw mut (*ptr).varlena).write((size << 2) as u32); - (&raw mut (*ptr).dims).write(vector.dims() as _); - (&raw mut (*ptr).unused).write(0); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - PgvectorHalfvecOutput(NonNull::new(ptr).unwrap()) - } - } - pub fn into_raw(self) -> *mut PgvectorHalfvecHeader { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for PgvectorHalfvecOutput { - type Target = PgvectorHalfvecHeader; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl Drop for PgvectorHalfvecOutput { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl FromDatum for PgvectorHalfvecInput<'_> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(PgvectorHalfvecInput::new(ptr)) } - } - } -} - -impl IntoDatum for PgvectorHalfvecOutput { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - Oid::INVALID - } - - fn is_compatible_with(_: Oid) -> bool { - true - } -} - -impl FromDatum for PgvectorHalfvecOutput { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let p = NonNull::new(datum.cast_mut_ptr::())?; - let q = - unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? }; - if p != q { - Some(PgvectorHalfvecOutput(q)) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - Some(PgvectorHalfvecOutput::new(vector)) - } - } - } -} - -unsafe impl pgrx::datum::UnboxDatum for PgvectorHalfvecOutput { - type As<'src> = PgvectorHalfvecOutput; - #[inline] - unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src> - where - Self: 'src, - { - let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::()).unwrap(); - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - PgvectorHalfvecOutput(q) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - PgvectorHalfvecOutput::new(vector) - } - } -} - -unsafe impl SqlTranslatable for PgvectorHalfvecInput<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("halfvec"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) - } -} - -unsafe impl SqlTranslatable for PgvectorHalfvecOutput { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("halfvec"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("halfvec")))) - } -} - -unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for PgvectorHalfvecInput<'fcx> { - unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { - unsafe { arg.unbox_arg_using_from_datum().unwrap() } - } -} - -unsafe impl pgrx::callconv::BoxRet for PgvectorHalfvecOutput { - unsafe fn box_into<'fcx>( - self, - fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, - ) -> pgrx::datum::Datum<'fcx> { - unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) } - } -} diff --git a/src/datatype/memory_pgvector_vector.rs b/src/datatype/memory_pgvector_vector.rs deleted file mode 100644 index e3ab9f9..0000000 --- a/src/datatype/memory_pgvector_vector.rs +++ /dev/null @@ -1,204 +0,0 @@ -use pgrx::datum::FromDatum; -use pgrx::datum::IntoDatum; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use std::ops::Deref; -use std::ptr::NonNull; -use vector::VectorBorrowed; -use vector::vect::VectBorrowed; - -#[repr(C, align(8))] -pub struct PgvectorVectorHeader { - varlena: u32, - dims: u16, - unused: u16, - phantom: [f32; 0], -} - -impl PgvectorVectorHeader { - fn size_of(len: usize) -> usize { - if len > 65535 { - panic!("vector is too large"); - } - (size_of::() + size_of::() * len).next_multiple_of(8) - } - pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> { - unsafe { - VectBorrowed::new_unchecked(std::slice::from_raw_parts( - self.phantom.as_ptr(), - self.dims as usize, - )) - } - } -} - -pub enum PgvectorVectorInput<'a> { - Owned(PgvectorVectorOutput), - Borrowed(&'a PgvectorVectorHeader), -} - -impl PgvectorVectorInput<'_> { - unsafe fn new(p: NonNull) -> Self { - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - PgvectorVectorInput::Owned(PgvectorVectorOutput(q)) - } else { - unsafe { PgvectorVectorInput::Borrowed(p.as_ref()) } - } - } -} - -impl Deref for PgvectorVectorInput<'_> { - type Target = PgvectorVectorHeader; - - fn deref(&self) -> &Self::Target { - match self { - PgvectorVectorInput::Owned(x) => x, - PgvectorVectorInput::Borrowed(x) => x, - } - } -} - -pub struct PgvectorVectorOutput(NonNull); - -impl PgvectorVectorOutput { - pub fn new(vector: VectBorrowed<'_, f32>) -> PgvectorVectorOutput { - unsafe { - let slice = vector.slice(); - let size = PgvectorVectorHeader::size_of(slice.len()); - - let ptr = pgrx::pg_sys::palloc0(size) as *mut PgvectorVectorHeader; - (&raw mut (*ptr).varlena).write((size << 2) as u32); - (&raw mut (*ptr).dims).write(vector.dims() as _); - (&raw mut (*ptr).unused).write(0); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - PgvectorVectorOutput(NonNull::new(ptr).unwrap()) - } - } - pub fn into_raw(self) -> *mut PgvectorVectorHeader { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for PgvectorVectorOutput { - type Target = PgvectorVectorHeader; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl Drop for PgvectorVectorOutput { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl FromDatum for PgvectorVectorInput<'_> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(PgvectorVectorInput::new(ptr)) } - } - } -} - -impl IntoDatum for PgvectorVectorOutput { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - Oid::INVALID - } - - fn is_compatible_with(_: Oid) -> bool { - true - } -} - -impl FromDatum for PgvectorVectorOutput { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let p = NonNull::new(datum.cast_mut_ptr::())?; - let q = - unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? }; - if p != q { - Some(PgvectorVectorOutput(q)) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - Some(PgvectorVectorOutput::new(vector)) - } - } - } -} - -unsafe impl pgrx::datum::UnboxDatum for PgvectorVectorOutput { - type As<'src> = PgvectorVectorOutput; - #[inline] - unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src> - where - Self: 'src, - { - let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::()).unwrap(); - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - PgvectorVectorOutput(q) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - PgvectorVectorOutput::new(vector) - } - } -} - -unsafe impl SqlTranslatable for PgvectorVectorInput<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -unsafe impl SqlTranslatable for PgvectorVectorOutput { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for PgvectorVectorInput<'fcx> { - unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { - unsafe { arg.unbox_arg_using_from_datum().unwrap() } - } -} - -unsafe impl pgrx::callconv::BoxRet for PgvectorVectorOutput { - unsafe fn box_into<'fcx>( - self, - fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, - ) -> pgrx::datum::Datum<'fcx> { - unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) } - } -} diff --git a/src/datatype/memory_scalar8.rs b/src/datatype/memory_scalar8.rs index 1641c63..4f30654 100644 --- a/src/datatype/memory_scalar8.rs +++ b/src/datatype/memory_scalar8.rs @@ -7,7 +7,7 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns; use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use std::ops::Deref; +use std::marker::PhantomData; use std::ptr::NonNull; use vector::VectorBorrowed; use vector::scalar8::Scalar8Borrowed; @@ -21,7 +21,7 @@ pub struct Scalar8Header { k: f32, b: f32, sum_of_code: f32, - phantom: [u8; 0], + elements: [u8; 0], } impl Scalar8Header { @@ -31,44 +31,43 @@ impl Scalar8Header { } (size_of::() + size_of::() * len).next_multiple_of(8) } - pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> { + pub unsafe fn as_borrowed<'a>(this: NonNull) -> Scalar8Borrowed<'a> { unsafe { + let this = this.as_ptr(); Scalar8Borrowed::new_unchecked( - self.sum_of_x2, - self.k, - self.b, - self.sum_of_code, - std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize), + (&raw const (*this).sum_of_x2).read(), + (&raw const (*this).k).read(), + (&raw const (*this).b).read(), + (&raw const (*this).sum_of_code).read(), + std::slice::from_raw_parts( + (&raw const (*this).elements).cast(), + (&raw const (*this).dims).read() as usize, + ), ) } } } -pub enum Scalar8Input<'a> { - Owned(Scalar8Output), - Borrowed(&'a Scalar8Header), -} +pub struct Scalar8Input<'a>(NonNull, PhantomData<&'a ()>, bool); impl Scalar8Input<'_> { - unsafe fn new(p: NonNull) -> Self { + unsafe fn from_ptr(p: NonNull) -> Self { let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap() }; - if p != q { - Scalar8Input::Owned(Scalar8Output(q)) - } else { - unsafe { Scalar8Input::Borrowed(p.as_ref()) } - } + Scalar8Input(q, PhantomData, p != q) + } + pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> { + unsafe { Scalar8Header::as_borrowed(self.0) } } } -impl Deref for Scalar8Input<'_> { - type Target = Scalar8Header; - - fn deref(&self) -> &Self::Target { - match self { - Scalar8Input::Owned(x) => x, - Scalar8Input::Borrowed(x) => x, +impl Drop for Scalar8Input<'_> { + fn drop(&mut self) { + if self.2 { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); + } } } } @@ -76,7 +75,13 @@ impl Deref for Scalar8Input<'_> { pub struct Scalar8Output(NonNull); impl Scalar8Output { - pub fn new(vector: Scalar8Borrowed<'_>) -> Scalar8Output { + unsafe fn from_ptr(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum_copy(p.as_ptr().cast()).cast()).unwrap() + }; + Self(q) + } + pub fn new(vector: Scalar8Borrowed<'_>) -> Self { unsafe { let code = vector.code(); let size = Scalar8Header::size_of(code.len()); @@ -89,10 +94,17 @@ impl Scalar8Output { (&raw mut (*ptr).k).write(vector.k()); (&raw mut (*ptr).b).write(vector.b()); (&raw mut (*ptr).sum_of_code).write(vector.sum_of_code()); - std::ptr::copy_nonoverlapping(code.as_ptr(), (*ptr).phantom.as_mut_ptr(), code.len()); - Scalar8Output(NonNull::new(ptr).unwrap()) + std::ptr::copy_nonoverlapping( + code.as_ptr(), + (&raw mut (*ptr).elements).cast(), + code.len(), + ); + Self(NonNull::new(ptr).unwrap()) } } + pub fn as_borrowed(&self) -> Scalar8Borrowed<'_> { + unsafe { Scalar8Header::as_borrowed(self.0) } + } pub fn into_raw(self) -> *mut Scalar8Header { let result = self.0.as_ptr(); std::mem::forget(self); @@ -100,36 +112,43 @@ impl Scalar8Output { } } -impl Deref for Scalar8Output { - type Target = Scalar8Header; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - impl Drop for Scalar8Output { fn drop(&mut self) { unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); } } } +// FromDatum + impl FromDatum for Scalar8Input<'_> { unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { if is_null { None } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(Scalar8Input::new(ptr)) } + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } + } + } +} + +impl FromDatum for Scalar8Output { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } } } } +// IntoDatum + impl IntoDatum for Scalar8Output { fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) + Some(Datum::from(self.into_raw())) } fn type_oid() -> Oid { @@ -141,46 +160,23 @@ impl IntoDatum for Scalar8Output { } } -impl FromDatum for Scalar8Output { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let p = NonNull::new(datum.cast_mut_ptr::())?; - let q = - unsafe { NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast())? }; - if p != q { - Some(Scalar8Output(q)) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - Some(Scalar8Output::new(vector)) - } - } - } -} +// UnboxDatum unsafe impl pgrx::datum::UnboxDatum for Scalar8Output { type As<'src> = Scalar8Output; #[inline] - unsafe fn unbox<'src>(d: pgrx::datum::Datum<'src>) -> Self::As<'src> + unsafe fn unbox<'src>(datum: pgrx::datum::Datum<'src>) -> Self::As<'src> where Self: 'src, { - let p = NonNull::new(d.sans_lifetime().cast_mut_ptr::()).unwrap(); - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - Scalar8Output(q) - } else { - let header = p.as_ptr(); - let vector = unsafe { (*header).as_borrowed() }; - Scalar8Output::new(vector) - } + let datum = datum.sans_lifetime(); + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Self::from_ptr(ptr) } } } +// SqlTranslatable + unsafe impl SqlTranslatable for Scalar8Input<'_> { fn argument_sql() -> Result { Ok(SqlMapping::As(String::from("scalar8"))) @@ -199,17 +195,28 @@ unsafe impl SqlTranslatable for Scalar8Output { } } +// ArgAbi + unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for Scalar8Input<'fcx> { unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { - unsafe { arg.unbox_arg_using_from_datum().unwrap() } + let index = arg.index(); + unsafe { + arg.unbox_arg_using_from_datum() + .unwrap_or_else(|| panic!("argument {index} must not be null")) + } } } +// BoxRet + unsafe impl pgrx::callconv::BoxRet for Scalar8Output { unsafe fn box_into<'fcx>( self, fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, ) -> pgrx::datum::Datum<'fcx> { - unsafe { fcinfo.return_raw_datum(Datum::from(self.into_raw() as *mut ())) } + match self.into_datum() { + Some(datum) => unsafe { fcinfo.return_raw_datum(datum) }, + None => fcinfo.return_null(), + } } } diff --git a/src/datatype/memory_vector.rs b/src/datatype/memory_vector.rs new file mode 100644 index 0000000..de70ba1 --- /dev/null +++ b/src/datatype/memory_vector.rs @@ -0,0 +1,209 @@ +use pgrx::datum::FromDatum; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use std::marker::PhantomData; +use std::ptr::NonNull; +use vector::VectorBorrowed; +use vector::vect::VectBorrowed; + +#[repr(C, align(8))] +pub struct VectorHeader { + varlena: u32, + dims: u16, + unused: u16, + elements: [f32; 0], +} + +impl VectorHeader { + fn size_of(len: usize) -> usize { + if len > 65535 { + panic!("vector is too large"); + } + (size_of::() + size_of::() * len).next_multiple_of(8) + } + pub unsafe fn as_borrowed<'a>(this: NonNull) -> VectBorrowed<'a, f32> { + unsafe { + let this = this.as_ptr(); + VectBorrowed::new_unchecked(std::slice::from_raw_parts( + (&raw const (*this).elements).cast(), + (&raw const (*this).dims).read() as usize, + )) + } + } +} + +pub struct VectorInput<'a>(NonNull, PhantomData<&'a ()>, bool); + +impl VectorInput<'_> { + unsafe fn from_ptr(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap() + }; + VectorInput(q, PhantomData, p != q) + } + pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> { + unsafe { VectorHeader::as_borrowed(self.0) } + } +} + +impl Drop for VectorInput<'_> { + fn drop(&mut self) { + if self.2 { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); + } + } + } +} + +pub struct VectorOutput(NonNull); + +impl VectorOutput { + unsafe fn from_ptr(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum_copy(p.as_ptr().cast()).cast()).unwrap() + }; + Self(q) + } + #[allow(dead_code)] + pub fn new(vector: VectBorrowed<'_, f32>) -> Self { + unsafe { + let slice = vector.slice(); + let size = VectorHeader::size_of(slice.len()); + + let ptr = pgrx::pg_sys::palloc0(size) as *mut VectorHeader; + (&raw mut (*ptr).varlena).write((size << 2) as u32); + (&raw mut (*ptr).dims).write(vector.dims() as _); + (&raw mut (*ptr).unused).write(0); + std::ptr::copy_nonoverlapping( + slice.as_ptr(), + (&raw mut (*ptr).elements).cast(), + slice.len(), + ); + Self(NonNull::new(ptr).unwrap()) + } + } + pub fn as_borrowed(&self) -> VectBorrowed<'_, f32> { + unsafe { VectorHeader::as_borrowed(self.0) } + } + pub fn into_raw(self) -> *mut VectorHeader { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Drop for VectorOutput { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr().cast()); + } + } +} + +// FromDatum + +impl FromDatum for VectorInput<'_> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } + } + } +} + +impl FromDatum for VectorOutput { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Some(Self::from_ptr(ptr)) } + } + } +} + +// IntoDatum + +impl IntoDatum for VectorOutput { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw())) + } + + fn type_oid() -> Oid { + Oid::INVALID + } + + fn is_compatible_with(_: Oid) -> bool { + true + } +} + +// UnboxDatum + +unsafe impl pgrx::datum::UnboxDatum for VectorOutput { + type As<'src> = VectorOutput; + #[inline] + unsafe fn unbox<'src>(datum: pgrx::datum::Datum<'src>) -> Self::As<'src> + where + Self: 'src, + { + let datum = datum.sans_lifetime(); + let ptr = NonNull::new(datum.cast_mut_ptr()).unwrap(); + unsafe { Self::from_ptr(ptr) } + } +} + +// SqlTranslatable + +unsafe impl SqlTranslatable for VectorInput<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} + +unsafe impl SqlTranslatable for VectorOutput { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} + +// ArgAbi + +unsafe impl<'fcx> pgrx::callconv::ArgAbi<'fcx> for VectorInput<'fcx> { + unsafe fn unbox_arg_unchecked(arg: pgrx::callconv::Arg<'_, 'fcx>) -> Self { + let index = arg.index(); + unsafe { + arg.unbox_arg_using_from_datum() + .unwrap_or_else(|| panic!("argument {index} must not be null")) + } + } +} + +// BoxAbi + +unsafe impl pgrx::callconv::BoxRet for VectorOutput { + unsafe fn box_into<'fcx>( + self, + fcinfo: &mut pgrx::callconv::FcInfo<'fcx>, + ) -> pgrx::datum::Datum<'fcx> { + match self.into_datum() { + Some(datum) => unsafe { fcinfo.return_raw_datum(datum) }, + None => fcinfo.return_null(), + } + } +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index 98b6650..cd08970 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -1,10 +1,10 @@ pub mod binary_scalar8; pub mod functions_scalar8; -pub mod memory_pgvector_halfvec; -pub mod memory_pgvector_vector; +pub mod memory_halfvec; pub mod memory_scalar8; -pub mod operators_pgvector_halfvec; -pub mod operators_pgvector_vector; +pub mod memory_vector; +pub mod operators_halfvec; pub mod operators_scalar8; +pub mod operators_vector; pub mod text_scalar8; pub mod typmod; diff --git a/src/datatype/operators_pgvector_halfvec.rs b/src/datatype/operators_halfvec.rs similarity index 83% rename from src/datatype/operators_pgvector_halfvec.rs rename to src/datatype/operators_halfvec.rs index fb0492a..907d415 100644 --- a/src/datatype/operators_pgvector_halfvec.rs +++ b/src/datatype/operators_halfvec.rs @@ -1,14 +1,14 @@ -use crate::datatype::memory_pgvector_halfvec::{PgvectorHalfvecInput, PgvectorHalfvecOutput}; +use crate::datatype::memory_halfvec::{HalfvecInput, HalfvecOutput}; use std::num::NonZero; use vector::VectorBorrowed; use vector::vect::VectBorrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_halfvec_sphere_l2_in( - lhs: PgvectorHalfvecInput<'_>, + lhs: HalfvecInput<'_>, rhs: pgrx::composite_type!("sphere_halfvec"), ) -> bool { - let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: HalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), @@ -29,10 +29,10 @@ fn _vchord_halfvec_sphere_l2_in( #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_halfvec_sphere_ip_in( - lhs: PgvectorHalfvecInput<'_>, + lhs: HalfvecInput<'_>, rhs: pgrx::composite_type!("sphere_halfvec"), ) -> bool { - let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: HalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), @@ -53,10 +53,10 @@ fn _vchord_halfvec_sphere_ip_in( #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_halfvec_sphere_cosine_in( - lhs: PgvectorHalfvecInput<'_>, + lhs: HalfvecInput<'_>, rhs: pgrx::composite_type!("sphere_halfvec"), ) -> bool { - let center: PgvectorHalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: HalfvecOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), diff --git a/src/datatype/operators_pgvector_vector.rs b/src/datatype/operators_vector.rs similarity index 83% rename from src/datatype/operators_pgvector_vector.rs rename to src/datatype/operators_vector.rs index e6b4be1..b011c3a 100644 --- a/src/datatype/operators_pgvector_vector.rs +++ b/src/datatype/operators_vector.rs @@ -1,14 +1,14 @@ -use crate::datatype::memory_pgvector_vector::{PgvectorVectorInput, PgvectorVectorOutput}; +use crate::datatype::memory_vector::{VectorInput, VectorOutput}; use std::num::NonZero; use vector::VectorBorrowed; use vector::vect::VectBorrowed; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_vector_sphere_l2_in( - lhs: PgvectorVectorInput<'_>, + lhs: VectorInput<'_>, rhs: pgrx::composite_type!("sphere_vector"), ) -> bool { - let center: PgvectorVectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: VectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), @@ -29,10 +29,10 @@ fn _vchord_vector_sphere_l2_in( #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_vector_sphere_ip_in( - lhs: PgvectorVectorInput<'_>, + lhs: VectorInput<'_>, rhs: pgrx::composite_type!("sphere_vector"), ) -> bool { - let center: PgvectorVectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: VectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), @@ -53,10 +53,10 @@ fn _vchord_vector_sphere_ip_in( #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vchord_vector_sphere_cosine_in( - lhs: PgvectorVectorInput<'_>, + lhs: VectorInput<'_>, rhs: pgrx::composite_type!("sphere_vector"), ) -> bool { - let center: PgvectorVectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { + let center: VectorOutput = match rhs.get_by_index(NonZero::new(1).unwrap()) { Ok(Some(s)) => s, Ok(None) => pgrx::error!("Bad input: empty center at sphere"), Err(_) => unreachable!(), diff --git a/src/index/am_options.rs b/src/index/am_options.rs index 34c76a9..06ca8e5 100644 --- a/src/index/am_options.rs +++ b/src/index/am_options.rs @@ -1,7 +1,7 @@ -use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecInput; -use crate::datatype::memory_pgvector_halfvec::PgvectorHalfvecOutput; -use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; -use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; +use crate::datatype::memory_halfvec::HalfvecInput; +use crate::datatype::memory_halfvec::HalfvecOutput; +use crate::datatype::memory_vector::VectorInput; +use crate::datatype::memory_vector::VectorOutput; use crate::datatype::typmod::Typmod; use crate::types::{BorrowedVector, OwnedVector}; use crate::types::{DistanceKind, VectorKind}; @@ -133,11 +133,11 @@ impl Opfamily { } let vector = match self.vector { VectorKind::Vecf32 => { - let vector = unsafe { PgvectorVectorInput::from_datum(datum, false).unwrap() }; + let vector = unsafe { VectorInput::from_datum(datum, false).unwrap() }; self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) } VectorKind::Vecf16 => { - let vector = unsafe { PgvectorHalfvecInput::from_datum(datum, false).unwrap() }; + let vector = unsafe { HalfvecInput::from_datum(datum, false).unwrap() }; self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed())) } }; @@ -154,11 +154,11 @@ impl Opfamily { let tuple = unsafe { PgHeapTuple::from_composite_datum(datum) }; let center = match self.vector { VectorKind::Vecf32 => tuple - .get_by_index::(NonZero::new(1).unwrap()) + .get_by_index::(NonZero::new(1).unwrap()) .unwrap() .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), VectorKind::Vecf16 => tuple - .get_by_index::(NonZero::new(1).unwrap()) + .get_by_index::(NonZero::new(1).unwrap()) .unwrap() .map(|vector| self.preprocess(BorrowedVector::Vecf16(vector.as_borrowed()))), };