From 5c0f59580573dbbdd37078f20024b54c3278e969 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Wed, 22 Nov 2023 18:14:42 +0100 Subject: [PATCH] batch_upsert validation + save a bunch of clones --- core/bin/dust_api.rs | 50 +++++++++++++++++++----------- core/src/databases/database.rs | 34 ++++++++++++++++++-- core/src/databases/table_schema.rs | 42 +++++++++++++++---------- core/src/stores/postgres.rs | 21 ++++++++----- core/src/stores/store.rs | 3 +- 5 files changed, 104 insertions(+), 46 deletions(-) diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index a0493d48d4845..ad67fe76037cb 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -1939,31 +1939,45 @@ async fn databases_rows_upsert( match state .store - .batch_upsert_database_rows( - &project, - &data_source_id, - &database_id, - &table_id, - &payload.contents, - truncate, - ) + .load_database(&project, &data_source_id, &database_id) .await { Err(e) => error_response( StatusCode::INTERNAL_SERVER_ERROR, "internal_server_error", - "Failed to upsert database rows", + "Failed to retrieve database", Some(e), ), - Ok(()) => ( - StatusCode::OK, - Json(APIResponse { - error: None, - response: Some(json!({ - "success": true - })), - }), - ), + Ok(db) => match db { + None => error_response( + StatusCode::NOT_FOUND, + "database_not_found", + &format!("No database found for id `{}`", database_id), + None, + ), + Some(db) => { + match db + .batch_upsert_rows(state.store.clone(), &table_id, &payload.contents, truncate) + .await + { + Err(e) => error_response( + StatusCode::BAD_REQUEST, + "invalid_database_rows_content", + "The rows content is invalid", + Some(e), + ), + Ok(()) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "success": true + })), + }), + ), + } + } + }, } } diff --git a/core/src/databases/database.rs b/core/src/databases/database.rs index bff04be3c5f86..4024364ebc843 100644 --- a/core/src/databases/database.rs +++ b/core/src/databases/database.rs @@ -74,6 +74,34 @@ impl Database { } } + pub async fn batch_upsert_rows( + &self, + store: Box, + table_id: &str, + contents: HashMap, + truncate: bool, + ) -> Result<()> { + let rows = contents + .into_iter() + .map(|(row_id, content)| DatabaseRow::new(utils::now(), Some(row_id), content)) + .collect::>(); + + // This will be used to update the schema incrementally once we store schemas. For now this + // is a way to validate the content of the rows (only primitive types). + let _ = TableSchema::from_rows(&rows)?; + + store + .batch_upsert_database_rows( + &self.project, + &self.data_source_id, + &self.database_id, + table_id, + &rows, + truncate, + ) + .await + } + pub async fn create_in_memory_sqlite_conn( &self, store: Box, @@ -244,7 +272,7 @@ impl Database { })? .collect::>>()? .into_par_iter() - .map(|v| DatabaseRow::new(utils::now(), None, &v)) + .map(|v| DatabaseRow::new(utils::now(), None, v)) .collect::>(); utils::done(&format!( "DSSTRUCTSTAT Finished executing user query: duration={}ms", @@ -369,11 +397,11 @@ pub struct DatabaseRow { } impl DatabaseRow { - pub fn new(created: u64, row_id: Option, content: &Value) -> Self { + pub fn new(created: u64, row_id: Option, content: Value) -> Self { DatabaseRow { created: created, row_id: row_id, - content: content.clone(), + content, } } diff --git a/core/src/databases/table_schema.rs b/core/src/databases/table_schema.rs index 3e19f84142e42..8a569b065ace3 100644 --- a/core/src/databases/table_schema.rs +++ b/core/src/databases/table_schema.rs @@ -56,10 +56,14 @@ impl TableSchema { let mut schema = HashMap::new(); for (row_index, row) in rows.iter().enumerate() { - let object = row - .content() - .as_object() - .ok_or_else(|| anyhow!("Row {} is not an object", row_index))?; + let object = match row.content().as_object() { + Some(object) => object, + None => Err(anyhow!( + "Row [{}] {:?} is not an object", + row_index, + row.row_id() + ))?, + }; for (k, v) in object { if v.is_null() { @@ -75,18 +79,24 @@ impl TableSchema { TableSchemaFieldType::Float } } - Value::String(_) | Value::Object(_) | Value::Array(_) => { - TableSchemaFieldType::Text - } - _ => unreachable!(), + Value::String(_) => TableSchemaFieldType::Text, + Value::Object(_) | Value::Array(_) => Err(anyhow!( + "Field {} is not a primitive type on row [{}] {:?} \ + (object and arrays are not supported)", + k, + row_index, + row.row_id() + ))?, + Value::Null => unreachable!(), }; if let Some(existing_type) = schema.get(k) { if existing_type != &value_type { return Err(anyhow!( - "Field {} has conflicting types on row {}: {:?} and {:?}", + "Field {} has conflicting types on row [{}] {:?}: {:?} and {:?}", k, row_index, + row.row_id(), existing_type, value_type )); @@ -198,8 +208,8 @@ mod tests { "field7": {"anotherKey": "anotherValue"} }); let rows = &vec![ - DatabaseRow::new(utils::now(), Some("1".to_string()), &row_1), - DatabaseRow::new(utils::now(), Some("2".to_string()), &row_2), + DatabaseRow::new(utils::now(), Some("1".to_string()), row_1), + DatabaseRow::new(utils::now(), Some("2".to_string()), row_2), ]; let schema = TableSchema::from_rows(rows)?; @@ -246,9 +256,9 @@ mod tests { "field1": "now it's a text field", }); let rows = &vec![ - DatabaseRow::new(utils::now(), Some("1".to_string()), &row_1), - DatabaseRow::new(utils::now(), Some("2".to_string()), &row_2), - DatabaseRow::new(utils::now(), Some("3".to_string()), &row_3), + DatabaseRow::new(utils::now(), Some("1".to_string()), row_1), + DatabaseRow::new(utils::now(), Some("2".to_string()), row_2), + DatabaseRow::new(utils::now(), Some("3".to_string()), row_3), ]; let schema = TableSchema::from_rows(rows); @@ -312,7 +322,7 @@ mod tests { let row = DatabaseRow::new( utils::now(), None, - &json!({ + json!({ "field1": 1, "field2": 2.4, "field3": "text", @@ -363,7 +373,7 @@ mod tests { let (sql, field_names) = schema.get_insert_sql("test_table"); let params = params_from_iter(schema.get_insert_params( &field_names, - &DatabaseRow::new(utils::now(), Some("1".to_string()), &row_content), + &DatabaseRow::new(utils::now(), Some("1".to_string()), row_content), )?); let mut stmt = conn.prepare(&sql)?; stmt.execute(params)?; diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 30ab2b7024558..f434ebc51db6e 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -2210,7 +2210,7 @@ impl Store for PostgresStore { Some((created, row_id, data)) => Ok(Some(DatabaseRow::new( created as u64, Some(row_id), - &Value::from_str(&data)?, + Value::from_str(&data)?, ))), } } @@ -2293,7 +2293,7 @@ impl Store for PostgresStore { let row_id: String = row.get(1); let data: String = row.get(2); let content: Value = serde_json::from_str(&data)?; - Ok(DatabaseRow::new(created as u64, Some(row_id), &content)) + Ok(DatabaseRow::new(created as u64, Some(row_id), content)) }) .collect::>>()?; @@ -2320,7 +2320,7 @@ impl Store for PostgresStore { data_source_id: &str, database_id: &str, table_id: &str, - contents: &HashMap, + rows: &Vec, truncate: bool, ) -> Result<()> { let project_id = project.project_id(); @@ -2390,12 +2390,19 @@ impl Store for PostgresStore { ) .await?; - for (row_id, content) in contents { - let row_created = utils::now(); - let row_data = content.to_string(); + for row in rows { + let row_id = match row.row_id() { + Some(row_id) => row_id.to_string(), + None => unreachable!(), + }; c.execute( &stmt, - &[&table_row_id, &(row_created as i64), &row_id, &row_data], + &[ + &table_row_id, + &(row.created() as i64), + &row_id, + &row.content().to_string(), + ], ) .await?; } diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index d6e322c6c3209..a71a067164b30 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -11,7 +11,6 @@ use crate::providers::llm::{LLMChatGeneration, LLMChatRequest, LLMGeneration, LL use crate::run::{Run, RunStatus, RunType}; use anyhow::Result; use async_trait::async_trait; -use serde_json::Value; use std::collections::HashMap; #[async_trait] @@ -202,7 +201,7 @@ pub trait Store { data_source_id: &str, database_id: &str, table_id: &str, - contents: &HashMap, + rows: &Vec, truncate: bool, ) -> Result<()>; async fn load_database_row(