From 56208cab08ab7682c938e1499d86ace8fd9bb6b0 Mon Sep 17 00:00:00 2001 From: WUJingdi Date: Wed, 17 Apr 2024 11:32:14 +0800 Subject: [PATCH] fix: add scalar subtrait extention --- Cargo.lock | 2 +- Cargo.toml | 2 +- .../substrait/src/extension_serializer.rs | 13 ++++- .../src/extension_plan/scalar_calculate.rs | 57 +++++++++++++++++-- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d4f44c1d8374..c4b25761ef74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3775,7 +3775,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=04d78b6e025ceb518040fdd10858c2a9d9345820#04d78b6e025ceb518040fdd10858c2a9d9345820" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=e04700efd3b16c2f7c1e2551bb7f63cdda0912df#e04700efd3b16c2f7c1e2551bb7f63cdda0912df" dependencies = [ "prost 0.12.3", "serde", diff --git a/Cargo.toml b/Cargo.toml index 788bc68798e0..4731e8e5187e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,7 +104,7 @@ etcd-client = "0.12" fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "04d78b6e025ceb518040fdd10858c2a9d9345820" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "e04700efd3b16c2f7c1e2551bb7f63cdda0912df" } humantime = "2.1" humantime-serde = "1.1" itertools = "0.10" diff --git a/src/common/substrait/src/extension_serializer.rs b/src/common/substrait/src/extension_serializer.rs index 813c525843b4..89944db508f9 100644 --- a/src/common/substrait/src/extension_serializer.rs +++ b/src/common/substrait/src/extension_serializer.rs @@ -19,7 +19,7 @@ use datafusion::execution::registry::SerializerRegistry; use datafusion_common::DataFusionError; use datafusion_expr::UserDefinedLogicalNode; use promql::extension_plan::{ - EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize, + EmptyMetric, InstantManipulate, RangeManipulate, ScalarCalculate, SeriesDivide, SeriesNormalize, }; pub struct ExtensionSerializer; @@ -50,6 +50,13 @@ impl SerializerRegistry for ExtensionSerializer { .expect("Failed to downcast to RangeManipulate"); Ok(range_manipulate.serialize()) } + name if name == ScalarCalculate::name() => { + let scalar_calculate = node + .as_any() + .downcast_ref::() + .expect("Failed to downcast to ScalarCalculate"); + Ok(scalar_calculate.serialize()) + } name if name == SeriesDivide::name() => { let series_divide = node .as_any() @@ -92,6 +99,10 @@ impl SerializerRegistry for ExtensionSerializer { let series_divide = SeriesDivide::deserialize(bytes)?; Ok(Arc::new(series_divide)) } + name if name == ScalarCalculate::name() => { + let scalar_calculate = ScalarCalculate::deserialize(bytes)?; + Ok(Arc::new(scalar_calculate)) + } name if name == EmptyMetric::name() => Err(DataFusionError::Substrait( "EmptyMetric should not be deserialized".to_string(), )), diff --git a/src/promql/src/extension_plan/scalar_calculate.rs b/src/promql/src/extension_plan/scalar_calculate.rs index b3249d26d9cf..41a0bdd7e504 100644 --- a/src/promql/src/extension_plan/scalar_calculate.rs +++ b/src/promql/src/extension_plan/scalar_calculate.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::any::Any; +use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -20,7 +21,7 @@ use std::task::{Context, Poll}; use datafusion::common::{DFField, DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics}; use datafusion::error::DataFusionError; use datafusion::execution::context::TaskContext; -use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion::logical_expr::{EmptyRelation, LogicalPlan, UserDefinedLogicalNodeCore}; use datafusion::physical_expr::PhysicalSortExpr; use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use datafusion::physical_plan::{ @@ -30,13 +31,15 @@ use datafusion::physical_plan::{ use datafusion::prelude::Expr; use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray}; use datatypes::arrow::compute::{cast_with_options, concat_batches, CastOptions}; -use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datatypes::arrow::record_batch::RecordBatch; use futures::{ready, Stream, StreamExt}; +use greptime_proto::substrait_extension as pb; +use prost::Message; use snafu::ResultExt; use super::Millisecond; -use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, Result}; +use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, DeserializeSnafu, Result}; /// `ScalarCalculate` is the custom logical plan to calculate /// [`scalar`](https://prometheus.io/docs/prometheus/latest/querying/functions/#scalar) @@ -93,7 +96,8 @@ impl ScalarCalculate { }) } - const fn name() -> &'static str { + /// The name of this custom plan + pub const fn name() -> &'static str { "ScalarCalculate" } @@ -127,6 +131,51 @@ impl ScalarCalculate { metric: ExecutionPlanMetricsSet::new(), })) } + + pub fn serialize(&self) -> Vec { + pb::ScalarCalculate { + start: self.start, + end: self.end, + interval: self.interval, + time_index: self.time_index.clone(), + tag_columns: self.tag_columns.clone(), + field_column: self.field_column.clone(), + } + .encode_to_vec() + } + + pub fn deserialize(bytes: &[u8]) -> Result { + let pb_scalar_caculate = pb::ScalarCalculate::decode(bytes).context(DeserializeSnafu)?; + let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }); + // TODO(Taylor-lagrange): Supports timestamps of different precisions + let ts_field = DFField::new_unqualified( + &pb_scalar_caculate.time_index, + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ); + let val_field = DFField::new_unqualified( + &format!("scalar({})", pb_scalar_caculate.field_column), + DataType::Float64, + true, + ); + let schema = + DFSchema::new_with_metadata(vec![ts_field.clone(), val_field.clone()], HashMap::new()) + .context(DataFusionPlanningSnafu)?; + + Ok(Self { + start: pb_scalar_caculate.start, + end: pb_scalar_caculate.end, + interval: pb_scalar_caculate.interval, + time_index: pb_scalar_caculate.time_index, + tag_columns: pb_scalar_caculate.tag_columns, + field_column: pb_scalar_caculate.field_column, + output_schema: Arc::new(schema), + input: placeholder_plan, + }) + } } impl UserDefinedLogicalNodeCore for ScalarCalculate {