Skip to content

Commit

Permalink
Refactor database insertions to use prepared
Browse files Browse the repository at this point in the history
statements
  • Loading branch information
fontanierh committed Nov 14, 2023
1 parent f8cd90d commit e6f6646
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 44 deletions.
1 change: 1 addition & 0 deletions core/src/databases/database.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::{anyhow, Result};
use rusqlite::ToSql;

use crate::{project::Project, stores::store::Store, utils};
use rayon::prelude::*;
Expand Down
97 changes: 53 additions & 44 deletions core/src/databases/table_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use anyhow::{anyhow, Result};

use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use serde_json::Value;

Expand Down Expand Up @@ -101,50 +102,45 @@ impl TableSchema {
&self,
table_name: &str,
row_content: &Value,
) -> Result<String> {
let row_content = row_content
) -> Result<(String, Vec<Box<dyn ToSql>>)> {
let row_content_map = row_content
.as_object()
.ok_or_else(|| anyhow!("Row content is not an object"))?;

let mut insert_row = format!("INSERT INTO \"{}\" (", table_name);

for (name, _) in &self.0 {
insert_row.push_str(&format!("\"{}\", ", name));
}

// Remove the trailing comma and space, then close the parentheses.
let len = insert_row.len();
insert_row.truncate(len - 2);
insert_row.push_str(") VALUES (");

for (name, _) in &self.0 {
let value = row_content.get(name);

// if the value is not present, it's a null value
let value = value.unwrap_or(&Value::Null);

let sql_value = match value {
Value::Null => "NULL".to_string(),
Value::Bool(x) => x.to_string(),
Value::Number(x) => x.to_string(),
Value::String(x) => format!("\"{}\"", x),
Value::Object(_) | Value::Array(_) => {
return Err(anyhow!(
"Row content field {} is not a primitive type",
name
))
}
};

insert_row.push_str(&format!("{}, ", sql_value));
}

// Remove the trailing comma and space, then close the parentheses.
let len = insert_row.len();
insert_row.truncate(len - 2);
insert_row.push_str(");");
let field_names: Vec<&String> = self.0.keys().collect();
let fields = field_names
.iter()
.map(|name| format!("\"{}\"", name))
.collect::<Vec<String>>()
.join(", ");
let placeholders = field_names
.iter()
.enumerate()
.map(|(i, _)| format!("?{}", i + 1))
.collect::<Vec<String>>()
.join(", ");

let params: Vec<Box<dyn ToSql>> = field_names
.iter()
.map(|name| match row_content_map.get(*name) {
Some(Value::Bool(b)) => Ok(Box::new(*b) as Box<dyn ToSql>),
Some(Value::Number(n)) => n
.as_i64()
.map(|i| Box::new(i) as Box<dyn ToSql>)
.or_else(|| n.as_f64().map(|f| Box::new(f) as Box<dyn ToSql>))
.ok_or_else(|| anyhow!("Invalid number value for field {}", name)),
Some(Value::String(s)) => Ok(Box::new(s.clone()) as Box<dyn ToSql>),
Some(Value::Null) | None => Ok(Box::new(rusqlite::types::Null) as Box<dyn ToSql>),
Some(_) => Err(anyhow!("Unsupported value type for field {}", name)),
})
.collect::<Result<Vec<Box<dyn ToSql>>>>()?;

let insert_row = format!(
"INSERT INTO \"{}\" ({}) VALUES ({});",
table_name, fields, placeholders
);

Ok(insert_row)
Ok((insert_row, params))
}
}

Expand Down Expand Up @@ -293,8 +289,15 @@ mod tests {
"field4": true
});

let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?;
conn.execute(&sql_string, [])?;
let (sql_string, boxed_params) =
schema.get_insert_row_sql_string("test_table", &row_content)?;

let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

conn.execute(&sql_string, params_refs.as_slice())?;

let mut stmt = conn.prepare("SELECT * FROM test_table;")?;
let mut rows = stmt.query([])?;
Expand Down Expand Up @@ -327,8 +330,14 @@ mod tests {
// Missing field3 and field4
});

let sql_string = schema.get_insert_row_sql_string("test_table", &row_content)?;
conn.execute(&sql_string, [])?;
let (sql_string, boxed_params) =
schema.get_insert_row_sql_string("test_table", &row_content)?;
let params_refs: Vec<&dyn ToSql> = boxed_params
.iter()
.map(|param| &**param as &dyn ToSql)
.collect();

conn.execute(&sql_string, params_refs.as_slice())?;

let mut stmt = conn.prepare("SELECT * FROM test_table;")?;
let mut rows = stmt.query([])?;
Expand Down

0 comments on commit e6f6646

Please sign in to comment.