Skip to content

Commit

Permalink
fix: add scalar subtrait extention
Browse files Browse the repository at this point in the history
  • Loading branch information
Taylor-lagrange committed Apr 17, 2024
1 parent 3787cd3 commit 56208ca
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ etcd-client = "0.12"
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "04d78b6e025ceb518040fdd10858c2a9d9345820" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "e04700efd3b16c2f7c1e2551bb7f63cdda0912df" }
humantime = "2.1"
humantime-serde = "1.1"
itertools = "0.10"
Expand Down
13 changes: 12 additions & 1 deletion src/common/substrait/src/extension_serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use datafusion::execution::registry::SerializerRegistry;
use datafusion_common::DataFusionError;
use datafusion_expr::UserDefinedLogicalNode;
use promql::extension_plan::{
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
EmptyMetric, InstantManipulate, RangeManipulate, ScalarCalculate, SeriesDivide, SeriesNormalize,
};

pub struct ExtensionSerializer;
Expand Down Expand Up @@ -50,6 +50,13 @@ impl SerializerRegistry for ExtensionSerializer {
.expect("Failed to downcast to RangeManipulate");
Ok(range_manipulate.serialize())
}
name if name == ScalarCalculate::name() => {
let scalar_calculate = node
.as_any()
.downcast_ref::<ScalarCalculate>()
.expect("Failed to downcast to ScalarCalculate");
Ok(scalar_calculate.serialize())
}
name if name == SeriesDivide::name() => {
let series_divide = node
.as_any()
Expand Down Expand Up @@ -92,6 +99,10 @@ impl SerializerRegistry for ExtensionSerializer {
let series_divide = SeriesDivide::deserialize(bytes)?;
Ok(Arc::new(series_divide))
}
name if name == ScalarCalculate::name() => {
let scalar_calculate = ScalarCalculate::deserialize(bytes)?;
Ok(Arc::new(scalar_calculate))
}
name if name == EmptyMetric::name() => Err(DataFusionError::Substrait(
"EmptyMetric should not be deserialized".to_string(),
)),
Expand Down
57 changes: 53 additions & 4 deletions src/promql/src/extension_plan/scalar_calculate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
// limitations under the License.

use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use datafusion::common::{DFField, DFSchema, DFSchemaRef, Result as DataFusionResult, Statistics};
use datafusion::error::DataFusionError;
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion::logical_expr::{EmptyRelation, LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::{
Expand All @@ -30,13 +31,15 @@ use datafusion::physical_plan::{
use datafusion::prelude::Expr;
use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray};
use datatypes::arrow::compute::{cast_with_options, concat_batches, CastOptions};
use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use datatypes::arrow::record_batch::RecordBatch;
use futures::{ready, Stream, StreamExt};
use greptime_proto::substrait_extension as pb;
use prost::Message;
use snafu::ResultExt;

use super::Millisecond;
use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, Result};
use crate::error::{ColumnNotFoundSnafu, DataFusionPlanningSnafu, DeserializeSnafu, Result};

/// `ScalarCalculate` is the custom logical plan to calculate
/// [`scalar`](https://prometheus.io/docs/prometheus/latest/querying/functions/#scalar)
Expand Down Expand Up @@ -93,7 +96,8 @@ impl ScalarCalculate {
})
}

const fn name() -> &'static str {
/// The name of this custom plan
pub const fn name() -> &'static str {
"ScalarCalculate"
}

Expand Down Expand Up @@ -127,6 +131,51 @@ impl ScalarCalculate {
metric: ExecutionPlanMetricsSet::new(),
}))
}

pub fn serialize(&self) -> Vec<u8> {
pb::ScalarCalculate {
start: self.start,
end: self.end,
interval: self.interval,
time_index: self.time_index.clone(),
tag_columns: self.tag_columns.clone(),
field_column: self.field_column.clone(),
}
.encode_to_vec()
}

pub fn deserialize(bytes: &[u8]) -> Result<Self> {
let pb_scalar_caculate = pb::ScalarCalculate::decode(bytes).context(DeserializeSnafu)?;

Check warning on line 148 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
let placeholder_plan = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
});
// TODO(Taylor-lagrange): Supports timestamps of different precisions
let ts_field = DFField::new_unqualified(
&pb_scalar_caculate.time_index,

Check warning on line 155 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
DataType::Timestamp(TimeUnit::Millisecond, None),
true,
);
let val_field = DFField::new_unqualified(
&format!("scalar({})", pb_scalar_caculate.field_column),

Check warning on line 160 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
DataType::Float64,
true,
);
let schema =
DFSchema::new_with_metadata(vec![ts_field.clone(), val_field.clone()], HashMap::new())
.context(DataFusionPlanningSnafu)?;

Ok(Self {
start: pb_scalar_caculate.start,

Check warning on line 169 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
end: pb_scalar_caculate.end,

Check warning on line 170 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
interval: pb_scalar_caculate.interval,

Check warning on line 171 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
time_index: pb_scalar_caculate.time_index,

Check warning on line 172 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
tag_columns: pb_scalar_caculate.tag_columns,

Check warning on line 173 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
field_column: pb_scalar_caculate.field_column,

Check warning on line 174 in src/promql/src/extension_plan/scalar_calculate.rs

View workflow job for this annotation

GitHub Actions / Check typos and docs

"caculate" should be "calculate".
output_schema: Arc::new(schema),
input: placeholder_plan,
})
}
}

impl UserDefinedLogicalNodeCore for ScalarCalculate {
Expand Down

0 comments on commit 56208ca

Please sign in to comment.