diff --git a/core/src/databases/database.rs b/core/src/databases/database.rs index eaafc1cfed024..91f57c61847cf 100644 --- a/core/src/databases/database.rs +++ b/core/src/databases/database.rs @@ -1,4 +1,5 @@ use anyhow::{anyhow, Result}; +use rusqlite::ToSql; use crate::{project::Project, stores::store::Store, utils}; use rayon::prelude::*; @@ -218,14 +219,17 @@ impl Database { .get(&table_name) .ok_or_else(|| anyhow!("No schema found for table {}", table_name))?; - let mut insert_sql = "".to_string(); for row in rows { - let insert_row_sql = + let (query, boxed_params) = table_schema.get_insert_row_sql_string(&table_name, row.content())?; - insert_sql += &insert_row_sql; - } - conn.execute_batch(&insert_sql)?; + let params_refs: Vec<&dyn ToSql> = boxed_params + .iter() + .map(|param| &**param as &dyn ToSql) + .collect(); + + conn.execute(&query, params_refs.as_slice())?; + } } utils::done(&format!( "DSSTRUCTSTAT Finished inserting rows: duration={}ms", diff --git a/core/src/databases/table_schema.rs b/core/src/databases/table_schema.rs index 1b5d988378b8d..684ca826a186c 100644 --- a/core/src/databases/table_schema.rs +++ b/core/src/databases/table_schema.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::{anyhow, Result}; +use rusqlite::ToSql; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -98,50 +99,45 @@ impl TableSchema { &self, table_name: &str, row_content: &Value, - ) -> Result { - let row_content = row_content + ) -> Result<(String, Vec>)> { + let row_content_map = row_content .as_object() .ok_or_else(|| anyhow!("Row content is not an object"))?; - let mut insert_row = format!("INSERT INTO \"{}\" (", table_name); - - for (name, _) in &self.0 { - insert_row.push_str(&format!("\"{}\", ", name)); - } - - // Remove the trailing comma and space, then close the parentheses. - let len = insert_row.len(); - insert_row.truncate(len - 2); - insert_row.push_str(") VALUES ("); - - for (name, _) in &self.0 { - let value = row_content.get(name); - - // if the value is not present, it's a null value - let value = value.unwrap_or(&Value::Null); - - let sql_value = match value { - Value::Null => "NULL".to_string(), - Value::Bool(x) => x.to_string(), - Value::Number(x) => x.to_string(), - Value::String(x) => format!("\"{}\"", x), - Value::Object(_) | Value::Array(_) => { - return Err(anyhow!( - "Row content field {} is not a primitive type", - name - )) - } - }; - - insert_row.push_str(&format!("{}, ", sql_value)); - } - - // Remove the trailing comma and space, then close the parentheses. - let len = insert_row.len(); - insert_row.truncate(len - 2); - insert_row.push_str(");"); + let field_names: Vec<&String> = self.0.keys().collect(); + let fields = field_names + .iter() + .map(|name| format!("\"{}\"", name)) + .collect::>() + .join(", "); + let placeholders = field_names + .iter() + .enumerate() + .map(|(i, _)| format!("?{}", i + 1)) + .collect::>() + .join(", "); + + let params: Vec> = field_names + .iter() + .map(|name| match row_content_map.get(*name) { + Some(Value::Bool(b)) => Ok(Box::new(*b) as Box), + Some(Value::Number(n)) => n + .as_i64() + .map(|i| Box::new(i) as Box) + .or_else(|| n.as_f64().map(|f| Box::new(f) as Box)) + .ok_or_else(|| anyhow!("Invalid number value for field {}", name)), + Some(Value::String(s)) => Ok(Box::new(s.clone()) as Box), + Some(Value::Null) | None => Ok(Box::new(rusqlite::types::Null) as Box), + Some(_) => Err(anyhow!("Unsupported value type for field {}", name)), + }) + .collect::>>>()?; + + let insert_row = format!( + "INSERT INTO \"{}\" ({}) VALUES ({});", + table_name, fields, placeholders + ); - Ok(insert_row) + Ok((insert_row, params)) } } @@ -282,8 +278,15 @@ mod tests { "field4": true }); - let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?; - conn.execute(&sql_string, [])?; + let (sql_string, boxed_params) = + schema.get_insert_row_sql_string("test_table", &row_content)?; + + let params_refs: Vec<&dyn ToSql> = boxed_params + .iter() + .map(|param| &**param as &dyn ToSql) + .collect(); + + conn.execute(&sql_string, params_refs.as_slice())?; let mut stmt = conn.prepare("SELECT * FROM test_table;")?; let mut rows = stmt.query([])?; @@ -316,8 +319,14 @@ mod tests { // Missing field3 and field4 }); - let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?; - conn.execute(&sql_string, [])?; + let (sql_string, boxed_params) = + schema.get_insert_row_sql_string("test_table", &row_content)?; + let params_refs: Vec<&dyn ToSql> = boxed_params + .iter() + .map(|param| &**param as &dyn ToSql) + .collect(); + + conn.execute(&sql_string, params_refs.as_slice())?; let mut stmt = conn.prepare("SELECT * FROM test_table;")?; let mut rows = stmt.query([])?;