Skip to content

Commit

Permalink
batch_upsert validation + save a bunch of clones
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu committed Nov 22, 2023
1 parent aa36405 commit 5c0f595
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 46 deletions.
50 changes: 32 additions & 18 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1939,31 +1939,45 @@ async fn databases_rows_upsert(

match state
.store
.batch_upsert_database_rows(
&project,
&data_source_id,
&database_id,
&table_id,
&payload.contents,
truncate,
)
.load_database(&project, &data_source_id, &database_id)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to upsert database rows",
"Failed to retrieve database",
Some(e),
),
Ok(()) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"success": true
})),
}),
),
Ok(db) => match db {
None => error_response(
StatusCode::NOT_FOUND,
"database_not_found",
&format!("No database found for id `{}`", database_id),
None,
),
Some(db) => {
match db
.batch_upsert_rows(state.store.clone(), &table_id, &payload.contents, truncate)
.await
{
Err(e) => error_response(
StatusCode::BAD_REQUEST,
"invalid_database_rows_content",
"The rows content is invalid",
Some(e),
),
Ok(()) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"success": true
})),
}),
),
}
}
},
}
}

Expand Down
34 changes: 31 additions & 3 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,34 @@ impl Database {
}
}

pub async fn batch_upsert_rows(
&self,
store: Box<dyn Store + Sync + Send>,
table_id: &str,
contents: HashMap<String, Value>,
truncate: bool,
) -> Result<()> {
let rows = contents
.into_iter()
.map(|(row_id, content)| DatabaseRow::new(utils::now(), Some(row_id), content))
.collect::<Vec<_>>();

// This will be used to update the schema incrementally once we store schemas. For now this
// is a way to validate the content of the rows (only primitive types).
let _ = TableSchema::from_rows(&rows)?;

store
.batch_upsert_database_rows(
&self.project,
&self.data_source_id,
&self.database_id,
table_id,
&rows,
truncate,
)
.await
}

