Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(structured data): query DB v0 #2499

Merged
merged 22 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2032,14 +2032,14 @@ async fn databases_schema_retrieve(
&format!("No database found for id `{}`", database_id),
None,
),
Ok(Some(db)) => match db.get_schema(&project, state.store.clone()).await {
Ok(Some(db)) => match db.get_schema(&project, state.store.clone(), false).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to retrieve database schema",
Some(e),
),
Ok(schema) => (
Ok((schema, _)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
Expand All @@ -2052,6 +2052,59 @@ async fn databases_schema_retrieve(
}
}

#[derive(serde::Deserialize)]
struct DatabaseQueryRunPayload {
query: String,
}

async fn databases_query_run(
extract::Path((project_id, data_source_id, database_id)): extract::Path<(i64, String, String)>,
extract::Json(payload): extract::Json<DatabaseQueryRunPayload>,
extract::Extension(state): extract::Extension<Arc<APIState>>,
) -> (StatusCode, Json<APIResponse>) {
let project = project::Project::new_from_id(project_id);

match state
.store
.load_database(&project, &data_source_id, &database_id)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to retrieve database",
Some(e),
),
Ok(None) => error_response(
StatusCode::NOT_FOUND,
"database_not_found",
&format!("No database found for id `{}`", database_id),
None,
),
Ok(Some(db)) => match db
.query(&project, state.store.clone(), &payload.query)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to run query",
Some(e),
),
Ok((rows, schema)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"schema": schema,
"rows": rows,
})),
}),
),
},
}
}

// Misc

#[derive(serde::Deserialize)]
Expand Down Expand Up @@ -2273,6 +2326,10 @@ fn main() {
"/projects/:project_id/data_sources/:data_source_id/databases/:database_id/schema",
get(databases_schema_retrieve),
)
.route(
"/projects/:project_id/data_sources/:data_source_id/databases/:database_id/query",
post(databases_query_run),
)
// Misc
.route("/tokenize", post(tokenize))

Expand Down
244 changes: 231 additions & 13 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{project::Project, stores::store::Store, utils};
use anyhow::{anyhow, Result};

