diff --git a/pg-worm-derive/src/parse.rs b/pg-worm-derive/src/parse.rs index e72d3fc..dcd2ecf 100644 --- a/pg-worm-derive/src/parse.rs +++ b/pg-worm-derive/src/parse.rs @@ -88,17 +88,46 @@ impl ModelInput { let model = self.impl_model(); quote!( + #[automatically_derived] impl #ident { #column_consts #insert #columns } + #[automatically_derived] #try_from_row + #[automatically_derived] #model ) } + fn impl_table(&self) -> TokenStream { + let table_name = self.table_name(); + let primary_keys = self + .all_fields() + .filter_map(|i| { + if !i.primary_key { + None + } else { + Some(i.column_name()) + } + }) + .collect::>(); + let columns = self.all_fields().map(|i| i.create_column()); + + quote!( + fn table() -> ::pg_worm::migration::Table { + use ::pg_worm::migration::{Table, Column}; + + Table::new(#table_name).primary_key(vec![#(#primary_keys.to_string()), *]) + #( + .column(#columns) + )* + } + ) + } + /// Generate the code for implementing the /// `Model` trait. fn impl_model(&self) -> TokenStream { @@ -113,6 +142,7 @@ impl ModelInput { let delete = self.impl_delete(); let update = self.impl_update(); let query = self.impl_query(); + let table = self.impl_table(); quote!( #[pg_worm::async_trait] @@ -121,6 +151,7 @@ impl ModelInput { #update #delete #query + #table fn table_name() -> &'static str { #table_name @@ -384,6 +415,27 @@ impl ModelField { self.ident().to_string().to_lowercase() } + fn create_column(&self) -> TokenStream { + let name = self.column_name(); + let mut dtype = self.try_pg_datatype().unwrap().to_string(); + if self.array { + dtype.push_str("[]"); + } + + let mut res = + quote!(::pg_worm::migration::Column::new(#name.to_string(), #dtype.to_string())); + + if self.unique { + res.extend(quote!(.unique())); + } + + if !self.nullable { + res.extend(quote!(.not_null())); + } + + res + } + /// Get the corresponding postgres type fn try_pg_datatype(&self) -> Result { let ty = self.ty.clone(); diff --git a/pg-worm/src/lib.rs b/pg-worm/src/lib.rs index d36f8ee..954dcd2 100644 --- a/pg-worm/src/lib.rs +++ b/pg-worm/src/lib.rs @@ -374,6 +374,10 @@ pub enum Error { /// Emitted when no connection could be fetched from the pool. #[error("couldn't fetch connection from pool")] NoConnectionInPool, + /// Emitted when pg-worm couldn't parse the query result into + /// the corresponding model. + #[error("couldn't parse row into {0}: couldn't read field {1}")] + ParseError(&'static str, &'static str), /// Errors emitted by the Postgres server. /// /// Most likely an invalid query. @@ -401,6 +405,10 @@ pub trait Model: FromRow { #[must_use] fn _table_creation_sql() -> &'static str; + /// Returns a table object used to created migrations. + #[doc(hidden)] + fn table() -> migration::Table; + /// Returns a slice of all columns this model's table has. fn columns() -> &'static [&'static dyn Deref]; @@ -431,6 +439,23 @@ pub trait Model: FromRow { fn query(_: impl Into, _: Vec<&(dyn ToSql + Sync)>) -> Query<'_, Vec>; } +/// A cleaner api for [`migration::migrate_tables`]. +/// +/// Call like this: +/// ```ignore +/// #[derive(Model)] +/// struct Book { +/// id: 64 +/// } +/// +/// migrate_tables(Book).await?; +#[macro_export] +macro_rules! migrate_tables { + ($($x:ty), +) => { + $crate::migration::migrate_tables(vec![$(<$x as $crate::Model<$x>>::table()),*]) + }; +} + /// Create a table for your model. /// /// Use the [`try_create_table!`] macro for a more convenient api. diff --git a/pg-worm/src/migration/mod.rs b/pg-worm/src/migration/mod.rs index 6440c0a..939b617 100644 --- a/pg-worm/src/migration/mod.rs +++ b/pg-worm/src/migration/mod.rs @@ -2,23 +2,31 @@ #![allow(dead_code)] -use std::fmt::Display; +use std::{fmt::Display, ops::Deref}; + +use hashbrown::HashMap; +use tokio_postgres::Row; + +use crate::{pool::fetch_client, FromRow}; /// Represents a collection of tables. -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone)] pub struct Schema { + name: String, tables: Vec, } +/// Represents a table. #[derive(Debug, Clone)] -struct Table { +pub struct Table { name: String, columns: Vec, constraints: Vec, } +/// Represents a column. #[derive(Debug, Clone)] -struct Column { +pub struct Column { name: String, data_type: String, constraints: Vec, @@ -49,16 +57,158 @@ enum ColumnConstraint { RawCheckNamed(String, String), } +/// Fetch a schema from a database connection. +/// +/// May fail due to connection errors, parsing errors or +/// when querying for a not existing schema. +async fn fetch_schema( + schema_name: impl Into, + client: &tokio_postgres::Client, +) -> Result { + let schema_name = schema_name.into(); + + struct Entry { + table: String, + column: String, + data_type: String, + } + + impl TryFrom for Entry { + type Error = crate::Error; + + fn try_from(row: Row) -> Result { + Ok(Entry { + table: row + .try_get("table") + .map_err(|_| crate::Error::ParseError("Entry", "table"))?, + column: row.try_get("column")?, + data_type: row.try_get("data_type")?, + }) + } + } + + impl FromRow for Entry {} + + // Query all columns and their data type of all tables in this schema + let res = client + .query( + r#" + SELECT + pg_attribute.attname AS column, + pg_catalog.format_type(pg_attribute.atttypid, pg_attribute.atttypmod) AS data_type, + pg_class.relname AS table + FROM + pg_catalog.pg_attribute + INNER JOIN + pg_catalog.pg_class ON pg_class.oid = pg_attribute.attrelid + INNER JOIN + pg_catalog.pg_namespace ON pg_namespace.oid = pg_class.relnamespace + WHERE + pg_attribute.attnum > 0 + AND NOT pg_attribute.attisdropped + AND pg_namespace.nspname = $1 + AND pg_class.relkind = 'r' + ORDER BY + attnum ASC + "#, + &[&schema_name], + ) + .await?; + + // Parse the query result to `Entry` objects. + let entries: Vec = res + .into_iter() + .map(Entry::try_from) + .collect::, crate::Error>>()?; + + // Group the columns by table. No idea if there's a better way to do this + let mut map: HashMap> = HashMap::new(); + for i in entries { + if let Some(columns) = map.get_mut(&i.table) { + columns.push(i); + } else { + map.insert(i.table.clone(), vec![i]); + } + } + + let tables = map.into_iter().map(|(table, columns)| { + Table::new(table).columns( + columns + .into_iter() + .map(|i| Column::new(i.column, i.data_type)), + ) + }); + + let schema = Schema::new(schema_name).tables(tables); + + Ok(schema) +} + +/// Try to automatically migrate from `old` to `new`. +pub async fn try_migration_from( + old: &Schema, + new: &Schema, + client: &tokio_postgres::Client, +) -> Result<(), crate::Error> { + let stmts = old.migrate_from(new)._join("; "); + + client + .simple_query(&stmts) + .await + .map(|_| ()) + .map_err(crate::Error::PostgresError) +} + +/// Try to fetch the current schema and then automatically +/// migrate from that to `new`. +pub async fn try_migration_to( + new: &Schema, + client: &tokio_postgres::Client, +) -> Result<(), crate::Error> { + let old = fetch_schema(&new.name, client).await?; + try_migration_from(&old, new, client).await +} + +/// Automatically create new or alter existing tables in the `'public'` schema. +pub async fn migrate_tables(table: impl IntoIterator) -> Result<(), crate::Error> { + let new = Schema::default().tables(&mut table.into_iter()); + try_migration_to(&new, fetch_client().await?.deref()).await +} + +impl Default for Schema { + fn default() -> Self { + Schema { + name: "public".into(), + tables: Vec::new(), + } + } +} + impl Schema { + /// Create a new schema. + pub fn new(name: impl Into) -> Schema { + Self { + name: name.into(), + tables: Vec::new(), + } + } + /// Add a table to this schema. - fn table(mut self, table: Table) -> Schema { + pub fn table(mut self, table: Table) -> Self { self.tables.push(table); self } + /// Add multiple tables to this schema. + pub fn tables(mut self, tables: impl IntoIterator) -> Self { + self.tables.extend(&mut tables.into_iter()); + + self + } + /// Generate SQL statements which migrate `old` to this schema. - pub fn migrate_from(&self, old: &Schema) -> Vec { + fn migrate_from(&self, old: &Schema) -> Vec { let mut statements = Vec::new(); for table in &self.tables { if let Some(old_table) = old.tables.iter().find(|i| i.name == table.name) { @@ -83,7 +233,8 @@ impl Schema { } impl Table { - fn new(name: impl Into) -> Self { + /// Create a new table. + pub fn new(name: impl Into) -> Self { let columns: Vec = Vec::new(); Table { @@ -93,8 +244,60 @@ impl Table { } } - fn column(mut self, col: Column) -> Table { - self.columns.push(col); + /// Add a column to this table. + pub fn column(mut self, column: Column) -> Self { + self.columns.push(column); + + self + } + + /// Add columns to this table. + pub fn columns(mut self, columns: impl IntoIterator) -> Self { + self.columns.extend(&mut columns.into_iter()); + + self + } + + /// Add a unique constraint to column(s) of this table. + pub fn unique(mut self, cols: impl IntoIterator) -> Self { + self.constraints + .push(TableConstraint::Unique(cols.into_iter().collect())); + + self + } + + /// Add a named unique constraint to column(s) of this table. + pub fn unique_named( + mut self, + name: impl Into, + cols: impl IntoIterator, + ) -> Self { + self.constraints.push(TableConstraint::UniqueNamed( + name.into(), + cols.into_iter().collect(), + )); + + self + } + + /// Add a primary key to this table. + pub fn primary_key(mut self, cols: impl IntoIterator) -> Self { + self.constraints + .push(TableConstraint::PrimaryKey(cols.into_iter().collect())); + + self + } + + /// Add a foreign key constraint to this table. + pub fn foreign_key( + mut self, + table: impl Into, + columns: impl IntoIterator, + ) -> Self { + self.constraints.push(TableConstraint::ForeignKey( + table.into(), + columns.into_iter().collect(), + )); self } @@ -142,6 +345,14 @@ impl Table { } } + for i in old_table + .columns + .iter() + .filter(|i| !self.columns.iter().any(|j| i.name == j.name)) + { + statements.push(format!("ALTER TABLE {} {}", self.name, i.down())); + } + statements } @@ -176,17 +387,18 @@ impl Table { FROM pg_catalog.pg_constraint con INNER JOIN pg_catalog.pg_class rel ON rel.oid = con.conrelid INNER JOIN pg_catalog.pg_namespace nsp ON nsp.oid = connamespace - WHERE rel.relname = {0}) LOOP + WHERE rel.relname = '{0}') LOOP EXECUTE format('ALTER TABLE {0} DROP CONSTRAINT %I CASCADE', i.conname); END LOOP; - END $$;"#, + END $$"#, self.name ) } } impl Column { - fn new(name: impl Into, data_type: impl Into) -> Self { + /// Create a new column. + pub fn new(name: impl Into, data_type: impl Into) -> Self { Column { name: name.into(), data_type: data_type.into(), @@ -194,37 +406,58 @@ impl Column { } } - fn not_null(mut self) -> Self { + /// Make this column `NOT NULL`. + pub fn not_null(mut self) -> Self { self.constraints.push(ColumnConstraint::NotNull); self } - fn unique(mut self) -> Self { + /// Add a `UNIQUE` constraint to this column. + pub fn unique(mut self) -> Self { self.constraints.push(ColumnConstraint::Unique); self } - fn unique_named(mut self, name: String) -> Self { + /// Add a named `UNIQUE` constraint to this column. + pub fn unique_named(mut self, name: String) -> Self { self.constraints.push(ColumnConstraint::UniqueNamed(name)); self } - fn primary_key(mut self) -> Self { + /// Make this column the `PRIMARY KEY`. + pub fn primary_key(mut self) -> Self { self.constraints.push(ColumnConstraint::PrimaryKey); self } - fn foreign_key(mut self, table_name: String, column_name: String) -> Self { + /// Add a `FOREIGN KEY` constraint to this column. + pub fn foreign_key(mut self, table_name: String, column_name: String) -> Self { self.constraints .push(ColumnConstraint::ForeignKey(table_name, column_name)); self } + /// Add a raw `CHECK` to this column. + pub fn check(mut self, check: impl Into) -> Self { + self.constraints + .push(ColumnConstraint::RawCheck(check.into())); + + self + } + + /// Add a named raw `CHECK` to this column. + pub fn check_named(mut self, name: impl Into, check: impl Into) -> Self { + self.constraints + .push(ColumnConstraint::RawCheckNamed(name.into(), check.into())); + + self + } + fn up(&self) -> String { format!("{} {}", self.name, self.data_type) } @@ -238,7 +471,7 @@ impl Column { if self.data_type != other.data_type { statements.push(format!( - "ALTER COLUMN {} SET TYPE {}", + "ALTER COLUMN {} TYPE {}", self.name, self.data_type )); } @@ -366,16 +599,27 @@ where #[cfg(test)] mod tests { - use super::{Column, Schema, Table}; - - #[test] - fn migrate() { + use crate::pool::{fetch_client, Connection}; + + use super::{fetch_schema, try_migration_from, Column, Schema, Table}; + + #[tokio::test] + async fn migrate() -> Result<(), Box> { + Connection::build("postgres://postgres:postgres@localhost:5432") + .connect() + .await + .unwrap(); + let src = fetch_schema("public", &&fetch_client().await.unwrap()) + .await + .unwrap(); let dest = Schema::default().table( Table::new("book") .column(Column::new("id", "BIGINT").primary_key().not_null()) .column(Column::new("title", "TEXT").unique().not_null()), ); - dbg!(dest.migrate_from(&Schema::default())); + try_migration_from(&src, &dest, &&fetch_client().await?).await?; + + Ok(()) } } diff --git a/pg-worm/tests/connect.rs b/pg-worm/tests/connect.rs index a0dbdfb..a144a2a 100644 --- a/pg-worm/tests/connect.rs +++ b/pg-worm/tests/connect.rs @@ -1,15 +1,13 @@ #![allow(dead_code)] -use pg_worm::prelude::*; -use pg_worm::{force_create_table, pool::Connection, query::Transaction}; +use pg_worm::{migrate_tables, prelude::*}; +use pg_worm::{pool::Connection, query::Transaction}; #[derive(Model)] struct Book { #[column(primary_key, auto)] id: i64, title: String, - sub_title: Option, - pages: Vec, author_id: i64, } @@ -27,17 +25,9 @@ async fn complete_procedure() -> Result<(), pg_worm::Error> { .max_pool_size(16) .connect() .await?; - println!("Hello World!"); // Then, create the tables for your models. - // Use `register!` if you want to fail if a - // table with the same name already exists. - // - // `force_register` drops the old table, - // which is useful for development. - // - // If your tables already exist, skip this part. - force_create_table!(Author, Book).await?; + migrate_tables!(Book, Author).await?; // Next, insert some data. // This works by passing values for all @@ -45,25 +35,18 @@ async fn complete_procedure() -> Result<(), pg_worm::Error> { Author::insert("Stephen King").await?; Author::insert("Martin Luther King").await?; Author::insert("Karl Marx").await?; - Book::insert( - "Foo - Part I", - "Subtitle".to_string(), - vec!["Page 1".to_string()], - 1, - ) - .await?; - Book::insert("Foo - Part II", None, vec![], 2).await?; - Book::insert("Foo - Part III", None, vec![], 3).await?; + Book::insert("Foo - Part I", 1).await?; + Book::insert("Foo - Part II", 2).await?; + Book::insert("Foo - Part III", 3).await?; // Easily query for all books let books = Book::select().await?; - assert_eq!(books.len(), 3); + assert!(books.len() >= 3); // Or check whether your favorite book is listed, // along some other arbitrary conditions let manifesto = Book::select_one() .where_(Book::title.eq(&"The Communist Manifesto".into())) - .where_(Book::pages.contains(&"You have nothing to lose but your chains!".into())) .where_(Book::id.gt(&3)) .prepared() .await?; @@ -73,7 +56,7 @@ async fn complete_procedure() -> Result<(), pg_worm::Error> { let books_updated = Book::update() .set(Book::title, &"The name of this book is a secret".into()) .await?; - assert_eq!(books_updated, 3); + assert!(books_updated >= 3); // Or run a raw query which gets automagically parsed to `Vec`. // @@ -88,22 +71,7 @@ async fn complete_procedure() -> Result<(), pg_worm::Error> { vec![&"King".to_string()], ) .await?; - assert_eq!(king_books.len(), 2); - - // Or do some array operations - let page_1 = "Page 1".to_string(); - let page_2 = "Page 2".to_string(); - let pages = vec![&page_1, &page_2]; - - let any_page = Book::select_one() - .where_(Book::pages.contains_any(&pages)) - .await?; - assert!(any_page.is_some()); - - let both_pages = Book::select_one() - .where_(Book::pages.contains_all(&pages)) - .await?; - assert!(both_pages.is_none()); + assert!(king_books.len() >= 2); // You can even do transactions: let transaction = Transaction::begin().await?; @@ -115,7 +83,7 @@ async fn complete_procedure() -> Result<(), pg_worm::Error> { // Verify that they still exist *outside* the transaction: let all_books_outside_tx = Book::select().await?; - assert_eq!(all_books_outside_tx.len(), 3); + assert!(all_books_outside_tx.len() >= 3); // Commit the transaction transaction.commit().await?;