Skip to content

Commit

Permalink
Refactor database insertions to use prepared
Browse files Browse the repository at this point in the history
statements
  • Loading branch information
fontanierh committed Nov 13, 2023
1 parent 391a5eb commit 3421e9f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 49 deletions.
14 changes: 9 additions & 5 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::{anyhow, Result};
use rusqlite::ToSql;

use crate::{project::Project, stores::store::Store, utils};
use rayon::prelude::*;
Expand Down Expand Up @@ -218,14 +219,17 @@ impl Database {
.get(&table_name)
.ok_or_else(|| anyhow!("No schema found for table {}", table_name))?;

let mut insert_sql = "".to_string();
for row in rows {
let insert_row_sql =
let (query, boxed_params) =
table_schema.get_insert_row_sql_string(&table_name, row.content())?;
insert_sql += &insert_row_sql;
}

conn.execute_batch(&insert_sql)?;
let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

conn.execute(&query, params_refs.as_slice())?;
}
}
utils::done(&format!(
"DSSTRUCTSTAT Finished inserting rows: duration={}ms",
Expand Down
97 changes: 53 additions & 44 deletions core/src/databases/table_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use anyhow::{anyhow, Result};

use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use serde_json::Value;

Expand Down Expand Up @@ -98,50 +99,45 @@ impl TableSchema {
&self,
table_name: &str,
row_content: &Value,
) -> Result<String> {
let row_content = row_content
) -> Result<(String, Vec<Box<dyn ToSql>>)> {
let row_content_map = row_content
.as_object()
.ok_or_else(|| anyhow!("Row content is not an object"))?;

let mut insert_row = format!("INSERT INTO \"{}\" (", table_name);

for (name, _) in &self.0 {
insert_row.push_str(&format!("\"{}\", ", name));
}

// Remove the trailing comma and space, then close the parentheses.
let len = insert_row.len();
insert_row.truncate(len - 2);
insert_row.push_str(") VALUES (");

for (name, _) in &self.0 {
let value = row_content.get(name);

// if the value is not present, it's a null value
let value = value.unwrap_or(&Value::Null);

let sql_value = match value {
Value::Null => "NULL".to_string(),
Value::Bool(x) => x.to_string(),
Value::Number(x) => x.to_string(),
Value::String(x) => format!("\"{}\"", x),
Value::Object(_) | Value::Array(_) => {
return Err(anyhow!(
"Row content field {} is not a primitive type",
name
))
}
};

insert_row.push_str(&format!("{}, ", sql_value));
}

// Remove the trailing comma and space, then close the parentheses.
let len = insert_row.len();
insert_row.truncate(len - 2);
insert_row.push_str(");");
let field_names: Vec<&String> = self.0.keys().collect();
let fields = field_names
.iter()
.map(|name| format!("\"{}\"", name))
.collect::<Vec<String>>()
.join(", ");
let placeholders = field_names
.iter()
.enumerate()
.map(|(i, _)| format!("?{}", i + 1))
.collect::<Vec<String>>()
.join(", ");

let params: Vec<Box<dyn ToSql>> = field_names
.iter()
.map(|name| match row_content_map.get(*name) {
Some(Value::Bool(b)) => Ok(Box::new(*b) as Box<dyn ToSql>),
Some(Value::Number(n)) => n
.as_i64()
.map(|i| Box::new(i) as Box<dyn ToSql>)
.or_else(|| n.as_f64().map(|f| Box::new(f) as Box<dyn ToSql>))
.ok_or_else(|| anyhow!("Invalid number value for field {}", name)),
Some(Value::String(s)) => Ok(Box::new(s.clone()) as Box<dyn ToSql>),
Some(Value::Null) | None => Ok(Box::new(rusqlite::types::Null) as Box<dyn ToSql>),
Some(_) => Err(anyhow!("Unsupported value type for field {}", name)),
})
.collect::<Result<Vec<Box<dyn ToSql>>>>()?;

let insert_row = format!(
"INSERT INTO \"{}\" ({}) VALUES ({});",
table_name, fields, placeholders
);

Ok(insert_row)
Ok((insert_row, params))
}
}

Expand Down Expand Up @@ -282,8 +278,15 @@ mod tests {
"field4": true
});

let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?;
conn.execute(&sql_string, [])?;
let (sql_string, boxed_params) =
schema.get_insert_row_sql_string("test_table", &row_content)?;

let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

conn.execute(&sql_string, params_refs.as_slice())?;

let mut stmt = conn.prepare("SELECT * FROM test_table;")?;
let mut rows = stmt.query([])?;
Expand Down Expand Up @@ -316,8 +319,14 @@ mod tests {
// Missing field3 and field4
});

let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?;
conn.execute(&sql_string, [])?;
let (sql_string, boxed_params) =
schema.get_insert_row_sql_string("test_table", &row_content)?;
let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

conn.execute(&sql_string, params_refs.as_slice())?;

let mut stmt = conn.prepare("SELECT * FROM test_table;")?;
let mut rows = stmt.query([])?;
Expand Down

0 comments on commit 3421e9f

Please sign in to comment.