diff --git a/src/database_components/table.rs b/src/database_components/table.rs index 2ac3c30..b4dc6f8 100644 --- a/src/database_components/table.rs +++ b/src/database_components/table.rs @@ -71,6 +71,38 @@ impl Table { Err(format!("Row is missing an 'id' field: {:?}", row)) } } + + pub fn add_row_with_fk( + &mut self, + db: &Database, + row_data: serde_json::Value, + fk_constraints: Option<&[(&str, &str)]>, // Vec of (Table, Column) + ) -> Result<(), String> { + // Validate FK constraints if provided + if let Some(constraints) = fk_constraints { + for (table_name, fk_column) in constraints { + if let Some(fk_value) = row_data.get(fk_column).and_then(|v| v.as_str()) { + if !db.record_exists(table_name, fk_value) { + return Err(format!( + "Foreign key constraint failed: `{}` does not exist in `{}`", + fk_value, table_name + )); + } + } else { + return Err(format!("Missing value for foreign key column `{}`", fk_column)); + } + } + } + + // Add the row after validation + let row_id = row_data + .get("id") + .and_then(|id| id.as_str()) + .ok_or_else(|| "Missing primary key `id` in row data".to_string())?; + + self.rows.insert(row_id.to_string(), Row::new(row_data)); + Ok(()) + } } #[cfg(test)] diff --git a/src/database_operations/core.rs b/src/database_operations/core.rs index e043839..2977f92 100644 --- a/src/database_operations/core.rs +++ b/src/database_operations/core.rs @@ -134,6 +134,14 @@ impl Database { ))) } } + + pub fn record_exists(&self, table_name: &str, pk_value: &str) -> bool { + if let Some(table) = self.tables.get(table_name) { + table.rows.contains_key(pk_value) + } else { + false + } + } } #[cfg(test)] @@ -326,4 +334,65 @@ mod tests { let result = db.count_rows("NonExistentTable"); assert!(matches!(result, Err(DatabaseError::TableNotFound(_)))); } + + #[tokio::test] + async fn test_foreign_key_validation() { + #[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Default)] + struct Post { + id: String, + title: String, + content: String, + user_id: String, + } + + #[derive(Serialize, Deserialize, Debug, PartialEq, Clone, Default)] + struct User { + id: String, + name: String, + email: String, + } + + let mut db = setup_temp_db().await; + + // Set up User table + let user_columns = Columns::from_struct::(true); + let mut users_table = Table::new("users".to_string(), user_columns.clone()); + db.add_table(&mut users_table).await.unwrap(); + + let user1 = json!({ + "id": "1", + "name": "John Doe", + "email": "johndoe@example.com" + }); + users_table.add_row(&mut db, user1).await; + + // Set up Post table + let post_columns = Columns::from_struct::(true); + let mut posts_table = Table::new("posts".to_string(), post_columns.clone()); + db.add_table(&mut posts_table).await.unwrap(); + + let valid_post = json!({ + "id": "101", + "title": "Valid Post", + "content": "Content", + "user_id": "1" + }); + + let invalid_post = json!({ + "id": "102", + "title": "Invalid Post", + "content": "Content", + "user_id": "999" + }); + + // Valid FK + assert!(posts_table + .add_row_with_fk(&db, valid_post, Some(&[("users", "user_id")])) + .is_ok()); + + // Invalid FK + assert!(posts_table + .add_row_with_fk(&db, invalid_post, Some(&[("users", "user_id")])) + .is_err()); + } }