Skip to content

Commit

Permalink
Replace refinery with our own backwards compatible migration code
Browse files Browse the repository at this point in the history
Currently server_tests must be run with --test-threads=1 because of two things:
1. port collision
2. concurrent migrations failing

I explored fixing these in #926. Ports are easily solved, but migrations are a pain

This migration code works around concurrency with two changes from refinery:
1. ignore unique constraint violation from CREATE TABLE IF NOT EXISTS
2. lock migration table so concurrent migrations are serialized

Considered submitting a PR to refinery with these two fixes,
but this simple change was non trivial since they support multiple async/sync database drivers
  • Loading branch information
serprex committed Dec 29, 2023
1 parent d8a0d0f commit 78fb996
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 115 deletions.
119 changes: 25 additions & 94 deletions nexus/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion nexus/catalog/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ prost = "0.12"
peer-cursor = { path = "../peer-cursor" }
peer-postgres = { path = "../peer-postgres" }
pt = { path = "../pt" }
refinery = { version = "0.8", features = ["tokio-postgres"] }
include_dir = { version = "0.7", default-features = false }
tokio = { version = "1.13.0", features = ["full"] }
tokio-postgres = { version = "0.7.6", features = [
"with-chrono-0_4",
Expand All @@ -21,4 +21,5 @@ tokio-postgres = { version = "0.7.6", features = [
] }
tracing = "0.1.29"
serde_json = "1.0"
siphasher = "1.0"
postgres-connection = { path = "../postgres-connection" }
147 changes: 127 additions & 20 deletions nexus/catalog/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::{collections::HashMap, sync::Arc};

use anyhow::{anyhow, Context};
use include_dir::{include_dir, Dir, File};
use peer_cursor::QueryExecutor;
use peer_postgres::PostgresQueryExecutor;
use postgres_connection::{connect_postgres, get_pg_connection_string};
Expand All @@ -11,31 +14,62 @@ use pt::{
peerdb_peers::{peer::Config, DbType, Peer},
};
use serde_json::Value;
use tokio_postgres::{types, Client};
use siphasher::sip::SipHasher13;
use tokio_postgres::{error::SqlState, types, Client};

mod embedded {
use refinery::embed_migrations;
embed_migrations!("migrations");
static MIGRATIONS: Dir<'_> = include_dir!("$CARGO_MANIFEST_DIR/migrations");

#[derive(Eq)]
struct Migration<'a> {
pub file: &'a File<'a>,
pub version: i32,
pub name: &'a str,
}

pub struct Catalog {
pg: Box<Client>,
executor: Arc<dyn QueryExecutor>,
impl<'a> Migration<'a> {
pub fn new(file: &'a File<'a>) -> anyhow::Result<Self> {
let Some(f) = file.path().to_str() else {
return Err(anyhow!("migration filename must be utf8"));
};
let Some(f) = f.strip_prefix('V') else {
return Err(anyhow!("migration name must start with V"));
};
let Some(__idx) = f.find("__") else {
return Err(anyhow!("migration name must contain __"));
};
let Ok(version) = f[..__idx].parse() else {
return Err(anyhow!("migration name must have number in between V & __"));
};
let name = &f[__idx + 2..];
Ok(Self {
file,
version,
name,
})
}
}

impl<'a> PartialEq for Migration<'a> {
fn eq(&self, other: &Self) -> bool {
self.version == other.version
}
}

impl<'a> Ord for Migration<'a> {
fn cmp(&self, other: &Self) -> Ordering {
self.version.cmp(&other.version)
}
}

async fn run_migrations(client: &mut Client) -> anyhow::Result<()> {
let migration_report = embedded::migrations::runner()
.run_async(client)
.await
.context("Failed to run migrations")?;
for migration in migration_report.applied_migrations() {
tracing::info!(
"Migration Applied - Name: {}, Version: {}",
migration.name(),
migration.version()
);
impl<'a> PartialOrd for Migration<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.version.cmp(&other.version))
}
Ok(())
}

pub struct Catalog {
pg: Box<Client>,
executor: Arc<dyn QueryExecutor>,
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -86,7 +120,80 @@ impl Catalog {
}

pub async fn run_migrations(&mut self) -> anyhow::Result<()> {
run_migrations(&mut self.pg).await
let mut migrations = MIGRATIONS
.files()
.map(Migration::new)
.collect::<anyhow::Result<Vec<_>>>()?;
migrations.sort();
let tx = self.pg.transaction().await?;
let create = tx
.query(
"create table if not exists refinery_schema_history(\
version int4 primary key, name text, applied_on text, checksum text)",
&[],
)
.await;
if let Err(err) = create {
if err.code() != Some(&SqlState::UNIQUE_VIOLATION) {
return Err(err.into());
}
}

tx.execute(
"lock table refinery_schema_history in share update exclusive mode",
&[],
)
.await?;
let rows = tx
.query(
"select version, name from refinery_schema_history order by version",
&[],
)
.await?;
let mut applied = rows
.iter()
.map(|row| (row.get::<usize, i32>(0), row.get::<usize, &str>(1)));

for migration in migrations {
if let Some((applied_version, applied_name)) = applied.next() {
if migration.version != applied_version {
return Err(anyhow!(
"Migration version mismatch: {} & {}",
migration.version,
applied_version
));
}
if migration.name != applied_name {
return Err(anyhow!(
"Migration name mismatch: '{}' & '{}'",
migration.name,
applied_name
));
}
continue;
}
let Some(sql) = migration.file.contents_utf8() else {
return Err(anyhow!("migration sql must be utf8"))
};
let checksum = {
let mut hasher = SipHasher13::new();
migration.name.hash(&mut hasher);
migration.version.hash(&mut hasher);
sql.hash(&mut hasher);
hasher.finish()
};

tx.batch_execute(sql).await?;
tx.execute("insert into refinery_schema_history (version, name, applied_on, checksum) values ($1, $2, NOW(), $3)",
&[&migration.version, &migration.name, &checksum.to_string()]).await?;
tracing::info!(
"Migration Applied: {} {}",
migration.version,
migration.name
);
}

tx.commit().await.map_err(|err| err.into())
}

pub fn get_executor(&self) -> &Arc<dyn QueryExecutor> {
Expand Down

0 comments on commit 78fb996

Please sign in to comment.