From 3ad02a86c199427501d7cad02a9985377703d2b7 Mon Sep 17 00:00:00 2001 From: Bei Chu <914745487@qq.com> Date: Fri, 17 Nov 2023 17:29:19 +0800 Subject: [PATCH] feat: Add `JavaScript` UDF (#2221) * feat: Add `JavaScript` UDF * refactor: Make `Runtime::call_function` accept `&mut self` to improve ergonomics --- Cargo.lock | 40 +- dozer-cli/src/lib.rs | 8 +- dozer-cli/src/live/state.rs | 3 +- dozer-cli/src/pipeline/builder.rs | 22 +- dozer-cli/src/pipeline/connector_source.rs | 7 +- dozer-cli/src/simple/orchestrator.rs | 5 +- dozer-core/src/builder_dag.rs | 39 +- dozer-core/src/dag_schemas.rs | 7 +- dozer-core/src/executor/mod.rs | 6 +- dozer-core/src/node.rs | 6 +- .../src/tests/dag_base_create_errors.rs | 6 +- dozer-core/src/tests/dag_base_errors.rs | 6 +- dozer-core/src/tests/dag_base_run.rs | 11 +- dozer-core/src/tests/dag_ports.rs | 6 +- dozer-core/src/tests/dag_schemas.rs | 34 +- dozer-core/src/tests/processors.rs | 12 +- dozer-deno/src/runtime/mod.rs | 28 +- dozer-lambda/src/js/worker/mod.rs | 19 +- dozer-sql/Cargo.toml | 1 + dozer-sql/expression/Cargo.toml | 3 + dozer-sql/expression/src/builder.rs | 431 +++++++++++------- dozer-sql/expression/src/error.rs | 3 + dozer-sql/expression/src/execution.rs | 4 + .../expression/src/javascript/evaluate.rs | 115 +++++ dozer-sql/expression/src/javascript/mod.rs | 5 + .../expression/src/javascript/validate.rs | 27 ++ dozer-sql/expression/src/lib.rs | 1 + dozer-sql/src/aggregation/factory.rs | 22 +- .../tests/aggregation_test_planner.rs | 9 +- .../tests/aggregation_tests_utils.rs | 9 +- dozer-sql/src/builder.rs | 187 ++++---- dozer-sql/src/expression/tests/execution.rs | 32 +- .../tests/expression_builder_test.rs | 261 +++++------ dozer-sql/src/expression/tests/test_common.rs | 30 +- .../src/pipeline_builder/from_builder.rs | 1 + .../src/pipeline_builder/join_builder.rs | 1 + dozer-sql/src/planner/projection.rs | 49 +- .../src/planner/tests/projection_tests.rs | 9 +- dozer-sql/src/planner/tests/schema_tests.rs | 16 +- dozer-sql/src/product/join/factory.rs | 6 +- dozer-sql/src/product/join/processor.rs | 21 +- dozer-sql/src/product/set/set_factory.rs | 6 +- dozer-sql/src/product/table/factory.rs | 7 +- dozer-sql/src/projection/factory.rs | 73 ++- dozer-sql/src/selection/factory.rs | 29 +- dozer-sql/src/table_operator/factory.rs | 63 ++- dozer-sql/src/tests/builder_test.rs | 27 +- dozer-sql/src/tests/utils.rs | 12 + dozer-sql/src/window/factory.rs | 7 +- dozer-tests/src/sql_tests/helper/pipeline.rs | 16 +- dozer-tests/src/sql_tests/logic_test.rs | 43 +- dozer-types/src/models/udf_config.rs | 8 + json_schemas/dozer.json | 25 + 53 files changed, 1195 insertions(+), 629 deletions(-) create mode 100644 dozer-sql/expression/src/javascript/evaluate.rs create mode 100644 dozer-sql/expression/src/javascript/mod.rs create mode 100644 dozer-sql/expression/src/javascript/validate.rs diff --git a/Cargo.lock b/Cargo.lock index 4f0e55b12f..a12fcd61ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -236,7 +236,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2 0.5.3", + "socket2 0.5.5", "time 0.3.20", "url", ] @@ -829,6 +829,17 @@ dependencies = [ "zstd-safe", ] +[[package]] +name = "async-recursion" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +dependencies = [ + "proc-macro2 1.0.63", + "quote 1.0.30", + "syn 2.0.29", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -3169,7 +3180,7 @@ dependencies = [ "log", "pin-project", "serde", - "socket2 0.5.3", + "socket2 0.5.5", "tokio", "trust-dns-proto 0.22.0", "trust-dns-resolver 0.22.0", @@ -3965,8 +3976,10 @@ dependencies = [ name = "dozer-sql-expression" version = "0.3.0" dependencies = [ + "async-recursion", "bigdecimal", "bincode", + "dozer-deno", "dozer-types", "half 2.3.1", "jsonpath", @@ -3976,6 +3989,7 @@ dependencies = [ "ort", "proptest", "sqlparser 0.35.0", + "tokio", ] [[package]] @@ -6485,14 +6499,14 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.6" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "log", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -6600,7 +6614,7 @@ dependencies = [ "rustls-pemfile", "serde", "serde_json", - "socket2 0.5.3", + "socket2 0.5.5", "thiserror", "tokio", "tokio-rustls 0.24.0", @@ -9884,9 +9898,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", "windows-sys 0.48.0", @@ -10792,9 +10806,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.33.0" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f38200e3ef7995e5ef13baec2f432a6da0aa9ac495b2c0e8f3b7eec2c92d653" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ "backtrace", "bytes", @@ -10804,7 +10818,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.3", + "socket2 0.5.5", "tokio-macros", "tracing", "windows-sys 0.48.0", @@ -10822,9 +10836,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2 1.0.63", "quote 1.0.30", diff --git a/dozer-cli/src/lib.rs b/dozer-cli/src/lib.rs index eb27d03c77..3b460fff9e 100644 --- a/dozer-cli/src/lib.rs +++ b/dozer-cli/src/lib.rs @@ -4,8 +4,7 @@ pub mod live; pub mod pipeline; pub mod simple; use dozer_api::shutdown::ShutdownSender; -use dozer_core::{app::AppPipeline, errors::ExecutionError}; -use dozer_sql::{builder::statement_to_pipeline, errors::PipelineError}; +use dozer_core::errors::ExecutionError; use dozer_types::log::debug; use errors::OrchestrationError; @@ -27,11 +26,6 @@ pub use dozer_ingestion::{ errors::ConnectorError, {get_connector, TableInfo}, }; -pub use dozer_sql::builder::QueryContext; -pub fn wrapped_statement_to_pipeline(sql: &str) -> Result { - let mut pipeline = AppPipeline::new_with_default_flags(); - statement_to_pipeline(sql, &mut pipeline, None, vec![]) -} pub use dozer_types::models::connection::Connection; use dozer_types::tracing::error; diff --git a/dozer-cli/src/live/state.rs b/dozer-cli/src/live/state.rs index 08bf2cdce5..7b55f1b4a6 100644 --- a/dozer-cli/src/live/state.rs +++ b/dozer-cli/src/live/state.rs @@ -324,7 +324,7 @@ fn get_contract(dozer_and_contract: &Option) -> Result<&Contra pub async fn create_contract(dozer: SimpleOrchestrator) -> Result { let dag = create_dag(&dozer).await?; let version = dozer.config.version; - let schemas = DagSchemas::new(dag)?; + let schemas = DagSchemas::new(dag).await?; let contract = Contract::new( version as usize, &schemas, @@ -393,6 +393,7 @@ fn get_dozer_run_instance( &mut AppPipeline::new(dozer.config.flags.clone().into()), None, dozer.config.udfs.clone(), + dozer.runtime.clone(), ) .map_err(LiveError::PipelineError)?; diff --git a/dozer-cli/src/pipeline/builder.rs b/dozer-cli/src/pipeline/builder.rs index 29bdf5350b..a1659e2e58 100644 --- a/dozer-cli/src/pipeline/builder.rs +++ b/dozer-cli/src/pipeline/builder.rs @@ -155,7 +155,10 @@ impl<'a> PipelineBuilder<'a> { // This function is used to figure out the sources that are used in the pipeline // based on the SQL and API Endpoints - pub fn calculate_sources(&self) -> Result { + pub fn calculate_sources( + &self, + runtime: Arc, + ) -> Result { let mut original_sources = vec![]; let mut query_ctx = None; @@ -164,8 +167,9 @@ impl<'a> PipelineBuilder<'a> { let mut transformed_sources = vec![]; if let Some(sql) = &self.sql { - let query_context = statement_to_pipeline(sql, &mut pipeline, None, self.udfs.to_vec()) - .map_err(OrchestrationError::PipelineError)?; + let query_context = + statement_to_pipeline(sql, &mut pipeline, None, self.udfs.to_vec(), runtime) + .map_err(OrchestrationError::PipelineError)?; query_ctx = Some(query_context.clone()); @@ -202,7 +206,7 @@ impl<'a> PipelineBuilder<'a> { runtime: &Arc, shutdown: ShutdownReceiver, ) -> Result { - let calculated_sources = self.calculate_sources()?; + let calculated_sources = self.calculate_sources(runtime.clone())?; debug!("Used Sources: {:?}", calculated_sources.original_sources); let grouped_connections = self @@ -229,8 +233,14 @@ impl<'a> PipelineBuilder<'a> { } if let Some(sql) = &self.sql { - let query_context = statement_to_pipeline(sql, &mut pipeline, None, self.udfs.to_vec()) - .map_err(OrchestrationError::PipelineError)?; + let query_context = statement_to_pipeline( + sql, + &mut pipeline, + None, + self.udfs.to_vec(), + runtime.clone(), + ) + .map_err(OrchestrationError::PipelineError)?; for (name, table_info) in query_context.output_tables_map { available_output_tables diff --git a/dozer-cli/src/pipeline/connector_source.rs b/dozer-cli/src/pipeline/connector_source.rs index 4e095a5dbd..7b2c9d0656 100644 --- a/dozer-cli/src/pipeline/connector_source.rs +++ b/dozer-cli/src/pipeline/connector_source.rs @@ -11,7 +11,7 @@ use dozer_ingestion::{IngestionConfig, Ingestor}; use dozer_tracing::LabelsAndProgress; use dozer_types::errors::internal::BoxedError; use dozer_types::indicatif::ProgressBar; -use dozer_types::log::info; +use dozer_types::log::{error, info}; use dozer_types::models::connection::Connection; use dozer_types::models::ingestion_types::IngestionMessage; use dozer_types::parking_lot::Mutex; @@ -286,7 +286,10 @@ impl Source for ConnectorSource { .await; match result { Ok(Ok(_)) => {} - Ok(Err(e)) => std::panic::panic_any(e), + Ok(Err(e)) => { + error!("{}", e); + std::panic::panic_any(e) + } // Aborted means we are shutting down Err(Aborted) => (), } diff --git a/dozer-cli/src/simple/orchestrator.rs b/dozer-cli/src/simple/orchestrator.rs index dad7dd661e..e6e9014653 100644 --- a/dozer-cli/src/simple/orchestrator.rs +++ b/dozer-cli/src/simple/orchestrator.rs @@ -377,7 +377,7 @@ impl SimpleOrchestrator { .runtime .block_on(builder.build(&self.runtime, shutdown))?; // Populate schemas. - let dag_schemas = DagSchemas::new(dag)?; + let dag_schemas = self.runtime.block_on(DagSchemas::new(dag))?; // Get current contract. let enable_token = self.config.api.api_security.is_some(); @@ -470,12 +470,13 @@ impl SimpleOrchestrator { } } -pub fn validate_sql(sql: String) -> Result<(), PipelineError> { +pub fn validate_sql(sql: String, runtime: Arc) -> Result<(), PipelineError> { statement_to_pipeline( &sql, &mut AppPipeline::new_with_default_flags(), None, vec![], + runtime, ) .map_or_else( |e| { diff --git a/dozer-core/src/builder_dag.rs b/dozer-core/src/builder_dag.rs index 0c91bce4f2..0ef79afa33 100644 --- a/dozer-core/src/builder_dag.rs +++ b/dozer-core/src/builder_dag.rs @@ -1,6 +1,9 @@ use std::{collections::HashMap, fmt::Debug}; -use daggy::petgraph::visit::{IntoNodeIdentifiers, IntoNodeReferences}; +use daggy::{ + petgraph::visit::{IntoNodeIdentifiers, IntoNodeReferences}, + NodeIndex, +}; use dozer_types::node::NodeHandle; use crate::{ @@ -62,8 +65,12 @@ impl BuilderDag { } // Build the nodes. - let graph = dag_schemas.into_graph().try_map( - |node_index, node| match node.kind { + let mut graph = daggy::Dag::new(); + let (nodes, edges) = dag_schemas.into_graph().into_graph().into_nodes_edges(); + for (node_index, node) in nodes.into_iter().enumerate() { + let node_index = NodeIndex::new(node_index); + let node = node.weight; + let node = match node.kind { DagNodeKind::Source(source) => { let mut last_checkpoint_by_name = checkpoint.get_source_state(&node.handle)?; let mut last_checkpoint = HashMap::new(); @@ -87,13 +94,13 @@ impl BuilderDag { ) .map_err(ExecutionError::Factory)?; - Ok::<_, ExecutionError>(NodeType { + NodeType { handle: node.handle, kind: NodeKind::Source { source, last_checkpoint, }, - }) + } } DagNodeKind::Processor(processor) => { let processor = processor @@ -109,11 +116,12 @@ impl BuilderDag { .remove(&node_index) .expect("we collected all processor checkpoint data"), ) + .await .map_err(ExecutionError::Factory)?; - Ok(NodeType { + NodeType { handle: node.handle, kind: NodeKind::Processor(processor), - }) + } } DagNodeKind::Sink(sink) => { let sink = sink @@ -123,14 +131,21 @@ impl BuilderDag { .expect("we collected all input schemas"), ) .map_err(ExecutionError::Factory)?; - Ok(NodeType { + NodeType { handle: node.handle, kind: NodeKind::Sink(sink), - }) + } } - }, - |_, edge| Ok(edge), - )?; + }; + graph.add_node(node); + } + + // Connect the edges. + for edge in edges { + graph + .add_edge(edge.source(), edge.target(), edge.weight) + .expect("we know there's no loop"); + } Ok(BuilderDag { graph }) } diff --git a/dozer-core/src/dag_schemas.rs b/dozer-core/src/dag_schemas.rs index 4ba6cda542..0b3f6b82bf 100644 --- a/dozer-core/src/dag_schemas.rs +++ b/dozer-core/src/dag_schemas.rs @@ -89,10 +89,10 @@ impl DagSchemas { impl DagSchemas { /// Validate and populate the schemas, the resultant DAG will have the exact same structure as the input DAG, /// with validated schema information on the edges. - pub fn new(dag: Dag) -> Result { + pub async fn new(dag: Dag) -> Result { validate_connectivity(&dag); - match populate_schemas(dag.into_graph()) { + match populate_schemas(dag.into_graph()).await { Ok(graph) => { info!("[pipeline] Validation completed"); Ok(Self { graph }) @@ -188,7 +188,7 @@ fn validate_connectivity(dag: &Dag) { } /// In topological order, pass output schemas to downstream nodes' input schemas. -fn populate_schemas( +async fn populate_schemas( dag: daggy::Dag, ) -> Result, ExecutionError> { let mut edges = vec![None; dag.graph().edge_count()]; @@ -226,6 +226,7 @@ fn populate_schemas( for edge in dag.graph().edges(node_index) { let schema = processor .get_output_schema(&edge.weight().from, &input_schemas) + .await .map_err(ExecutionError::Factory)?; create_edge(&mut edges, edge, EdgeKind::FromProcessor, schema); } diff --git a/dozer-core/src/executor/mod.rs b/dozer-core/src/executor/mod.rs index 5962bfdb62..cd0608fff8 100644 --- a/dozer-core/src/executor/mod.rs +++ b/dozer-core/src/executor/mod.rs @@ -78,7 +78,7 @@ impl DagExecutor { checkpoint: OptionCheckpoint, options: ExecutorOptions, ) -> Result { - let dag_schemas = DagSchemas::new(dag)?; + let dag_schemas = DagSchemas::new(dag).await?; let builder_dag = BuilderDag::new(&checkpoint, dag_schemas).await?; @@ -89,8 +89,8 @@ impl DagExecutor { }) } - pub fn validate(dag: Dag) -> Result<(), ExecutionError> { - DagSchemas::new(dag)?; + pub async fn validate(dag: Dag) -> Result<(), ExecutionError> { + DagSchemas::new(dag).await?; Ok(()) } diff --git a/dozer-core/src/node.rs b/dozer-core/src/node.rs index 5169e9129e..7c9510fd08 100644 --- a/dozer-core/src/node.rs +++ b/dozer-core/src/node.rs @@ -7,6 +7,7 @@ use dozer_log::storage::{Object, Queue}; use dozer_types::errors::internal::BoxedError; use dozer_types::node::OpIdentifier; use dozer_types::serde::{Deserialize, Serialize}; +use dozer_types::tonic::async_trait; use dozer_types::types::Schema; use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; @@ -63,15 +64,16 @@ pub trait Source: Send + Sync + Debug { ) -> Result<(), BoxedError>; } +#[async_trait] pub trait ProcessorFactory: Send + Sync + Debug { - fn get_output_schema( + async fn get_output_schema( &self, output_port: &PortHandle, input_schemas: &HashMap, ) -> Result; fn get_input_ports(&self) -> Vec; fn get_output_ports(&self) -> Vec; - fn build( + async fn build( &self, input_schemas: HashMap, output_schemas: HashMap, diff --git a/dozer-core/src/tests/dag_base_create_errors.rs b/dozer-core/src/tests/dag_base_create_errors.rs index 3b33f089e9..8d146f153b 100644 --- a/dozer-core/src/tests/dag_base_create_errors.rs +++ b/dozer-core/src/tests/dag_base_create_errors.rs @@ -13,6 +13,7 @@ use dozer_log::tokio; use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_types::errors::internal::BoxedError; use dozer_types::node::NodeHandle; +use dozer_types::tonic::async_trait; use dozer_types::types::{FieldDefinition, FieldType, Schema, SourceDefinition}; use std::collections::HashMap; @@ -169,12 +170,13 @@ impl CreateErrProcessorFactory { } } +#[async_trait] impl ProcessorFactory for CreateErrProcessorFactory { fn type_name(&self) -> String { "CreateErr".to_owned() } - fn get_output_schema( + async fn get_output_schema( &self, _port: &PortHandle, _input_schemas: &HashMap, @@ -200,7 +202,7 @@ impl ProcessorFactory for CreateErrProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-core/src/tests/dag_base_errors.rs b/dozer-core/src/tests/dag_base_errors.rs index 6e7ac86ee3..7fae7b9303 100644 --- a/dozer-core/src/tests/dag_base_errors.rs +++ b/dozer-core/src/tests/dag_base_errors.rs @@ -17,6 +17,7 @@ use dozer_recordstore::{ProcessorRecordStore, ProcessorRecordStoreDeserializer}; use dozer_types::errors::internal::BoxedError; use dozer_types::models::ingestion_types::IngestionMessage; use dozer_types::node::{NodeHandle, OpIdentifier}; +use dozer_types::tonic::async_trait; use dozer_types::types::{ Field, FieldDefinition, FieldType, Operation, Record, Schema, SourceDefinition, }; @@ -35,12 +36,13 @@ struct ErrorProcessorFactory { panic: bool, } +#[async_trait] impl ProcessorFactory for ErrorProcessorFactory { fn type_name(&self) -> String { "Error".to_owned() } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -56,7 +58,7 @@ impl ProcessorFactory for ErrorProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-core/src/tests/dag_base_run.rs b/dozer-core/src/tests/dag_base_run.rs index fb3a6ebbfe..e27b070375 100644 --- a/dozer-core/src/tests/dag_base_run.rs +++ b/dozer-core/src/tests/dag_base_run.rs @@ -16,6 +16,7 @@ use dozer_log::tokio; use dozer_recordstore::{ProcessorRecordStore, ProcessorRecordStoreDeserializer}; use dozer_types::errors::internal::BoxedError; use dozer_types::node::NodeHandle; +use dozer_types::tonic::async_trait; use dozer_types::types::Schema; use std::collections::HashMap; @@ -27,12 +28,13 @@ use std::time::Duration; #[derive(Debug)] pub(crate) struct NoopProcessorFactory {} +#[async_trait] impl ProcessorFactory for NoopProcessorFactory { fn type_name(&self) -> String { "Noop".to_owned() } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -48,7 +50,7 @@ impl ProcessorFactory for NoopProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, @@ -188,12 +190,13 @@ pub(crate) struct NoopJoinProcessorFactory {} pub const NOOP_JOIN_LEFT_INPUT_PORT: u16 = 1; pub const NOOP_JOIN_RIGHT_INPUT_PORT: u16 = 2; +#[async_trait] impl ProcessorFactory for NoopJoinProcessorFactory { fn type_name(&self) -> String { "NoopJoin".to_owned() } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -209,7 +212,7 @@ impl ProcessorFactory for NoopJoinProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-core/src/tests/dag_ports.rs b/dozer-core/src/tests/dag_ports.rs index b096c7ecb7..b07bee9e4a 100644 --- a/dozer-core/src/tests/dag_ports.rs +++ b/dozer-core/src/tests/dag_ports.rs @@ -4,6 +4,7 @@ use crate::node::{ use crate::{Dag, Endpoint, DEFAULT_PORT_HANDLE}; use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_types::errors::internal::BoxedError; +use dozer_types::tonic::async_trait; use dozer_types::{node::NodeHandle, types::Schema}; use std::collections::HashMap; @@ -57,8 +58,9 @@ impl DynPortsProcessorFactory { } } +#[async_trait] impl ProcessorFactory for DynPortsProcessorFactory { - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, _input_schemas: &HashMap, @@ -74,7 +76,7 @@ impl ProcessorFactory for DynPortsProcessorFactory { self.output_ports.clone() } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-core/src/tests/dag_schemas.rs b/dozer-core/src/tests/dag_schemas.rs index 545b7504ae..b774d366f8 100644 --- a/dozer-core/src/tests/dag_schemas.rs +++ b/dozer-core/src/tests/dag_schemas.rs @@ -5,18 +5,14 @@ use crate::node::{ }; use crate::{Dag, Endpoint, DEFAULT_PORT_HANDLE}; +use dozer_log::tokio; use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_types::errors::internal::BoxedError; use dozer_types::node::NodeHandle; +use dozer_types::tonic::async_trait; use dozer_types::types::{FieldDefinition, FieldType, Schema, SourceDefinition}; use std::collections::HashMap; -macro_rules! chk { - ($stmt:expr) => { - $stmt.unwrap_or_else(|e| panic!("{}", e.to_string())) - }; -} - #[derive(Debug)] struct TestUsersSourceFactory {} @@ -121,8 +117,9 @@ impl SourceFactory for TestCountriesSourceFactory { #[derive(Debug)] struct TestJoinProcessorFactory {} +#[async_trait] impl ProcessorFactory for TestJoinProcessorFactory { - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -144,7 +141,7 @@ impl ProcessorFactory for TestJoinProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, @@ -183,8 +180,8 @@ impl SinkFactory for TestSinkFactory { } } -#[test] -fn test_extract_dag_schemas() { +#[tokio::test] +async fn test_extract_dag_schemas() { let mut dag = Dag::new(); let users_handle = NodeHandle::new(Some(1), 1.to_string()); @@ -200,20 +197,23 @@ fn test_extract_dag_schemas() { let join_index = dag.add_processor(join_handle.clone(), Box::new(TestJoinProcessorFactory {})); let sink_index = dag.add_sink(sink_handle.clone(), Box::new(TestSinkFactory {})); - chk!(dag.connect( + dag.connect( Endpoint::new(users_handle, DEFAULT_PORT_HANDLE), Endpoint::new(join_handle.clone(), 1), - )); - chk!(dag.connect( + ) + .unwrap(); + dag.connect( Endpoint::new(countries_handle, DEFAULT_PORT_HANDLE), Endpoint::new(join_handle.clone(), 2), - )); - chk!(dag.connect( + ) + .unwrap(); + dag.connect( Endpoint::new(join_handle, DEFAULT_PORT_HANDLE), Endpoint::new(sink_handle, DEFAULT_PORT_HANDLE), - )); + ) + .unwrap(); - let dag_schemas = chk!(DagSchemas::new(dag)); + let dag_schemas = DagSchemas::new(dag).await.unwrap(); let users_output = dag_schemas.get_node_output_schemas(users_index); assert_eq!( diff --git a/dozer-core/src/tests/processors.rs b/dozer-core/src/tests/processors.rs index b3ebdfad3b..c0343785f9 100644 --- a/dozer-core/src/tests/processors.rs +++ b/dozer-core/src/tests/processors.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use dozer_recordstore::ProcessorRecordStoreDeserializer; -use dozer_types::{errors::internal::BoxedError, types::Schema}; +use dozer_types::{errors::internal::BoxedError, tonic::async_trait, types::Schema}; use crate::{ node::{PortHandle, Processor, ProcessorFactory}, @@ -11,6 +11,7 @@ use crate::{ #[derive(Debug)] pub struct ConnectivityTestProcessorFactory; +#[async_trait] impl ProcessorFactory for ConnectivityTestProcessorFactory { fn type_name(&self) -> String { "ConnectivityTest".to_owned() @@ -23,7 +24,7 @@ impl ProcessorFactory for ConnectivityTestProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, _input_schemas: &HashMap, @@ -33,7 +34,7 @@ impl ProcessorFactory for ConnectivityTestProcessorFactory { ) } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, @@ -53,6 +54,7 @@ impl ProcessorFactory for ConnectivityTestProcessorFactory { #[derive(Debug)] pub struct NoInputPortProcessorFactory; +#[async_trait] impl ProcessorFactory for NoInputPortProcessorFactory { fn get_input_ports(&self) -> Vec { vec![] @@ -62,7 +64,7 @@ impl ProcessorFactory for NoInputPortProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, _input_schemas: &HashMap, @@ -72,7 +74,7 @@ impl ProcessorFactory for NoInputPortProcessorFactory { ) } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-deno/src/runtime/mod.rs b/dozer-deno/src/runtime/mod.rs index dacc5a5877..687d80facd 100644 --- a/dozer-deno/src/runtime/mod.rs +++ b/dozer-deno/src/runtime/mod.rs @@ -37,7 +37,7 @@ use tokio::{ #[derive(Debug)] pub struct Runtime { work_sender: mpsc::Sender, - handle: JoinHandle<()>, + handle: Option>, } #[derive(Debug, thiserror::Error)] @@ -112,17 +112,17 @@ impl Runtime { Ok(( Self { work_sender, - handle, + handle: Some(handle), }, functions, )) } pub async fn call_function( - self, + &mut self, id: NonZeroI32, args: Vec, - ) -> (Self, Result) { + ) -> Result { let (return_sender, return_receiver) = oneshot::channel(); if self .work_sender @@ -134,16 +134,22 @@ impl Runtime { .await .is_err() { - // Propagate the panic. - self.handle.await.unwrap(); - unreachable!("we should have panicked"); + return self.propagate_panic().await; } let Ok(result) = return_receiver.await else { - // Propagate the panic. - self.handle.await.unwrap(); - unreachable!("we should have panicked"); + return self.propagate_panic().await; }; - (self, result) + result + } + + // Return type is actually `!` + async fn propagate_panic(&mut self) -> Result { + self.handle + .take() + .expect("runtime panicked before and cannot be used again") + .await + .unwrap(); + unreachable!("we should have panicked"); } } diff --git a/dozer-lambda/src/js/worker/mod.rs b/dozer-lambda/src/js/worker/mod.rs index f5f4b006cb..17d4f01bee 100644 --- a/dozer-lambda/src/js/worker/mod.rs +++ b/dozer-lambda/src/js/worker/mod.rs @@ -10,8 +10,7 @@ use dozer_types::{ #[derive(Debug)] pub struct Worker { - /// Always `Some`. - runtime: Option, + runtime: dozer_deno::Runtime, } impl Worker { @@ -20,12 +19,7 @@ impl Worker { modules: Vec, ) -> Result<(Self, Vec), dozer_deno::RuntimeError> { let (runtime, lambdas) = dozer_deno::Runtime::new(runtime, modules).await?; - Ok(( - Self { - runtime: Some(runtime), - }, - lambdas, - )) + Ok((Self { runtime }, lambdas)) } pub async fn call_lambda( @@ -46,14 +40,7 @@ impl Worker { "new": create_record_json_value(field_names.clone(), new_values), "old": old_values.map(|old_values| create_record_json_value(field_names, old_values)), }); - let result = self - .runtime - .take() - .unwrap() - .call_function(func, vec![arg]) - .await; - self.runtime = Some(result.0); - if let Err(e) = result.1 { + if let Err(e) = self.runtime.call_function(func, vec![arg]).await { error!("error calling lambda: {}", e); } } diff --git a/dozer-sql/Cargo.toml b/dozer-sql/Cargo.toml index 3841091d77..d3048f8a19 100644 --- a/dozer-sql/Cargo.toml +++ b/dozer-sql/Cargo.toml @@ -21,6 +21,7 @@ linked-hash-map = { version = "0.5.6", features = ["serde_impl"] } metrics = "0.21.0" multimap = "0.9.0" regex = "1.10.2" +tokio = { version = "1", features = ["rt", "macros"] } [dev-dependencies] tempdir = "0.3.7" diff --git a/dozer-sql/expression/Cargo.toml b/dozer-sql/expression/Cargo.toml index 7fc1c40b62..ffea5c348a 100644 --- a/dozer-sql/expression/Cargo.toml +++ b/dozer-sql/expression/Cargo.toml @@ -6,6 +6,7 @@ authors = ["getdozer/dozer-dev"] [dependencies] dozer-types = { path = "../../dozer-types" } +dozer-deno = { path = "../../dozer-deno" } num-traits = "0.2.16" sqlparser = { git = "https://github.com/getdozer/sqlparser-rs.git" } bigdecimal = { version = "0.3", features = ["serde"], optional = true } @@ -15,6 +16,8 @@ half = { version = "2.3.1", optional = true } like = "0.3.1" jsonpath = { path = "../jsonpath" } bincode = { workspace = true } +tokio = "1.34.0" +async-recursion = "1.0.5" [dev-dependencies] proptest = "1.2.0" diff --git a/dozer-sql/expression/src/builder.rs b/dozer-sql/expression/src/builder.rs index c00d6d4e15..495cf4d9fe 100644 --- a/dozer-sql/expression/src/builder.rs +++ b/dozer-sql/expression/src/builder.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::aggregate::AggregateFunctionType; use crate::conditional::ConditionalExpressionType; use crate::datetime::DateTimeFunctionType; @@ -13,6 +15,7 @@ use sqlparser::ast::{ FunctionArg, FunctionArgExpr, Ident, Interval, TrimWhereField, UnaryOperator as SqlUnaryOperator, Value as SqlValue, }; +use tokio::runtime::Runtime; use crate::execution::Expression; use crate::execution::Expression::{ConditionalExpression, GeoFunction, Now, ScalarFunction}; @@ -24,29 +27,32 @@ use crate::scalar::string::TrimType; use super::cast::CastOperatorType; -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, Debug)] pub struct ExpressionBuilder { // Must be an aggregation function pub aggregations: Vec, pub offset: usize, + runtime: Arc, } impl ExpressionBuilder { - pub fn new(offset: usize) -> Self { + pub fn new(offset: usize, runtime: Arc) -> Self { Self { aggregations: Vec::new(), offset, + runtime, } } - pub fn from(offset: usize, aggregations: Vec) -> Self { + pub fn from(offset: usize, aggregations: Vec, runtime: Arc) -> Self { Self { aggregations, offset, + runtime, } } - pub fn build( + pub async fn build( &mut self, parse_aggregations: bool, sql_expression: &SqlExpr, @@ -54,9 +60,11 @@ impl ExpressionBuilder { udfs: &[UdfConfig], ) -> Result { self.parse_sql_expression(parse_aggregations, sql_expression, schema, udfs) + .await } - pub fn parse_sql_expression( + #[async_recursion::async_recursion] + pub async fn parse_sql_expression( &mut self, parse_aggregations: bool, expression: &SqlExpr, @@ -68,14 +76,17 @@ impl ExpressionBuilder { expr, trim_where, trim_what, - } => self.parse_sql_trim_function( - parse_aggregations, - expr, - trim_where, - trim_what, - schema, - udfs, - ), + } => { + self.parse_sql_trim_function( + parse_aggregations, + expr, + trim_where, + trim_what, + schema, + udfs, + ) + .await + } SqlExpr::Identifier(ident) => Self::parse_sql_column(&[ident.clone()], schema), SqlExpr::CompoundIdentifier(ident) => Self::parse_sql_column(ident, schema), SqlExpr::Value(SqlValue::Number(n, _)) => Self::parse_sql_number(n), @@ -85,48 +96,60 @@ impl ExpressionBuilder { } SqlExpr::UnaryOp { expr, op } => { self.parse_sql_unary_op(parse_aggregations, op, expr, schema, udfs) + .await } SqlExpr::BinaryOp { left, op, right } => { self.parse_sql_binary_op(parse_aggregations, left, op, right, schema, udfs) + .await } SqlExpr::Nested(expr) => { self.parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await } SqlExpr::Function(sql_function) => { self.parse_sql_function(parse_aggregations, sql_function, schema, udfs) + .await } SqlExpr::Like { negated, expr, pattern, escape_char, - } => self.parse_sql_like_operator( - parse_aggregations, - negated, - expr, - pattern, - escape_char, - schema, - udfs, - ), + } => { + self.parse_sql_like_operator( + parse_aggregations, + negated, + expr, + pattern, + escape_char, + schema, + udfs, + ) + .await + } SqlExpr::InList { expr, list, negated, - } => self.parse_sql_in_list_operator( - parse_aggregations, - expr, - list, - *negated, - schema, - udfs, - ), + } => { + self.parse_sql_in_list_operator( + parse_aggregations, + expr, + list, + *negated, + schema, + udfs, + ) + .await + } SqlExpr::Cast { expr, data_type } => { self.parse_sql_cast_operator(parse_aggregations, expr, data_type, schema, udfs) + .await } SqlExpr::Extract { field, expr } => { self.parse_sql_extract_operator(parse_aggregations, field, expr, schema, udfs) + .await } SqlExpr::Interval(Interval { value, @@ -134,27 +157,33 @@ impl ExpressionBuilder { leading_precision: _, last_field: _, fractional_seconds_precision: _, - }) => self.parse_sql_interval_expression( - parse_aggregations, - value, - leading_field, - schema, - udfs, - ), + }) => { + self.parse_sql_interval_expression( + parse_aggregations, + value, + leading_field, + schema, + udfs, + ) + .await + } SqlExpr::Case { operand, conditions, results, else_result, - } => self.parse_sql_case_expression( - parse_aggregations, - operand, - conditions, - results, - else_result, - schema, - udfs, - ), + } => { + self.parse_sql_case_expression( + parse_aggregations, + operand, + conditions, + results, + else_result, + schema, + udfs, + ) + .await + } _ => Err(Error::UnsupportedExpression(expression.clone())), } } @@ -233,7 +262,7 @@ impl ExpressionBuilder { } } - fn parse_sql_trim_function( + async fn parse_sql_trim_function( &mut self, parse_aggregations: bool, expr: &Expr, @@ -242,14 +271,15 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let arg = Box::new(self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?); + let arg = Box::new( + self.parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?, + ); let what = match trim_what { - Some(e) => Some(Box::new(self.parse_sql_expression( - parse_aggregations, - e, - schema, - udfs, - )?)), + Some(e) => Some(Box::new( + self.parse_sql_expression(parse_aggregations, e, schema, udfs) + .await?, + )), _ => None, }; let typ = trim_where.as_ref().map(|e| match e { @@ -260,7 +290,7 @@ impl ExpressionBuilder { Ok(Expression::Trim { arg, what, typ }) } - fn aggr_function_check( + async fn aggr_function_check( &mut self, function_name: String, parse_aggregations: bool, @@ -276,7 +306,10 @@ impl ExpressionBuilder { let mut arg_expr: Vec = Vec::new(); for arg in &sql_function.args { - let aggregation = self.parse_sql_function_arg(true, arg, schema, udfs).ok()?; + let aggregation = self + .parse_sql_function_arg(true, arg, schema, udfs) + .await + .ok()?; arg_expr.push(aggregation); } let measure = Expression::AggregateFunction { @@ -300,7 +333,7 @@ impl ExpressionBuilder { }) } - fn scalar_function_check( + async fn scalar_function_check( &mut self, function_name: String, parse_aggregations: bool, @@ -312,6 +345,7 @@ impl ExpressionBuilder { for arg in &sql_function.args { function_args.push( self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .await .ok()?, ); } @@ -323,7 +357,7 @@ impl ExpressionBuilder { }) } - fn geo_expr_check( + async fn geo_expr_check( &mut self, function_name: String, parse_aggregations: bool, @@ -335,6 +369,7 @@ impl ExpressionBuilder { for arg in &sql_function.args { function_args.push( self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .await .ok()?, ); } @@ -351,7 +386,7 @@ impl ExpressionBuilder { Some(Now { fun: dtf }) } - fn json_func_check( + async fn json_func_check( &mut self, function_name: String, parse_aggregations: bool, @@ -363,6 +398,7 @@ impl ExpressionBuilder { for arg in &sql_function.args { function_args.push( self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .await .ok()?, ); } @@ -374,7 +410,7 @@ impl ExpressionBuilder { }) } - fn conditional_expr_check( + async fn conditional_expr_check( &mut self, function_name: String, parse_aggregations: bool, @@ -386,6 +422,7 @@ impl ExpressionBuilder { for arg in &sql_function.args { function_args.push( self.parse_sql_function_arg(parse_aggregations, arg, schema, udfs) + .await .ok()?, ); } @@ -397,7 +434,7 @@ impl ExpressionBuilder { }) } - fn parse_sql_function( + async fn parse_sql_function( &mut self, parse_aggregations: bool, sql_function: &Function, @@ -410,46 +447,60 @@ impl ExpressionBuilder { if function_name.starts_with("py_") { // The function is from python udf. let udf_name = function_name.strip_prefix("py_").unwrap(); - return self.parse_python_udf(udf_name, sql_function, schema, udfs); + return self + .parse_python_udf(udf_name, sql_function, schema, udfs) + .await; } - if let Some(aggr_check) = self.aggr_function_check( - function_name.clone(), - parse_aggregations, - sql_function, - schema, - udfs, - ) { + if let Some(aggr_check) = self + .aggr_function_check( + function_name.clone(), + parse_aggregations, + sql_function, + schema, + udfs, + ) + .await + { return Ok(aggr_check); } - if let Some(scalar_check) = self.scalar_function_check( - function_name.clone(), - parse_aggregations, - sql_function, - schema, - udfs, - ) { + if let Some(scalar_check) = self + .scalar_function_check( + function_name.clone(), + parse_aggregations, + sql_function, + schema, + udfs, + ) + .await + { return Ok(scalar_check); } - if let Some(geo_check) = self.geo_expr_check( - function_name.clone(), - parse_aggregations, - sql_function, - schema, - udfs, - ) { + if let Some(geo_check) = self + .geo_expr_check( + function_name.clone(), + parse_aggregations, + sql_function, + schema, + udfs, + ) + .await + { return Ok(geo_check); } - if let Some(conditional_check) = self.conditional_expr_check( - function_name.clone(), - parse_aggregations, - sql_function, - schema, - udfs, - ) { + if let Some(conditional_check) = self + .conditional_expr_check( + function_name.clone(), + parse_aggregations, + sql_function, + schema, + udfs, + ) + .await + { return Ok(conditional_check); } @@ -457,13 +508,16 @@ impl ExpressionBuilder { return Ok(datetime_check); } - if let Some(json_check) = self.json_func_check( - function_name.clone(), - parse_aggregations, - sql_function, - schema, - udfs, - ) { + if let Some(json_check) = self + .json_func_check( + function_name.clone(), + parse_aggregations, + sql_function, + schema, + udfs, + ) + .await + { return Ok(json_check); } @@ -481,6 +535,7 @@ impl ExpressionBuilder { schema, udfs, ) + .await } #[cfg(not(feature = "onnx"))] @@ -489,13 +544,23 @@ impl ExpressionBuilder { Err(Error::OnnxNotEnabled) } } + UdfType::JavaScript(config) => { + self.parse_javascript_udf( + function_name.clone(), + config, + sql_function, + schema, + udfs, + ) + .await + } }; } Err(Error::UnknownFunction(function_name.clone())) } - fn parse_sql_function_arg( + async fn parse_sql_function_arg( &mut self, parse_aggregations: bool, argument: &FunctionArg, @@ -506,13 +571,17 @@ impl ExpressionBuilder { FunctionArg::Named { name: _, arg: FunctionArgExpr::Expr(arg), - } => self.parse_sql_expression(parse_aggregations, arg, schema, udfs), + } => { + self.parse_sql_expression(parse_aggregations, arg, schema, udfs) + .await + } FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, } => Ok(Expression::Literal(Field::Null)), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.parse_sql_expression(parse_aggregations, arg, schema, udfs) + .await } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expression::Literal(Field::Null)), _ => Err(Error::UnsupportedFunctionArg(argument.clone())), @@ -520,7 +589,7 @@ impl ExpressionBuilder { } #[allow(clippy::too_many_arguments)] - fn parse_sql_case_expression( + async fn parse_sql_case_expression( &mut self, parse_aggregations: bool, operand: &Option>, @@ -531,29 +600,31 @@ impl ExpressionBuilder { udfs: &[UdfConfig], ) -> Result { let op = match operand { - Some(o) => Some(Box::new(self.parse_sql_expression( - parse_aggregations, - o, - schema, - udfs, - )?)), + Some(o) => Some(Box::new( + self.parse_sql_expression(parse_aggregations, o, schema, udfs) + .await?, + )), None => None, }; - let conds = conditions - .iter() - .map(|cond| self.parse_sql_expression(parse_aggregations, cond, schema, udfs)) - .collect::, Error>>()?; - let res = results - .iter() - .map(|r| self.parse_sql_expression(parse_aggregations, r, schema, udfs)) - .collect::, Error>>()?; + let mut conds = vec![]; + for cond in conditions { + conds.push( + self.parse_sql_expression(parse_aggregations, cond, schema, udfs) + .await?, + ); + } + let mut res = vec![]; + for r in results { + res.push( + self.parse_sql_expression(parse_aggregations, r, schema, udfs) + .await?, + ); + } let else_res = match else_result { - Some(r) => Some(Box::new(self.parse_sql_expression( - parse_aggregations, - r, - schema, - udfs, - )?)), + Some(r) => Some(Box::new( + self.parse_sql_expression(parse_aggregations, r, schema, udfs) + .await?, + )), None => None, }; @@ -565,7 +636,7 @@ impl ExpressionBuilder { }) } - fn parse_sql_interval_expression( + async fn parse_sql_interval_expression( &mut self, parse_aggregations: bool, value: &Expr, @@ -573,7 +644,9 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let right = self.parse_sql_expression(parse_aggregations, value, schema, udfs)?; + let right = self + .parse_sql_expression(parse_aggregations, value, schema, udfs) + .await?; if let Some(leading_field) = leading_field { Ok(Expression::DateTimeFunction { fun: DateTimeFunctionType::Interval { @@ -586,7 +659,7 @@ impl ExpressionBuilder { } } - fn parse_sql_unary_op( + async fn parse_sql_unary_op( &mut self, parse_aggregations: bool, op: &SqlUnaryOperator, @@ -594,7 +667,10 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let arg = Box::new(self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?); + let arg = Box::new( + self.parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?, + ); let operator = match op { SqlUnaryOperator::Not => UnaryOperatorType::Not, SqlUnaryOperator::Plus => UnaryOperatorType::Plus, @@ -605,7 +681,7 @@ impl ExpressionBuilder { Ok(Expression::UnaryOperator { operator, arg }) } - fn parse_sql_binary_op( + async fn parse_sql_binary_op( &mut self, parse_aggregations: bool, left: &SqlExpr, @@ -614,8 +690,12 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let left_op = self.parse_sql_expression(parse_aggregations, left, schema, udfs)?; - let right_op = self.parse_sql_expression(parse_aggregations, right, schema, udfs)?; + let left_op = self + .parse_sql_expression(parse_aggregations, left, schema, udfs) + .await?; + let right_op = self + .parse_sql_expression(parse_aggregations, right, schema, udfs) + .await?; let operator = match op { SqlBinaryOperator::Gt => BinaryOperatorType::Gt, @@ -666,7 +746,7 @@ impl ExpressionBuilder { } #[allow(clippy::too_many_arguments)] - fn parse_sql_like_operator( + async fn parse_sql_like_operator( &mut self, parse_aggregations: bool, negated: &bool, @@ -676,8 +756,12 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let arg = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; - let pattern = self.parse_sql_expression(parse_aggregations, pattern, schema, udfs)?; + let arg = self + .parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?; + let pattern = self + .parse_sql_expression(parse_aggregations, pattern, schema, udfs) + .await?; let like_expression = Expression::Like { arg: Box::new(arg), pattern: Box::new(pattern), @@ -693,7 +777,7 @@ impl ExpressionBuilder { } } - fn parse_sql_extract_operator( + async fn parse_sql_extract_operator( &mut self, parse_aggregations: bool, field: &sqlparser::ast::DateTimeField, @@ -701,14 +785,16 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let right = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; + let right = self + .parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?; Ok(Expression::DateTimeFunction { fun: DateTimeFunctionType::Extract { field: *field }, arg: Box::new(right), }) } - fn parse_sql_cast_operator( + async fn parse_sql_cast_operator( &mut self, parse_aggregations: bool, expr: &Expr, @@ -716,7 +802,9 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let expression = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; + let expression = self + .parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?; let cast_to = match data_type { DataType::Decimal(_) => CastOperatorType(FieldType::Decimal), DataType::Binary(_) => CastOperatorType(FieldType::Binary), @@ -770,7 +858,7 @@ impl ExpressionBuilder { } #[cfg(feature = "python")] - fn parse_python_udf( + async fn parse_python_udf( &mut self, name: &str, function: &Function, @@ -781,11 +869,13 @@ impl ExpressionBuilder { // First, get python function define by name. // Then, transfer python function to Expression::PythonUDF - let args = function - .args - .iter() - .map(|argument| self.parse_sql_function_arg(false, argument, schema, udfs)) - .collect::, Error>>()?; + let mut args = vec![]; + for argument in &function.args { + let arg = self + .parse_sql_function_arg(false, argument, schema, udfs) + .await?; + args.push(arg); + } let return_type = { let ident = function @@ -804,7 +894,7 @@ impl ExpressionBuilder { } #[cfg(feature = "onnx")] - fn parse_onnx_udf( + async fn parse_onnx_udf( &mut self, name: String, config: &dozer_types::models::udf_config::OnnxConfig, @@ -821,11 +911,13 @@ impl ExpressionBuilder { use ort::{Environment, GraphOptimizationLevel, LoggingLevel, SessionBuilder}; use std::path::Path; - let args = function - .args - .iter() - .map(|argument| self.parse_sql_function_arg(false, argument, schema, udfs)) - .collect::, Error>>()?; + let mut args = vec![]; + for argument in &function.args { + let arg = self + .parse_sql_function_arg(false, argument, schema, udfs) + .await?; + args.push(arg); + } let environment = Environment::builder() .with_name("dozer_onnx") @@ -855,7 +947,35 @@ impl ExpressionBuilder { }) } - fn parse_sql_in_list_operator( + async fn parse_javascript_udf( + &mut self, + name: String, + config: &dozer_types::models::udf_config::JavaScriptConfig, + function: &Function, + schema: &Schema, + udfs: &[UdfConfig], + ) -> Result { + let mut args = vec![]; + for argument in &function.args { + let arg = self + .parse_sql_function_arg(false, argument, schema, udfs) + .await?; + args.push(arg); + } + + use crate::javascript::{validate_args, Udf}; + validate_args(name.clone(), &args, schema)?; + let udf = Udf::new( + self.runtime.clone(), + name, + config.module.clone(), + args.remove(0), + ) + .await?; + Ok(Expression::JavaScriptUdf(udf)) + } + + async fn parse_sql_in_list_operator( &mut self, parse_aggregations: bool, expr: &Expr, @@ -864,14 +984,19 @@ impl ExpressionBuilder { schema: &Schema, udfs: &[UdfConfig], ) -> Result { - let expr = self.parse_sql_expression(parse_aggregations, expr, schema, udfs)?; - let list = list - .iter() - .map(|expr| self.parse_sql_expression(parse_aggregations, expr, schema, udfs)) - .collect::, Error>>()?; + let expr = self + .parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?; + let mut list_expressions = vec![]; + for expr in list { + list_expressions.push( + self.parse_sql_expression(parse_aggregations, expr, schema, udfs) + .await?, + ); + } let in_list_expression = Expression::InList { expr: Box::new(expr), - list, + list: list_expressions, negated, }; diff --git a/dozer-sql/expression/src/error.rs b/dozer-sql/expression/src/error.rs index 786fe30005..7e4ee3572e 100644 --- a/dozer-sql/expression/src/error.rs +++ b/dozer-sql/expression/src/error.rs @@ -100,6 +100,9 @@ pub enum Error { #[error("ONNX UDF is not enabled")] OnnxNotEnabled, + #[error("JavaScript UDF error: {0}")] + JavaScript(#[from] crate::javascript::Error), + // Legacy error types. #[error("Sql error: {0}")] SqlError(#[source] OperationError), diff --git a/dozer-sql/expression/src/execution.rs b/dozer-sql/expression/src/execution.rs index 71c0316e8e..1508520826 100644 --- a/dozer-sql/expression/src/execution.rs +++ b/dozer-sql/expression/src/execution.rs @@ -96,6 +96,7 @@ pub enum Expression { session: crate::onnx::DozerSession, args: Vec, }, + JavaScriptUdf(crate::javascript::Udf), } impl Expression { @@ -274,6 +275,7 @@ impl Expression { .as_str() + ")" } + Expression::JavaScriptUdf(udf) => udf.to_string(schema), } } } @@ -361,6 +363,7 @@ impl Expression { results, else_result, } => evaluate_case(schema, operand, conditions, results, else_result, record), + Expression::JavaScriptUdf(udf) => udf.evaluate(record, schema), } } @@ -468,6 +471,7 @@ impl Expression { SourceDefinition::Dynamic, false, )), + Expression::JavaScriptUdf(udf) => Ok(udf.get_type()), } } } diff --git a/dozer-sql/expression/src/javascript/evaluate.rs b/dozer-sql/expression/src/javascript/evaluate.rs new file mode 100644 index 0000000000..a9be6761b2 --- /dev/null +++ b/dozer-sql/expression/src/javascript/evaluate.rs @@ -0,0 +1,115 @@ +use std::{num::NonZeroI32, sync::Arc}; + +use dozer_deno::deno_runtime::deno_core::error::AnyError; +use dozer_types::{ + errors::types::DeserializationError, + json_types::{json_value_to_serde_json, serde_json_to_json_value}, + thiserror, + types::{Field, FieldType, Record, Schema, SourceDefinition}, +}; +use tokio::{runtime::Runtime, sync::Mutex}; + +use crate::execution::{Expression, ExpressionType}; + +#[derive(Debug, Clone)] +pub struct Udf { + function_name: String, + arg: Box, + tokio_runtime: Arc, + /// `Arc` to enable `Clone`. Not sure why `Expression` should be `Clone`. + deno_runtime: Arc>, + function: NonZeroI32, +} + +impl PartialEq for Udf { + fn eq(&self, other: &Self) -> bool { + // This is obviously wrong. We have to lift the `PartialEq` constraint. + self.function_name == other.function_name && self.arg == other.arg + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("failed to create deno runtime: {0}")] + CreateRuntime(#[from] dozer_deno::RuntimeError), + #[error("failed to evaluate udf: {0}")] + Evaluate(#[source] AnyError), + #[error("failed to convert json: {0}")] + JsonConversion(#[source] DeserializationError), +} + +impl Udf { + pub async fn new( + tokio_runtime: Arc, + function_name: String, + module: String, + arg: Expression, + ) -> Result { + let (deno_runtime, functions) = + dozer_deno::Runtime::new(tokio_runtime.clone(), vec![module]).await?; + let function = functions[0]; + Ok(Self { + function_name, + arg: Box::new(arg), + tokio_runtime, + deno_runtime: Arc::new(Mutex::new(deno_runtime)), + function, + }) + } + + pub fn get_type(&self) -> ExpressionType { + ExpressionType { + return_type: FieldType::Json, + nullable: false, + source: SourceDefinition::Dynamic, + is_primary_key: false, + } + } + + pub fn evaluate( + &mut self, + record: &Record, + schema: &Schema, + ) -> Result { + self.tokio_runtime.block_on(evaluate_impl( + self.function_name.clone(), + &mut self.arg, + &self.deno_runtime, + self.function, + record, + schema, + )) + } + + pub fn to_string(&self, schema: &Schema) -> String { + format!("{}({})", self.function_name, self.arg.to_string(schema)) + } +} + +async fn evaluate_impl( + function_name: String, + arg: &mut Expression, + runtime: &Arc>, + function: NonZeroI32, + record: &Record, + schema: &Schema, +) -> Result { + let arg = arg.evaluate(record, schema)?; + let Field::Json(arg) = arg else { + return Err(crate::error::Error::InvalidFunctionArgument { + function_name, + argument_index: 0, + argument: arg, + }); + }; + + let mut runtime = runtime.lock().await; + let result = runtime + .call_function(function, vec![json_value_to_serde_json(&arg)]) + .await + .map_err(Error::Evaluate)?; + drop(runtime); + + let result = serde_json_to_json_value(result).map_err(Error::JsonConversion)?; + Ok(Field::Json(result)) +} diff --git a/dozer-sql/expression/src/javascript/mod.rs b/dozer-sql/expression/src/javascript/mod.rs new file mode 100644 index 0000000000..10a48e1756 --- /dev/null +++ b/dozer-sql/expression/src/javascript/mod.rs @@ -0,0 +1,5 @@ +mod evaluate; +mod validate; + +pub use evaluate::{Error, Udf}; +pub use validate::validate_args; diff --git a/dozer-sql/expression/src/javascript/validate.rs b/dozer-sql/expression/src/javascript/validate.rs new file mode 100644 index 0000000000..b6359ae78a --- /dev/null +++ b/dozer-sql/expression/src/javascript/validate.rs @@ -0,0 +1,27 @@ +use dozer_types::types::{FieldType, Schema}; + +use crate::{error::Error, execution::Expression}; + +pub fn validate_args( + function_name: String, + args: &[Expression], + schema: &Schema, +) -> Result<(), Error> { + if args.len() != 1 { + return Err(Error::InvalidNumberOfArguments { + function_name, + expected: 1..2, + actual: args.len(), + }); + } + let typ = args[0].get_type(schema)?; + if typ.return_type != FieldType::Json { + return Err(Error::InvalidFunctionArgumentType { + function_name, + argument_index: 0, + expected: vec![FieldType::Json], + actual: typ.return_type, + }); + } + Ok(()) +} diff --git a/dozer-sql/expression/src/lib.rs b/dozer-sql/expression/src/lib.rs index 8a2895779b..63600f018c 100644 --- a/dozer-sql/expression/src/lib.rs +++ b/dozer-sql/expression/src/lib.rs @@ -16,6 +16,7 @@ mod mathematical; pub mod operator; pub mod scalar; +mod javascript; #[cfg(feature = "onnx")] mod onnx; #[cfg(feature = "python")] diff --git a/dozer-sql/src/aggregation/factory.rs b/dozer-sql/src/aggregation/factory.rs index 19a4f72d6a..657f26a829 100644 --- a/dozer-sql/src/aggregation/factory.rs +++ b/dozer-sql/src/aggregation/factory.rs @@ -10,8 +10,11 @@ use dozer_sql_expression::sqlparser::ast::Select; use dozer_types::errors::internal::BoxedError; use dozer_types::models::udf_config::UdfConfig; use dozer_types::parking_lot::Mutex; +use dozer_types::tonic::async_trait; use dozer_types::types::Schema; use std::collections::HashMap; +use std::sync::Arc; +use tokio::runtime::Runtime; #[derive(Debug)] pub struct AggregationProcessorFactory { @@ -20,6 +23,7 @@ pub struct AggregationProcessorFactory { _stateful: bool, enable_probabilistic_optimizations: bool, udfs: Vec, + runtime: Arc, /// Type name can only be determined after schema propagation. type_name: Mutex>, @@ -32,6 +36,7 @@ impl AggregationProcessorFactory { stateful: bool, enable_probabilistic_optimizations: bool, udfs: Vec, + runtime: Arc, ) -> Self { Self { id, @@ -39,17 +44,20 @@ impl AggregationProcessorFactory { _stateful: stateful, enable_probabilistic_optimizations, udfs, + runtime, type_name: Mutex::new(None), } } - fn get_planner(&self, input_schema: Schema) -> Result { - let mut projection_planner = CommonPlanner::new(input_schema, self.udfs.as_slice()); - projection_planner.plan(self.projection.clone())?; + async fn get_planner(&self, input_schema: Schema) -> Result { + let mut projection_planner = + CommonPlanner::new(input_schema, self.udfs.as_slice(), self.runtime.clone()); + projection_planner.plan(self.projection.clone()).await?; Ok(projection_planner) } } +#[async_trait] impl ProcessorFactory for AggregationProcessorFactory { fn type_name(&self) -> String { self.type_name @@ -66,7 +74,7 @@ impl ProcessorFactory for AggregationProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -75,7 +83,7 @@ impl ProcessorFactory for AggregationProcessorFactory { .get(&DEFAULT_PORT_HANDLE) .ok_or(PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?; - let planner = self.get_planner(input_schema.clone())?; + let planner = self.get_planner(input_schema.clone()).await?; *self.type_name.lock() = Some( if is_projection(&planner) { @@ -89,7 +97,7 @@ impl ProcessorFactory for AggregationProcessorFactory { Ok(planner.post_projection_schema) } - fn build( + async fn build( &self, input_schemas: HashMap, _output_schemas: HashMap, @@ -100,7 +108,7 @@ impl ProcessorFactory for AggregationProcessorFactory { .get(&DEFAULT_PORT_HANDLE) .ok_or(PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?; - let planner = self.get_planner(input_schema.clone())?; + let planner = self.get_planner(input_schema.clone()).await?; let processor: Box = if is_projection(&planner) { Box::new(ProjectionProcessor::new( diff --git a/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs b/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs index 0ad53a17e8..beb343e302 100644 --- a/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_test_planner.rs @@ -1,6 +1,6 @@ -use crate::aggregation::processor::AggregationProcessor; use crate::planner::projection::CommonPlanner; use crate::tests::utils::get_select; +use crate::{aggregation::processor::AggregationProcessor, tests::utils::create_test_runtime}; use dozer_types::types::{ Field, FieldDefinition, FieldType, Operation, Record, Schema, SourceDefinition, }; @@ -71,10 +71,13 @@ fn test_planner_with_aggregator() { ) .clone(); - let mut projection_planner = CommonPlanner::new(schema.clone(), &[]); + let runtime = create_test_runtime(); + let mut projection_planner = CommonPlanner::new(schema.clone(), &[], runtime.clone()); let statement = get_select(sql).unwrap(); - projection_planner.plan(*statement).unwrap(); + runtime + .block_on(projection_planner.plan(*statement)) + .unwrap(); let mut processor = AggregationProcessor::new( "".to_string(), diff --git a/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs b/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs index b2f368ca13..0b1e34bace 100644 --- a/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs +++ b/dozer-sql/src/aggregation/tests/aggregation_tests_utils.rs @@ -8,7 +8,7 @@ use std::collections::HashMap; use crate::aggregation::processor::AggregationProcessor; use crate::errors::PipelineError; use crate::planner::projection::CommonPlanner; -use crate::tests::utils::get_select; +use crate::tests::utils::{create_test_runtime, get_select}; use dozer_types::arrow::datatypes::ArrowNativeTypeOp; use dozer_types::chrono::{DateTime, NaiveDate, TimeZone, Utc}; use dozer_types::ordered_float::OrderedFloat; @@ -23,10 +23,13 @@ pub(crate) fn init_processor( .get(&DEFAULT_PORT_HANDLE) .unwrap_or_else(|| panic!("Error getting Input Schema")); - let mut projection_planner = CommonPlanner::new(input_schema.clone(), &[]); + let runtime = create_test_runtime(); + let mut projection_planner = CommonPlanner::new(input_schema.clone(), &[], runtime.clone()); let statement = get_select(sql).unwrap(); - projection_planner.plan(*statement).unwrap(); + runtime + .block_on(projection_planner.plan(*statement)) + .unwrap(); let processor = AggregationProcessor::new( "".to_string(), diff --git a/dozer-sql/src/builder.rs b/dozer-sql/src/builder.rs index 0cdebf4c08..8b6dc6673d 100644 --- a/dozer-sql/src/builder.rs +++ b/dozer-sql/src/builder.rs @@ -19,6 +19,8 @@ use dozer_sql_expression::sqlparser::{ }; use std::collections::HashMap; use std::collections::HashSet; +use std::sync::Arc; +use tokio::runtime::Runtime; use super::errors::UnsupportedSqlError; use super::pipeline_builder::from_builder::insert_from_to_pipeline; @@ -42,7 +44,7 @@ pub struct TableInfo { pub is_derived: bool, } /// The struct contains some contexts during query to pipeline. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct QueryContext { // Internal tables map, used to store the tables that are created by the queries pub pipeline_map: HashMap<(usize, String), OutputNodeInfo>, @@ -61,6 +63,9 @@ pub struct QueryContext { // Udf related configs pub udfs: Vec, + + // The tokio runtime + pub runtime: Arc, } impl QueryContext { @@ -69,10 +74,15 @@ impl QueryContext { self.processor_counter } - pub fn new(udfs: Vec) -> Self { + pub fn new(udfs: Vec, runtime: Arc) -> Self { QueryContext { + pipeline_map: Default::default(), + output_tables_map: Default::default(), + used_sources: Default::default(), + processors_list: Default::default(), + processor_counter: Default::default(), udfs, - ..Default::default() + runtime, } } } @@ -87,9 +97,10 @@ pub fn statement_to_pipeline( pipeline: &mut AppPipeline, override_name: Option, udfs: Vec, + runtime: Arc, ) -> Result { let dialect = DozerDialect {}; - let mut ctx = QueryContext::new(udfs); + let mut ctx = QueryContext::new(udfs, runtime); let is_top_select = true; let ast = Parser::parse_sql(&dialect, sql) .map_err(|err| PipelineError::InternalError(Box::new(err)))?; @@ -198,7 +209,7 @@ fn query_to_pipeline( } SetExpr::Query(query) => { let query_name = format!("subquery_{}", query_ctx.get_next_processor_id()); - let mut ctx = QueryContext::new(query_ctx.udfs.clone()); + let mut ctx = QueryContext::new(query_ctx.udfs.clone(), query_ctx.runtime.clone()); query_to_pipeline( &TableInfo { name: NameOrAlias(query_name, None), @@ -306,6 +317,7 @@ fn select_to_pipeline( .in_aggregations .unwrap_or(false), query_ctx.udfs.clone(), + query_ctx.runtime.clone(), ); pipeline.add_processor(Box::new(aggregation), &gen_agg_name, vec![]); @@ -316,6 +328,7 @@ fn select_to_pipeline( gen_selection_name.to_owned(), selection, query_ctx.udfs.clone(), + query_ctx.runtime.clone(), ); pipeline.add_processor(Box::new(selection), &gen_selection_name, vec![]); @@ -663,17 +676,19 @@ pub fn get_from_source( #[cfg(test)] mod tests { use super::statement_to_pipeline; - use crate::errors::PipelineError; + use crate::{errors::PipelineError, tests::utils::create_test_runtime}; use dozer_core::app::AppPipeline; #[test] #[should_panic] fn disallow_zero_outgoing_ndes() { let sql = "select * from film"; + let runtime = create_test_runtime(); statement_to_pipeline( sql, &mut AppPipeline::new_with_default_flags(), None, vec![], + runtime, ) .unwrap(); } @@ -681,11 +696,13 @@ mod tests { #[test] fn test_duplicate_into_clause() { let sql = "select * into table1 from film1 ; select * into table1 from film2"; + let runtime = create_test_runtime(); let result = statement_to_pipeline( sql, &mut AppPipeline::new_with_default_flags(), None, vec![], + runtime, ); assert!(matches!( result, @@ -750,11 +767,13 @@ mod tests { from stocks join tbl on tbl.id = stocks.id; "#; + let runtime = create_test_runtime(); let context = statement_to_pipeline( sql, &mut AppPipeline::new_with_default_flags(), None, vec![], + runtime, ) .unwrap(); @@ -772,87 +791,99 @@ mod tests { expected_keys.sort(); assert_eq!(output_keys, expected_keys); } -} -#[test] -fn test_missing_into_in_simple_from_clause() { - let sql = r#"SELECT a FROM B "#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is an error - assert!(matches!(result, Err(PipelineError::MissingIntoClause))) -} + #[test] + fn test_missing_into_in_simple_from_clause() { + let sql = r#"SELECT a FROM B "#; + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is an error + assert!(matches!(result, Err(PipelineError::MissingIntoClause))) + } -#[test] -fn test_correct_into_clause() { - let sql = r#"SELECT a INTO C FROM B"#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is ok - assert!(result.is_ok()); -} + #[test] + fn test_correct_into_clause() { + let sql = r#"SELECT a INTO C FROM B"#; + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is ok + assert!(result.is_ok()); + } -#[test] -fn test_missing_into_in_nested_from_clause() { - let sql = r#"SELECT a FROM (SELECT a from b)"#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is an error - assert!(matches!(result, Err(PipelineError::MissingIntoClause))) -} + #[test] + fn test_missing_into_in_nested_from_clause() { + let sql = r#"SELECT a FROM (SELECT a from b)"#; + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is an error + assert!(matches!(result, Err(PipelineError::MissingIntoClause))) + } -#[test] -fn test_correct_into_in_nested_from() { - let sql = r#"SELECT a INTO c FROM (SELECT a from b)"#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is ok - assert!(result.is_ok()); -} + #[test] + fn test_correct_into_in_nested_from() { + let sql = r#"SELECT a INTO c FROM (SELECT a from b)"#; + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is ok + assert!(result.is_ok()); + } -#[test] -fn test_missing_into_in_with_clause() { - let sql = r#"WITH tbl as (select a from B) + #[test] + fn test_missing_into_in_with_clause() { + let sql = r#"WITH tbl as (select a from B) select B from tbl;"#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is an error - assert!(matches!(result, Err(PipelineError::MissingIntoClause))) -} + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is an error + assert!(matches!(result, Err(PipelineError::MissingIntoClause))) + } -#[test] -fn test_correct_into_in_with_clause() { - let sql = r#"WITH tbl as (select a from B) + #[test] + fn test_correct_into_in_with_clause() { + let sql = r#"WITH tbl as (select a from B) select B into C from tbl;"#; - let result = statement_to_pipeline( - sql, - &mut AppPipeline::new_with_default_flags(), - None, - vec![], - ); - //check if the result is ok - assert!(result.is_ok()); + let runtime = create_test_runtime(); + let result = statement_to_pipeline( + sql, + &mut AppPipeline::new_with_default_flags(), + None, + vec![], + runtime, + ); + //check if the result is ok + assert!(result.is_ok()); + } } diff --git a/dozer-sql/src/expression/tests/execution.rs b/dozer-sql/src/expression/tests/execution.rs index 57acaa2486..42b697815b 100644 --- a/dozer-sql/src/expression/tests/execution.rs +++ b/dozer-sql/src/expression/tests/execution.rs @@ -1,5 +1,5 @@ use crate::projection::factory::ProjectionProcessorFactory; -use crate::tests::utils::get_select; +use crate::tests::utils::{create_test_runtime, get_select}; use dozer_core::node::ProcessorFactory; use dozer_core::DEFAULT_PORT_HANDLE; use dozer_sql_expression::execution::Expression; @@ -137,13 +137,18 @@ fn test_alias() { .clone(); let select = get_select("SELECT count(fn) AS alias1, ln as alias2 FROM t1").unwrap(); - let processor_factory = - ProjectionProcessorFactory::_new("projection_id".to_owned(), select.projection, vec![]); - let r = processor_factory - .get_output_schema( + let runtime = create_test_runtime(); + let processor_factory = ProjectionProcessorFactory::_new( + "projection_id".to_owned(), + select.projection, + vec![], + runtime.clone(), + ); + let r = runtime + .block_on(processor_factory.get_output_schema( &DEFAULT_PORT_HANDLE, &[(DEFAULT_PORT_HANDLE, schema)].into_iter().collect(), - ) + )) .unwrap(); assert_eq!( @@ -195,13 +200,18 @@ fn test_wildcard() { .clone(); let select = get_select("SELECT * FROM t1").unwrap(); - let processor_factory = - ProjectionProcessorFactory::_new("projection_id".to_owned(), select.projection, vec![]); - let r = processor_factory - .get_output_schema( + let runtime = create_test_runtime(); + let processor_factory = ProjectionProcessorFactory::_new( + "projection_id".to_owned(), + select.projection, + vec![], + runtime.clone(), + ); + let r = runtime + .block_on(processor_factory.get_output_schema( &DEFAULT_PORT_HANDLE, &[(DEFAULT_PORT_HANDLE, schema)].into_iter().collect(), - ) + )) .unwrap(); assert_eq!( diff --git a/dozer-sql/src/expression/tests/expression_builder_test.rs b/dozer-sql/src/expression/tests/expression_builder_test.rs index 5c44840762..5fd3e38771 100644 --- a/dozer-sql/src/expression/tests/expression_builder_test.rs +++ b/dozer-sql/src/expression/tests/expression_builder_test.rs @@ -1,4 +1,4 @@ -use crate::tests::utils::get_select; +use crate::tests::utils::{create_test_runtime, get_select}; use dozer_sql_expression::execution::Expression; use dozer_sql_expression::operator::BinaryOperatorType; use dozer_sql_expression::scalar::common::ScalarFunctionType; @@ -31,19 +31,17 @@ fn test_simple_function() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; - assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![] - } - ); + assert_eq!(builder.offset, schema.fields.len()); + assert_eq!(builder.aggregations, vec![]); assert_eq!( e, Expression::ScalarFunction { @@ -71,21 +69,22 @@ fn test_simple_aggr_function() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::Column { index: 0 }] - }] - } + builder.aggregations, + vec![Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::Column { index: 0 }] + }] ); assert_eq!(e, Expression::Column { index: 1 }); } @@ -114,27 +113,28 @@ fn test_2_nested_aggr_function() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::ScalarFunction { - fun: ScalarFunctionType::Round, - args: vec![ - Expression::Column { index: 1 }, - Expression::Literal(Field::Int(2)) - ] - }] + builder.aggregations, + vec![Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::ScalarFunction { + fun: ScalarFunctionType::Round, + args: vec![ + Expression::Column { index: 1 }, + Expression::Literal(Field::Int(2)) + ] }] - } + }] ); assert_eq!(e, Expression::Column { index: 2 }); } @@ -163,27 +163,28 @@ fn test_3_nested_aggr_function() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::ScalarFunction { - fun: ScalarFunctionType::Round, - args: vec![ - Expression::Column { index: 1 }, - Expression::Literal(Field::Int(2)) - ] - }] + builder.aggregations, + vec![Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::ScalarFunction { + fun: ScalarFunctionType::Round, + args: vec![ + Expression::Column { index: 1 }, + Expression::Literal(Field::Int(2)) + ] }] - } + }] ); assert_eq!( e, @@ -218,27 +219,28 @@ fn test_3_nested_aggr_function_dup() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::ScalarFunction { - fun: ScalarFunctionType::Round, - args: vec![ - Expression::Column { index: 1 }, - Expression::Literal(Field::Int(2)) - ] - }] + builder.aggregations, + vec![Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::ScalarFunction { + fun: ScalarFunctionType::Round, + args: vec![ + Expression::Column { index: 1 }, + Expression::Literal(Field::Int(2)) + ] }] - } + }] ); assert_eq!( e, @@ -276,33 +278,34 @@ fn test_3_nested_aggr_function_and_sum() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![ - Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::ScalarFunction { - fun: ScalarFunctionType::Round, - args: vec![ - Expression::Column { index: 1 }, - Expression::Literal(Field::Int(2)) - ] - }] - }, - Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::Column { index: 0 }] - } - ] - } + builder.aggregations, + vec![ + Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::ScalarFunction { + fun: ScalarFunctionType::Round, + args: vec![ + Expression::Column { index: 1 }, + Expression::Literal(Field::Int(2)) + ] + }] + }, + Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::Column { index: 0 }] + } + ] ); assert_eq!( e, @@ -341,33 +344,34 @@ fn test_3_nested_aggr_function_and_sum_3() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; + assert_eq!(builder.offset, schema.fields.len()); assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![ - Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::ScalarFunction { - fun: ScalarFunctionType::Round, - args: vec![ - Expression::Column { index: 1 }, - Expression::Literal(Field::Int(2)) - ] - }] - }, - Expression::AggregateFunction { - fun: AggregateFunctionType::Sum, - args: vec![Expression::Column { index: 0 }] - } - ] - } + builder.aggregations, + vec![ + Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::ScalarFunction { + fun: ScalarFunctionType::Round, + args: vec![ + Expression::Column { index: 1 }, + Expression::Literal(Field::Int(2)) + ] + }] + }, + Expression::AggregateFunction { + fun: AggregateFunctionType::Sum, + args: vec![Expression::Column { index: 0 }] + } + ] ); assert_eq!( e, @@ -403,9 +407,12 @@ fn test_wrong_nested_aggregations() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let _e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; } @@ -440,19 +447,17 @@ fn test_name_resolution() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; - assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![] - } - ); + assert_eq!(builder.offset, schema.fields.len()); + assert_eq!(builder.aggregations, vec![]); assert_eq!( e, Expression::ScalarFunction { @@ -483,19 +488,17 @@ fn test_alias_resolution() { ) .to_owned(); - let mut builder = ExpressionBuilder::new(schema.fields.len()); + let runtime = create_test_runtime(); + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime.clone()); let e = match &get_select(sql).unwrap().projection[0] { - SelectItem::UnnamedExpr(e) => builder.build(true, e, &schema, &[]).unwrap(), + SelectItem::UnnamedExpr(e) => runtime + .block_on(builder.build(true, e, &schema, &[])) + .unwrap(), _ => panic!("Invalid expr"), }; - assert_eq!( - builder, - ExpressionBuilder { - offset: schema.fields.len(), - aggregations: vec![] - } - ); + assert_eq!(builder.offset, schema.fields.len()); + assert_eq!(builder.aggregations, vec![]); assert_eq!( e, Expression::ScalarFunction { diff --git a/dozer-sql/src/expression/tests/test_common.rs b/dozer-sql/src/expression/tests/test_common.rs index 1ed845f599..80fea26bd7 100644 --- a/dozer-sql/src/expression/tests/test_common.rs +++ b/dozer-sql/src/expression/tests/test_common.rs @@ -1,3 +1,4 @@ +use crate::tests::utils::create_test_runtime; use crate::{projection::factory::ProjectionProcessorFactory, tests::utils::get_select}; use dozer_core::channels::ProcessorChannelForwarder; use dozer_core::executor_operation::ProcessorOperation; @@ -22,24 +23,31 @@ pub(crate) fn run_fct(sql: &str, schema: Schema, input: Vec) -> Field { let record_store = ProcessorRecordStoreDeserializer::new(Default::default()).unwrap(); let select = get_select(sql).unwrap(); - let processor_factory = - ProjectionProcessorFactory::_new("projection_id".to_owned(), select.projection, vec![]); - processor_factory - .get_output_schema( - &DEFAULT_PORT_HANDLE, - &[(DEFAULT_PORT_HANDLE, schema.clone())] - .into_iter() - .collect(), + let runtime = create_test_runtime(); + let processor_factory = ProjectionProcessorFactory::_new( + "projection_id".to_owned(), + select.projection, + vec![], + runtime.clone(), + ); + runtime + .block_on( + processor_factory.get_output_schema( + &DEFAULT_PORT_HANDLE, + &[(DEFAULT_PORT_HANDLE, schema.clone())] + .into_iter() + .collect(), + ), ) .unwrap(); - let mut processor = processor_factory - .build( + let mut processor = runtime + .block_on(processor_factory.build( HashMap::from([(DEFAULT_PORT_HANDLE, schema)]), HashMap::new(), &record_store, None, - ) + )) .unwrap(); let record_store = record_store.into_record_store(); diff --git a/dozer-sql/src/pipeline_builder/from_builder.rs b/dozer-sql/src/pipeline_builder/from_builder.rs index c1199863ba..99ca8d55a3 100644 --- a/dozer-sql/src/pipeline_builder/from_builder.rs +++ b/dozer-sql/src/pipeline_builder/from_builder.rs @@ -159,6 +159,7 @@ pub fn insert_table_operator_processor_to_pipeline( processor_name.clone(), operator.clone(), query_context.udfs.to_owned(), + query_context.runtime.clone(), ); if let Some(table) = operator.args.get(0) { diff --git a/dozer-sql/src/pipeline_builder/join_builder.rs b/dozer-sql/src/pipeline_builder/join_builder.rs index a3a5d7ed4b..b5e3ef812d 100644 --- a/dozer-sql/src/pipeline_builder/join_builder.rs +++ b/dozer-sql/src/pipeline_builder/join_builder.rs @@ -204,6 +204,7 @@ fn insert_table_operator_to_pipeline( processor_name.clone(), operator.clone(), query_context.udfs.to_owned(), + query_context.runtime.clone(), ); if let Some(table) = operator.args.get(0) { diff --git a/dozer-sql/src/planner/projection.rs b/dozer-sql/src/planner/projection.rs index 3d62001ba7..924d3cceb2 100644 --- a/dozer-sql/src/planner/projection.rs +++ b/dozer-sql/src/planner/projection.rs @@ -1,4 +1,6 @@ #![allow(dead_code)] +use std::sync::Arc; + use crate::errors::PipelineError; use crate::pipeline_builder::from_builder::string_from_sql_object_name; use dozer_sql_expression::builder::ExpressionBuilder; @@ -6,6 +8,7 @@ use dozer_sql_expression::execution::Expression; use dozer_sql_expression::sqlparser::ast::{Expr, Ident, Select, SelectItem}; use dozer_types::models::udf_config::UdfConfig; use dozer_types::types::{FieldDefinition, Schema}; +use tokio::runtime::Runtime; #[derive(Clone, Copy)] pub enum PrimaryKeyAction { @@ -24,6 +27,7 @@ pub struct CommonPlanner<'a> { pub groupby: Vec, pub projection_output: Vec, pub udfs: &'a [UdfConfig], + pub runtime: Arc, } impl<'a> CommonPlanner<'_> { @@ -44,7 +48,7 @@ impl<'a> CommonPlanner<'_> { Ok(()) } - fn add_select_item(&mut self, item: SelectItem) -> Result<(), PipelineError> { + async fn add_select_item(&mut self, item: SelectItem) -> Result<(), PipelineError> { let expr_items: Vec<(Expr, Option)> = match item { SelectItem::UnnamedExpr(expr) => vec![(expr, None)], SelectItem::ExprWithAlias { expr, alias } => vec![(expr, Some(alias.value))], @@ -74,9 +78,11 @@ impl<'a> CommonPlanner<'_> { for (expr, alias) in expr_items { let mut builder = ExpressionBuilder::new( self.input_schema.fields.len() + self.aggregation_output.len(), + self.runtime.clone(), ); - let projection_expression = - builder.build(true, &expr, &self.input_schema, self.udfs)?; + let projection_expression = builder + .build(true, &expr, &self.input_schema, self.udfs) + .await?; for new_aggr in builder.aggregations { Self::append_to_schema( @@ -100,7 +106,7 @@ impl<'a> CommonPlanner<'_> { Ok(()) } - fn add_join_item(&mut self, item: SelectItem) -> Result<(), PipelineError> { + async fn add_join_item(&mut self, item: SelectItem) -> Result<(), PipelineError> { let expr_items: Vec<(Expr, Option)> = match item { SelectItem::UnnamedExpr(expr) => vec![(expr, None)], SelectItem::ExprWithAlias { expr, alias } => vec![(expr, Some(alias.value))], @@ -111,9 +117,11 @@ impl<'a> CommonPlanner<'_> { for (expr, alias) in expr_items { let mut builder = ExpressionBuilder::new( self.input_schema.fields.len() + self.aggregation_output.len(), + self.runtime.clone(), ); - let projection_expression = - builder.build(true, &expr, &self.input_schema, self.udfs)?; + let projection_expression = builder + .build(true, &expr, &self.input_schema, self.udfs) + .await?; for new_aggr in builder.aggregations { Self::append_to_schema( @@ -137,12 +145,15 @@ impl<'a> CommonPlanner<'_> { Ok(()) } - fn add_having_item(&mut self, expr: Expr) -> Result<(), PipelineError> { + async fn add_having_item(&mut self, expr: Expr) -> Result<(), PipelineError> { let mut builder = ExpressionBuilder::from( self.input_schema.fields.len(), self.aggregation_output.clone(), + self.runtime.clone(), ); - let having_expression = builder.build(true, &expr, &self.input_schema, self.udfs)?; + let having_expression = builder + .build(true, &expr, &self.input_schema, self.udfs) + .await?; let mut post_aggregation_schema = self.input_schema.clone(); let mut aggregation_output = Vec::new(); @@ -164,14 +175,17 @@ impl<'a> CommonPlanner<'_> { Ok(()) } - fn add_groupby_items(&mut self, expr_items: Vec) -> Result<(), PipelineError> { + async fn add_groupby_items(&mut self, expr_items: Vec) -> Result<(), PipelineError> { let mut indexes = vec![]; let mut set_pk = true; for expr in expr_items { let mut builder = ExpressionBuilder::new( self.input_schema.fields.len() + self.aggregation_output.len(), + self.runtime.clone(), ); - let groupby_expression = builder.build(false, &expr, &self.input_schema, self.udfs)?; + let groupby_expression = builder + .build(false, &expr, &self.input_schema, self.udfs) + .await?; self.groupby.push(groupby_expression.clone()); if let Some(e) = self @@ -194,22 +208,26 @@ impl<'a> CommonPlanner<'_> { Ok(()) } - pub fn plan(&mut self, select: Select) -> Result<(), PipelineError> { + pub async fn plan(&mut self, select: Select) -> Result<(), PipelineError> { for expr in select.clone().projection { - self.add_select_item(expr)?; + self.add_select_item(expr).await?; } if !select.group_by.is_empty() { - self.add_groupby_items(select.group_by)?; + self.add_groupby_items(select.group_by).await?; } if let Some(having) = select.having { - self.add_having_item(having)?; + self.add_having_item(having).await?; } Ok(()) } - pub fn new(input_schema: Schema, udfs: &'a [UdfConfig]) -> CommonPlanner<'a> { + pub fn new( + input_schema: Schema, + udfs: &'a [UdfConfig], + runtime: Arc, + ) -> CommonPlanner<'a> { CommonPlanner { input_schema: input_schema.clone(), post_aggregation_schema: input_schema, @@ -219,6 +237,7 @@ impl<'a> CommonPlanner<'_> { groupby: Vec::new(), projection_output: Vec::new(), udfs, + runtime, } } } diff --git a/dozer-sql/src/planner/tests/projection_tests.rs b/dozer-sql/src/planner/tests/projection_tests.rs index 525832aea8..0ca127d87b 100644 --- a/dozer-sql/src/planner/tests/projection_tests.rs +++ b/dozer-sql/src/planner/tests/projection_tests.rs @@ -1,6 +1,6 @@ use dozer_sql_expression::aggregate::AggregateFunctionType; -use crate::planner::projection::CommonPlanner; +use crate::{planner::projection::CommonPlanner, tests::utils::create_test_runtime}; use dozer_sql_expression::execution::Expression; use dozer_sql_expression::operator::BinaryOperatorType; use dozer_sql_expression::scalar::common::ScalarFunctionType; @@ -39,10 +39,13 @@ fn test_basic_projection() { ) .to_owned(); - let mut projection_planner = CommonPlanner::new(schema, &[]); + let runtime = create_test_runtime(); + let mut projection_planner = CommonPlanner::new(schema, &[], runtime.clone()); let statement = get_select(sql).unwrap(); - projection_planner.plan(*statement).unwrap(); + runtime + .block_on(projection_planner.plan(*statement)) + .unwrap(); assert_eq!( projection_planner.aggregation_output, diff --git a/dozer-sql/src/planner/tests/schema_tests.rs b/dozer-sql/src/planner/tests/schema_tests.rs index fa5e20762d..360cbb4f8e 100644 --- a/dozer-sql/src/planner/tests/schema_tests.rs +++ b/dozer-sql/src/planner/tests/schema_tests.rs @@ -1,5 +1,5 @@ -use crate::planner::projection::CommonPlanner; use crate::tests::utils::get_select; +use crate::{planner::projection::CommonPlanner, tests::utils::create_test_runtime}; use dozer_types::types::{FieldDefinition, FieldType, Schema, SourceDefinition}; #[test] @@ -44,10 +44,13 @@ fn test_schema_index_partial_group_by() { ) .to_owned(); - let mut projection_planner = CommonPlanner::new(schema, &[]); + let runtime = create_test_runtime(); + let mut projection_planner = CommonPlanner::new(schema, &[], runtime.clone()); let statement = get_select(sql).unwrap(); - projection_planner.plan(*statement).unwrap(); + runtime + .block_on(projection_planner.plan(*statement)) + .unwrap(); assert!(projection_planner .post_projection_schema @@ -97,10 +100,13 @@ fn test_schema_index_full_group_by() { ) .to_owned(); - let mut projection_planner = CommonPlanner::new(schema, &[]); + let runtime = create_test_runtime(); + let mut projection_planner = CommonPlanner::new(schema, &[], runtime.clone()); let statement = get_select(sql).unwrap(); - projection_planner.plan(*statement).unwrap(); + runtime + .block_on(projection_planner.plan(*statement)) + .unwrap(); assert_eq!( projection_planner.post_projection_schema.primary_index, diff --git a/dozer-sql/src/product/join/factory.rs b/dozer-sql/src/product/join/factory.rs index 0694001bdb..08c2c85beb 100644 --- a/dozer-sql/src/product/join/factory.rs +++ b/dozer-sql/src/product/join/factory.rs @@ -15,6 +15,7 @@ use dozer_sql_expression::{ use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_types::{ errors::internal::BoxedError, + tonic::async_trait, types::{FieldDefinition, Schema}, }; @@ -57,6 +58,7 @@ impl JoinProcessorFactory { } } +#[async_trait] impl ProcessorFactory for JoinProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -73,7 +75,7 @@ impl ProcessorFactory for JoinProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -105,7 +107,7 @@ impl ProcessorFactory for JoinProcessorFactory { Ok(output_schema) } - fn build( + async fn build( &self, input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-sql/src/product/join/processor.rs b/dozer-sql/src/product/join/processor.rs index df1bf9cddb..de8b28cef4 100644 --- a/dozer-sql/src/product/join/processor.rs +++ b/dozer-sql/src/product/join/processor.rs @@ -251,7 +251,7 @@ mod tests { } impl Executor { - fn new(kind: JoinType) -> Self { + async fn new(kind: JoinType) -> Self { let record_store = ProcessorRecordStoreDeserializer::new(RecordStore::InMemory).unwrap(); let left_schema = create_schema("left"); @@ -287,6 +287,7 @@ mod tests { .collect(); let processor = factory .build(schemas, HashMap::new(), &record_store, None) + .await .unwrap(); let record_store = record_store.into_record_store(); @@ -352,9 +353,9 @@ mod tests { } } - #[test] - fn test_inner_join() { - let mut exec = Executor::new(JoinType::Inner); + #[tokio::test] + async fn test_inner_join() { + let mut exec = Executor::new(JoinType::Inner).await; let (left_record, ops) = exec.insert(JoinSide::Left, &[Field::UInt(0), Field::UInt(1)]); assert_eq!(ops, &[]); @@ -400,9 +401,9 @@ mod tests { ); } - #[test] - fn test_left_outer_join() { - let mut exec = Executor::new(JoinType::LeftOuter); + #[tokio::test] + async fn test_left_outer_join() { + let mut exec = Executor::new(JoinType::LeftOuter).await; let null_record = exec .record_store @@ -485,9 +486,9 @@ mod tests { ); } - #[test] - fn test_right_outer_join() { - let mut exec = Executor::new(JoinType::RightOuter); + #[tokio::test] + async fn test_right_outer_join() { + let mut exec = Executor::new(JoinType::RightOuter).await; let null_record = exec .record_store diff --git a/dozer-sql/src/product/set/set_factory.rs b/dozer-sql/src/product/set/set_factory.rs index ec7fec6599..0c51961c83 100644 --- a/dozer-sql/src/product/set/set_factory.rs +++ b/dozer-sql/src/product/set/set_factory.rs @@ -10,6 +10,7 @@ use dozer_core::{ use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_sql_expression::sqlparser::ast::{SetOperator, SetQuantifier}; use dozer_types::errors::internal::BoxedError; +use dozer_types::tonic::async_trait; use dozer_types::types::{FieldDefinition, Schema, SourceDefinition}; use super::operator::SetOperation; @@ -37,6 +38,7 @@ impl SetProcessorFactory { } } +#[async_trait] impl ProcessorFactory for SetProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -53,7 +55,7 @@ impl ProcessorFactory for SetProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -73,7 +75,7 @@ impl ProcessorFactory for SetProcessorFactory { Ok(output_schema) } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-sql/src/product/table/factory.rs b/dozer-sql/src/product/table/factory.rs index 7d36f16ab5..a0bd872391 100644 --- a/dozer-sql/src/product/table/factory.rs +++ b/dozer-sql/src/product/table/factory.rs @@ -9,7 +9,7 @@ use dozer_sql_expression::{ builder::{extend_schema_source_def, NameOrAlias}, sqlparser::ast::TableFactor, }; -use dozer_types::{errors::internal::BoxedError, types::Schema}; +use dozer_types::{errors::internal::BoxedError, tonic::async_trait, types::Schema}; use crate::errors::{PipelineError, ProductError}; use crate::window::builder::string_from_sql_object_name; @@ -28,6 +28,7 @@ impl TableProcessorFactory { } } +#[async_trait] impl ProcessorFactory for TableProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -45,7 +46,7 @@ impl ProcessorFactory for TableProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -59,7 +60,7 @@ impl ProcessorFactory for TableProcessorFactory { } } - fn build( + async fn build( &self, _input_schemas: HashMap, _output_schemas: HashMap, diff --git a/dozer-sql/src/projection/factory.rs b/dozer-sql/src/projection/factory.rs index e50240ab66..d99ad0893a 100644 --- a/dozer-sql/src/projection/factory.rs +++ b/dozer-sql/src/projection/factory.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use dozer_core::{ node::{PortHandle, Processor, ProcessorFactory}, @@ -10,11 +10,12 @@ use dozer_sql_expression::{ execution::Expression, sqlparser::ast::{Expr, Ident, SelectItem}, }; -use dozer_types::models::udf_config::UdfConfig; use dozer_types::{ errors::internal::BoxedError, types::{FieldDefinition, Schema}, }; +use dozer_types::{models::udf_config::UdfConfig, tonic::async_trait}; +use tokio::runtime::Runtime; use crate::errors::PipelineError; @@ -25,15 +26,27 @@ pub struct ProjectionProcessorFactory { select: Vec, id: String, udfs: Vec, + runtime: Arc, } impl ProjectionProcessorFactory { /// Creates a new [`ProjectionProcessorFactory`]. - pub fn _new(id: String, select: Vec, udfs: Vec) -> Self { - Self { select, id, udfs } + pub fn _new( + id: String, + select: Vec, + udfs: Vec, + runtime: Arc, + ) -> Self { + Self { + select, + id, + udfs, + runtime, + } } } +#[async_trait] impl ProcessorFactory for ProjectionProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -50,7 +63,7 @@ impl ProcessorFactory for ProjectionProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -71,13 +84,23 @@ impl ProcessorFactory for ProjectionProcessorFactory { }) .collect(); for f in fields { - if let Ok(res) = parse_sql_select_item(&f, input_schema, &self.udfs) { + if let Ok(res) = parse_sql_select_item( + &f, + input_schema, + &self.udfs, + self.runtime.clone(), + ) + .await + { select_expr.push(res) } } } _ => { - if let Ok(res) = parse_sql_select_item(s, input_schema, &self.udfs) { + if let Ok(res) = + parse_sql_select_item(s, input_schema, &self.udfs, self.runtime.clone()) + .await + { select_expr.push(res) } } @@ -101,7 +124,7 @@ impl ProcessorFactory for ProjectionProcessorFactory { Ok(output_schema) } - fn build( + async fn build( &self, input_schemas: HashMap, _output_schemas: HashMap, @@ -113,35 +136,37 @@ impl ProcessorFactory for ProjectionProcessorFactory { None => Err(PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE)), }?; - match self - .select - .iter() - .map(|item| parse_sql_select_item(item, schema, &self.udfs)) - .collect::, PipelineError>>() - { - Ok(expressions) => Ok(Box::new(ProjectionProcessor::new( - schema.clone(), - expressions.into_iter().map(|e| e.1).collect(), - checkpoint_data, - ))), - Err(error) => Err(error.into()), + let mut expressions = vec![]; + for select in &self.select { + expressions.push( + parse_sql_select_item(select, schema, &self.udfs, self.runtime.clone()).await?, + ); } + Ok(Box::new(ProjectionProcessor::new( + schema.clone(), + expressions.into_iter().map(|e| e.1).collect(), + checkpoint_data, + ))) } } -pub(crate) fn parse_sql_select_item( +pub(crate) async fn parse_sql_select_item( sql: &SelectItem, schema: &Schema, udfs: &[UdfConfig], + runtime: Arc, ) -> Result<(String, Expression), PipelineError> { match sql { SelectItem::UnnamedExpr(sql_expr) => { - let expr = - ExpressionBuilder::new(0).parse_sql_expression(true, sql_expr, schema, udfs)?; + let expr = ExpressionBuilder::new(0, runtime) + .parse_sql_expression(true, sql_expr, schema, udfs) + .await?; Ok((sql_expr.to_string(), expr)) } SelectItem::ExprWithAlias { expr, alias } => { - let expr = ExpressionBuilder::new(0).parse_sql_expression(true, expr, schema, udfs)?; + let expr = ExpressionBuilder::new(0, runtime) + .parse_sql_expression(true, expr, schema, udfs) + .await?; Ok((alias.value.clone(), expr)) } SelectItem::Wildcard(_) => Err(PipelineError::InvalidOperator("*".to_string())), diff --git a/dozer-sql/src/selection/factory.rs b/dozer-sql/src/selection/factory.rs index fd9a6a6fc9..26d2f1ac49 100644 --- a/dozer-sql/src/selection/factory.rs +++ b/dozer-sql/src/selection/factory.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use crate::errors::PipelineError; use dozer_core::{ @@ -8,8 +8,9 @@ use dozer_core::{ use dozer_recordstore::ProcessorRecordStoreDeserializer; use dozer_sql_expression::builder::ExpressionBuilder; use dozer_sql_expression::sqlparser::ast::Expr as SqlExpr; -use dozer_types::models::udf_config::UdfConfig; use dozer_types::{errors::internal::BoxedError, types::Schema}; +use dozer_types::{models::udf_config::UdfConfig, tonic::async_trait}; +use tokio::runtime::Runtime; use super::processor::SelectionProcessor; @@ -18,19 +19,27 @@ pub struct SelectionProcessorFactory { statement: SqlExpr, id: String, udfs: Vec, + runtime: Arc, } impl SelectionProcessorFactory { /// Creates a new [`SelectionProcessorFactory`]. - pub fn new(id: String, statement: SqlExpr, udf_config: Vec) -> Self { + pub fn new( + id: String, + statement: SqlExpr, + udf_config: Vec, + runtime: Arc, + ) -> Self { Self { statement, id, udfs: udf_config, + runtime, } } } +#[async_trait] impl ProcessorFactory for SelectionProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -46,7 +55,7 @@ impl ProcessorFactory for SelectionProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -57,7 +66,7 @@ impl ProcessorFactory for SelectionProcessorFactory { Ok(schema.clone()) } - fn build( + async fn build( &self, input_schemas: HashMap, _output_schemas: HashMap, @@ -68,12 +77,10 @@ impl ProcessorFactory for SelectionProcessorFactory { .get(&DEFAULT_PORT_HANDLE) .ok_or(PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?; - match ExpressionBuilder::new(schema.fields.len()).build( - false, - &self.statement, - schema, - &self.udfs, - ) { + match ExpressionBuilder::new(schema.fields.len(), self.runtime.clone()) + .build(false, &self.statement, schema, &self.udfs) + .await + { Ok(expression) => Ok(Box::new(SelectionProcessor::new( schema.clone(), expression, diff --git a/dozer-sql/src/table_operator/factory.rs b/dozer-sql/src/table_operator/factory.rs index caf09923ee..435ace253a 100644 --- a/dozer-sql/src/table_operator/factory.rs +++ b/dozer-sql/src/table_operator/factory.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use dozer_core::{ node::{PortHandle, Processor, ProcessorFactory}, @@ -10,8 +10,9 @@ use dozer_sql_expression::{ execution::Expression, sqlparser::ast::{Expr, FunctionArg, FunctionArgExpr, Value}, }; -use dozer_types::models::udf_config::UdfConfig; use dozer_types::{errors::internal::BoxedError, types::Schema}; +use dozer_types::{models::udf_config::UdfConfig, tonic::async_trait}; +use tokio::runtime::Runtime; use crate::{ errors::{PipelineError, TableOperatorError}, @@ -32,19 +33,27 @@ pub struct TableOperatorProcessorFactory { table: TableOperatorDescriptor, name: String, udfs: Vec, + runtime: Arc, } impl TableOperatorProcessorFactory { - pub fn new(id: String, table: TableOperatorDescriptor, udfs: Vec) -> Self { + pub fn new( + id: String, + table: TableOperatorDescriptor, + udfs: Vec, + runtime: Arc, + ) -> Self { Self { id: id.clone(), table, name: id, udfs, + runtime, } } } +#[async_trait] impl ProcessorFactory for TableOperatorProcessorFactory { fn id(&self) -> String { self.id.clone() @@ -61,7 +70,7 @@ impl ProcessorFactory for TableOperatorProcessorFactory { vec![DEFAULT_PORT_HANDLE] } - fn get_output_schema( + async fn get_output_schema( &self, _output_port: &PortHandle, input_schemas: &HashMap, @@ -71,7 +80,14 @@ impl ProcessorFactory for TableOperatorProcessorFactory { .ok_or(PipelineError::InvalidPortHandle(DEFAULT_PORT_HANDLE))?; let output_schema = - match operator_from_descriptor(&self.table, input_schema, &self.udfs)? { + match operator_from_descriptor( + &self.table, + input_schema, + &self.udfs, + self.runtime.clone(), + ) + .await? + { Some(operator) => operator .get_output_schema(input_schema) .map_err(PipelineError::TableOperatorError)?, @@ -86,7 +102,7 @@ impl ProcessorFactory for TableOperatorProcessorFactory { Ok(output_schema) } - fn build( + async fn build( &self, input_schemas: HashMap, _output_schemas: HashMap, @@ -100,7 +116,9 @@ impl ProcessorFactory for TableOperatorProcessorFactory { ))? .clone(); - match operator_from_descriptor(&self.table, &input_schema, &self.udfs)? { + match operator_from_descriptor(&self.table, &input_schema, &self.udfs, self.runtime.clone()) + .await? + { Some(operator) => Ok(Box::new(TableOperatorProcessor::new( self.id.clone(), operator, @@ -117,13 +135,14 @@ impl ProcessorFactory for TableOperatorProcessorFactory { } } -pub(crate) fn operator_from_descriptor( +pub(crate) async fn operator_from_descriptor( descriptor: &TableOperatorDescriptor, schema: &Schema, udfs: &[UdfConfig], + runtime: Arc, ) -> Result, PipelineError> { if &descriptor.name.to_uppercase() == "TTL" { - let operator = lifetime_from_descriptor(descriptor, schema, udfs)?; + let operator = lifetime_from_descriptor(descriptor, schema, udfs, runtime).await?; Ok(Some(operator.into())) } else { @@ -131,10 +150,11 @@ pub(crate) fn operator_from_descriptor( } } -fn lifetime_from_descriptor( +async fn lifetime_from_descriptor( descriptor: &TableOperatorDescriptor, schema: &Schema, udfs: &[UdfConfig], + runtime: Arc, ) -> Result { let table_expression_arg = descriptor @@ -168,7 +188,14 @@ fn lifetime_from_descriptor( )); }; - let expression = get_expression(descriptor.name.to_owned(), expression_arg, schema, udfs)?; + let expression = get_expression( + descriptor.name.to_owned(), + expression_arg, + schema, + udfs, + runtime, + ) + .await?; let duration = get_interval(descriptor.name.to_owned(), duration_arg)?; let operator = LifetimeTableOperator::new(None, expression, duration); @@ -214,11 +241,12 @@ fn get_interval( } } -fn get_expression( +async fn get_expression( function_name: String, interval_arg: &FunctionArg, schema: &Schema, udfs: &[UdfConfig], + runtime: Arc, ) -> Result { match interval_arg { FunctionArg::Named { name, arg: _ } => { @@ -230,10 +258,13 @@ fn get_expression( } FunctionArg::Unnamed(arg_expr) => match arg_expr { FunctionArgExpr::Expr(expr) => { - let mut builder = ExpressionBuilder::new(schema.fields.len()); - let expression = builder.build(false, expr, schema, udfs).map_err(|_| { - TableOperatorError::InvalidReference(expr.to_string(), function_name) - })?; + let mut builder = ExpressionBuilder::new(schema.fields.len(), runtime); + let expression = builder + .build(false, expr, schema, udfs) + .await + .map_err(|_| { + TableOperatorError::InvalidReference(expr.to_string(), function_name) + })?; Ok(expression) } diff --git a/dozer-sql/src/tests/builder_test.rs b/dozer-sql/src/tests/builder_test.rs index 6c7494efa0..b5aa6efdee 100644 --- a/dozer-sql/src/tests/builder_test.rs +++ b/dozer-sql/src/tests/builder_test.rs @@ -28,6 +28,7 @@ use std::sync::atomic::AtomicBool; use std::sync::Arc; use crate::builder::statement_to_pipeline; +use crate::tests::utils::create_test_runtime; /// Test Source #[derive(Debug)] @@ -190,9 +191,10 @@ impl Sink for TestSink { } } -#[tokio::test] -async fn test_pipeline_builder() { +#[test] +fn test_pipeline_builder() { let mut pipeline = AppPipeline::new_with_default_flags(); + let runtime = create_test_runtime(); let context = statement_to_pipeline( "SELECT t.Spending \ FROM TTL(TUMBLE(users, timestamp, '5 MINUTES'), timestamp, '1 MINUTE') t JOIN users u on t.CustomerID=u.CustomerID \ @@ -200,6 +202,7 @@ async fn test_pipeline_builder() { &mut pipeline, Some("results".to_string()), vec![], + runtime.clone() ) .unwrap(); @@ -236,15 +239,17 @@ async fn test_pipeline_builder() { let now = std::time::Instant::now(); - let (_temp_dir, checkpoint) = create_checkpoint_for_test().await; - DagExecutor::new(dag, checkpoint, Default::default()) - .await - .unwrap() - .start(Arc::new(AtomicBool::new(true)), Default::default()) - .await - .unwrap() - .join() - .unwrap(); + runtime.block_on(async move { + let (_temp_dir, checkpoint) = create_checkpoint_for_test().await; + DagExecutor::new(dag, checkpoint, Default::default()) + .await + .unwrap() + .start(Arc::new(AtomicBool::new(true)), Default::default()) + .await + .unwrap() + .join() + .unwrap(); + }); let elapsed = now.elapsed(); debug!("Elapsed: {:.2?}", elapsed); diff --git a/dozer-sql/src/tests/utils.rs b/dozer-sql/src/tests/utils.rs index a123fbfce1..f5cade9909 100644 --- a/dozer-sql/src/tests/utils.rs +++ b/dozer-sql/src/tests/utils.rs @@ -1,9 +1,12 @@ +use std::sync::Arc; + use crate::errors::PipelineError; use dozer_sql_expression::sqlparser::{ ast::{Query, Select, SetExpr, Statement}, dialect::DozerDialect, parser::Parser, }; +use tokio::runtime::Runtime; pub fn get_select(sql: &str) -> Result, PipelineError> { let dialect = DozerDialect {}; @@ -25,3 +28,12 @@ pub fn get_query_select(query: &Query) -> Box