Skip to content

Commit

Permalink
refactor: extract implicit conversion helper functions of vector
Browse files Browse the repository at this point in the history
Signed-off-by: Zhenchi <[email protected]>
  • Loading branch information
zhongzc committed Dec 9, 2024
1 parent 19373d8 commit e92af70
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 124 deletions.
1 change: 1 addition & 0 deletions src/common/function/src/scalars/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

mod convert;
mod distance;
pub(crate) mod impl_conv;

use std::sync::Arc;

Expand Down
132 changes: 8 additions & 124 deletions src/common/function/src/scalars/vector/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@ mod l2sq;

use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;

use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::prelude::Signature;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::value::ValueRef;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef};
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use snafu::ensure;

use crate::function::{Function, FunctionContext};
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};

macro_rules! define_distance_function {
($StructName:ident, $display_name:expr, $similarity_method:path) => {
Expand Down Expand Up @@ -80,17 +79,17 @@ macro_rules! define_distance_function {
return Ok(result.to_vector());
}

let arg0_const = parse_if_constant_string(arg0)?;
let arg1_const = parse_if_constant_string(arg1)?;
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;

for i in 0..size {
let vec0 = match arg0_const.as_ref() {
Some(a) => Some(Cow::Borrowed(a.as_slice())),
None => as_vector(arg0.get_ref(i))?,
Some(a) => Some(Cow::Borrowed(a.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let vec1 = match arg1_const.as_ref() {
Some(b) => Some(Cow::Borrowed(b.as_slice())),
None => as_vector(arg1.get_ref(i))?,
Some(b) => Some(Cow::Borrowed(b.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};

if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
Expand Down Expand Up @@ -129,98 +128,6 @@ define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq);
define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);

/// Parse a vector value if the value is a constant string.
fn parse_if_constant_string(arg: &Arc<dyn Vector>) -> Result<Option<Vec<f32>>> {
if !arg.is_const() {
return Ok(None);
}
if arg.data_type() != ConcreteDataType::string_datatype() {
return Ok(None);
}
arg.get_ref(0)
.as_string()
.unwrap() // Safe: checked if it is a string
.map(parse_f32_vector_from_string)
.transpose()
}

/// Convert a value to a vector value.
/// Supported data types are binary and string.
fn as_vector(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
match arg.data_type() {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(binary_as_vector)
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
.unwrap() // Safe: checked if it is a string
.map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?)))
.transpose(),
ConcreteDataType::Null(_) => Ok(None),
_ => InvalidFuncArgsSnafu {
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
}
.fail(),
}
}

/// Convert a u8 slice to a vector value.
fn binary_as_vector(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
if bytes.len() % std::mem::size_of::<f32>() != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
}
.fail();
}

if cfg!(target_endian = "little") {
Ok(unsafe {
let vec = std::slice::from_raw_parts(
bytes.as_ptr() as *const f32,
bytes.len() / std::mem::size_of::<f32>(),
);
Cow::Borrowed(vec)
})
} else {
let v = bytes
.chunks_exact(std::mem::size_of::<f32>())
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<f32>>();
Ok(Cow::Owned(v))
}
}

/// Parse a string to a vector value.
/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
fn parse_f32_vector_from_string(s: &str) -> Result<Vec<f32>> {
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return InvalidFuncArgsSnafu {
err_msg: format!(
"Failed to parse {s} to Vector value: not properly enclosed in brackets"
),
}
.fail();
}
let content = trimmed[1..trimmed.len() - 1].trim();
if content.is_empty() {
return Ok(Vec::new());
}

content
.split(',')
.map(|s| s.trim().parse::<f32>())
.collect::<std::result::Result<_, _>>()
.map_err(|e| {
InvalidFuncArgsSnafu {
err_msg: format!("Failed to parse {s} to Vector value: {e}"),
}
.build()
})
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -456,27 +363,4 @@ mod tests {
assert!(result.is_err());
}
}

#[test]
fn test_parse_vector_from_string() {
let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap();
assert_eq!(result, vec![1.0, 2.0, 3.0]);

let result = parse_f32_vector_from_string("[]").unwrap();
assert_eq!(result, Vec::<f32>::new());

let result = parse_f32_vector_from_string("[1.0, a, 3.0]");
assert!(result.is_err());
}

