From 49e56acb8500c01b19ce8618469ad31afe2b4d46 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Mon, 3 Feb 2025 18:15:25 +0100 Subject: [PATCH] feat: add support for bigquery in core (#10447) * feat: add support for bigquery in core * add table schema * region -> location * use row count limit --------- Co-authored-by: Henry Fontanier --- core/src/databases/database.rs | 1 + .../databases/remote_databases/bigquery.rs | 390 ++++++++++++++++++ .../remote_databases/get_remote_database.rs | 6 +- .../databases/remote_databases/snowflake.rs | 14 +- core/src/lib.rs | 1 + 5 files changed, 404 insertions(+), 8 deletions(-) create mode 100644 core/src/databases/remote_databases/bigquery.rs diff --git a/core/src/databases/database.rs b/core/src/databases/database.rs index 3684822423d0..0e50fb4db02b 100644 --- a/core/src/databases/database.rs +++ b/core/src/databases/database.rs @@ -33,6 +33,7 @@ pub enum QueryDatabaseError { pub enum SqlDialect { DustSqlite, Snowflake, + Bigquery, } #[derive(Debug, Deserialize, Serialize, Clone)] diff --git a/core/src/databases/remote_databases/bigquery.rs b/core/src/databases/remote_databases/bigquery.rs new file mode 100644 index 000000000000..d21b5b980019 --- /dev/null +++ b/core/src/databases/remote_databases/bigquery.rs @@ -0,0 +1,390 @@ +use std::collections::HashSet; + +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::future::try_join_all; +use gcp_bigquery_client::{ + model::{ + field_type::FieldType, get_query_results_parameters::GetQueryResultsParameters, job::Job, + job_configuration::JobConfiguration, job_configuration_query::JobConfigurationQuery, + job_reference::JobReference, table_row::TableRow, + }, + yup_oauth2::ServiceAccountKey, + Client, +}; +use serde_json::Value; + +use crate::databases::{ + database::{QueryDatabaseError, QueryResult, SqlDialect}, + table::Table, + table_schema::{TableSchema, TableSchemaColumn, TableSchemaFieldType}, +}; + +use super::remote_database::RemoteDatabase; + +#[derive(Debug)] +pub struct BigQueryQueryPlan { + is_select_query: bool, + affected_tables: Vec, +} + +pub struct BigQueryRemoteDatabase { + project_id: String, + location: String, + client: Client, +} + +impl TryFrom<&gcp_bigquery_client::model::table_schema::TableSchema> for TableSchema { + type Error = anyhow::Error; + + fn try_from( + schema: &gcp_bigquery_client::model::table_schema::TableSchema, + ) -> Result { + match &schema.fields { + Some(fields) => Ok(TableSchema::from_columns( + fields + .iter() + .map(|f| TableSchemaColumn { + name: f.name.clone(), + value_type: match f.r#type { + FieldType::String => TableSchemaFieldType::Text, + FieldType::Integer | FieldType::Int64 => TableSchemaFieldType::Int, + FieldType::Float + | FieldType::Float64 + | FieldType::Numeric + | FieldType::Bignumeric => TableSchemaFieldType::Float, + FieldType::Boolean | FieldType::Bool => TableSchemaFieldType::Bool, + FieldType::Timestamp + | FieldType::Datetime + | FieldType::Date + | FieldType::Time => TableSchemaFieldType::DateTime, + FieldType::Bytes + | FieldType::Geography + | FieldType::Json + | FieldType::Record + | FieldType::Struct + | FieldType::Interval => TableSchemaFieldType::Text, + }, + possible_values: None, + }) + .collect(), + )), + None => Err(anyhow!("No fields found in schema"))?, + } + } +} + +pub const MAX_QUERY_RESULT_ROWS: usize = 25_000; +pub const PAGE_SIZE: i32 = 500; + +impl BigQueryRemoteDatabase { + pub fn new( + project_id: String, + location: String, + client: Client, + ) -> Result { + Ok(Self { + project_id, + location, + client, + }) + } + + pub async fn execute_query( + &self, + query: &str, + ) -> Result<(Vec, TableSchema), QueryDatabaseError> { + let job = Job { + configuration: Some(JobConfiguration { + query: Some(JobConfigurationQuery { + query: query.to_string(), + use_legacy_sql: Some(false), + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }; + + let inserted_job = self + .client + .job() + .insert(&self.project_id, job) + .await + .map_err(|e| QueryDatabaseError::GenericError(anyhow!("Error inserting job: {}", e)))?; + + let job_id = match inserted_job.job_reference { + Some(job_reference) => match job_reference.job_id { + Some(job_id) => job_id, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "Job reference not found" + )))?, + }, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "Job reference not found" + )))?, + }; + + let mut query_result_rows: usize = 0; + let mut all_rows: Vec = Vec::new(); + let mut page_token: Option = None; + let mut schema: Option = None; + + 'fetch_rows: loop { + let res = self + .client + .job() + .get_query_results( + &self.project_id, + &job_id, + GetQueryResultsParameters { + location: Some(self.location.clone()), + page_token: page_token.clone(), + max_results: Some(PAGE_SIZE), + ..Default::default() + }, + ) + .await + .map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error getting query results: {}", e)) + })?; + + if !res.job_complete.unwrap_or(false) { + Err(QueryDatabaseError::GenericError(anyhow!( + "Query job not complete" + )))? + } + + let rows = res.rows.unwrap_or_default(); + + query_result_rows += rows.len(); + + if query_result_rows >= MAX_QUERY_RESULT_ROWS { + return Err(QueryDatabaseError::ResultTooLarge(format!( + "Query result size exceeds limit of {} rows", + MAX_QUERY_RESULT_ROWS + ))); + } + + page_token = res.page_token; + all_rows.extend(rows); + + if let (None, Some(s)) = (&mut schema, res.schema) { + schema = Some(s); + } + + if page_token.is_none() { + break 'fetch_rows; + } + } + + let fields = match &schema { + Some(s) => match &s.fields { + Some(f) => f, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "Schema not found" + )))?, + }, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "Schema not found" + )))?, + }; + + let schema = match &schema { + Some(s) => TableSchema::try_from(s)?, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "Schema not found" + )))?, + }; + + let parsed_rows = all_rows + .into_iter() + .map(|row| { + let cols = row.columns.unwrap_or_default(); + let mut map = serde_json::Map::new(); + for (c, f) in cols.into_iter().zip(fields) { + map.insert( + f.name.clone(), + match c.value { + Some(v) => match f.r#type { + FieldType::Struct + | FieldType::Record + | FieldType::Json + | FieldType::Geography => match &v { + Value::String(_) => v, + _ => Value::String(v.to_string()), + }, + _ => v, + }, + None => serde_json::Value::Null, + }, + ); + } + + Ok(QueryResult { + value: serde_json::Value::Object(map), + }) + }) + .collect::>>()?; + + Ok((parsed_rows, schema)) + } + + pub async fn get_query_plan( + &self, + query: &str, + ) -> Result { + let job = Job { + configuration: Some(JobConfiguration { + query: Some(JobConfigurationQuery { + query: query.to_string(), + use_legacy_sql: Some(false), + ..Default::default() + }), + dry_run: Some(true), + ..Default::default() + }), + job_reference: Some(JobReference { + location: Some(self.location.clone()), + ..Default::default() + }), + ..Default::default() + }; + + let job_result = self + .client + .job() + .insert(&self.project_id, job) + .await + .map_err(|e| QueryDatabaseError::GenericError(anyhow!("Error inserting job: {}", e)))?; + + let query_stats = match job_result.statistics { + Some(stats) => match stats.query { + Some(stats) => stats, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "No statistics found" + )))?, + }, + None => Err(QueryDatabaseError::GenericError(anyhow!( + "No statistics found" + )))?, + }; + + let is_select_query = match query_stats.statement_type { + Some(stmt_type) => stmt_type.to_ascii_uppercase() == "SELECT", + None => false, + }; + + let affected_tables = match query_stats.referenced_tables { + Some(tables) => tables, + None => Vec::new(), + } + .iter() + .map(|t| format!("{}.{}", t.dataset_id, t.table_id)) + .collect(); + + Ok(BigQueryQueryPlan { + is_select_query, + affected_tables, + }) + } +} + +#[async_trait] +impl RemoteDatabase for BigQueryRemoteDatabase { + fn dialect(&self) -> SqlDialect { + SqlDialect::Bigquery + } + + async fn authorize_and_execute_query( + &self, + tables: &Vec, + query: &str, + ) -> Result<(Vec, TableSchema), QueryDatabaseError> { + // Ensure that query is a SELECT query and only uses tables that are allowed. + let plan = self.get_query_plan(query).await?; + + if !plan.is_select_query { + Err(QueryDatabaseError::ExecutionError(format!( + "Query is not a SELECT query" + )))? + } + + let used_tables: HashSet<&str> = plan + .affected_tables + .iter() + .map(|table| table.as_str()) + .collect(); + + let allowed_tables: HashSet<&str> = tables.iter().map(|table| table.name()).collect(); + + let used_forbidden_tables = used_tables + .into_iter() + .filter(|table| !allowed_tables.contains(*table)) + .collect::>(); + + if !used_forbidden_tables.is_empty() { + Err(QueryDatabaseError::ExecutionError(format!( + "Query uses tables that are not allowed: {}", + used_forbidden_tables.join(", ") + )))? + } + + self.execute_query(query).await + } + + async fn get_tables_schema(&self, opaque_ids: &Vec<&str>) -> Result> { + let bq_tables: Vec = + try_join_all(opaque_ids.iter().map(|opaque_id| async move { + let parts: Vec<&str> = opaque_id.split('.').collect(); + if parts.len() != 2 { + Err(anyhow!("Invalid opaque ID: {}", opaque_id))? + } + let (dataset_id, table_id) = (parts[0], parts[1]); + + self.client + .table() + .get(&self.project_id, dataset_id, table_id, None) + .await + .map_err(|e| anyhow!("Error getting table metadata: {}", e)) + })) + .await?; + + let schemas: Vec = bq_tables + .into_iter() + .map(|table| TableSchema::try_from(&table.schema)) + .collect::>>()?; + + Ok(schemas) + } +} + +pub async fn get_bigquery_remote_database( + credentials: serde_json::Map, +) -> Result> { + let location = match credentials.get("location") { + Some(serde_json::Value::String(v)) => v.to_string(), + _ => Err(anyhow!("Invalid credentials: location not found"))?, + }; + let project_id = match credentials.get("project_id") { + Some(serde_json::Value::String(v)) => v.to_string(), + _ => Err(anyhow!("Invalid credentials: project_id not found"))?, + }; + + let sa_key: ServiceAccountKey = serde_json::from_value(serde_json::Value::Object(credentials)) + .map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error deserializing credentials: {}", e)) + })?; + + let client = Client::from_service_account_key(sa_key, false) + .await + .map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error creating BigQuery client: {}", e)) + })?; + + Ok(Box::new(BigQueryRemoteDatabase { + project_id, + location, + client, + })) +} diff --git a/core/src/databases/remote_databases/get_remote_database.rs b/core/src/databases/remote_databases/get_remote_database.rs index 6ceea829aedb..25e56445649d 100644 --- a/core/src/databases/remote_databases/get_remote_database.rs +++ b/core/src/databases/remote_databases/get_remote_database.rs @@ -5,7 +5,7 @@ use crate::{ oauth::{client::OauthClient, credential::CredentialProvider}, }; -use super::snowflake::SnowflakeRemoteDatabase; +use super::{bigquery::get_bigquery_remote_database, snowflake::SnowflakeRemoteDatabase}; pub async fn get_remote_database( credential_id: &str, @@ -17,6 +17,10 @@ pub async fn get_remote_database( let db = SnowflakeRemoteDatabase::new(content)?; Ok(Box::new(db) as Box) } + CredentialProvider::Bigquery => { + let db = get_bigquery_remote_database(content).await?; + Ok(db) + } _ => Err(anyhow!( "{:?} is not a supported remote database provider", provider diff --git a/core/src/databases/remote_databases/snowflake.rs b/core/src/databases/remote_databases/snowflake.rs index 3ff54946d81d..c9c1a342f9c5 100644 --- a/core/src/databases/remote_databases/snowflake.rs +++ b/core/src/databases/remote_databases/snowflake.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, env, mem}; +use std::{collections::HashSet, env}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -44,7 +44,7 @@ struct SnowflakeQueryPlanEntry { operation: Option, } -pub const MAX_QUERY_RESULT_SIZE_BYTES: usize = 8 * 1024 * 1024; // 8MB +pub const MAX_QUERY_RESULT_ROWS: usize = 25_000; pub const FORBIDDEN_OPERATIONS: [&str; 3] = ["UPDATE", "DELETE", "INSERT"]; @@ -235,7 +235,7 @@ impl SnowflakeRemoteDatabase { ))), }?; - let mut query_result_size: usize = 0; + let mut query_result_rows: usize = 0; let mut all_rows: Vec = Vec::new(); // Fetch results chunk by chunk. @@ -253,11 +253,11 @@ impl SnowflakeRemoteDatabase { .collect::>>()?; // Check that total result size so far does not exceed the limit. - query_result_size += rows.len() * mem::size_of::(); - if query_result_size >= MAX_QUERY_RESULT_SIZE_BYTES { + query_result_rows += rows.len(); + if query_result_rows >= MAX_QUERY_RESULT_ROWS { return Err(QueryDatabaseError::ResultTooLarge(format!( - "Query result size exceeds limit of {} bytes", - MAX_QUERY_RESULT_SIZE_BYTES + "Query result size exceeds limit of {} rows", + MAX_QUERY_RESULT_ROWS ))); } diff --git a/core/src/lib.rs b/core/src/lib.rs index 31e6dbca6d6e..da8d7dba480a 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -28,6 +28,7 @@ pub mod databases { pub mod table; pub mod table_schema; pub mod remote_databases { + pub mod bigquery; pub mod get_remote_database; pub mod remote_database; pub mod snowflake;