diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..defe515 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,2 @@ +* @pavlospt +* @nikoshet diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..bb2d2a4 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,48 @@ +name: CI Pipeline + +on: + pull_request: + branches: + - main + +concurrency: + group: '${{ github.workflow }} @ ${{ github.head_ref || github.ref }}' + cancel-in-progress: true + +jobs: + build: + name: cargo build + runs-on: ubuntu-latest + strategy: + fail-fast: true + matrix: + include: + - name: "library" + path: "." + - name: "client" + path: "rust-pgdatadiff-client" + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - name: Cargo Build ${{ matrix.name }} + run: cargo build + working-directory: ${{ matrix.path }} + test: + name: cargo test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - run: cargo test --all + format-and-clippy: + name: Cargo format & Clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + components: rustfmt, clippy + - name: Rustfmt Check + uses: actions-rust-lang/rustfmt@v1 + - name: Lint with Clippy + run: cargo clippy --all diff --git a/.github/workflows/git.yaml b/.github/workflows/git.yaml new file mode 100644 index 0000000..2d041b4 --- /dev/null +++ b/.github/workflows/git.yaml @@ -0,0 +1,30 @@ +name: Git Checks + +on: [pull_request] + +jobs: + block-fixup: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Block Fixup Commit Merge + uses: alexkappa/block-fixup-merge-action@v2 + add-assignee: + runs-on: ubuntu-latest + steps: + - uses: actions/github-script@v7 + with: + script: | + const issue = await github.rest.issues.get({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number + }); + if (issue.data.assignees.length === 0) { + await github.rest.issues.addAssignees({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + assignees: [context.actor] + }); + } diff --git a/.gitignore b/.gitignore index 6985cf1..b0e0373 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,12 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb + + +# Added by cargo + +/target +.idea/ +.DS_Store +postgres-data1/ +postgres-data2/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0093db4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "rust-pgdatadiff" +version = "0.1.2" +edition = "2021" +license = "MIT" +description = "Rust library for comparing two PostgreSQL databases" +readme = "README.md" +homepage = "https://github.com/pavlospt/rust-pgdatadiff" +repository = "https://github.com/pavlospt/rust-pgdatadiff" +keywords = ["postgres", "postgresql", "diff"] +documentation = "https://docs.rs/rust-pgdatadiff" + +[dependencies] +anyhow = "1.0.81" +tokio = { version = "1.36.0", features = ["full"] } +sqlx = { version = "0.7", features = ["runtime-tokio", "tls-native-tls", "postgres"] } +colored = "2.1.0" +futures = { version = "0.3.30", default-features = true, features = ["async-await"] } +env_logger = "0.11.3" +log = "0.4.21" +async-trait = "0.1.77" +pretty_assertions = "1.4.0" + +[dependencies.clap] +version = "4.5.2" +features = ["derive"] + +[dev-dependencies] +mockall = "0.12.1" +tokio = { version = "1.36.0", features = ["rt-multi-thread", "macros"] } + +[lib] +test = true +edition = "2021" +crate-type = ["lib"] +name = "rust_pgdatadiff" + +[workspace] +members = ["rust-pgdatadiff-client"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..1be8974 --- /dev/null +++ b/README.md @@ -0,0 +1,136 @@ +# Rust PGDataDiff + +`rust-pgdatadiff` is a re-write of the Python version of [pgdatadiff](https://github.com/dmarkey/pgdatadiff) + +## What makes it different? + +* It is schema aware right from the get-go, as when we had to use the original + `pgdatadiff` we ended up having different schemas that we needed to perform checks on. + +* It runs DB operations in a parallel fashion, + making it at least 3x faster in comparison to the original `pgdatadiff` which performs the checks sequentially. + +* It is written in Rust, which means that it is memory safe and has a very low overhead. + +* It provides both a library and a client, which means that you can use it as a standalone tool + and in your own projects. + +_The benchmarks below are based on DBs with 5 tables and 1M rows each. The results are as follows:_ + +## Python (sequential) +![python-timings](images/python.png) + +## Rust (parallel) +![rust-timings](images/rust.png) + +# Installation (Client) + +In case you want to use this as a client you can install it through `cargo`: + +```shell +cargo install rust-pgdatadiff-client +``` + +# Installation (Library) + +In case you want to use this as a library you can add it to your `Cargo.toml`: + +```shell +cargo add rust-pgdatadiff +``` + +or + +```toml +[dependencies] +rust-pgdatadiff = "0.1.0" +``` + +# Usage (Client) + +``` +Usage: rust-pgdatadiff-client diff [OPTIONS] + +Arguments: + postgres://postgres:postgres@localhost:5438/example + postgres://postgres:postgres@localhost:5439/example + +Options: + --only-tables Only compare data, exclude sequences + --only-sequences Only compare sequences, exclude data + --only-count Do a quick test based on counts alone + --chunk-size The chunk size when comparing data [default: 10000] + --max-connections Max connections for Postgres pool [default: 100] + -i, --include-tables [...] Tables included in the comparison + -e, --exclude-tables [...] Tables excluded from the comparison + --schema-name Schema name [default: public] + -h, --help Print help + -V, --version Print version +``` + +# Usage (Library) + +```rust +use rust_pgdatadiff::diff::diff_ops::Differ; +use rust_pgdatadiff::diff::diff_payload::DiffPayload; + +#[tokio::main] +async fn main() -> Result<()> { + let first_db = "postgres://postgres:postgres@localhost:5438/example"; + let second_db = "postgres://postgres:postgres@localhost:5439/example"; + + let payload = DiffPayload::new( + first_db.to_owned(), + second_db.to_owned(), + *only_data, + *only_sequences, + *only_count, + *chunk_size, + *max_connections, + included_tables.to_vec(), + schema_name.clone(), + ); + let diff_result = Differ::diff_dbs(payload).await; + // Handle `diff_result` in any way it fits your use case + Ok(()) +} +``` + +# Examples + +You can spin up two databases already prefilled with data through Docker Compose. + +```shell +docker compose up --build +``` + +Prefilled databases include a considerable amount of data + rows, so you can run benchmarks against them to check the +performance of it. You can modify a few of the generated data in order to see it in action. + +You can find an example of using it as a library in the [`examples`](./examples) directory. + +Run the example with the following command, after Docker Compose has started: + +```shell +cargo run --example example_diff diff \ + "postgresql://localhost:5438?dbname=example&user=postgres&password=postgres" \ + "postgresql://localhost:5439?dbname=example&user=postgres&password=postgres" +``` + +You can also enable Rust related logs by exporting the following: + +```shell +export RUST_LOG=rust_pgdatadiff=info +``` + +Switching from `info` to `debug` will give you more detailed logs. Also since we are utilizing +`sqlx` under the hood, you can enable `sqlx` logs by exporting the following: + +```shell +export RUST_LOG=rust_pgdatadiff=info,sqlx=debug +``` + +# Authors + +* [Pavlos-Petros Tournaris](https://github.com/pavlospt) +* [Nikolaos Nikitas](https://github.com/nikoshet) diff --git a/db/create_tables.sql b/db/create_tables.sql new file mode 100644 index 0000000..ba76b88 --- /dev/null +++ b/db/create_tables.sql @@ -0,0 +1,215 @@ +-- Creation of product table +CREATE TABLE IF NOT EXISTS product +( + product_id + INT + NOT + NULL, + name + varchar +( + 250 +) NOT NULL, + PRIMARY KEY +( + product_id +) + ); + +-- Creation of country table +CREATE TABLE IF NOT EXISTS country +( + country_id + INT + NOT + NULL, + country_name + varchar +( + 450 +) NOT NULL, + PRIMARY KEY +( + country_id +) + ); + +-- Creation of city table +CREATE TABLE IF NOT EXISTS city +( + city_id + INT + NOT + NULL, + city_name + varchar +( + 450 +) NOT NULL, + country_id INT NOT NULL, + PRIMARY KEY +( + city_id +), + CONSTRAINT fk_country + FOREIGN KEY +( + country_id +) + REFERENCES country +( + country_id +) + ); + +-- Creation of store table +CREATE TABLE IF NOT EXISTS store +( + store_id + INT + NOT + NULL, + name + varchar +( + 250 +) NOT NULL, + city_id INT NOT NULL, + PRIMARY KEY +( + store_id +), + CONSTRAINT fk_city + FOREIGN KEY +( + city_id +) + REFERENCES city +( + city_id +) + ); + +-- Creation of user table +CREATE TABLE IF NOT EXISTS users +( + user_id + INT + NOT + NULL, + name + varchar +( + 250 +) NOT NULL, + PRIMARY KEY +( + user_id +) + ); + +-- Creation of status_name table +CREATE TABLE IF NOT EXISTS status_name +( + status_name_id + INT + NOT + NULL, + status_name + varchar +( + 450 +) NOT NULL, + PRIMARY KEY +( + status_name_id +) + ); + +-- Creation of sale table +CREATE TABLE IF NOT EXISTS sale +( + sale_id + varchar +( + 200 +) NOT NULL, + amount DECIMAL +( + 20, + 3 +) NOT NULL, + date_sale TIMESTAMP, + product_id INT NOT NULL, + user_id INT NOT NULL, + store_id INT NOT NULL, + PRIMARY KEY +( + sale_id +), + CONSTRAINT fk_product + FOREIGN KEY +( + product_id +) + REFERENCES product +( + product_id +), + CONSTRAINT fk_user + FOREIGN KEY +( + user_id +) + REFERENCES users +( + user_id +), + CONSTRAINT fk_store + FOREIGN KEY +( + store_id +) + REFERENCES store +( + store_id +) + ); + +-- Creation of order_status table +CREATE TABLE IF NOT EXISTS order_status +( + order_status_id + varchar +( + 200 +) NOT NULL, + update_at TIMESTAMP, + sale_id varchar +( + 200 +) NOT NULL, + status_name_id INT NOT NULL, + PRIMARY KEY +( + order_status_id +), + CONSTRAINT fk_sale + FOREIGN KEY +( + sale_id +) + REFERENCES sale +( + sale_id +), + CONSTRAINT fk_status_name + FOREIGN KEY +( + status_name_id +) + REFERENCES status_name +( + status_name_id +) + ); diff --git a/db/fill_tables.sql b/db/fill_tables.sql new file mode 100644 index 0000000..5d372a9 --- /dev/null +++ b/db/fill_tables.sql @@ -0,0 +1,71 @@ +-- Set params +set session my.number_of_sales = '1000000'; +set session my.number_of_users = '1000000'; +set session my.number_of_products = '100000'; +set session my.number_of_stores = '100000'; +set session my.number_of_coutries = '100000'; +set session my.number_of_cities = '30000'; +set session my.status_names = '15'; +set session my.start_date = '2019-01-01 00:00:00'; +set session my.end_date = '2020-02-01 00:00:00'; + +-- load the pgcrypto extension to gen_random_uuid () +CREATE EXTENSION pgcrypto; + +-- Filling of products +INSERT INTO product +select id, concat('Product ', id) +FROM GENERATE_SERIES(1, current_setting('my.number_of_products')::int) as id; + +-- Filling of countries +INSERT INTO country +select id, concat('Country ', id) +FROM GENERATE_SERIES(1, current_setting('my.number_of_coutries')::int) as id; + +-- Filling of cities +INSERT INTO city +select id + , concat('City ', id) + , floor(random() * (current_setting('my.number_of_coutries')::int) + 1)::int +FROM GENERATE_SERIES(1, current_setting('my.number_of_cities')::int) as id; + +-- Filling of stores +INSERT INTO store +select id + , concat('Store ', id) + , floor(random() * (current_setting('my.number_of_cities')::int) + 1)::int +FROM GENERATE_SERIES(1, current_setting('my.number_of_stores')::int) as id; + +-- Filling of users +INSERT INTO users +select id + , concat('User ', id) +FROM GENERATE_SERIES(1, current_setting('my.number_of_users')::int) as id; + +-- Filling of users +INSERT INTO status_name +select status_name_id + , concat('Status Name ', status_name_id) +FROM GENERATE_SERIES(1, current_setting('my.status_names')::int) as status_name_id; + +-- Filling of sales +INSERT INTO sale +select gen_random_uuid () + , round(CAST(float8 (random() * 10000) as numeric), 3) + , TO_TIMESTAMP(start_date, 'YYYY-MM-DD HH24:MI:SS') + + random()* (TO_TIMESTAMP(end_date, 'YYYY-MM-DD HH24:MI:SS') + - TO_TIMESTAMP(start_date, 'YYYY-MM-DD HH24:MI:SS')) + , floor(random() * (current_setting('my.number_of_products')::int) + 1)::int + , floor(random() * (current_setting('my.number_of_users')::int) + 1)::int + , floor(random() * (current_setting('my.number_of_stores')::int) + 1)::int +FROM GENERATE_SERIES(1, current_setting('my.number_of_sales')::int) as id + , current_setting('my.start_date') as start_date + , current_setting('my.end_date') as end_date; + +-- Filling of order_status +INSERT INTO order_status +select gen_random_uuid () + , date_sale + random()* (date_sale + '5 days' - date_sale) + , sale_id + , floor(random() * (current_setting('my.status_names')::int) + 1)::int +from sale; diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6ddbc4d --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,40 @@ +services: + postgres1: + image: public.ecr.aws/docker/library/postgres:16.2 + restart: always + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=example + logging: + options: + max-size: 10m + max-file: "3" + ports: + - '5438:5432' + volumes: + - ./postgres-data1:/var/lib/postgresql/data + # copy the sql script to create tables + - ./db/create_tables.sql:/docker-entrypoint-initdb.d/create_tables.sql + # copy the sql script to fill tables + - ./db/fill_tables.sql:/docker-entrypoint-initdb.d/fill_tables.sql + + postgres2: + image: public.ecr.aws/docker/library/postgres:16.2 + restart: always + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=example + logging: + options: + max-size: 10m + max-file: "3" + ports: + - '5439:5432' + volumes: + - ./postgres-data2:/var/lib/postgresql/data + # copy the sql script to create tables + - ./db/create_tables.sql:/docker-entrypoint-initdb.d/create_tables.sql + # copy the sql script to fill tables + - ./db/fill_tables.sql:/docker-entrypoint-initdb.d/fill_tables.sql diff --git a/examples/example_diff.rs b/examples/example_diff.rs new file mode 100644 index 0000000..c1dfba5 --- /dev/null +++ b/examples/example_diff.rs @@ -0,0 +1,94 @@ +// Path: examples/example_diff.rs +extern crate anyhow; +extern crate clap; +extern crate env_logger; +extern crate rust_pgdatadiff; + +use anyhow::Result; +use clap::{Parser, Subcommand}; +use rust_pgdatadiff::diff::diff_ops::Differ; +use rust_pgdatadiff::diff::diff_payload::DiffPayload; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +#[command(propagate_version = true)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + #[command(about = "Print the version")] + Version, + Diff { + /// postgres://postgres:postgres@localhost:5438/example + first_db: String, + /// postgres://postgres:postgres@localhost:5439/example + second_db: String, + /// Only compare data, exclude sequences + #[arg(long, default_value_t = false, required = false)] + only_tables: bool, + /// Only compare sequences, exclude data + #[arg(long, default_value_t = false, required = false)] + only_sequences: bool, + /// Do a quick test based on counts alone + #[arg(long, default_value_t = false, required = false)] + only_count: bool, + /// The chunk size when comparing data + #[arg(long, default_value_t = 10000, required = false)] + chunk_size: i64, + /// Max connections for Postgres pool + #[arg(long, default_value_t = 100, required = false)] + max_connections: i64, + /// Tables included in the comparison + #[arg(short, long, value_delimiter = ',', num_args = 0.., required = false, conflicts_with = "exclude_tables")] + include_tables: Vec, + /// Tables excluded from the comparison + #[arg(short, long, value_delimiter = ',', num_args = 0.., required = false, conflicts_with = "include_tables")] + exclude_tables: Vec, + /// Schema name + #[arg(long, default_value = "public", required = false)] + schema_name: String, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let cli = Cli::parse(); + match &cli.command { + Commands::Version => { + println!("Version: {}", env!("CARGO_PKG_VERSION")); + Ok(()) + } + Commands::Diff { + first_db, + second_db, + only_tables, + only_sequences, + only_count, + chunk_size, + max_connections, + include_tables, + exclude_tables, + schema_name, + } => { + let payload = DiffPayload::new( + first_db.clone(), + second_db.clone(), + *only_tables, + *only_sequences, + *only_count, + *chunk_size, + *max_connections, + include_tables.to_vec(), + exclude_tables.to_vec(), + schema_name.clone(), + ); + let _ = Differ::diff_dbs(payload).await; + Ok(()) + } + } +} diff --git a/images/python.png b/images/python.png new file mode 100644 index 0000000..f8f362e Binary files /dev/null and b/images/python.png differ diff --git a/images/rust.png b/images/rust.png new file mode 100644 index 0000000..2bd56ee Binary files /dev/null and b/images/rust.png differ diff --git a/rust-pgdatadiff-client/Cargo.toml b/rust-pgdatadiff-client/Cargo.toml new file mode 100644 index 0000000..2a8658a --- /dev/null +++ b/rust-pgdatadiff-client/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rust-pgdatadiff-client" +version = "0.1.2" +edition = "2021" +license = "MIT" +description = "Rust client for comparing two PostgreSQL databases" +readme = "../README.md" +homepage = "https://github.com/pavlospt/rust-pgdatadiff" +repository = "https://github.com/pavlospt/rust-pgdatadiff" +keywords = ["postgres", "postgresql", "diff", "comparison"] +documentation = "https://docs.rs/rust-pgdatadiff-client" + +[dependencies] +anyhow = "1.0.81" +clap = { version = "4.5.2", features = ["derive"] } +tokio = "1.36.0" +env_logger = "0.11.3" +rust-pgdatadiff = { version = "0.1.2", path = ".." } diff --git a/rust-pgdatadiff-client/src/main.rs b/rust-pgdatadiff-client/src/main.rs new file mode 100644 index 0000000..fa90bad --- /dev/null +++ b/rust-pgdatadiff-client/src/main.rs @@ -0,0 +1,88 @@ +use anyhow::Result; +use clap::{Parser, Subcommand}; +use rust_pgdatadiff::diff::diff_ops::Differ; +use rust_pgdatadiff::diff::diff_payload::DiffPayload; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +#[command(propagate_version = true)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + #[command(about = "Print the version")] + Version, + Diff { + /// postgres://postgres:postgres@localhost:5438/example + first_db: String, + /// postgres://postgres:postgres@localhost:5439/example + second_db: String, + /// Only compare data, exclude sequences + #[arg(long, default_value_t = false, required = false)] + only_tables: bool, + /// Only compare sequences, exclude data + #[arg(long, default_value_t = false, required = false)] + only_sequences: bool, + /// Do a quick test based on counts alone + #[arg(long, default_value_t = false, required = false)] + only_count: bool, + /// The chunk size when comparing data + #[arg(long, default_value_t = 10000, required = false)] + chunk_size: i64, + /// Max connections for Postgres pool + #[arg(long, default_value_t = 100, required = false)] + max_connections: i64, + /// Tables included in the comparison + #[arg(short, long, value_delimiter = ',', num_args = 0.., required = false, conflicts_with = "exclude_tables")] + include_tables: Vec, + /// Tables excluded from the comparison + #[arg(short, long, value_delimiter = ',', num_args = 0.., required = false, conflicts_with = "include_tables")] + exclude_tables: Vec, + /// Schema name + #[arg(long, default_value = "public", required = false)] + schema_name: String, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let cli = Cli::parse(); + match &cli.command { + Commands::Version => { + println!("Version: {}", env!("CARGO_PKG_VERSION")); + Ok(()) + } + Commands::Diff { + first_db, + second_db, + only_tables, + only_sequences, + only_count, + chunk_size, + max_connections, + include_tables, + exclude_tables, + schema_name, + } => { + let payload = DiffPayload::new( + first_db.clone(), + second_db.clone(), + *only_tables, + *only_sequences, + *only_count, + *chunk_size, + *max_connections, + include_tables.to_vec(), + exclude_tables.to_vec(), + schema_name.clone(), + ); + let _ = Differ::diff_dbs(payload).await; + Ok(()) + } + } +} diff --git a/src/diff/db_clients.rs b/src/diff/db_clients.rs new file mode 100644 index 0000000..21fac0a --- /dev/null +++ b/src/diff/db_clients.rs @@ -0,0 +1,25 @@ +use sqlx::{Pool, Postgres}; + +/// This structure is holding 2 Postgres DB pools. +/// These will be used to query both the source and the destination databases. +pub struct DBClients { + first_db_pool: Pool, + second_db_pool: Pool, +} + +impl DBClients { + pub fn new(first_db_pool: Pool, second_db_pool: Pool) -> Self { + Self { + first_db_pool, + second_db_pool, + } + } + + pub fn first_db_pool(&self) -> Pool { + self.first_db_pool.clone() + } + + pub fn second_db_pool(&self) -> Pool { + self.second_db_pool.clone() + } +} diff --git a/src/diff/diff_ops.rs b/src/diff/diff_ops.rs new file mode 100644 index 0000000..046a6f5 --- /dev/null +++ b/src/diff/diff_ops.rs @@ -0,0 +1,131 @@ +use crate::diff::db_clients::DBClients; +use anyhow::Result; +use colored::Colorize; + +use crate::diff::diff_output::DiffOutput; +use log::info; +use sqlx::postgres::PgPoolOptions; +use sqlx::Executor; + +use crate::diff::diff_payload::DiffPayload; +use crate::diff::sequence::query::sequence_query_executor::{ + SequenceDualSourceQueryExecutorImpl, SequenceSingleSourceQueryExecutorImpl, +}; + +use crate::diff::sequence::sequence_differ::SequenceDiffer; +use crate::diff::table::query::table_query_executor::{ + TableDualSourceQueryExecutorImpl, TableSingleSourceQueryExecutorImpl, +}; + +use crate::diff::table::table_differ::TableDiffer; + +/// The `Differ` struct represents a database differ. +/// +/// It provides a method `diff_dbs` that performs the diffing operation between two databases. +pub struct Differ; + +impl Differ { + pub async fn diff_dbs(diff_payload: DiffPayload) -> Result> { + info!("{}", "Initiating DB diffing…".bold().blue()); + + let first_db_pool = PgPoolOptions::new() + .after_connect(|conn, _meta| { + Box::pin(async move { + conn.execute("SET application_name = 'rust-pgdatadiff';") + .await?; + Ok(()) + }) + }) + .max_connections(diff_payload.max_connections()) + .connect(diff_payload.first_db()) + .await + .expect("Failed to connect to first DB"); + + info!("{}", "Connected to first DB".magenta().bold()); + + let second_db_pool = PgPoolOptions::new() + .after_connect(|conn, _meta| { + Box::pin(async move { + conn.execute("SET application_name = 'rust-pgdatadiff';") + .await?; + Ok(()) + }) + }) + .max_connections(diff_payload.max_connections()) + .connect(diff_payload.second_db()) + .await + .expect("Failed to connect to second DB"); + + info!("{}", "Connected to second DB".magenta().bold()); + + let db_clients = DBClients::new(first_db_pool, second_db_pool); + + info!("{}", "Going for diff…".green().bold()); + + // Create a single source query executor for tables + let single_table_query_executor = + TableSingleSourceQueryExecutorImpl::new(db_clients.first_db_pool()); + + // Create a dual source query executor for tables + let dual_source_table_query_executor = TableDualSourceQueryExecutorImpl::new( + db_clients.first_db_pool(), + db_clients.second_db_pool(), + ); + + // Create a table differ + let table_differ = TableDiffer::new( + single_table_query_executor, + dual_source_table_query_executor, + ); + + // Create a single source query executor for sequences + let single_sequence_query_executor = + SequenceSingleSourceQueryExecutorImpl::new(db_clients.first_db_pool()); + + // Create a dual source query executor for sequences + let dual_source_sequence_query_executor = SequenceDualSourceQueryExecutorImpl::new( + db_clients.first_db_pool(), + db_clients.second_db_pool(), + ); + + // Create a sequence differ + let sequence_differ = SequenceDiffer::new( + single_sequence_query_executor, + dual_source_sequence_query_executor, + ); + + // Prepare diff output + let diff_output = if diff_payload.only_tables() { + // Load only tables diff + let original_table_diff = table_differ.diff_all_table_data(&diff_payload).await?; + original_table_diff.into_iter().collect::>() + } else if diff_payload.only_sequences() { + // Load only sequences diff + let original_sequence_diff = sequence_differ + .diff_all_sequences(diff_payload.schema_name().into()) + .await?; + original_sequence_diff + .into_iter() + .collect::>() + } else { + // Load both tables and sequences diff + let original_sequence_diff = + sequence_differ.diff_all_sequences(diff_payload.schema_name().into()); + + let original_table_diff = table_differ.diff_all_table_data(&diff_payload); + + let (table_diff, sequence_diff) = + futures::future::join(original_table_diff, original_sequence_diff).await; + + let table_diff: Vec = table_diff.unwrap(); + let sequence_diff: Vec = sequence_diff.unwrap(); + + table_diff + .into_iter() + .chain(sequence_diff.into_iter()) + .collect::>() + }; + + Ok(diff_output) + } +} diff --git a/src/diff/diff_output.rs b/src/diff/diff_output.rs new file mode 100644 index 0000000..5de708b --- /dev/null +++ b/src/diff/diff_output.rs @@ -0,0 +1,10 @@ +use crate::diff::sequence::query::output::SequenceDiffOutput; +use crate::diff::table::query::output::TableDiffOutput; + +/// The output of a diff operation. +/// This is used in order to have a common format for +/// both table and sequence diff outputs. +pub enum DiffOutput { + TableDiff(TableDiffOutput), + SequenceDiff(SequenceDiffOutput), +} diff --git a/src/diff/diff_payload.rs b/src/diff/diff_payload.rs new file mode 100644 index 0000000..701d33f --- /dev/null +++ b/src/diff/diff_payload.rs @@ -0,0 +1,120 @@ +/// Represents a payload for performing database diffs. +pub struct DiffPayload { + first_db: String, + second_db: String, + only_tables: bool, + only_sequences: bool, + only_count: bool, + chunk_size: i64, + max_connections: i64, + include_tables: Vec, + exclude_tables: Vec, + schema_name: String, +} + +impl DiffPayload { + /// Creates a new `DiffPayload` instance. + /// + /// # Arguments + /// + /// * `first_db` - The name of the first database. + /// * `second_db` - The name of the second database. + /// * `only_data` - A flag indicating whether to compare only data. + /// * `only_sequences` - A flag indicating whether to compare only sequences. + /// * `count_only` - A flag indicating whether to count differences only. + /// * `chunk_size` - The chunk size for processing large tables. + /// * `max_connections` - The maximum number of database connections to use. + /// * `include_tables` - A list of tables to include in the comparison. + /// * `exclude_tables` - A list of tables to exclude in the comparison. + /// * `schema_name` - The name of the schema to compare. + /// + /// # Returns + /// + /// A new `DiffPayload` instance. + #[allow(clippy::too_many_arguments)] + pub fn new( + first_db: impl Into, + second_db: impl Into, + only_tables: bool, + only_sequences: bool, + only_count: bool, + chunk_size: i64, + max_connections: i64, + include_tables: Vec>, + exclude_tables: Vec>, + schema_name: impl Into, + ) -> Self { + let has_included_tables = !include_tables.is_empty(); + let has_excluded_tables = !exclude_tables.is_empty(); + + if has_included_tables && has_excluded_tables { + panic!("Cannot include and exclude tables at the same time"); + } + + Self { + first_db: first_db.into(), + second_db: second_db.into(), + only_tables, + only_sequences, + only_count, + chunk_size, + max_connections, + include_tables: include_tables.into_iter().map(|t| t.into()).collect(), + exclude_tables: exclude_tables.into_iter().map(|t| t.into()).collect(), + schema_name: schema_name.into(), + } + } + + pub fn first_db(&self) -> &str { + &self.first_db + } + pub fn second_db(&self) -> &str { + &self.second_db + } + pub fn only_tables(&self) -> bool { + self.only_tables + } + pub fn only_sequences(&self) -> bool { + self.only_sequences + } + pub fn only_count(&self) -> bool { + self.only_count + } + pub fn chunk_size(&self) -> i64 { + self.chunk_size + } + pub fn max_connections(&self) -> u32 { + self.max_connections as u32 + } + pub fn included_tables(&self) -> &Vec { + &self.include_tables + } + pub fn excluded_tables(&self) -> &Vec { + &self.exclude_tables + } + pub fn schema_name(&self) -> &str { + &self.schema_name + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic = "Cannot include and exclude tables at the same time"] + fn test_new_diff_payload() { + _ = DiffPayload::new( + "first_db", + "second_db", + false, + false, + false, + 10000, + 10, + vec!["table1"], + vec!["table2"], + "schema_name", + ); + } +} diff --git a/src/diff/internal/mod.rs b/src/diff/internal/mod.rs new file mode 100644 index 0000000..e19bdb6 --- /dev/null +++ b/src/diff/internal/mod.rs @@ -0,0 +1 @@ +pub(crate) mod tests; diff --git a/src/diff/internal/tests/mod.rs b/src/diff/internal/tests/mod.rs new file mode 100644 index 0000000..9afecd7 --- /dev/null +++ b/src/diff/internal/tests/mod.rs @@ -0,0 +1,8 @@ +#[cfg(test)] +pub(crate) fn sanitize_raw_string(raw: impl Into) -> String { + raw.into() + .split_ascii_whitespace() + .map(|e| e.to_string()) + .reduce(|acc, s| format!("{acc} {s}")) + .unwrap() +} diff --git a/src/diff/mod.rs b/src/diff/mod.rs new file mode 100644 index 0000000..bb5b893 --- /dev/null +++ b/src/diff/mod.rs @@ -0,0 +1,9 @@ +pub(crate) mod db_clients; +pub mod diff_ops; +pub mod diff_output; +pub mod diff_payload; +#[cfg(test)] +mod internal; +pub mod sequence; +pub mod table; +pub mod types; diff --git a/src/diff/sequence/mod.rs b/src/diff/sequence/mod.rs new file mode 100644 index 0000000..7e7a69a --- /dev/null +++ b/src/diff/sequence/mod.rs @@ -0,0 +1,4 @@ +pub mod query; +pub mod sequence_differ; +#[cfg(test)] +mod sequence_differ_tests; diff --git a/src/diff/sequence/query/input/mod.rs b/src/diff/sequence/query/input/mod.rs new file mode 100644 index 0000000..c017115 --- /dev/null +++ b/src/diff/sequence/query/input/mod.rs @@ -0,0 +1,65 @@ +use crate::diff::sequence::query::sequence_types::SequenceName; +use crate::diff::types::SchemaName; + +/// Represents the input for querying the sequence names for a schema. +pub struct QueryAllSequencesInput(SchemaName); + +impl QueryAllSequencesInput { + /// Creates a new `QueryAllSequencesInput` with the given schema name. + /// + /// # Arguments + /// + /// * `schema_name` - The name of the schema to query. + /// + /// # Returns + /// + /// A new `QueryAllSequencesInput` instance. + pub fn new(schema_name: SchemaName) -> Self { + Self(schema_name) + } + + /// Returns the schema name. + /// + /// # Returns + /// + /// A reference to the schema name. + pub fn schema_name(self) -> SchemaName { + self.0 + } +} + +/// Represents the input for querying the last values of a sequence. +pub struct QueryLastValuesInput(SchemaName, SequenceName); + +impl QueryLastValuesInput { + /// Creates a new `QueryLastValuesInput` with the given sequence name. + /// + /// # Arguments + /// + /// * `sequence_name` - The name of the sequence to query. + /// + /// # Returns + /// + /// A new `QueryLastValuesInput` instance. + pub fn new(schema_name: SchemaName, sequence_name: SequenceName) -> Self { + Self(schema_name, sequence_name) + } + + /// Returns the schema name. + /// + /// # Returns + /// + /// A reference to the schema name. + pub fn schema_name(&self) -> &SchemaName { + &self.0 + } + + /// Returns the sequence name. + /// + /// # Returns + /// + /// A reference to the sequence name. + pub fn sequence_name(&self) -> &SequenceName { + &self.1 + } +} diff --git a/src/diff/sequence/query/mod.rs b/src/diff/sequence/query/mod.rs new file mode 100644 index 0000000..add8727 --- /dev/null +++ b/src/diff/sequence/query/mod.rs @@ -0,0 +1,5 @@ +pub mod input; +pub mod output; +pub mod sequence_query; +pub mod sequence_query_executor; +pub mod sequence_types; diff --git a/src/diff/sequence/query/output/mod.rs b/src/diff/sequence/query/output/mod.rs new file mode 100644 index 0000000..a65ffc1 --- /dev/null +++ b/src/diff/sequence/query/output/mod.rs @@ -0,0 +1,88 @@ +use crate::diff::diff_output::DiffOutput; +use crate::diff::types::DiffOutputMarker; +use colored::{ColoredString, Colorize}; +use std::fmt::Display; + +/// Represents the source of a sequence. +#[derive(Clone)] +pub enum SequenceSource { + First, + Second, +} + +impl Display for SequenceSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::First => write!(f, "first"), + Self::Second => write!(f, "second"), + } + } +} + +/// Represents the difference in count between two sequences. +#[derive(Clone)] +pub struct SequenceCountDiff(i64, i64); + +impl SequenceCountDiff { + /// Creates a new `SequenceCountDiff` instance with the given counts. + pub fn new(first: i64, second: i64) -> Self { + Self(first, second) + } + + /// Returns the count of the first sequence. + pub fn first(&self) -> i64 { + self.0 + } + + /// Returns the count of the second sequence. + pub fn second(&self) -> i64 { + self.1 + } +} + +#[derive(Clone)] +/// Represents the output of a sequence difference. +pub enum SequenceDiffOutput { + /// Indicates that there is no difference between the sequences. + NoDiff(String), + /// Indicates that a sequence does not exist in a specific source. + NotExists(String, SequenceSource), + /// Indicates a difference in count between the sequences. + Diff(String, SequenceCountDiff), +} + +impl SequenceDiffOutput { + /// Converts the `SequenceDiffOutput` to a colored string representation. + pub fn to_string(&self) -> ColoredString { + match self { + Self::NoDiff(sequence) => format!("{} - No difference\n", sequence).green().bold(), + Self::NotExists(sequence, source) => { + format!("{} - Does not exist in {}\n", sequence, source) + .red() + .bold() + .underline() + } + Self::Diff(sequence, diffs) => format!( + "Difference in sequence:{} - First: {}, Second: {}\n", + sequence, + diffs.first(), + diffs.second() + ) + .red() + .bold() + .underline(), + } + } +} + +impl DiffOutputMarker for SequenceDiffOutput { + fn convert(self) -> DiffOutput { + DiffOutput::SequenceDiff(self.clone()) + } +} + +impl From for DiffOutput { + fn from(val: SequenceDiffOutput) -> Self { + DiffOutput::SequenceDiff(val) + } +} diff --git a/src/diff/sequence/query/sequence_query.rs b/src/diff/sequence/query/sequence_query.rs new file mode 100644 index 0000000..3e90347 --- /dev/null +++ b/src/diff/sequence/query/sequence_query.rs @@ -0,0 +1,73 @@ +use crate::diff::sequence::query::sequence_types::SequenceName; +use crate::diff::types::SchemaName; +use std::fmt::{Display, Formatter}; + +/// Represents a query for retrieving information about sequences. +pub enum SequenceQuery { + /// Retrieves the last value of a specific sequence. + LastValue(SchemaName, SequenceName), + /// Retrieves all sequences in the database. + AllSequences(SchemaName), +} + +impl Display for SequenceQuery { + /// Formats the `SequenceQuery` as a string. + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::LastValue(schema_name, sequence_name) => { + write!( + f, + "SELECT last_value FROM {}.{};", + schema_name.name(), + sequence_name.name() + ) + } + SequenceQuery::AllSequences(schema_name) => { + write!( + f, + r#" + SELECT sequence_name + FROM information_schema.sequences + WHERE sequence_schema = '{}'; + "#, + schema_name.name() + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::diff::internal::tests::sanitize_raw_string; + + impl From for String { + fn from(value: SequenceQuery) -> Self { + value.to_string() + } + } + + #[test] + fn test_sequence_last_value_query() { + let schema_name = SchemaName::new("test_schema"); + let sequence_name = SequenceName::new("test_sequence"); + let last_value_query = SequenceQuery::LastValue(schema_name, sequence_name); + + assert_eq!( + last_value_query.to_string(), + "SELECT last_value FROM test_schema.test_sequence;" + ); + } + + #[test] + fn test_all_sequences_query() { + let schema_name = SchemaName::new("test_schema"); + let all_sequences_query = SequenceQuery::AllSequences(schema_name); + + assert_eq!( + sanitize_raw_string(all_sequences_query), + "SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = 'test_schema';" + ); + } +} diff --git a/src/diff/sequence/query/sequence_query_executor.rs b/src/diff/sequence/query/sequence_query_executor.rs new file mode 100644 index 0000000..db84d0e --- /dev/null +++ b/src/diff/sequence/query/sequence_query_executor.rs @@ -0,0 +1,175 @@ +/// This module contains the implementation of query executors for sequence-related operations. +/// It provides traits and structs for executing queries on a single data source and on dual data sources. +/// The single data source executor is responsible for querying sequence names. +/// The dual data source executor is responsible for querying sequence last values. +/// Both executors use the `sqlx` crate for interacting with the database. +/// +/// # Examples +/// +/// ```no_run +/// use sqlx::postgres::PgPool; +/// use rust_pgdatadiff::diff::sequence::query::sequence_query_executor::SequenceSingleSourceQueryExecutorImpl; +/// use rust_pgdatadiff::diff::sequence::query::sequence_query_executor::SequenceSingleSourceQueryExecutor; +/// use rust_pgdatadiff::diff::sequence::query::input::QueryAllSequencesInput; +/// use rust_pgdatadiff::diff::types::SchemaName; +/// use rust_pgdatadiff::diff::sequence::query::sequence_query_executor::SequenceDualSourceQueryExecutorImpl; +/// use rust_pgdatadiff::diff::sequence::query::sequence_query_executor::SequenceDualSourceQueryExecutor; +/// use rust_pgdatadiff::diff::sequence::query::sequence_types::SequenceName; +/// use rust_pgdatadiff::diff::sequence::query::input::QueryLastValuesInput; +/// +/// #[tokio::main] +/// async fn main() { +/// +/// let db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database") +/// .await +/// .unwrap(); +/// +/// // Create a single data source executor +/// let single_source_executor = SequenceSingleSourceQueryExecutorImpl::new(db_client); +/// +/// // Query sequence names +/// let schema_name = SchemaName::new("public".to_string()); +/// let table_names = single_source_executor +/// .query_sequence_names(QueryAllSequencesInput::new(schema_name)) +/// .await; +/// +/// // Create a dual data source executor +/// let first_db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database1") +/// .await +/// .unwrap(); +/// let second_db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database2") +/// .await +/// .unwrap(); +/// let dual_source_executor = SequenceDualSourceQueryExecutorImpl::new(first_db_client, second_db_client); +/// +/// // Query sequence last values +/// let sequence_name = SequenceName::new("public".to_string()); +/// let schema_name = SchemaName::new("public".to_string()); +/// let (first_count, second_count) = dual_source_executor +/// .query_sequence_last_values(QueryLastValuesInput::new(schema_name, sequence_name)) +/// .await; +/// } +/// ``` +use crate::diff::sequence::query::input::{QueryAllSequencesInput, QueryLastValuesInput}; +use crate::diff::sequence::query::sequence_query::SequenceQuery; + +use anyhow::Result; +use async_trait::async_trait; +use log::error; +use sqlx::{Pool, Postgres, Row}; + +#[cfg_attr(test, mockall::automock)] +#[async_trait] +pub trait SequenceSingleSourceQueryExecutor { + /// Queries the sequence names from the database. + /// + /// # Returns + /// + /// A vector of sequence names. + async fn query_sequence_names(&self, input: QueryAllSequencesInput) -> Vec; +} + +pub struct SequenceSingleSourceQueryExecutorImpl { + db_pool: Pool, +} + +impl SequenceSingleSourceQueryExecutorImpl { + pub fn new(db_pool: Pool) -> Self { + Self { db_pool } + } +} + +#[async_trait] +impl SequenceSingleSourceQueryExecutor for SequenceSingleSourceQueryExecutorImpl { + async fn query_sequence_names(&self, input: QueryAllSequencesInput) -> Vec { + let pool = &self.db_pool; + + let schema_name = input.schema_name(); + let sequence_query = SequenceQuery::AllSequences(schema_name); + + let query_binding = sequence_query.to_string(); + + sqlx::query(query_binding.as_str()) + .fetch_all(pool) + .await + .unwrap() + .into_iter() + .map(|row| row.try_get("sequence_name").unwrap()) + .collect::>() + } +} + +#[cfg_attr(test, mockall::automock)] +#[async_trait] +pub trait SequenceDualSourceQueryExecutor { + /// Executes a query to retrieve the last value of a sequence. + /// + /// # Arguments + /// + /// * `input` - The input parameters for the query. + /// + /// # Returns + /// + /// A tuple containing the result of the query as a `Result`. + async fn query_sequence_last_values( + &self, + input: QueryLastValuesInput, + ) -> (Result, Result); +} + +pub struct SequenceDualSourceQueryExecutorImpl { + first_db_pool: Pool, + second_db_pool: Pool, +} + +impl SequenceDualSourceQueryExecutorImpl { + pub fn new(first_db_pool: Pool, second_db_pool: Pool) -> Self { + Self { + first_db_pool, + second_db_pool, + } + } +} + +#[async_trait] +impl SequenceDualSourceQueryExecutor for SequenceDualSourceQueryExecutorImpl { + async fn query_sequence_last_values( + &self, + input: QueryLastValuesInput, + ) -> (Result, Result) { + let first_pool = &self.first_db_pool; + let second_pool = &self.second_db_pool; + + let sequence_query = SequenceQuery::LastValue( + input.schema_name().to_owned(), + input.sequence_name().to_owned(), + ); + + let query_binding = sequence_query.to_string(); + + let first_result = sqlx::query(query_binding.as_str()).fetch_one(first_pool); + + let second_result = sqlx::query(query_binding.as_str()).fetch_one(second_pool); + + let (first_result, second_result) = + futures::future::join(first_result, second_result).await; + + let first_count: Result = match first_result { + Ok(pg_row) => Ok(pg_row.try_get::("last_value").unwrap()), + Err(e) => { + error!("Error while fetching first sequence: {}", e); + Err(anyhow::anyhow!("Failed to fetch count for first sequence")) + } + }; + + let second_count: Result = match second_result { + Ok(pg_row) => Ok(pg_row.try_get::("last_value").unwrap()), + Err(e) => { + error!("Error while fetching second sequence: {}", e); + Err(anyhow::anyhow!("Failed to fetch count for second sequence")) + } + }; + + (first_count, second_count) + } +} diff --git a/src/diff/sequence/query/sequence_types.rs b/src/diff/sequence/query/sequence_types.rs new file mode 100644 index 0000000..b2f7ce5 --- /dev/null +++ b/src/diff/sequence/query/sequence_types.rs @@ -0,0 +1,12 @@ +#[derive(Clone)] +pub struct SequenceName(String); + +impl SequenceName { + pub fn new(name: impl Into) -> Self { + Self(name.into()) + } + + pub fn name(&self) -> String { + self.0.to_string() + } +} diff --git a/src/diff/sequence/sequence_differ.rs b/src/diff/sequence/sequence_differ.rs new file mode 100644 index 0000000..f62b94e --- /dev/null +++ b/src/diff/sequence/sequence_differ.rs @@ -0,0 +1,132 @@ +use anyhow::Result; +use colored::Colorize; + +use log::{debug, info}; + +use crate::diff::diff_output::DiffOutput; +use crate::diff::sequence::query::input::{QueryAllSequencesInput, QueryLastValuesInput}; +use crate::diff::sequence::query::output::{SequenceCountDiff, SequenceDiffOutput, SequenceSource}; +use tokio::time::Instant; + +use crate::diff::sequence::query::sequence_query_executor::{ + SequenceDualSourceQueryExecutor, SequenceSingleSourceQueryExecutor, +}; +use crate::diff::sequence::query::sequence_types::SequenceName; +use crate::diff::types::SchemaName; + +pub struct SequenceDiffer< + SQE: SequenceSingleSourceQueryExecutor, + DSQE: SequenceDualSourceQueryExecutor, +> { + single_sequence_query_executor: SQE, + dual_sequence_query_executor: DSQE, +} + +impl + SequenceDiffer +{ + pub fn new(single_sequence_query_executor: SQE, dual_sequence_query_executor: DSQE) -> Self { + Self { + single_sequence_query_executor, + dual_sequence_query_executor, + } + } + + pub async fn diff_all_sequences(&self, schema_name: String) -> Result> { + info!("{}", "Starting sequence analysis…".bold().yellow()); + let mut sequences = self.get_all_sequences(schema_name.to_owned()).await?; + + sequences.sort_by_key(|s| s.to_lowercase()); + + let sorted_sequences = sequences.to_owned(); + + let futures = sorted_sequences.iter().map(|sequence_name| async { + let start = Instant::now(); + + let schema_name = SchemaName::new(schema_name.to_owned()); + let sequence_name = SequenceName::new(sequence_name.to_owned()); + let input = QueryLastValuesInput::new(schema_name, sequence_name.to_owned()); + let (first_result, second_result) = self + .dual_sequence_query_executor + .query_sequence_last_values(input) + .await; + + debug!( + "{}", + format!("Analyzing sequence: {}", &sequence_name.name()) + .yellow() + .bold() + ); + + let sequence_diff_result = + Self::extract_result(sequence_name.name(), first_result, second_result); + + let elapsed = start.elapsed(); + debug!( + "{}", + format!("Sequence analysis completed in: {}ms", elapsed.as_millis()) + ); + debug!("##############################################"); + + sequence_diff_result + }); + + info!( + "{}", + "Waiting for total sequence analysis to complete…" + .yellow() + .bold() + ); + let start = Instant::now(); + let sequences_analysed = futures::future::join_all(futures).await; + let elapsed = start.elapsed(); + debug!( + "{}", + format!( + "Total sequence analysis completed in: {}ms", + elapsed.as_millis() + ) + .yellow() + .bold(), + ); + + for sequence_diff_result in &sequences_analysed { + info!("{}", sequence_diff_result.to_string()); + } + + Ok(sequences_analysed + .into_iter() + .map(|diff| diff.into()) + .collect()) + } + + pub async fn get_all_sequences(&self, schema_name: String) -> Result> { + let input = QueryAllSequencesInput::new(SchemaName::new(schema_name)); + let query_result = self + .single_sequence_query_executor + .query_sequence_names(input) + .await; + Ok(query_result) + } + + fn extract_result( + sequence_name: String, + first_result: Result, + second_result: Result, + ) -> SequenceDiffOutput { + match (first_result, second_result) { + (Ok(first_value), Ok(second_value)) => { + if first_value != second_value { + SequenceDiffOutput::Diff( + sequence_name, + SequenceCountDiff::new(first_value, second_value), + ) + } else { + SequenceDiffOutput::NoDiff(sequence_name) + } + } + (Err(_e), _) => SequenceDiffOutput::NotExists(sequence_name, SequenceSource::First), + (_, Err(_e)) => SequenceDiffOutput::NotExists(sequence_name, SequenceSource::Second), + } + } +} diff --git a/src/diff/sequence/sequence_differ_tests.rs b/src/diff/sequence/sequence_differ_tests.rs new file mode 100644 index 0000000..637c939 --- /dev/null +++ b/src/diff/sequence/sequence_differ_tests.rs @@ -0,0 +1,71 @@ +#[cfg(test)] +mod tests { + use crate::diff::diff_output::DiffOutput; + use crate::diff::sequence::query::output::SequenceDiffOutput; + use crate::diff::sequence::query::sequence_query_executor::{ + MockSequenceDualSourceQueryExecutor, MockSequenceSingleSourceQueryExecutor, + }; + use crate::diff::sequence::sequence_differ::SequenceDiffer; + + #[tokio::test] + async fn test_get_all_sequences() { + let mut single_source_query_executor = MockSequenceSingleSourceQueryExecutor::new(); + let dual_source_query_executor = MockSequenceDualSourceQueryExecutor::new(); + + single_source_query_executor + .expect_query_sequence_names() + .times(1) + .returning(|_| vec!["sequence1".to_string(), "sequence2".to_string()]); + + let sequence_differ = + SequenceDiffer::new(single_source_query_executor, dual_source_query_executor); + + let sequences = sequence_differ + .get_all_sequences("public".to_string()) + .await + .unwrap(); + + assert_eq!(sequences.len(), 2); + assert_eq!(sequences[0], "sequence1"); + assert_eq!(sequences[1], "sequence2"); + } + + #[tokio::test] + async fn test_diff_all_sequences() { + let mut single_source_query_executor = MockSequenceSingleSourceQueryExecutor::new(); + let mut dual_source_query_executor = MockSequenceDualSourceQueryExecutor::new(); + + single_source_query_executor + .expect_query_sequence_names() + .times(1) + .returning(|_| vec!["sequence1".to_string()]); + + dual_source_query_executor + .expect_query_sequence_last_values() + .times(1) + .returning(|_| (Ok(2), Ok(1))); + + let sequence_differ = + SequenceDiffer::new(single_source_query_executor, dual_source_query_executor); + + let sequences = sequence_differ + .diff_all_sequences("public".to_string()) + .await + .unwrap(); + let actual = sequences.first().unwrap(); + + assert_eq!(sequences.len(), 1); + assert!(matches!(actual, DiffOutput::SequenceDiff(_))); + match actual { + DiffOutput::SequenceDiff(sequence_diff_output) => match sequence_diff_output { + SequenceDiffOutput::Diff(sequence_name, sequence_count_diff) => { + assert_eq!("sequence1", sequence_name); + assert_eq!(1, sequence_count_diff.second()); + assert_eq!(2, sequence_count_diff.first()); + } + _ => panic!("Expected Diff"), + }, + _ => panic!("Expected SequenceDiff"), + } + } +} diff --git a/src/diff/table/mod.rs b/src/diff/table/mod.rs new file mode 100644 index 0000000..3992295 --- /dev/null +++ b/src/diff/table/mod.rs @@ -0,0 +1,5 @@ +pub mod query; +pub mod table_differ; + +#[cfg(test)] +mod table_differ_tests; diff --git a/src/diff/table/query/input/mod.rs b/src/diff/table/query/input/mod.rs new file mode 100644 index 0000000..3a5533d --- /dev/null +++ b/src/diff/table/query/input/mod.rs @@ -0,0 +1,123 @@ +use super::table_types::{TableName, TableOffset, TablePosition, TablePrimaryKeys}; +use crate::diff::types::SchemaName; + +/// Represents the input for querying the count of a table. +pub struct QueryTableCountInput { + schema_name: SchemaName, + table_name: TableName, +} + +impl QueryTableCountInput { + /// Creates a new `QueryTableCountInput` instance. + pub fn new(schema_name: SchemaName, table_name: TableName) -> Self { + Self { + schema_name, + table_name, + } + } + + pub fn schema_name(&self) -> &SchemaName { + &self.schema_name + } + + pub fn table_name(&self) -> &TableName { + &self.table_name + } +} + +/// Represents the input for querying table names. +pub struct QueryTableNamesInput { + schema_name: SchemaName, + included_tables: Vec, + excluded_tables: Vec, +} + +impl QueryTableNamesInput { + /// Creates a new `QueryTableNamesInput` instance. + pub fn new( + schema_name: SchemaName, + included_tables: Vec>, + excluded_tables: Vec>, + ) -> Self { + Self { + schema_name, + included_tables: included_tables.into_iter().map(|t| t.into()).collect(), + excluded_tables: excluded_tables.into_iter().map(|t| t.into()).collect(), + } + } + + pub fn schema_name(&self) -> &SchemaName { + &self.schema_name + } + + pub fn included_tables(&self) -> Vec { + self.included_tables.to_vec() + } + + pub fn excluded_tables(&self) -> Vec { + self.excluded_tables.to_vec() + } +} + +/// Represents the input for querying hash data. +pub struct QueryHashDataInput { + schema_name: SchemaName, + table_name: TableName, + primary_keys: TablePrimaryKeys, + position: TablePosition, + offset: TableOffset, +} + +impl QueryHashDataInput { + /// Creates a new `QueryHashDataInput` instance. + pub fn new( + schema_name: SchemaName, + table_name: TableName, + primary_keys: TablePrimaryKeys, + position: TablePosition, + offset: TableOffset, + ) -> Self { + Self { + schema_name, + table_name, + primary_keys, + position, + offset, + } + } + + pub fn schema_name(&self) -> SchemaName { + self.schema_name.clone() + } + + pub fn table_name(&self) -> TableName { + self.table_name.clone() + } + + pub fn primary_keys(&self) -> TablePrimaryKeys { + self.primary_keys.clone() + } + + pub fn position(&self) -> TablePosition { + self.position.clone() + } + + pub fn offset(&self) -> TableOffset { + self.offset.clone() + } +} + +/// Represents the input for querying primary keys. +pub struct QueryPrimaryKeysInput { + table_name: String, +} + +impl QueryPrimaryKeysInput { + pub fn new(table_name: String) -> Self { + Self { table_name } + } + + pub fn table_name(&self) -> String { + self.table_name.to_string() + } +} diff --git a/src/diff/table/query/mod.rs b/src/diff/table/query/mod.rs new file mode 100644 index 0000000..aac5b85 --- /dev/null +++ b/src/diff/table/query/mod.rs @@ -0,0 +1,5 @@ +pub mod input; +pub mod output; +pub mod table_query; +pub mod table_query_executor; +pub mod table_types; diff --git a/src/diff/table/query/output/mod.rs b/src/diff/table/query/output/mod.rs new file mode 100644 index 0000000..090ca3e --- /dev/null +++ b/src/diff/table/query/output/mod.rs @@ -0,0 +1,150 @@ +use colored::{ColoredString, Colorize}; +use std::fmt::Display; + +use crate::diff::diff_output::DiffOutput; +use crate::diff::types::DiffOutputMarker; +use std::time::Duration; + +/// Represents the source of a table (either the first or the second). +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone)] +pub enum TableSource { + First, + Second, +} + +impl Display for TableSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::First => write!(f, "first"), + Self::Second => write!(f, "second"), + } + } +} + +/// Represents the difference in table counts between two tables. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone)] +pub struct TableCountDiff(i64, i64); + +impl TableCountDiff { + /// Creates a new `TableCountDiff` instance with the given counts. + pub fn new(first: i64, second: i64) -> Self { + Self(first, second) + } + + pub fn first(&self) -> i64 { + self.0 + } + + pub fn second(&self) -> i64 { + self.1 + } +} + +/// Represents the output of a table difference. +#[cfg_attr(test, derive(PartialEq))] +#[derive(Debug, Clone)] +pub enum TableDiffOutput { + /// Indicates that there is no difference between the tables. + NoCountDiff(String, i64), + /// Indicates that there is no difference between the tables, along with the duration of the comparison. + NoDiffWithDuration(String, Duration), + /// Indicates that the table does not exist in a specific source. + NotExists(String, TableSource), + /// Indicates a difference in table counts. + Diff(String, TableCountDiff), + /// Indicates that no primary key was found in the table. + NoPrimaryKeyFound(String), + /// Indicates a difference in table data, along with the duration of the comparison. + DataDiffWithDuration(String, i64, i64, Duration), +} + +impl TableDiffOutput { + /// Determines whether the table difference should be skipped. + pub fn skip_table_diff(&self) -> bool { + matches!(self, Self::Diff(_, _) | Self::NotExists(_, _)) + } + + /// Converts the table difference output to a colored string. + pub fn to_string(&self) -> ColoredString { + match self { + Self::NoCountDiff(table, count) => { + format!("{} - No difference. Total rows: {}", table, count) + .green() + .bold() + } + Self::NotExists(table, source) => format!("{} - Does not exist in {}", table, source) + .red() + .bold() + .underline(), + Self::Diff(table, diffs) => format!( + "{} - First table rows: {}, Second table rows: {}", + table, + diffs.first(), + diffs.second() + ) + .red() + .bold(), + TableDiffOutput::NoPrimaryKeyFound(table) => { + format!("{} - No primary key found", table).red().bold() + } + TableDiffOutput::NoDiffWithDuration(table, duration) => { + format!("{} - No difference in {}ms", table, duration.as_millis()) + .green() + .bold() + } + TableDiffOutput::DataDiffWithDuration(table_name, position, offset, duration) => { + format!( + "{} - Data diff between rows [{},{}] - in {}ms", + table_name, + position, + offset, + duration.as_millis() + ) + .red() + .bold() + } + } + } +} + +impl DiffOutputMarker for TableDiffOutput { + fn convert(self) -> DiffOutput { + DiffOutput::TableDiff(self.clone()) + } +} + +impl From for DiffOutput { + fn from(val: TableDiffOutput) -> Self { + DiffOutput::TableDiff(val) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_skip_table_when_needed() { + let no_count_diff = TableDiffOutput::NoCountDiff("test".to_string(), 1000); + let not_exists = TableDiffOutput::NotExists("test".to_string(), TableSource::First); + let diff = TableDiffOutput::Diff("test".to_string(), TableCountDiff::new(1, 2)); + let no_primary_key = TableDiffOutput::NoPrimaryKeyFound("test".to_string()); + let no_diff_with_duration = + TableDiffOutput::NoDiffWithDuration("test".to_string(), Duration::from_millis(1)); + let data_diff_with_duration = TableDiffOutput::DataDiffWithDuration( + "test".to_string(), + 1, + 2, + Duration::from_millis(1), + ); + + assert!(not_exists.skip_table_diff()); + assert!(diff.skip_table_diff()); + assert!(!no_count_diff.skip_table_diff()); + assert!(!no_primary_key.skip_table_diff()); + assert!(!no_diff_with_duration.skip_table_diff()); + assert!(!data_diff_with_duration.skip_table_diff()); + } +} diff --git a/src/diff/table/query/table_query.rs b/src/diff/table/query/table_query.rs new file mode 100644 index 0000000..14e2964 --- /dev/null +++ b/src/diff/table/query/table_query.rs @@ -0,0 +1,179 @@ +use crate::diff::table::query::table_types::{ + IncludedExcludedTables, TableMode, TableName, TableOffset, TablePosition, TablePrimaryKeys, +}; +use crate::diff::types::SchemaName; +use std::fmt::Display; + +pub enum TableQuery { + AllTablesForSchema(SchemaName, IncludedExcludedTables), + CountRowsForTable(SchemaName, TableName), + FindPrimaryKeyForTable(TableName), + HashQuery( + SchemaName, + TableName, + TablePrimaryKeys, + TablePosition, + TableOffset, + ), +} + +impl Display for TableQuery { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::AllTablesForSchema(schema_name, included_excluded_tables) => { + let inclusion_exclusion_statement = match included_excluded_tables.table_mode() { + None => "".to_string(), + Some(table_mode) => match table_mode { + TableMode::Include => included_excluded_tables.inclusion_statement(), + TableMode::Exclude => included_excluded_tables.exclusion_statement(), + }, + }; + + write!( + f, + r#" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = '{}' + {} + "#, + schema_name.name(), + inclusion_exclusion_statement + ) + } + // https://stackoverflow.com/questions/7943233/fast-way-to-discover-the-row-count-of-a-table-in-postgresql + TableQuery::CountRowsForTable(schema_name, table_name) => { + write!( + f, + "SELECT count(*) FROM {}.{}", + schema_name.name(), + table_name.name() + ) + } + TableQuery::FindPrimaryKeyForTable(table_name) => write!( + f, + // language=postgresql + r#" + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{}'::regclass + AND i.indisprimary"#, + table_name.name() + ), + TableQuery::HashQuery( + schema_name, + table_name, + table_primary_keys, + table_position, + table_offset, + ) => { + write!( + f, + r#" + SELECT md5(array_agg(md5((t.*)::varchar))::varchar) + FROM ( + SELECT * + FROM {}.{} + ORDER BY {} limit {} offset {} + ) AS t + "#, + schema_name.name(), + table_name.name(), + table_primary_keys.keys(), + table_offset.offset(), + table_position.position(), + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn test_display_all_tables_for_schema_with_included_tables() { + let schema_name = SchemaName::new("public"); + let included_tables = vec!["table1".to_string(), "table2".to_string()]; + let excluded_tables: Vec = vec![]; + let included_excluded_tables = + IncludedExcludedTables::new(included_tables, excluded_tables); + let query = TableQuery::AllTablesForSchema(schema_name, included_excluded_tables); + let expected = r#" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name IN ('table1','table2') + "#; + assert_eq!(expected, query.to_string()); + } + + #[test] + fn test_display_all_tables_for_schema_with_excluded_tables() { + let schema_name = SchemaName::new("public"); + let included_tables: Vec = vec![]; + let excluded_tables = vec!["table1", "table2"]; + let included_excluded_tables = + IncludedExcludedTables::new(included_tables, excluded_tables); + let query = TableQuery::AllTablesForSchema(schema_name, included_excluded_tables); + let expected = r#" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name NOT IN ('table1','table2') + "#; + assert_eq!(expected, query.to_string()); + } + + #[test] + fn test_display_count_rows_for_table() { + let schema_name = SchemaName::new("public".to_string()); + let table_name = TableName::new("table1".to_string()); + let query = TableQuery::CountRowsForTable(schema_name, table_name); + let expected = "SELECT count(*) FROM public.table1"; + assert_eq!(expected, query.to_string()); + } + + #[test] + fn test_display_find_primary_key_for_table() { + let table_name = TableName::new("table1".to_string()); + let query = TableQuery::FindPrimaryKeyForTable(table_name); + let expected = r#" + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = 'table1'::regclass + AND i.indisprimary"#; + assert_eq!(expected, query.to_string()); + } + + #[test] + fn test_display_hash_query() { + let schema_name = SchemaName::new("public".to_string()); + let table_name = TableName::new("table1".to_string()); + let table_primary_keys = TablePrimaryKeys::new("id".to_string()); + let table_position = TablePosition::new(0); + let table_offset = TableOffset::new(100); + let query = TableQuery::HashQuery( + schema_name, + table_name, + table_primary_keys, + table_position, + table_offset, + ); + let expected = r#" + SELECT md5(array_agg(md5((t.*)::varchar))::varchar) + FROM ( + SELECT * + FROM public.table1 + ORDER BY id limit 100 offset 0 + ) AS t + "#; + assert_eq!(expected, query.to_string()); + } +} diff --git a/src/diff/table/query/table_query_executor.rs b/src/diff/table/query/table_query_executor.rs new file mode 100644 index 0000000..80a1db0 --- /dev/null +++ b/src/diff/table/query/table_query_executor.rs @@ -0,0 +1,287 @@ +/// This module contains the implementation of query executors for table-related operations. +/// It provides traits and structs for executing queries on a single data source and on dual data sources. +/// The single data source executor is responsible for querying table names and primary keys. +/// The dual data source executor is responsible for querying table counts and hash data. +/// Both executors use the `sqlx` crate for interacting with the database. +/// +/// # Examples +/// +/// ```no_run +/// use sqlx::postgres::PgPool; +/// use rust_pgdatadiff::diff::table::query::table_query_executor::{ +/// TableSingleSourceQueryExecutor, TableSingleSourceQueryExecutorImpl, +/// TableDualSourceQueryExecutor, TableDualSourceQueryExecutorImpl, +/// }; +/// use rust_pgdatadiff::diff::table::query::input::{QueryHashDataInput, QueryPrimaryKeysInput, QueryTableCountInput, QueryTableNamesInput};/// +/// use rust_pgdatadiff::diff::table::query::table_types::{TableName, TableOffset, TablePosition, TablePrimaryKeys}; +/// use rust_pgdatadiff::diff::types::SchemaName; +/// +/// #[tokio::main] +/// async fn main() { +/// // Create a single data source executor +/// let db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database") +/// .await +/// .unwrap(); +/// let single_source_executor = TableSingleSourceQueryExecutorImpl::new(db_client); +/// +/// // Query table names +/// let schema_name = SchemaName::new("public".to_string()); +/// let included_tables = vec!["table1", "table2"]; +/// let excluded_tables: Vec = vec![]; +/// let table_names = single_source_executor +/// .query_table_names(QueryTableNamesInput::new(schema_name, included_tables, excluded_tables)) +/// .await; +/// +/// // Query primary keys +/// let primary_keys = single_source_executor +/// .query_primary_keys(QueryPrimaryKeysInput::new("table1".to_string())) +/// .await; +/// +/// // Create a dual data source executor +/// let first_db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database1") +/// .await +/// .unwrap(); +/// let second_db_client: PgPool = PgPool::connect("postgres://user:password@localhost:5432/database2") +/// .await +/// .unwrap(); +/// let dual_source_executor = TableDualSourceQueryExecutorImpl::new(first_db_client, second_db_client); +/// +/// // Query table counts +/// let schema_name = SchemaName::new("public"); +/// let table_name = TableName::new("table1"); +/// let (first_count, second_count) = dual_source_executor +/// .query_table_count(QueryTableCountInput::new(schema_name, table_name)) +/// .await; +/// +/// // Query hash data +/// let schema_name = SchemaName::new("public"); +/// let table_name = TableName::new("table1"); +/// let primary_keys = TablePrimaryKeys::new("id"); +/// let table_position = TablePosition::new(0); +/// let table_offset = TableOffset::new(100); +/// let (first_hash, second_hash) = dual_source_executor +/// .query_hash_data(QueryHashDataInput::new(schema_name, table_name, primary_keys, table_position, table_offset)) +/// .await; +/// } +/// ``` +use anyhow::Result; +use async_trait::async_trait; +use sqlx::{Pool, Postgres, Row}; + +use crate::diff::table::query::input::{ + QueryHashDataInput, QueryPrimaryKeysInput, QueryTableCountInput, QueryTableNamesInput, +}; +use crate::diff::table::query::table_query::TableQuery; +use crate::diff::table::query::table_types::{IncludedExcludedTables, TableName}; + +#[cfg(test)] +use mockall::automock; + +#[cfg_attr(test, automock)] +#[async_trait] +/// This trait represents a query executor for a single source table. +pub trait TableSingleSourceQueryExecutor { + /// Queries the table names from the database. + /// + /// # Arguments + /// + /// * `input` - The input parameters for the query. + /// + /// # Returns + /// + /// A vector of table names. + async fn query_table_names(&self, input: QueryTableNamesInput) -> Vec; + + /// Queries the primary keys of a table from the database. + /// + /// # Arguments + /// + /// * `input` - The input parameters for the query. + /// + /// # Returns + /// + /// A vector of primary key column names. + async fn query_primary_keys(&self, input: QueryPrimaryKeysInput) -> Vec; +} + +pub struct TableSingleSourceQueryExecutorImpl { + db_client: Pool, +} + +impl TableSingleSourceQueryExecutorImpl { + pub fn new(db_client: Pool) -> Self { + Self { db_client } + } +} + +#[async_trait] +impl TableSingleSourceQueryExecutor for TableSingleSourceQueryExecutorImpl { + async fn query_table_names(&self, input: QueryTableNamesInput) -> Vec { + // Clone the database client + let pool = self.db_client.clone(); + + // Prepare the query for fetching table names + let all_tables_query = TableQuery::AllTablesForSchema( + input.schema_name().to_owned(), + IncludedExcludedTables::new(input.included_tables(), input.excluded_tables()), + ); + + // Fetch table names + let query_result = sqlx::query(all_tables_query.to_string().as_str()) + .bind(input.schema_name().name()) + .fetch_all(&pool) + .await + .unwrap_or(vec![]); + + // Map query results to [Vec] + query_result + .iter() + .map(|row| row.get("table_name")) + .collect::>() + } + + async fn query_primary_keys(&self, input: QueryPrimaryKeysInput) -> Vec { + // Clone the database client + let pool = self.db_client.clone(); + + // Prepare the query for primary keys fetching + let find_primary_key_query = + TableQuery::FindPrimaryKeyForTable(TableName::new(input.table_name())); + + // Fetch primary keys for the table + let query_result = sqlx::query(find_primary_key_query.to_string().as_str()) + .fetch_all(&pool) + .await + .unwrap_or(vec![]); + + // Map query results to [Vec] + query_result + .iter() + .map(|row| row.get("attname")) + .collect::>() + } +} + +#[cfg_attr(test, automock)] +#[async_trait] +/// This trait defines the methods for executing queries on a dual source table. +pub trait TableDualSourceQueryExecutor { + /// Executes a query to retrieve the count of rows in a table. + /// + /// # Arguments + /// + /// * `input` - The input parameters for the query. + /// + /// # Returns + /// + /// A tuple containing the result of the query as a `Result`. + async fn query_table_count(&self, input: QueryTableCountInput) -> (Result, Result); + + /// Executes a query to retrieve the hash data of a table. + /// + /// # Arguments + /// + /// * `input` - The input parameters for the query. + /// + /// # Returns + /// + /// A tuple containing the hash data as two `String` values. + async fn query_hash_data(&self, input: QueryHashDataInput) -> (String, String); +} + +pub struct TableDualSourceQueryExecutorImpl { + first_db_client: Pool, + second_db_client: Pool, +} + +impl TableDualSourceQueryExecutorImpl { + pub fn new(first_db_client: Pool, second_db_client: Pool) -> Self { + Self { + first_db_client, + second_db_client, + } + } +} + +#[async_trait] +impl TableDualSourceQueryExecutor for TableDualSourceQueryExecutorImpl { + async fn query_table_count(&self, input: QueryTableCountInput) -> (Result, Result) { + // Clone the database clients + let first_pool = self.first_db_client.clone(); + let second_pool = self.second_db_client.clone(); + + // Prepare the query for counting rows + let count_rows_query = TableQuery::CountRowsForTable( + input.schema_name().to_owned(), + input.table_name().to_owned(), + ); + + let count_query_binding = count_rows_query.to_string(); + + // Prepare count queries for both databases + let first_count = sqlx::query(count_query_binding.as_str()).fetch_one(&first_pool); + let second_count = sqlx::query(count_query_binding.as_str()).fetch_one(&second_pool); + + // Fetch counts for both databases + let count_fetch_futures = futures::future::join_all(vec![first_count, second_count]).await; + + let first_count = count_fetch_futures.first().unwrap(); + let second_count = count_fetch_futures.get(1).unwrap(); + + // Map count results to [anyhow::Result] + let first_count: Result = match first_count { + Ok(pg_row) => Ok(pg_row.try_get::("count").unwrap()), + Err(_e) => Err(anyhow::anyhow!("Failed to fetch count for first table")), + }; + + let second_count: Result = match second_count { + Ok(pg_row) => Ok(pg_row.try_get::("count").unwrap()), + Err(_e) => Err(anyhow::anyhow!("Failed to fetch count for second table")), + }; + + (first_count, second_count) + } + + async fn query_hash_data(&self, input: QueryHashDataInput) -> (String, String) { + // Clone the database clients + let first_pool = self.first_db_client.clone(); + let second_pool = self.second_db_client.clone(); + + // Prepare the query for fetching data hashes + let hash_query = TableQuery::HashQuery( + input.schema_name(), + input.table_name(), + input.primary_keys(), + input.position(), + input.offset(), + ); + + let hash_query_binding = hash_query.to_string(); + + // Prepare hash queries for both databases + let first_hash = sqlx::query(hash_query_binding.as_str()).fetch_one(&first_pool); + let second_hash = sqlx::query(hash_query_binding.as_str()).fetch_one(&second_pool); + + // Fetch hashes for both databases + let hash_fetch_futures = futures::future::join_all(vec![first_hash, second_hash]).await; + + let first_hash = hash_fetch_futures.first().unwrap(); + let second_hash = hash_fetch_futures.get(1).unwrap(); + + // Map hash results to [String] + let first_hash = match first_hash { + Ok(pg_row) => pg_row + .try_get::("md5") + .unwrap_or("not_available".to_string()), + Err(e) => e.to_string(), + }; + let second_hash = match second_hash { + Ok(pg_row) => pg_row + .try_get::("md5") + .unwrap_or("not_available".to_string()), + Err(e) => e.to_string(), + }; + + (first_hash, second_hash) + } +} diff --git a/src/diff/table/query/table_types.rs b/src/diff/table/query/table_types.rs new file mode 100644 index 0000000..7ebef30 --- /dev/null +++ b/src/diff/table/query/table_types.rs @@ -0,0 +1,184 @@ +#[derive(Clone)] +pub struct TableName(String); + +impl TableName { + pub fn new(name: impl Into) -> Self { + Self(name.into()) + } + + pub fn name(&self) -> &str { + &self.0 + } +} + +#[derive(Clone)] +pub struct TablePrimaryKeys(String); + +impl TablePrimaryKeys { + pub fn new(keys: impl Into) -> Self { + Self(keys.into()) + } + + pub fn keys(&self) -> &str { + &self.0 + } +} + +#[derive(Clone)] +pub struct TablePosition(i64); + +impl TablePosition { + pub fn new(position: i64) -> Self { + Self(position) + } + + pub fn position(&self) -> i64 { + self.0 + } +} + +#[derive(Clone)] +pub struct TableOffset(i64); + +impl TableOffset { + pub fn new(offset: i64) -> Self { + Self(offset) + } + + pub fn offset(&self) -> i64 { + self.0 + } +} + +pub struct IncludedExcludedTables { + included_tables: Vec, + excluded_tables: Vec, +} + +pub enum TableMode { + Include, + Exclude, +} + +impl IncludedExcludedTables { + pub fn new( + include_tables: Vec>, + exclude_tables: Vec>, + ) -> Self { + if !include_tables.is_empty() && !exclude_tables.is_empty() { + panic!("Cannot include and exclude tables at the same time"); + } + + Self { + included_tables: include_tables.into_iter().map(|t| t.into()).collect(), + excluded_tables: exclude_tables.into_iter().map(|t| t.into()).collect(), + } + } + + pub fn table_mode(&self) -> Option { + if self.has_included_tables() { + Some(TableMode::Include) + } else if self.has_excluded_tables() { + Some(TableMode::Exclude) + } else { + None + } + } + + pub fn exclusion_statement(&self) -> String { + if !self.has_excluded_tables() { + return String::new(); + } + + let joined_tables = self + .excluded_tables + .iter() + .map(|table| format!("'{}'", table)) + .collect::>() + .join(","); + + format!("AND table_name NOT IN ({})", joined_tables) + } + + pub fn inclusion_statement(&self) -> String { + if !self.has_included_tables() { + return String::new(); + } + + let joined_tables = self + .included_tables + .iter() + .map(|table| format!("'{}'", table)) + .collect::>() + .join(","); + + format!("AND table_name IN ({})", joined_tables) + } + + fn has_included_tables(&self) -> bool { + !self.included_tables.is_empty() + } + + fn has_excluded_tables(&self) -> bool { + !self.excluded_tables.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_included_tables_when_include_tables_not_empty() { + let included_tables = vec!["table1", "table2"]; + let excluded_tables: Vec = vec![]; + let included_excluded_tables = + IncludedExcludedTables::new(included_tables, excluded_tables); + + assert!(matches!( + included_excluded_tables.table_mode().unwrap(), + TableMode::Include + )); + + assert_eq!( + included_excluded_tables.inclusion_statement(), + "AND table_name IN ('table1','table2')" + ); + } + + #[test] + fn test_excluded_tables_when_exclude_tables_not_empty() { + let included_tables: Vec = vec![]; + let excluded_tables = vec!["table1", "table2"]; + let included_excluded_tables = + IncludedExcludedTables::new(included_tables, excluded_tables); + + assert!(matches!( + included_excluded_tables.table_mode().unwrap(), + TableMode::Exclude + )); + + assert_eq!( + included_excluded_tables.exclusion_statement(), + "AND table_name NOT IN ('table1','table2')" + ); + } + + #[test] + fn test_when_included_tables_and_excluded_tables_empty() { + let included_tables: Vec = vec![]; + let excluded_tables: Vec = vec![]; + let included_excluded_tables = + IncludedExcludedTables::new(included_tables, excluded_tables); + + assert_eq!(included_excluded_tables.inclusion_statement(), ""); + } + + #[test] + #[should_panic = "Cannot include and exclude tables at the same time"] + fn test_when_included_tables_and_excluded_tables_are_both_not_empty() { + let included_tables: Vec<&str> = vec!["table1"]; + let excluded_tables: Vec<&str> = vec!["table2"]; + _ = IncludedExcludedTables::new(included_tables, excluded_tables); + } +} diff --git a/src/diff/table/table_differ.rs b/src/diff/table/table_differ.rs new file mode 100644 index 0000000..b9ed4b0 --- /dev/null +++ b/src/diff/table/table_differ.rs @@ -0,0 +1,267 @@ +use crate::diff::diff_payload::DiffPayload; +use crate::diff::table::query::input::{ + QueryHashDataInput, QueryPrimaryKeysInput, QueryTableCountInput, QueryTableNamesInput, +}; +use crate::diff::table::query::output::{TableCountDiff, TableDiffOutput, TableSource}; + +use crate::diff::table::query::table_query_executor::{ + TableDualSourceQueryExecutor, TableSingleSourceQueryExecutor, +}; +use crate::diff::table::query::table_types::{ + TableName, TableOffset, TablePosition, TablePrimaryKeys, +}; +use anyhow::Result; +use colored::Colorize; +use log::{debug, info}; + +use crate::diff::diff_output::DiffOutput; +use crate::diff::types::SchemaName; +use std::time::Instant; + +pub struct TableDiffer { + single_table_query_executor: TQE, + dual_table_query_executor: DTQE, +} + +impl + TableDiffer +{ + pub fn new(single_table_query_executor: TQE, dual_table_query_executor: DTQE) -> Self { + Self { + single_table_query_executor, + dual_table_query_executor, + } + } + + pub async fn diff_all_table_data(&self, diff_payload: &DiffPayload) -> Result> { + info!("{}", "Starting data analysis…".yellow().bold()); + + let mut tables = self.get_all_tables(diff_payload).await?; + + tables.sort_by_key(|s| s.to_lowercase()); + + let sorted_tables = tables.to_owned(); + + let futures = sorted_tables.iter().map(|table_name| async { + let start = Instant::now(); + + // Start loading counts for table from both DBs + let query_count_input = QueryTableCountInput::new( + SchemaName::new(diff_payload.schema_name().to_string()), + TableName::new(table_name.to_string()), + ); + + let table_counts_start = Instant::now(); + let (first_result, second_result) = self + .dual_table_query_executor + .query_table_count(query_count_input) + .await; + + let table_counts_elapsed = table_counts_start.elapsed(); + debug!( + "Table counts for {} loaded in: {}ms", + table_name.clone(), + table_counts_elapsed.as_millis() + ); + + debug!( + "{}", + format!("Analyzing table: {}", table_name.clone()) + .yellow() + .bold() + ); + + // Start counts comparison + let table_diff_result = Self::extract_result(table_name, first_result, second_result); + + let elapsed = start.elapsed(); + debug!( + "{}", + format!("Table analysis completed in: {}ms", elapsed.as_millis()) + ); + + debug!("##############################################"); + + // If we only care about counts, return the result + if diff_payload.only_count() { + return table_diff_result; + } + + // If the diff result permits us to skip data comparison, return the result + if table_diff_result.skip_table_diff() { + return table_diff_result; + } + + let query_primary_keys_input = QueryPrimaryKeysInput::new(table_name.clone()); + + let primary_keys = self + .single_table_query_executor + .query_primary_keys(query_primary_keys_input) + .await; + + // If no primary keys found, return the result + if primary_keys.is_empty() { + let table_diff_result = TableDiffOutput::NoPrimaryKeyFound(table_name.clone()); + return table_diff_result; + } + + // Prepare the primary keys for the table + // Will be used for query ordering when hashing data + let primary_keys = primary_keys.as_slice().join(","); + + let total_rows = match table_diff_result { + TableDiffOutput::NoCountDiff(_, rows) => rows, + _ => { + // Since we do not expect to reach here, print the result and panic + panic!("Unexpected table diff result") + } + }; + + let schema_name = SchemaName::new(diff_payload.schema_name().to_string()); + let query_table_name = TableName::new(table_name.clone()); + let table_offset = TableOffset::new(diff_payload.chunk_size()); + let table_primary_keys = TablePrimaryKeys::new(primary_keys); + + let start = Instant::now(); + + if let Some(value) = self + .diff_table_data( + diff_payload, + schema_name, + query_table_name, + table_offset, + table_primary_keys, + total_rows, + start, + ) + .await + { + return value; + } + + let elapsed = start.elapsed(); + + TableDiffOutput::NoDiffWithDuration(table_name.clone(), elapsed) + }); + + info!( + "{}", + "Waiting for table analysis to complete…".yellow().bold() + ); + let start = Instant::now(); + let analysed_tables = futures::future::join_all(futures).await; + let elapsed = start.elapsed(); + info!( + "{}", + format!( + "Total table analysis completed in: {}ms", + elapsed.as_millis() + ) + .yellow() + .bold(), + ); + info!( + "{}", + format!("Total tables for row count check: {}", tables.len()) + .bright_blue() + .bold() + ); + + info!("##############################################"); + info!("{}", "Table analysis results 👇".bright_magenta().bold()); + + for table_diff_result in &analysed_tables { + info!("{}", table_diff_result.to_string()); + } + + info!("##############################################"); + + Ok(analysed_tables + .into_iter() + .map(|diff| diff.into()) + .collect()) + } + + pub async fn get_all_tables(&self, diff_payload: &DiffPayload) -> Result> { + let input = QueryTableNamesInput::new( + SchemaName::new(diff_payload.schema_name().to_string()), + diff_payload.included_tables().to_vec(), + diff_payload.excluded_tables().to_vec(), + ); + let tables = self + .single_table_query_executor + .query_table_names(input) + .await; + Ok(tables) + } + + fn extract_result( + table_name: &str, + first_result: Result, + second_result: Result, + ) -> TableDiffOutput { + match (first_result, second_result) { + (Ok(first_total_rows), Ok(second_total_rows)) => { + if first_total_rows != second_total_rows { + TableDiffOutput::Diff( + table_name.to_owned(), + TableCountDiff::new(first_total_rows, second_total_rows), + ) + } else { + TableDiffOutput::NoCountDiff(table_name.to_owned(), first_total_rows) + } + } + (Err(_e), _) => TableDiffOutput::NotExists(table_name.to_owned(), TableSource::First), + (_, Err(_e)) => TableDiffOutput::NotExists(table_name.to_owned(), TableSource::Second), + } + } + #[allow(clippy::too_many_arguments)] + async fn diff_table_data( + &self, + diff_payload: &DiffPayload, + schema_name: SchemaName, + query_table_name: TableName, + table_offset: TableOffset, + table_primary_keys: TablePrimaryKeys, + total_rows: i64, + start: Instant, + ) -> Option { + // Start data comparison + let mut position = 0; + while position <= total_rows { + let input = QueryHashDataInput::new( + schema_name.clone(), + query_table_name.clone(), + table_primary_keys.clone(), + TablePosition::new(position), + table_offset.clone(), + ); + + let hash_fetch_start = Instant::now(); + let (first_hash, second_hash) = + self.dual_table_query_executor.query_hash_data(input).await; + let hash_fetch_elapsed = hash_fetch_start.elapsed(); + debug!( + "Hashes for {} loaded in: {}ms", + query_table_name.name(), + hash_fetch_elapsed.as_millis() + ); + + // If hashes are different, return the result + if first_hash != second_hash { + let elapsed = start.elapsed(); + return Some(TableDiffOutput::DataDiffWithDuration( + query_table_name.name().to_string(), + position, + position + diff_payload.chunk_size(), + elapsed, + )); + } + + // Increase the position for the next iteration + position += diff_payload.chunk_size(); + } + + None + } +} diff --git a/src/diff/table/table_differ_tests.rs b/src/diff/table/table_differ_tests.rs new file mode 100644 index 0000000..0df1d1d --- /dev/null +++ b/src/diff/table/table_differ_tests.rs @@ -0,0 +1,169 @@ +#[cfg(test)] +mod tests { + use crate::diff::diff_output::DiffOutput; + use crate::diff::diff_payload::DiffPayload; + use crate::diff::table::query::output::TableDiffOutput; + use crate::diff::table::query::table_query_executor::{ + MockTableDualSourceQueryExecutor, MockTableSingleSourceQueryExecutor, + }; + use crate::diff::table::table_differ::TableDiffer; + + const EMPTY_STRING_VEC: Vec = Vec::new(); + + #[tokio::test] + async fn test_get_all_tables_from_table_differ() { + let mut single_source_query_executor = MockTableSingleSourceQueryExecutor::new(); + let dual_source_query_executor = MockTableDualSourceQueryExecutor::new(); + + single_source_query_executor + .expect_query_table_names() + .times(1) + .returning(|_| vec!["table1".to_string(), "table2".to_string()]); + + let table_differ = + TableDiffer::new(single_source_query_executor, dual_source_query_executor); + + let diff_payload = DiffPayload::new( + "first_db", + "second_db", + false, + false, + false, + 10000, + 10, + vec!["table1", "table2"], + EMPTY_STRING_VEC, + "schema_name", + ); + + let tables = table_differ.get_all_tables(&diff_payload).await.unwrap(); + + assert_eq!(tables.len(), 2); + assert_eq!(tables[0], "table1"); + assert_eq!(tables[1], "table2"); + } + + #[tokio::test] + async fn test_not_diff_table_data_from_table_differ_when_different_counts() { + let mut single_source_query_executor = MockTableSingleSourceQueryExecutor::new(); + let mut dual_source_query_executor = MockTableDualSourceQueryExecutor::new(); + + single_source_query_executor + .expect_query_table_names() + .times(1) + .returning(|_| vec!["table1".to_string()]); + + dual_source_query_executor + .expect_query_table_count() + .times(1) + .returning(|_| (Ok(2), Ok(1))); + + single_source_query_executor + .expect_query_primary_keys() + .times(0); + + dual_source_query_executor.expect_query_hash_data().times(0); + + let table_differ = + TableDiffer::new(single_source_query_executor, dual_source_query_executor); + + let diff_payload = DiffPayload::new( + "first_db", + "second_db", + false, + false, + false, + 10000, + 10, + vec!["table1", "table2"], + EMPTY_STRING_VEC, + "schema_name", + ); + + let diff_output = table_differ + .diff_all_table_data(&diff_payload) + .await + .unwrap(); + + assert_eq!(diff_output.len(), 1); + + let actual = diff_output.first().unwrap(); + + assert!(matches!(actual, DiffOutput::TableDiff(_))); + match actual { + DiffOutput::TableDiff(table_diff_output) => match table_diff_output { + TableDiffOutput::Diff(table_name, table_count_diff) => { + assert_eq!("table1", table_name); + assert_eq!(2, table_count_diff.first()); + assert_eq!(1, table_count_diff.second()); + } + _ => panic!("Expected TableDiffOutput::Diff"), + }, + _ => panic!("Expected DiffOutput::TableDiff"), + } + } + + #[tokio::test] + async fn test_diff_all_table_data_from_table_differ_when_same_counts() { + let mut single_source_query_executor = MockTableSingleSourceQueryExecutor::new(); + let mut dual_source_query_executor = MockTableDualSourceQueryExecutor::new(); + + single_source_query_executor + .expect_query_table_names() + .times(1) + .returning(|_| vec!["table1".to_string()]); + + dual_source_query_executor + .expect_query_table_count() + .times(1) + .returning(|_| (Ok(1), Ok(1))); + + single_source_query_executor + .expect_query_primary_keys() + .times(1) + .returning(|_| vec!["id".to_string()]); + + dual_source_query_executor + .expect_query_hash_data() + .times(1) + .returning(|_| ("hash1".to_string(), "hash2".to_string())); + + let table_differ = + TableDiffer::new(single_source_query_executor, dual_source_query_executor); + + let diff_payload = DiffPayload::new( + "first_db", + "second_db", + false, + false, + false, + 10000, + 10, + vec!["table1", "table2"], + EMPTY_STRING_VEC, + "schema_name", + ); + + let diff_output = table_differ + .diff_all_table_data(&diff_payload) + .await + .unwrap(); + + assert_eq!(diff_output.len(), 1); + + let actual = diff_output.first().unwrap(); + + assert!(matches!(actual, DiffOutput::TableDiff(_))); + match actual { + DiffOutput::TableDiff(diff_output) => match diff_output { + TableDiffOutput::DataDiffWithDuration(table_name, position, offset, _) => { + assert_eq!("table1", table_name); + assert_eq!(0, *position); + assert_eq!(10000, *offset); + } + _ => panic!("Expected TableDiffOutput::DataDiffWithDuration"), + }, + _ => panic!("Expected DiffOutput::TableDiff"), + } + } +} diff --git a/src/diff/types/mod.rs b/src/diff/types/mod.rs new file mode 100644 index 0000000..71159cd --- /dev/null +++ b/src/diff/types/mod.rs @@ -0,0 +1,18 @@ +use crate::diff::diff_output::DiffOutput; + +#[derive(Clone)] +pub struct SchemaName(String); + +impl SchemaName { + pub fn new(name: impl Into) -> Self { + Self(name.into()) + } + + pub fn name(&self) -> &str { + &self.0 + } +} + +pub trait DiffOutputMarker { + fn convert(self) -> DiffOutput; +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d0e98c5 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod diff;