Skip to content

Commit

Permalink
scaffolding
Browse files Browse the repository at this point in the history
  • Loading branch information
fontanierh committed Nov 13, 2023
1 parent a17b5b9 commit 0cff1e8
Show file tree
Hide file tree
Showing 4 changed files with 352 additions and 13 deletions.
61 changes: 59 additions & 2 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2032,14 +2032,14 @@ async fn databases_schema_retrieve(
&format!("No database found for id `{}`", database_id),
None,
),
Ok(Some(db)) => match db.get_schema(&project, state.store.clone()).await {
Ok(Some(db)) => match db.get_schema(&project, state.store.clone(), false).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to retrieve database schema",
Some(e),
),
Ok(schema) => (
Ok((schema, _)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
Expand All @@ -2052,6 +2052,59 @@ async fn databases_schema_retrieve(
}
}

#[derive(serde::Deserialize)]
struct DatabaseQueryRunPayload {
query: String,
}

async fn databases_query_run(
extract::Path((project_id, data_source_id, database_id)): extract::Path<(i64, String, String)>,
extract::Json(payload): extract::Json<DatabaseQueryRunPayload>,
extract::Extension(state): extract::Extension<Arc<APIState>>,
) -> (StatusCode, Json<APIResponse>) {
let project = project::Project::new_from_id(project_id);

match state
.store
.load_database(&project, &data_source_id, &database_id)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to retrieve database",
Some(e),
),
Ok(None) => error_response(
StatusCode::NOT_FOUND,
"database_not_found",
&format!("No database found for id `{}`", database_id),
None,
),
Ok(Some(db)) => match db
.query(&project, state.store.clone(), &payload.query)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to run query",
Some(e),
),
Ok((rows, schema)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"schema": schema,
"rows": rows,
})),
}),
),
},
}
}

// Misc

#[derive(serde::Deserialize)]
Expand Down Expand Up @@ -2273,6 +2326,10 @@ fn main() {
"/projects/:project_id/data_sources/:data_source_id/databases/:database_id/schema",
get(databases_schema_retrieve),
)
.route(
"/projects/:project_id/data_sources/:data_source_id/databases/:database_id/query",
post(databases_query_run),
)
// Misc
.route("/tokenize", post(tokenize))

Expand Down
114 changes: 110 additions & 4 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use anyhow::{anyhow, Result};
use crate::{project::Project, stores::store::Store};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{Number, Value};
use std::collections::HashMap;

use super::table_schema::TableSchema;
Expand Down Expand Up @@ -48,7 +48,8 @@ impl Database {
&self,
project: &Project,
store: Box<dyn Store + Sync + Send>,
) -> Result<DatabaseSchema> {
return_rows: bool,
) -> Result<(DatabaseSchema, Option<HashMap<String, Vec<DatabaseRow>>>)> {
match self.db_type {
DatabaseType::REMOTE => Err(anyhow!("Remote DB not implemented.")),
DatabaseType::LOCAL => {
Expand Down Expand Up @@ -134,7 +135,99 @@ impl Database {
}
}

Ok(DatabaseSchema(schema))
let rows = match return_rows {
true => Some(table_rows),
false => None,
};
Ok((DatabaseSchema(schema), rows))
}
}
}

pub async fn query(
&self,
project: &Project,
store: Box<dyn Store + Sync + Send>,
query: &str,
) -> Result<(Vec<Value>, TableSchema)> {
match self.db_type {
DatabaseType::REMOTE => Err(anyhow!("Remote DB not implemented.")),
DatabaseType::LOCAL => {
// Retrieve the DB schema and construct a SQL string.
let (schema, rows_by_table) = self.get_schema(project, store.clone(), true).await?;
let mut create_tables_sql = "".to_string();
// TODO: maybe we can // ?
let mut table_schemas = HashMap::new();
for (table_name, table) in schema.into_iter() {
if table.schema.is_empty() {
continue;
}
table_schemas.insert(table_name.clone(), table.schema.clone());
create_tables_sql += &table
.schema
.get_create_table_sql_string(table_name.as_str());
create_tables_sql += "\n";
}

// Build the in-memory SQLite DB with the schema.
let conn = rusqlite::Connection::open_in_memory()?;
conn.execute_batch(&create_tables_sql)?;

let mut stmt = conn.prepare(query)?;

let column_names = stmt
.column_names()
.into_iter()
.map(|x| x.to_string())
.collect::<Vec<String>>();
let column_count = stmt.column_count();

// insert the rows in the DB
for (table_name, rows) in rows_by_table.expect("No rows found") {
if rows.is_empty() {
continue;
}

let table_schema = table_schemas
.get(&table_name)
.expect("No schema found for table");

let mut insert_sql = "".to_string();
for row in rows {
let insert_row_sql =
table_schema.get_insert_row_sql_string(&table_name, row.content())?;
insert_sql += &insert_row_sql;
}
conn.execute_batch(&insert_sql)?;
}

let rows = stmt.query_map([], |row| {
let mut map = serde_json::Map::new();
for i in 0..column_count {
let column_name = column_names.get(i).expect("Invalid column name");
let value = match row.get(i).expect("Invalid value") {
rusqlite::types::Value::Integer(i) => Value::Number(i.into()),
rusqlite::types::Value::Real(f) => {
Value::Number(Number::from_f64(f).expect("invalid float value"))
}
rusqlite::types::Value::Text(t) => Value::String(t),
// convert blob into string
rusqlite::types::Value::Blob(b) => {
Value::String(String::from_utf8(b).expect("Invalid UTF-8 sequence"))
}

rusqlite::types::Value::Null => Value::Null,
};
map.insert(column_name.to_string(), value);
}
Ok(Value::Object(map))
})?;

let results = rows.collect::<Result<Vec<Value>, rusqlite::Error>>()?;
let results_refs = results.iter().collect::<Vec<&Value>>();
let table_schema = TableSchema::from_rows(&results_refs)?;

Ok((results, table_schema))
}
}
}
Expand Down Expand Up @@ -229,7 +322,7 @@ impl DatabaseRow {
}

#[derive(Debug, Serialize)]
struct DatabaseSchemaTable {
pub struct DatabaseSchemaTable {
table: DatabaseTable,
schema: TableSchema,
}
Expand All @@ -242,7 +335,20 @@ impl DatabaseSchemaTable {
pub fn table(&self) -> &DatabaseTable {
&self.table
}

pub fn is_empty(&self) -> bool {
self.schema.is_empty()
}
}

#[derive(Debug, Serialize)]
pub struct DatabaseSchema(HashMap<String, DatabaseSchemaTable>);

impl IntoIterator for DatabaseSchema {
type Item = (String, DatabaseSchemaTable);
type IntoIter = std::collections::hash_map::IntoIter<String, DatabaseSchemaTable>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
Loading

0 comments on commit 0cff1e8

Please sign in to comment.