From eec22c9dd311763c97a43607c4b98cd735af5583 Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Mon, 28 Oct 2024 08:51:08 +0800 Subject: [PATCH] doc mapper: convert number handling to deserialization change number deserialization in docmapper from json to generic deserialization. This improves codes reuse between different code paths, e.g. serialization and validation. --- quickwit/Cargo.lock | 4 +- quickwit/Cargo.toml | 2 +- .../src/doc_mapper/deser_num.rs | 226 ++++++++++++++++++ .../src/doc_mapper/doc_mapper_impl.rs | 2 +- .../src/doc_mapper/mapping_tree.rs | 140 ++--------- .../quickwit-doc-mapper/src/doc_mapper/mod.rs | 1 + 6 files changed, 252 insertions(+), 123 deletions(-) create mode 100644 quickwit/quickwit-doc-mapper/src/doc_mapper/deser_num.rs diff --git a/quickwit/Cargo.lock b/quickwit/Cargo.lock index d39e91442fd..7dbfaf2ee82 100644 --- a/quickwit/Cargo.lock +++ b/quickwit/Cargo.lock @@ -7541,9 +7541,9 @@ dependencies = [ [[package]] name = "serde_json_borrow" -version = "0.5.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a60291362be3646d15fb0b5a5bddfd8003ebf013b2186a3c60a534fd35d6a26" +checksum = "44c8dc27b181f9294b9cd937ae4375414cd0a77f542a34e063ced1e47ed2ceaa" dependencies = [ "serde", "serde_json", diff --git a/quickwit/Cargo.toml b/quickwit/Cargo.toml index 82419148e2f..b71261d9933 100644 --- a/quickwit/Cargo.toml +++ b/quickwit/Cargo.toml @@ -212,7 +212,7 @@ sea-query-binder = { version = "0.5", features = [ # ^1.0.184 due to serde-rs/serde#2538 serde = { version = "1.0.184", features = ["derive", "rc"] } serde_json = "1.0" -serde_json_borrow = "0.5" +serde_json_borrow = "0.7" serde_qs = { version = "0.12", features = ["warp"] } serde_with = "3.9.0" serde_yaml = "0.9" diff --git a/quickwit/quickwit-doc-mapper/src/doc_mapper/deser_num.rs b/quickwit/quickwit-doc-mapper/src/doc_mapper/deser_num.rs new file mode 100644 index 00000000000..9be92d5fd94 --- /dev/null +++ b/quickwit/quickwit-doc-mapper/src/doc_mapper/deser_num.rs @@ -0,0 +1,226 @@ +// Copyright (C) 2024 Quickwit, Inc. +// +// Quickwit is offered under the AGPL v3.0 and as commercial software. +// For commercial licensing, contact us at hello@quickwit.io. +// +// AGPL: +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +use std::fmt::{self, Display}; + +use serde::de::{self, Deserializer, IntoDeserializer, Visitor}; +use serde::Deserialize; +use serde_json::Value; + +/// Deserialize a number. +/// +/// If the value is a string, it can be optionally coerced to a number. +fn deserialize_num_with_coerce<'de, T, D>(deserializer: D, coerce: bool) -> Result +where + T: std::str::FromStr + Deserialize<'de>, + T::Err: fmt::Display, + D: Deserializer<'de>, +{ + struct CoerceVisitor { + coerce: bool, + marker: std::marker::PhantomData, + } + + impl<'de, T> Visitor<'de> for CoerceVisitor + where + T: std::str::FromStr + Deserialize<'de>, + T::Err: fmt::Display, + { + type Value = T; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + if self.coerce { + formatter + .write_str("any number of i64, u64, or f64 or a string that can be coerced") + } else { + formatter.write_str("any number of i64, u64, or f64") + } + } + + fn visit_str(self, v: &str) -> Result + where E: de::Error { + if self.coerce { + v.parse::().map_err(|_e| { + de::Error::custom(format!( + "failed to coerce JSON string `\"{}\"` to {}", + v, + std::any::type_name::(), + )) + }) + } else { + Err(de::Error::custom(format!( + "expected JSON number, got string `\"{}\"`. enable coercion to {} with the \ + `coerce` parameter in the field mapping", + v, + std::any::type_name::() + ))) + } + } + + fn visit_i64(self, v: i64) -> Result + where E: de::Error { + T::deserialize(v.into_deserializer()).map_err(|_: E| { + de::Error::custom(format!( + "expected {}, got inconvertible JSON number `{}`", + std::any::type_name::(), + v + )) + }) + } + + fn visit_u64(self, v: u64) -> Result + where E: de::Error { + T::deserialize(v.into_deserializer()).map_err(|_: E| { + de::Error::custom(format!( + "expected {}, got inconvertible JSON number `{}`", + std::any::type_name::(), + v + )) + }) + } + + fn visit_f64(self, v: f64) -> Result + where E: de::Error { + T::deserialize(v.into_deserializer()).map_err(|_: E| { + de::Error::custom(format!( + "expected {}, got inconvertible JSON number `{}`", + std::any::type_name::(), + v + )) + }) + } + + fn visit_map(self, mut map: M) -> Result + where M: de::MapAccess<'de> { + let json_value: Value = + Deserialize::deserialize(de::value::MapAccessDeserializer::new(&mut map))?; + Err(de::Error::custom(error_message(json_value, self.coerce))) + } + + fn visit_seq(self, mut seq: S) -> Result + where S: de::SeqAccess<'de> { + let json_value: Value = + Deserialize::deserialize(de::value::SeqAccessDeserializer::new(&mut seq))?; + Err(de::Error::custom(error_message(json_value, self.coerce))) + } + + fn visit_none(self) -> Result + where E: de::Error { + Err(de::Error::custom(error_message("null", self.coerce))) + } + + fn visit_bool(self, v: bool) -> Result + where E: de::Error { + Err(de::Error::custom(error_message(v, self.coerce))) + } + } + + deserializer + .deserialize_any(CoerceVisitor { + coerce, + marker: std::marker::PhantomData, + }) + .map_err(|err| err.to_string()) +} + +fn error_message(got: T, coerce: bool) -> String { + if coerce { + format!("expected JSON number or string, got `{}`", got) + } else { + format!("expected JSON, got `{}`", got) + } +} + +pub fn deserialize_i64<'de, D>(deserializer: D, coerce: bool) -> Result +where D: Deserializer<'de> { + deserialize_num_with_coerce(deserializer, coerce) +} + +pub fn deserialize_u64<'de, D>(deserializer: D, coerce: bool) -> Result +where D: Deserializer<'de> { + deserialize_num_with_coerce(deserializer, coerce) +} + +pub fn deserialize_f64<'de, D>(deserializer: D, coerce: bool) -> Result +where D: Deserializer<'de> { + deserialize_num_with_coerce(deserializer, coerce) +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + #[test] + fn test_deserialize_i64_with_coercion() { + let json_data = json!("-123"); + let result: i64 = deserialize_i64(json_data.into_deserializer(), true).unwrap(); + assert_eq!(result, -123); + + let json_data = json!("456"); + let result: i64 = deserialize_i64(json_data.into_deserializer(), true).unwrap(); + assert_eq!(result, 456); + } + + #[test] + fn test_deserialize_u64_with_coercion() { + let json_data = json!("789"); + let result: u64 = deserialize_u64(json_data.into_deserializer(), true).unwrap(); + assert_eq!(result, 789); + + let json_data = json!(123); + let result: u64 = deserialize_u64(json_data.into_deserializer(), false).unwrap(); + assert_eq!(result, 123); + } + + #[test] + fn test_deserialize_f64_with_coercion() { + let json_data = json!("78.9"); + let result: f64 = deserialize_f64(json_data.into_deserializer(), true).unwrap(); + assert_eq!(result, 78.9); + + let json_data = json!(45.6); + let result: f64 = deserialize_f64(json_data.into_deserializer(), false).unwrap(); + assert_eq!(result, 45.6); + } + + #[test] + fn test_deserialize_invalid_string_coercion() { + let json_data = json!("abc"); + let result: Result = deserialize_i64(json_data.into_deserializer(), true); + assert!(result.is_err()); + + let err_msg = result.unwrap_err().to_string(); + assert_eq!(err_msg, "failed to coerce JSON string `\"abc\"` to i64"); + } + + #[test] + fn test_deserialize_json_object() { + let json_data = json!({ "key": "value" }); + let result: Result = deserialize_i64(json_data.into_deserializer(), true); + assert!(result.is_err()); + + let err_msg = result.unwrap_err().to_string(); + assert_eq!( + err_msg, + "expected JSON number or string, got `{\"key\":\"value\"}`" + ); + } +} diff --git a/quickwit/quickwit-doc-mapper/src/doc_mapper/doc_mapper_impl.rs b/quickwit/quickwit-doc-mapper/src/doc_mapper/doc_mapper_impl.rs index 42f233013a4..68bceb75a6b 100644 --- a/quickwit/quickwit-doc-mapper/src/doc_mapper/doc_mapper_impl.rs +++ b/quickwit/quickwit-doc-mapper/src/doc_mapper/doc_mapper_impl.rs @@ -1846,7 +1846,7 @@ mod tests { }"#, "concat", r#"{"some_int": 25}"#, - vec![25_u64.into()], + vec![25_i64.into()], ); } diff --git a/quickwit/quickwit-doc-mapper/src/doc_mapper/mapping_tree.rs b/quickwit/quickwit-doc-mapper/src/doc_mapper/mapping_tree.rs index c10ab9699fb..be2a949ab57 100644 --- a/quickwit/quickwit-doc-mapper/src/doc_mapper/mapping_tree.rs +++ b/quickwit/quickwit-doc-mapper/src/doc_mapper/mapping_tree.rs @@ -17,7 +17,6 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use std::any::type_name; use std::collections::BTreeMap; use std::net::IpAddr; use std::str::FromStr; @@ -33,6 +32,7 @@ use tantivy::schema::{ use tantivy::TantivyDocument as Document; use super::date_time_type::QuickwitDateTimeOptions; +use super::deser_num::{deserialize_f64, deserialize_i64, deserialize_u64}; use super::field_mapping_entry::QuickwitBoolOptions; use super::tantivy_val_to_json::formatted_tantivy_value_to_json; use crate::doc_mapper::field_mapping_entry::{ @@ -149,13 +149,11 @@ pub(crate) fn map_primitive_json_to_tantivy(value: JsonValue) -> Option None, JsonValue::String(text) => Some(TantivyValue::Str(text)), JsonValue::Bool(val) => Some((val).into()), - JsonValue::Number(number) => { - if let Some(val) = u64::from_json_number(&number) { - Some((val).into()) - } else { - i64::from_json_number(&number).map(|val| (val).into()) - } - } + JsonValue::Number(number) => number + .as_i64() + .map(Into::into) + .or(number.as_u64().map(Into::into)) + .or(number.as_f64().map(Into::into)), } } @@ -170,13 +168,13 @@ impl LeafType { } } LeafType::I64(numeric_options) => { - i64::validate_json(json_val, numeric_options.coerce).map(|_| ()) + deserialize_i64(json_val, numeric_options.coerce).map(|_| ()) } LeafType::U64(numeric_options) => { - u64::validate_json(json_val, numeric_options.coerce).map(|_| ()) + deserialize_u64(json_val, numeric_options.coerce).map(|_| ()) } LeafType::F64(numeric_options) => { - f64::validate_json(json_val, numeric_options.coerce).map(|_| ()) + deserialize_f64(json_val, numeric_options.coerce).map(|_| ()) } LeafType::Bool(_) => { if json_val.is_bool() { @@ -226,9 +224,15 @@ impl LeafType { Err(format!("expected string, got `{json_val}`")) } } - LeafType::I64(numeric_options) => i64::from_json(json_val, numeric_options.coerce), - LeafType::U64(numeric_options) => u64::from_json(json_val, numeric_options.coerce), - LeafType::F64(numeric_options) => f64::from_json(json_val, numeric_options.coerce), + LeafType::I64(numeric_options) => { + deserialize_i64(json_val, numeric_options.coerce).map(i64::into) + } + LeafType::U64(numeric_options) => { + deserialize_u64(json_val, numeric_options.coerce).map(u64::into) + } + LeafType::F64(numeric_options) => { + deserialize_f64(json_val, numeric_options.coerce).map(f64::into) + } LeafType::Bool(_) => { if let JsonValue::Bool(val) = json_val { Ok(TantivyValue::Bool(val)) @@ -276,15 +280,15 @@ impl LeafType { } } LeafType::I64(numeric_options) => { - let val = i64::from_json_to_self(&json_val, numeric_options.coerce)?; + let val = deserialize_i64(&json_val, numeric_options.coerce)?; Ok(OneOrIter::one((val).into())) } LeafType::U64(numeric_options) => { - let val = u64::from_json_to_self(&json_val, numeric_options.coerce)?; + let val = deserialize_u64(&json_val, numeric_options.coerce)?; Ok(OneOrIter::one((val).into())) } LeafType::F64(numeric_options) => { - let val = f64::from_json_to_self(&json_val, numeric_options.coerce)?; + let val = deserialize_f64(&json_val, numeric_options.coerce)?; Ok(OneOrIter::one((val).into())) } LeafType::Bool(_) => { @@ -628,108 +632,6 @@ fn insert_json_val( doc_json.insert(last_field_name.to_string(), json_val); } -pub(crate) trait NumVal: Sized + FromStr + ToString + Into { - fn from_json_number(num: &serde_json::Number) -> Option; - - fn validate_json(json_val: &BorrowedJsonValue, coerce: bool) -> Result<(), String> { - match json_val { - BorrowedJsonValue::Number(num_val) => { - let num_val = serde_json::Number::from(*num_val); - Self::from_json_number(&num_val).ok_or_else(|| { - format!( - "expected {}, got inconvertible JSON number `{}`", - type_name::(), - num_val - ) - })?; - Ok(()) - } - BorrowedJsonValue::Str(str_val) => { - if coerce { - str_val.parse::().map_err(|_| { - format!( - "failed to coerce JSON string `\"{str_val}\"` to {}", - type_name::() - ) - })?; - Ok(()) - } else { - Err(format!( - "expected JSON number, got string `\"{str_val}\"`. enable coercion to {} \ - with the `coerce` parameter in the field mapping", - type_name::() - )) - } - } - _ => { - let message = if coerce { - format!("expected JSON number or string, got `{json_val}`") - } else { - format!("expected JSON number, got `{json_val}`") - }; - Err(message) - } - } - } - - fn from_json_to_self(json_val: &JsonValue, coerce: bool) -> Result { - match json_val { - JsonValue::Number(num_val) => Self::from_json_number(num_val).ok_or_else(|| { - format!( - "expected {}, got inconvertible JSON number `{}`", - type_name::(), - num_val - ) - }), - JsonValue::String(str_val) => { - if coerce { - str_val.parse::().map_err(|_| { - format!( - "failed to coerce JSON string `\"{str_val}\"` to {}", - type_name::() - ) - }) - } else { - Err(format!( - "expected JSON number, got string `\"{str_val}\"`. enable coercion to {} \ - with the `coerce` parameter in the field mapping", - type_name::() - )) - } - } - _ => { - let message = if coerce { - format!("expected JSON number or string, got `{json_val}`") - } else { - format!("expected JSON number, got `{json_val}`") - }; - Err(message) - } - } - } - - fn from_json(json_val: JsonValue, coerce: bool) -> Result { - Self::from_json_to_self(&json_val, coerce).map(Self::into) - } -} - -impl NumVal for u64 { - fn from_json_number(num: &serde_json::Number) -> Option { - num.as_u64() - } -} - -impl NumVal for i64 { - fn from_json_number(num: &serde_json::Number) -> Option { - num.as_i64() - } -} -impl NumVal for f64 { - fn from_json_number(num: &serde_json::Number) -> Option { - num.as_f64() - } -} - #[derive(Clone, Default)] pub(crate) struct MappingNode { pub branches: fnv::FnvHashMap, diff --git a/quickwit/quickwit-doc-mapper/src/doc_mapper/mod.rs b/quickwit/quickwit-doc-mapper/src/doc_mapper/mod.rs index 146c2f1f51c..a1b6ae7ee10 100644 --- a/quickwit/quickwit-doc-mapper/src/doc_mapper/mod.rs +++ b/quickwit/quickwit-doc-mapper/src/doc_mapper/mod.rs @@ -18,6 +18,7 @@ // along with this program. If not, see . mod date_time_type; +mod deser_num; mod doc_mapper_builder; mod doc_mapper_impl; mod field_mapping_entry;