From 40a80c3739b7189aa645c1a12f20e0521075b000 Mon Sep 17 00:00:00 2001 From: Abdulla Abdurakhmanov Date: Fri, 12 Apr 2024 15:53:06 +0200 Subject: [PATCH] Initial FirestoreVector type support and low level API implementation --- examples/query.rs | 12 + src/db/aggregated_query.rs | 2 +- src/db/listen_changes.rs | 2 +- src/db/query.rs | 94 +++--- src/db/query_models.rs | 87 ++++- src/firestore_serde/mod.rs | 8 + src/firestore_serde/serializer.rs | 7 +- src/firestore_serde/vector_serializers.rs | 376 ++++++++++++++++++++++ 8 files changed, 536 insertions(+), 52 deletions(-) create mode 100644 src/firestore_serde/vector_serializers.rs diff --git a/examples/query.rs b/examples/query.rs index 34219ea..b89b9d5 100644 --- a/examples/query.rs +++ b/examples/query.rs @@ -65,7 +65,19 @@ async fn main() -> Result<(), Box> { println!("Querying a test collection as a stream using Fluent API"); + // Simple query into vector // Query as a stream our data + let as_vec: Vec = db + .fluent() + .select() + .from(TEST_COLLECTION_NAME) + .obj() + .query() + .await?; + + println!("{:?}", as_vec); + + // Query as a stream our data with filters and ordering let object_stream: BoxStream> = db .fluent() .select() diff --git a/src/db/aggregated_query.rs b/src/db/aggregated_query.rs index ef8ca42..93c3ae7 100644 --- a/src/db/aggregated_query.rs +++ b/src/db/aggregated_query.rs @@ -289,7 +289,7 @@ impl FirestoreDb { query_type: Some(run_aggregation_query_request::QueryType::StructuredAggregationQuery( StructuredAggregationQuery { aggregations: params.aggregations.iter().map(|agg| agg.into()).collect(), - query_type: Some(gcloud_sdk::google::firestore::v1::structured_aggregation_query::QueryType::StructuredQuery(params.query_params.into())), + query_type: Some(gcloud_sdk::google::firestore::v1::structured_aggregation_query::QueryType::StructuredQuery(params.query_params.try_into()?)), } )), explain_options: None, diff --git a/src/db/listen_changes.rs b/src/db/listen_changes.rs index 721002e..b8566e5 100644 --- a/src/db/listen_changes.rs +++ b/src/db/listen_changes.rs @@ -191,7 +191,7 @@ impl FirestoreDb { .unwrap_or_else(|| self.get_documents_path()) .clone(), query_type: Some(target::query_target::QueryType::StructuredQuery( - query_params.into(), + query_params.try_into()?, )), }) } diff --git a/src/db/query.rs b/src/db/query.rs index 3c34ba2..c9fd50b 100644 --- a/src/db/query.rs +++ b/src/db/query.rs @@ -103,7 +103,9 @@ impl FirestoreDb { .as_ref() .map(|eo| eo.try_into()) .transpose()?, - query_type: Some(run_query_request::QueryType::StructuredQuery(params.into())), + query_type: Some(run_query_request::QueryType::StructuredQuery( + params.try_into()?, + )), })) } @@ -415,49 +417,59 @@ impl FirestoreQuerySupport for FirestoreDb { Some((params, consistency_selector)), move |maybe_params| async move { if let Some((params, maybe_consistency_selector)) = maybe_params { - let request = gcloud_sdk::tonic::Request::new(PartitionQueryRequest { - page_size: params.page_size as i32, - partition_count: params.partition_count as i64, - parent: params - .query_params - .parent - .as_ref() - .unwrap_or_else(|| self.get_documents_path()) - .clone(), - consistency_selector: maybe_consistency_selector.clone(), - query_type: Some( - partition_query_request::QueryType::StructuredQuery( - params.query_params.clone().into(), - ), - ), - page_token: params.page_token.clone().unwrap_or_default(), - }); - - match self.client().get().partition_query(request).await { - Ok(response) => { - let partition_response = response.into_inner(); - let firestore_cursors: Vec = - partition_response - .partitions - .into_iter() - .map(|e| e.into()) - .collect(); - - if !partition_response.next_page_token.is_empty() { - Some(( - Ok(firestore_cursors), - Some(( - params.with_page_token( - partition_response.next_page_token, + match params.query_params.clone().try_into() { + Ok(query_params) => { + let request = + gcloud_sdk::tonic::Request::new(PartitionQueryRequest { + page_size: params.page_size as i32, + partition_count: params.partition_count as i64, + parent: params + .query_params + .parent + .as_ref() + .unwrap_or_else(|| self.get_documents_path()) + .clone(), + consistency_selector: maybe_consistency_selector + .clone(), + query_type: Some( + partition_query_request::QueryType::StructuredQuery( + query_params, ), - maybe_consistency_selector, - )), - )) - } else { - Some((Ok(firestore_cursors), None)) + ), + page_token: params + .page_token + .clone() + .unwrap_or_default(), + }); + + match self.client().get().partition_query(request).await { + Ok(response) => { + let partition_response = response.into_inner(); + let firestore_cursors: Vec = + partition_response + .partitions + .into_iter() + .map(|e| e.into()) + .collect(); + + if !partition_response.next_page_token.is_empty() { + Some(( + Ok(firestore_cursors), + Some(( + params.with_page_token( + partition_response.next_page_token, + ), + maybe_consistency_selector, + )), + )) + } else { + Some((Ok(firestore_cursors), None)) + } + } + Err(err) => Some((Err(FirestoreError::from(err)), None)), } } - Err(err) => Some((Err(FirestoreError::from(err)), None)), + Err(err) => Some((Err(err), None)), } } else { None diff --git a/src/db/query_models.rs b/src/db/query_models.rs index 94e82e7..943ac31 100644 --- a/src/db/query_models.rs +++ b/src/db/query_models.rs @@ -1,7 +1,9 @@ #![allow(clippy::derive_partial_eq_without_eq)] // Since we may not be able to implement Eq for the changes coming from Firestore protos -use crate::errors::FirestoreError; -use crate::FirestoreValue; +use crate::errors::{ + FirestoreError, FirestoreInvalidParametersError, FirestoreInvalidParametersPublicDetails, +}; +use crate::{FirestoreValue, FirestoreVector}; use gcloud_sdk::google::firestore::v1::*; use rsb_derive::Builder; @@ -39,13 +41,16 @@ pub struct FirestoreQueryParams { pub start_at: Option, pub end_at: Option, pub explain_options: Option, + pub find_nearest: Option, } -impl From for StructuredQuery { - fn from(params: FirestoreQueryParams) -> Self { +impl TryFrom for StructuredQuery { + type Error = FirestoreError; + + fn try_from(params: FirestoreQueryParams) -> Result { let query_filter = params.filter.map(|f| f.into()); - StructuredQuery { + Ok(StructuredQuery { select: params.return_only_fields.map(|select_only_fields| { structured_query::Projection { fields: select_only_fields @@ -79,9 +84,12 @@ impl From for StructuredQuery { }) .collect(), }, - find_nearest: None, + find_nearest: params + .find_nearest + .map(|find_nearest| find_nearest.try_into()) + .transpose()?, r#where: query_filter, - } + }) } } @@ -425,3 +433,68 @@ impl TryFrom<&FirestoreExplainOptions> for gcloud_sdk::google::firestore::v1::Ex }) } } + +#[derive(Debug, PartialEq, Clone, Builder)] +pub struct FirestoreFindNearestOptions { + pub field_name: String, + pub query_vector: FirestoreVector, + pub distance_measure: FirestoreFindNearestDistanceMeasure, + pub neighbors_limit: u32, +} + +impl TryFrom + for gcloud_sdk::google::firestore::v1::structured_query::FindNearest +{ + type Error = FirestoreError; + + fn try_from(options: FirestoreFindNearestOptions) -> Result { + Ok(structured_query::FindNearest { + vector_field: Some(structured_query::FieldReference { + field_path: options.field_name, + }), + query_vector: Some(Into::::into(options.query_vector).value), + distance_measure: { + let distance_measure: structured_query::find_nearest::DistanceMeasure = options.distance_measure.try_into()?; + distance_measure.into() + }, + limit: Some(options.neighbors_limit.try_into().map_err(|e| FirestoreError::InvalidParametersError( + FirestoreInvalidParametersError::new(FirestoreInvalidParametersPublicDetails::new( + "neighbors_limit".to_string(), + format!( + "Invalid value for neighbors_limit: {}. Maximum allowed value is {}. Error: {}", + options.neighbors_limit, + i32::MAX, + e + ), + ))) + )?), + }) + } +} + +#[derive(Debug, PartialEq, Clone)] +pub enum FirestoreFindNearestDistanceMeasure { + Euclidean, + Cosine, + DotProduct, +} + +impl TryFrom + for structured_query::find_nearest::DistanceMeasure +{ + type Error = FirestoreError; + + fn try_from(measure: FirestoreFindNearestDistanceMeasure) -> Result { + match measure { + FirestoreFindNearestDistanceMeasure::Euclidean => { + Ok(structured_query::find_nearest::DistanceMeasure::Euclidean) + } + FirestoreFindNearestDistanceMeasure::Cosine => { + Ok(structured_query::find_nearest::DistanceMeasure::Cosine) + } + FirestoreFindNearestDistanceMeasure::DotProduct => { + Ok(structured_query::find_nearest::DistanceMeasure::DotProduct) + } + } + } +} diff --git a/src/firestore_serde/mod.rs b/src/firestore_serde/mod.rs index 2245d73..6b311a4 100644 --- a/src/firestore_serde/mod.rs +++ b/src/firestore_serde/mod.rs @@ -2,17 +2,25 @@ mod deserializer; mod serializer; mod timestamp_serializers; + pub use timestamp_serializers::*; mod null_serializers; + pub use null_serializers::*; mod latlng_serializers; + pub use latlng_serializers::*; mod reference_serializers; + pub use reference_serializers::*; +mod vector_serializers; + +pub use vector_serializers::*; + use crate::FirestoreValue; use gcloud_sdk::google::firestore::v1::Value; diff --git a/src/firestore_serde/serializer.rs b/src/firestore_serde/serializer.rs index 022b877..2c22321 100644 --- a/src/firestore_serde/serializer.rs +++ b/src/firestore_serde/serializer.rs @@ -17,8 +17,8 @@ impl FirestoreValueSerializer { } pub struct SerializeVec { - none_as_null: bool, - vec: Vec, + pub none_as_null: bool, + pub vec: Vec, } pub struct SerializeTupleVariant { @@ -232,6 +232,9 @@ impl serde::Serializer for FirestoreValueSerializer { value, false, ) } + crate::firestore_serde::vector_serializers::FIRESTORE_VECTOR_TYPE_TAG_TYPE => { + crate::firestore_serde::vector_serializers::serialize_vector_for_firestore(value) + } _ => value.serialize(self), } } diff --git a/src/firestore_serde/vector_serializers.rs b/src/firestore_serde/vector_serializers.rs new file mode 100644 index 0000000..c86837c --- /dev/null +++ b/src/firestore_serde/vector_serializers.rs @@ -0,0 +1,376 @@ +use crate::errors::{FirestoreError, FirestoreSerializationError}; +use crate::firestore_serde::serializer::SerializeVec; +use crate::FirestoreValue; +use serde::de::{MapAccess, Visitor}; +use serde::{Deserializer, Serialize, Serializer}; + +pub(crate) const FIRESTORE_VECTOR_TYPE_TAG_TYPE: &str = "FirestoreVector"; + +#[derive(Serialize, Clone, Debug, PartialEq, PartialOrd, Default)] +pub struct FirestoreVector(pub Vec); + +impl FirestoreVector { + pub fn new(vec: Vec) -> Self { + FirestoreVector(vec) + } +} + +impl From for FirestoreVector +where + I: IntoIterator, +{ + fn from(vec: I) -> Self { + FirestoreVector(vec.into_iter().collect()) + } +} + +pub fn serialize_vector_for_firestore( + value: &T, +) -> Result { + struct VectorSerializer; + + impl Serializer for VectorSerializer { + type Ok = FirestoreValue; + type Error = FirestoreError; + type SerializeSeq = crate::firestore_serde::serializer::SerializeVec; + type SerializeTuple = crate::firestore_serde::serializer::SerializeVec; + type SerializeTupleStruct = crate::firestore_serde::serializer::SerializeVec; + type SerializeTupleVariant = crate::firestore_serde::serializer::SerializeTupleVariant; + type SerializeMap = crate::firestore_serde::serializer::SerializeMap; + type SerializeStruct = crate::firestore_serde::serializer::SerializeMap; + type SerializeStructVariant = crate::firestore_serde::serializer::SerializeStructVariant; + + fn serialize_bool(self, _v: bool) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_i8(self, _v: i8) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_i16(self, _v: i16) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_i32(self, _v: i32) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_i64(self, _v: i64) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_u8(self, _v: u8) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_u16(self, _v: u16) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_u32(self, _v: u32) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_u64(self, _v: u64) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_f32(self, _v: f32) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_f64(self, _v: f64) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_char(self, _v: char) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_str(self, _v: &str) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type", + ), + )) + } + + fn serialize_none(self) -> Result { + Ok(FirestoreValue::from( + gcloud_sdk::google::firestore::v1::Value { value_type: None }, + )) + } + + fn serialize_some(self, value: &T) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_unit(self) -> Result { + Ok(FirestoreValue::from( + gcloud_sdk::google::firestore::v1::Value { value_type: None }, + )) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + self.serialize_unit() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + self.serialize_str(variant) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: Serialize, + { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: newtype_variant", + ), + )) + } + + fn serialize_seq(self, len: Option) -> Result { + Ok(SerializeVec { + none_as_null: false, + vec: Vec::with_capacity(len.unwrap_or(0)), + }) + } + + fn serialize_tuple(self, _len: usize) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: tuple", + ), + )) + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: tuple_struct", + ), + )) + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: tuple_variant", + ), + )) + } + + fn serialize_map(self, _len: Option) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: map", + ), + )) + } + + fn serialize_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: struct", + ), + )) + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(FirestoreError::SerializeError( + FirestoreSerializationError::from_message( + "Vector serializer doesn't support this type: struct_variant", + ), + )) + } + } + + let value_with_array = value.serialize(VectorSerializer {})?; + + Ok(FirestoreValue::from( + gcloud_sdk::google::firestore::v1::Value { + value_type: Some(gcloud_sdk::google::firestore::v1::value::ValueType::MapValue( + gcloud_sdk::google::firestore::v1::MapValue { + fields: vec![ + ( + "__type__".to_string(), + gcloud_sdk::google::firestore::v1::Value { + value_type: Some(gcloud_sdk::google::firestore::v1::value::ValueType::StringValue( + "__vector__".to_string() + )), + } + ), + ( + "value".to_string(), + value_with_array.value + )].into_iter().collect() + } + )) + }), + ) +} + +struct FirestoreVectorVisitor; + +impl<'de> Visitor<'de> for FirestoreVectorVisitor { + type Value = FirestoreVector; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a FirestoreVector") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + + while let Some(value) = seq.next_element()? { + vec.push(value); + } + + Ok(FirestoreVector(vec)) + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + while let Some(field) = map.next_key::()? { + match field.as_str() { + "__type__" => { + let value = map.next_value::()?; + if value != "__vector__" { + return Err(serde::de::Error::custom( + "Expected __vector__ for FirestoreVector", + )); + } + } + "value" => { + let value = map.next_value::>()?; + return Ok(FirestoreVector(value)); + } + _ => { + return Err(serde::de::Error::custom( + "Unknown field for FirestoreVector", + )); + } + } + } + Err(serde::de::Error::custom( + "Unknown structure for FirestoreVector", + )) + } +} + +impl<'de> serde::Deserialize<'de> for FirestoreVector { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(FirestoreVectorVisitor) + } +}