use crate::{project::Project, stores::store::Store};
use rayon::prelude::*;
use rusqlite::{Connection, ToSql};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
Expand Down Expand Up @@ -48,7 +48,8 @@ impl Database {
&self,
project: &Project,
store: Box<dyn Store + Sync + Send>,
) -> Result<DatabaseSchema> {
return_rows: bool,
) -> Result<(DatabaseSchema, Option<HashMap<String, Vec<DatabaseRow>>>)> {
match self.db_type {
DatabaseType::REMOTE => Err(anyhow!("Remote DB not implemented.")),
DatabaseType::LOCAL => {
Expand Down Expand Up @@ -83,20 +84,227 @@ impl Database {
.into_iter()
.collect::<Vec<_>>();

Ok(DatabaseSchema(
rows.into_par_iter()
.map(|(table, rows)| {
Ok((
table.table_id().to_string(),
DatabaseSchemaTable::new(table, TableSchema::from_rows(&rows)?),
))
})
.collect::<Result<HashMap<_, _>>>()?,
let returned_rows = match return_rows {
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
true => Some(
rows.clone()
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.map(|(table, rows)| (table.table_id().to_string(), rows))
.collect::<HashMap<_, _>>(),
),
false => None,
};

Ok((
DatabaseSchema(
rows.into_par_iter()
.map(|(table, r)| {
Ok((
table.table_id().to_string(),
DatabaseSchemaTable::new(table, TableSchema::from_rows(&r)?),
))
})
.collect::<Result<HashMap<_, _>>>()?,
),
returned_rows,
))
}
}
}

pub async fn create_in_memory_sqlite_conn(
&self,
project: &Project,
store: Box<dyn Store + Sync + Send>,
) -> Result<Connection> {
match self.db_type {
DatabaseType::REMOTE => Err(anyhow!(
"Cannot build an in-memory SQLite DB for a remote database."
)),
DatabaseType::LOCAL => {
let time_build_db_start = utils::now();
let (schema, rows_by_table) = self.get_schema(project, store.clone(), true).await?;
let rows_by_table = match rows_by_table {
Some(rows) => rows,
None => return Err(anyhow!("No rows found")),
};
utils::done(&format!(
"DSSTRUCTSTAT Finished retrieving schema: duration={}ms",
utils::now() - time_build_db_start
));

let table_schemas: HashMap<String, TableSchema> = schema
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.filter(|(_, table)| !table.schema.is_empty())
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
.map(|(table_name, table)| (table_name.clone(), table.schema.clone()))
.collect();

let generate_create_table_sql_start = utils::now();
let create_tables_sql: String = schema
.iter()
.filter(|(_, table)| !table.schema.is_empty())
.map(|(table_name, table)| {
table
.schema
.get_create_table_sql_string(table_name.as_str())
})
.collect::<Vec<_>>()
.join("\n");
utils::done(&format!(
"DSSTRUCTSTAT Finished generating create table SQL: duration={}ms",
utils::now() - generate_create_table_sql_start
));

let conn = rusqlite::Connection::open_in_memory()?;

let create_tables_execute_start = utils::now();
conn.execute_batch(&create_tables_sql)?;
utils::done(&format!(
"DSSTRUCTSTAT Finished creating tables: duration={}ms",
utils::now() - create_tables_execute_start
));

let insert_execute_start = utils::now();
rows_by_table
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
.iter()
.filter(|(_, rows)| !rows.is_empty())
.map(|(table_name, rows)| {
let table_schema = table_schemas
.get(table_name)
.ok_or_else(|| anyhow!("No schema found for table {}", table_name))?;

rows.iter()
.map(|row| {
match table_schema
.get_insert_row_sql_string(table_name, &row.content)
{
Ok((query, boxed_params)) => {
let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

match conn.execute(&query, params_refs.as_slice()) {
Ok(res) => Ok(res),
Err(e) => Err(anyhow!("Error: {}", e)),
}
}
Err(e) => Err(anyhow!("Error: {}", e)),
}
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
utils::done(&format!(
"DSSTRUCTSTAT Finished inserting rows: duration={}ms",
utils::now() - insert_execute_start
));

Ok(conn)
}
}
}

pub async fn query(
&self,
project: &Project,
store: Box<dyn Store + Sync + Send>,
query: &str,
) -> Result<(Vec<DatabaseRow>, TableSchema)> {
match self.db_type {
DatabaseType::REMOTE => Err(anyhow!("Remote DB not implemented.")),
DatabaseType::LOCAL => {
let conn = self
.create_in_memory_sqlite_conn(project, store.clone())
.await?;

let time_query_start = utils::now();

let mut stmt = conn.prepare(query)?;

// copy the column names into a vector of strings
let column_names = stmt
.column_names()
.into_iter()
.map(|x| x.to_string())
.collect::<Vec<String>>();

// Execute the query and collect the results in a vector of serde_json::Value objects.
let result_rows = stmt
.query_and_then([], |row| {
column_names
.iter()
.enumerate()
.map(|(i, column_name)| {
Ok((
column_name.clone(),
match row.get(i) {
Err(e) => {
return Err(anyhow!(
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
"Failed to retrieve value for column {}: {}",
column_name,
e
))
}
Ok(v) => match v {
rusqlite::types::Value::Integer(i) => {
Ok(serde_json::Value::Number(i.into()))
}
rusqlite::types::Value::Real(f) => {
match serde_json::Number::from_f64(f) {
Some(n) => Ok(serde_json::Value::Number(n)),
None => Err(anyhow!(
"Invalid float value for column {}",
column_name
)),
}
}
rusqlite::types::Value::Text(t) => {
Ok(serde_json::Value::String(t.clone()))
}
rusqlite::types::Value::Blob(b) => {
match String::from_utf8(b.clone()) {
Err(_) => Err(anyhow!(
"Invalid UTF-8 sequence for column {}",
column_name
)),
Ok(s) => Ok(serde_json::Value::String(s)),
}
}
rusqlite::types::Value::Null => {
Ok(serde_json::Value::Null)
}
},
}?,
))
})
.collect::<Result<serde_json::Value>>()
})?
.collect::<Result<Vec<_>>>()?
fontanierh marked this conversation as resolved.
Show resolved Hide resolved
.into_par_iter()
.map(|v| DatabaseRow::new(utils::now(), None, &v))
.collect::<Vec<_>>();
utils::done(&format!(
"DSSTRUCTSTAT Finished executing user query: duration={}ms",
utils::now() - time_query_start
));

let infer_result_schema_start = utils::now();
let table_schema = TableSchema::from_rows(&result_rows)?;
utils::done(&format!(
"DSSTRUCTSTAT Finished inferring schema: duration={}ms",
utils::now() - infer_result_schema_start
));

utils::done(&format!(
"DSSTRUCTSTAT Finished query database: duration={}ms",
utils::now() - time_query_start
));

Ok((result_rows, table_schema))
}
}
}

// Getters
pub fn created(&self) -> u64 {
self.created
Expand Down Expand Up @@ -182,7 +390,7 @@ impl DatabaseRow {
}

#[derive(Debug, Serialize)]
struct DatabaseSchemaTable {
pub struct DatabaseSchemaTable {
table: DatabaseTable,
schema: TableSchema,
}
Expand All @@ -191,7 +399,17 @@ impl DatabaseSchemaTable {
pub fn new(table: DatabaseTable, schema: TableSchema) -> Self {
DatabaseSchemaTable { table, schema }
}

pub fn is_empty(&self) -> bool {
self.schema.is_empty()
}
}

#[derive(Debug, Serialize)]
pub struct DatabaseSchema(HashMap<String, DatabaseSchemaTable>);

impl DatabaseSchema {
pub fn iter(&self) -> std::collections::hash_map::Iter<String, DatabaseSchemaTable> {
self.0.iter()
}
}
Loading
Loading