pub async fn create_in_memory_sqlite_conn(
&self,
store: Box<dyn Store + Sync + Send>,
Expand Down Expand Up @@ -244,7 +272,7 @@ impl Database {
})?
.collect::<Result<Vec<_>>>()?
.into_par_iter()
.map(|v| DatabaseRow::new(utils::now(), None, &v))
.map(|v| DatabaseRow::new(utils::now(), None, v))
.collect::<Vec<_>>();
utils::done(&format!(
"DSSTRUCTSTAT Finished executing user query: duration={}ms",
Expand Down Expand Up @@ -369,11 +397,11 @@ pub struct DatabaseRow {
}

impl DatabaseRow {
pub fn new(created: u64, row_id: Option<String>, content: &Value) -> Self {
pub fn new(created: u64, row_id: Option<String>, content: Value) -> Self {
DatabaseRow {
created: created,
row_id: row_id,
content: content.clone(),
content,
}
}

Expand Down
42 changes: 26 additions & 16 deletions core/src/databases/table_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@ impl TableSchema {
let mut schema = HashMap::new();

for (row_index, row) in rows.iter().enumerate() {
let object = row
.content()
.as_object()
.ok_or_else(|| anyhow!("Row {} is not an object", row_index))?;
let object = match row.content().as_object() {
Some(object) => object,
None => Err(anyhow!(
"Row [{}] {:?} is not an object",
row_index,
row.row_id()
))?,
};

for (k, v) in object {
if v.is_null() {
Expand All @@ -75,18 +79,24 @@ impl TableSchema {
TableSchemaFieldType::Float
}
}
Value::String(_) | Value::Object(_) | Value::Array(_) => {
TableSchemaFieldType::Text
}
_ => unreachable!(),
Value::String(_) => TableSchemaFieldType::Text,
Value::Object(_) | Value::Array(_) => Err(anyhow!(
"Field {} is not a primitive type on row [{}] {:?} \
(object and arrays are not supported)",
k,
row_index,
row.row_id()
))?,
Value::Null => unreachable!(),
};

if let Some(existing_type) = schema.get(k) {
if existing_type != &value_type {
return Err(anyhow!(
"Field {} has conflicting types on row {}: {:?} and {:?}",
"Field {} has conflicting types on row [{}] {:?}: {:?} and {:?}",
k,
row_index,
row.row_id(),
existing_type,
value_type
));
Expand Down Expand Up @@ -198,8 +208,8 @@ mod tests {
"field7": {"anotherKey": "anotherValue"}
});
let rows = &vec![
DatabaseRow::new(utils::now(), Some("1".to_string()), &row_1),
DatabaseRow::new(utils::now(), Some("2".to_string()), &row_2),
DatabaseRow::new(utils::now(), Some("1".to_string()), row_1),
DatabaseRow::new(utils::now(), Some("2".to_string()), row_2),
];

let schema = TableSchema::from_rows(rows)?;
Expand Down Expand Up @@ -246,9 +256,9 @@ mod tests {
"field1": "now it's a text field",
});
let rows = &vec![
DatabaseRow::new(utils::now(), Some("1".to_string()), &row_1),
DatabaseRow::new(utils::now(), Some("2".to_string()), &row_2),
DatabaseRow::new(utils::now(), Some("3".to_string()), &row_3),
DatabaseRow::new(utils::now(), Some("1".to_string()), row_1),
DatabaseRow::new(utils::now(), Some("2".to_string()), row_2),
DatabaseRow::new(utils::now(), Some("3".to_string()), row_3),
];

let schema = TableSchema::from_rows(rows);
Expand Down Expand Up @@ -312,7 +322,7 @@ mod tests {
let row = DatabaseRow::new(
utils::now(),
None,
&json!({
json!({
"field1": 1,
"field2": 2.4,
"field3": "text",
Expand Down Expand Up @@ -363,7 +373,7 @@ mod tests {
let (sql, field_names) = schema.get_insert_sql("test_table");
let params = params_from_iter(schema.get_insert_params(
&field_names,
&DatabaseRow::new(utils::now(), Some("1".to_string()), &row_content),
&DatabaseRow::new(utils::now(), Some("1".to_string()), row_content),
)?);
let mut stmt = conn.prepare(&sql)?;
stmt.execute(params)?;
Expand Down
21 changes: 14 additions & 7 deletions core/src/stores/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2210,7 +2210,7 @@ impl Store for PostgresStore {
Some((created, row_id, data)) => Ok(Some(DatabaseRow::new(
created as u64,
Some(row_id),
&Value::from_str(&data)?,
Value::from_str(&data)?,
))),
}
}
Expand Down Expand Up @@ -2293,7 +2293,7 @@ impl Store for PostgresStore {
let row_id: String = row.get(1);
let data: String = row.get(2);
let content: Value = serde_json::from_str(&data)?;
Ok(DatabaseRow::new(created as u64, Some(row_id), &content))
Ok(DatabaseRow::new(created as u64, Some(row_id), content))
})
.collect::<Result<Vec<_>>>()?;

Expand All @@ -2320,7 +2320,7 @@ impl Store for PostgresStore {
data_source_id: &str,
database_id: &str,
table_id: &str,
contents: &HashMap<String, Value>,
rows: &Vec<DatabaseRow>,
truncate: bool,
) -> Result<()> {
let project_id = project.project_id();
Expand Down Expand Up @@ -2390,12 +2390,19 @@ impl Store for PostgresStore {
)
.await?;

for (row_id, content) in contents {
let row_created = utils::now();
let row_data = content.to_string();
for row in rows {
let row_id = match row.row_id() {
Some(row_id) => row_id.to_string(),
None => unreachable!(),
};
c.execute(
&stmt,
&[&table_row_id, &(row_created as i64), &row_id, &row_data],
&[
&table_row_id,
&(row.created() as i64),
&row_id,
&row.content().to_string(),
],
)
.await?;
}
Expand Down
3 changes: 1 addition & 2 deletions core/src/stores/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::providers::llm::{LLMChatGeneration, LLMChatRequest, LLMGeneration, LL
use crate::run::{Run, RunStatus, RunType};
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;

#[async_trait]
Expand Down Expand Up @@ -202,7 +201,7 @@ pub trait Store {
data_source_id: &str,
database_id: &str,
table_id: &str,
contents: &HashMap<String, Value>,
rows: &Vec<DatabaseRow>,
truncate: bool,
) -> Result<()>;
async fn load_database_row(
Expand Down

0 comments on commit 5c0f595

Please sign in to comment.