#[test]
fn test_binary_as_vector() {
let bytes = [0, 0, 128, 63];
let result = binary_as_vector(&bytes).unwrap();
assert_eq!(result.as_ref(), &[1.0]);

let invalid_bytes = [0, 0, 128];
let result = binary_as_vector(&invalid_bytes);
assert!(result.is_err());
}
}
156 changes: 156 additions & 0 deletions src/common/function/src/scalars/vector/impl_conv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::sync::Arc;

use common_query::error::{InvalidFuncArgsSnafu, Result};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::ValueRef;
use datatypes::vectors::Vector;

/// Convert a constant string or binary literal to a vector literal.
pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
if !arg.is_const() {
return Ok(None);
}
if arg.data_type() != ConcreteDataType::string_datatype()
&& arg.data_type() != ConcreteDataType::binary_datatype()
{
return Ok(None);
}
as_veclit(arg.get_ref(0))
}

/// Convert a string or binary literal to a vector literal.
pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
match arg.data_type() {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(binlit_as_veclit)
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
.unwrap() // Safe: checked if it is a string
.map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
.transpose(),
ConcreteDataType::Null(_) => Ok(None),
_ => InvalidFuncArgsSnafu {
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
}
.fail(),
}
}

/// Convert a u8 slice to a vector literal.
pub fn binlit_as_veclit(bytes: &[u8]) -> Result<Cow<'_, [f32]>> {
if bytes.len() % std::mem::size_of::<f32>() != 0 {
return InvalidFuncArgsSnafu {
err_msg: format!("Invalid binary length of vector: {}", bytes.len()),
}
.fail();
}

if cfg!(target_endian = "little") {
Ok(unsafe {
let vec = std::slice::from_raw_parts(
bytes.as_ptr() as *const f32,
bytes.len() / std::mem::size_of::<f32>(),
);
Cow::Borrowed(vec)
})
} else {
let v = bytes
.chunks_exact(std::mem::size_of::<f32>())
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
.collect::<Vec<f32>>();
Ok(Cow::Owned(v))
}
}

/// Parse a string literal to a vector literal.
/// Valid inputs are strings like "[1.0, 2.0, 3.0]".
pub fn parse_veclit_from_strlit(s: &str) -> Result<Vec<f32>> {
let trimmed = s.trim();
if !trimmed.starts_with('[') || !trimmed.ends_with(']') {
return InvalidFuncArgsSnafu {
err_msg: format!(
"Failed to parse {s} to Vector value: not properly enclosed in brackets"
),
}
.fail();
}
let content = trimmed[1..trimmed.len() - 1].trim();
if content.is_empty() {
return Ok(Vec::new());
}

content
.split(',')
.map(|s| s.trim().parse::<f32>())
.collect::<std::result::Result<_, _>>()
.map_err(|e| {
InvalidFuncArgsSnafu {
err_msg: format!("Failed to parse {s} to Vector value: {e}"),
}
.build()
})
}

#[allow(unused)]
/// Convert a vector literal to a binary literal.
pub fn veclit_to_binlit(vec: &[f32]) -> Vec<u8> {
if cfg!(target_endian = "little") {
unsafe {
std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec))
.to_vec()
}
} else {
let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec));
for e in vec {
bytes.extend_from_slice(&e.to_le_bytes());
}
bytes
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_parse_veclit_from_strlit() {
let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap();
assert_eq!(result, vec![1.0, 2.0, 3.0]);

let result = parse_veclit_from_strlit("[]").unwrap();
assert_eq!(result, Vec::<f32>::new());

let result = parse_veclit_from_strlit("[1.0, a, 3.0]");
assert!(result.is_err());
}

#[test]
fn test_binlit_as_veclit() {
let vec = &[1.0, 2.0, 3.0];
let bytes = veclit_to_binlit(vec);
let result = binlit_as_veclit(&bytes).unwrap();
assert_eq!(result.as_ref(), vec);

let invalid_bytes = [0, 0, 128];
let result = binlit_as_veclit(&invalid_bytes);
assert!(result.is_err());
}
}

0 comments on commit e92af70

Please sign in to comment.