diff --git a/core/src/blocks/block.rs b/core/src/blocks/block.rs index a266e182a2db..e98389ea36b1 100644 --- a/core/src/blocks/block.rs +++ b/core/src/blocks/block.rs @@ -1,7 +1,7 @@ use crate::blocks::{ browser::Browser, chat::Chat, code::Code, curl::Curl, data::Data, data_source::DataSource, - database_schema::DatabaseSchema, end::End, input::Input, llm::LLM, map::Map, r#while::While, - reduce::Reduce, search::Search, + database::Database, database_schema::DatabaseSchema, end::End, input::Input, llm::LLM, + map::Map, r#while::While, reduce::Reduce, search::Search, }; use crate::data_sources::qdrant::QdrantClients; use crate::project::Project; @@ -75,6 +75,7 @@ pub enum BlockType { While, End, DatabaseSchema, + Database, } impl ToString for BlockType { @@ -94,6 +95,7 @@ impl ToString for BlockType { BlockType::While => String::from("while"), BlockType::End => String::from("end"), BlockType::DatabaseSchema => String::from("database_schema"), + BlockType::Database => String::from("database"), } } } @@ -196,6 +198,7 @@ pub fn parse_block(t: BlockType, block_pair: Pair) -> Result Ok(Box::new(While::parse(block_pair)?)), BlockType::End => Ok(Box::new(End::parse(block_pair)?)), BlockType::DatabaseSchema => Ok(Box::new(DatabaseSchema::parse(block_pair)?)), + BlockType::Database => Ok(Box::new(Database::parse(block_pair)?)), } } diff --git a/core/src/blocks/database.rs b/core/src/blocks/database.rs new file mode 100644 index 000000000000..9bbf9cab5e1e --- /dev/null +++ b/core/src/blocks/database.rs @@ -0,0 +1,145 @@ +use crate::blocks::block::{parse_pair, Block, BlockResult, BlockType, Env}; +use crate::Rule; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; + +use pest::iterators::Pair; +use serde_json::{json, Value}; +use tokio::sync::mpsc::UnboundedSender; + +use super::block::replace_variables_in_string; +use super::helpers::get_data_source_project; + +#[derive(Clone)] +pub struct Database { + query: String, + workspace_id: String, + data_source_id: String, + database_id: String, +} + +impl Database { + pub fn parse(block_pair: Pair) -> Result { + let mut query: Option = None; + let mut workspace_id: Option = None; + let mut data_source_id: Option = None; + let mut database_id: Option = None; + + for pair in block_pair.into_inner() { + match pair.as_rule() { + Rule::pair => { + let (key, value) = parse_pair(pair)?; + match key.as_str() { + "query" => query = Some(value), + "workspace_id" => workspace_id = Some(value), + "data_source_id" => data_source_id = Some(value), + "database_id" => database_id = Some(value), + _ => Err(anyhow!("Unexpected `{}` in `database` block", key))?, + } + } + Rule::expected => Err(anyhow!( + "`expected` is not yet supported in `database` block" + ))?, + _ => unreachable!(), + } + } + + if !query.is_some() { + Err(anyhow!("Missing required `query` in `database` block"))?; + } + if !workspace_id.is_some() { + Err(anyhow!( + "Missing required `workspace_id` in `database` block" + ))?; + } + if !data_source_id.is_some() { + Err(anyhow!( + "Missing required `data_source_id` in `database` block" + ))?; + } + if !database_id.is_some() { + Err(anyhow!( + "Missing required `database_id` in `database` block" + ))?; + } + + Ok(Database { + query: query.unwrap(), + workspace_id: workspace_id.unwrap(), + data_source_id: data_source_id.unwrap(), + database_id: database_id.unwrap(), + }) + } +} + +#[async_trait] +impl Block for Database { + fn block_type(&self) -> BlockType { + BlockType::Database + } + + fn inner_hash(&self) -> String { + let mut hasher = blake3::Hasher::new(); + hasher.update("database_schema".as_bytes()); + hasher.update(self.query.as_bytes()); + hasher.update(self.data_source_id.as_bytes()); + hasher.update(self.database_id.as_bytes()); + format!("{}", hasher.finalize().to_hex()) + } + + async fn execute( + &self, + _name: &str, + env: &Env, + _event_sender: Option>, + ) -> Result { + let workspace_id = replace_variables_in_string(&self.workspace_id, "workspace_id", env)?; + let data_source_id = + replace_variables_in_string(&self.data_source_id, "data_source_id", env)?; + let database_id = replace_variables_in_string(&self.database_id, "database_id", env)?; + + let project = get_data_source_project(&workspace_id, &data_source_id, env).await?; + + let database = match env + .store + .load_database(&project, &data_source_id, &database_id) + .await? + { + Some(d) => d, + None => Err(anyhow!( + "Database `{}` not found in data source `{}`", + database_id, + data_source_id + ))?, + }; + + let (rows, schema) = match database + .query(&project, env.store.clone(), &self.query) + .await + { + Ok(r) => r, + Err(e) => Err(anyhow!( + "Error querying database `{}` in data source `{}`: {}", + database_id, + data_source_id, + e + ))?, + }; + + Ok(BlockResult { + value: json!({ + "rows": rows, + "schema": schema, + }), + meta: None, + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} diff --git a/core/src/blocks/database_schema.rs b/core/src/blocks/database_schema.rs index 7b6cc1fb3dd2..dc3baad40311 100644 --- a/core/src/blocks/database_schema.rs +++ b/core/src/blocks/database_schema.rs @@ -4,7 +4,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::future::try_join_all; use pest::iterators::Pair; -use serde_json::Value; +use serde_json::{json, Value}; use tokio::sync::mpsc::UnboundedSender; use super::helpers::get_data_source_project; @@ -78,8 +78,19 @@ impl Block for DatabaseSchema { )) .await?; + let results = std::iter::zip(schemas, databases) + .map(|(s, (w, d, db))| { + json!({ + "workspace_id": w, + "data_source_id": d, + "database_id": db, + "schema": s, + }) + }) + .collect::>(); + Ok(BlockResult { - value: serde_json::to_value(schemas)?, + value: serde_json::to_value(results)?, meta: None, }) } diff --git a/core/src/lib.rs b/core/src/lib.rs index 2d23f50cdb43..ba15f8b131bd 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -48,6 +48,7 @@ pub mod blocks { pub mod curl; pub mod data; pub mod data_source; + pub mod database; pub mod database_schema; pub mod end; pub mod helpers; diff --git a/core/src/run.rs b/core/src/run.rs index 7ad888e5fb1b..12bdb14a1d21 100644 --- a/core/src/run.rs +++ b/core/src/run.rs @@ -57,6 +57,7 @@ impl RunConfig { BlockType::While => 64, BlockType::End => 64, BlockType::DatabaseSchema => 8, + BlockType::Database => 8, } } }