Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Dec 18, 2024
1 parent 3aaafbe commit bca4a22
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 111 deletions.
160 changes: 116 additions & 44 deletions benchmarks/src/bin/tpcds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,30 @@ use test_utils::tpcds::tpcds_schemas;

/// Global list of TPC-DS table names
static TPCDS_TABLES: &[&str] = &[
"store_sales",
"catalog_sales",
"web_sales",
"store_returns",
"catalog_returns",
"web_returns",
"inventory",
"store",
"call_center",
"catalog_page",
"web_page",
"warehouse",
"catalog_returns",
"catalog_sales",
"customer",
"customer_address",
"customer_demographics",
"date_dim",
"household_demographics",
"income_band",
"inventory",
"item",
"promotion",
"reason",
"ship_mode",
"store",
"store_returns",
"store_sales",
"time_dim",
"warehouse",
"web_page",
"web_returns",
"web_sales",
"web_site",
];

#[cfg(all(feature = "snmalloc", feature = "mimalloc"))]
Expand All @@ -68,24 +71,30 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// Command-line options for the TPC-DS benchmark tool
#[derive(Debug, StructOpt)]
#[structopt(name = "tpcds", about = "TPC-DS Benchmark Tool.")]
/// Enum for TPC-DS command-line options
/// Includes options to run a single query or all queries
enum TpcdsOpt {
/// Run TPC-DS queries
/// Run a single TPC-DS query
Run(RunOpt),

/// Run all TPC-DS queries
QueryAll(QueryAllOpt),
}

/// Options for running TPC-DS queries
/// Options for running a single query
#[derive(Debug, StructOpt)]
pub struct RunOpt {
/// Query number (e.g., 1 for query1.sql)
/// The query number (e.g., 1 for query1.sql)
#[structopt(short, long)]
query: usize,

/// Path to the data directory containing Parquet files
/// The path to the data directory containing Parquet files
#[structopt(short, long)]
data_dir: String,
}

impl RunOpt {
/// Executes a single query
pub async fn run(&self) -> Result<()> {
let query_number = self.query;
let parquet_dir = &self.data_dir;
Expand All @@ -95,8 +104,12 @@ impl RunOpt {

println!("▶️ Running query {}", query_number);

// Compare DuckDB and DataFusion results
if let Err(e) = compare_duckdb_datafusion(&sql, parquet_dir).await {
// Create a new DuckDB connection and register tables
let conn = create_duckdb_connection(parquet_dir)?;
let ctx = create_tpcds_context(parquet_dir).await?;

// Compare results between DuckDB and DataFusion
if let Err(e) = compare_duckdb_datafusion(&sql, &conn, &ctx).await {
eprintln!("❌ Query {} failed: {}", query_number, e);
return Err(e);
}
Expand All @@ -106,7 +119,76 @@ impl RunOpt {
}
}

/// Unified function to register all TPC-DS tables in DataFusion's SessionContext
/// Options for running all queries
#[derive(Debug, StructOpt)]
pub struct QueryAllOpt {
/// The path to the data directory containing Parquet files
#[structopt(short, long)]
data_dir: String,
}

impl QueryAllOpt {
/// Executes all queries sequentially
pub async fn run(&self) -> Result<()> {
let parquet_dir = &self.data_dir;

println!("▶️ Running all TPC-DS queries...");

// Create a single DuckDB connection and register tables once
let conn = create_duckdb_connection(parquet_dir)?;
let ctx = create_tpcds_context(parquet_dir).await?;

// Iterate through query numbers 1 to 99 and execute each query
for query_number in 1..=99 {
match load_query(query_number) {
Ok(sql) => {
println!("▶️ Running query {}", query_number);

// Compare results between DuckDB and DataFusion
if let Err(e) = compare_duckdb_datafusion(&sql, &conn, &ctx).await {
eprintln!("❌ Query {} failed: {}", query_number, e);
continue;
}

println!("✅ Query {} passed.", query_number);
}
Err(e) => {
eprintln!("❌ Failed to load query {}: {}", query_number, e);
continue;
}
}
}

println!("✅ All TPC-DS queries completed.");
Ok(())
}
}

/// Creates a new DuckDB connection and registers all TPC-DS tables
fn create_duckdb_connection(parquet_dir: &str) -> Result<Connection> {
let conn = Connection::open_in_memory().map_err(|e| {
DataFusionError::Execution(format!("DuckDB connection error: {}", e))
})?;

for table in TPCDS_TABLES {
let path = format!("{}/{}.parquet", parquet_dir, table);
let sql = format!(
"CREATE TABLE {} AS SELECT * FROM read_parquet('{}')",
table, path
);
conn.execute(&sql, []).map_err(|e| {
DataFusionError::Execution(format!(
"Error registering table '{}': {}",
table, e
))
})?;
}

println!("✅ All TPC-DS tables registered in DuckDB.");
Ok(conn)
}

/// Registers all TPC-DS tables in DataFusion's SessionContext
async fn create_tpcds_context(parquet_dir: &str) -> Result<SessionContext> {
let ctx = SessionContext::new();

Expand All @@ -117,16 +199,17 @@ async fn create_tpcds_context(parquet_dir: &str) -> Result<SessionContext> {
.await?;
}

println!("✅ All TPC-DS tables registered in DataFusion.");
Ok(ctx)
}

/// Compare RecordBatch results from DuckDB and DataFusion
/// Compares results of a query between DuckDB and DataFusion
async fn compare_duckdb_datafusion(
sql: &str,
parquet_dir: &str,
conn: &Connection,
ctx: &SessionContext,
) -> Result<(), DataFusionError> {
let expected_batches = execute_duckdb_query(sql, parquet_dir)?;
let ctx = create_tpcds_context(parquet_dir).await?;
let expected_batches = execute_duckdb_query(sql, conn)?;
let actual_batches = execute_datafusion_query(sql, ctx).await?;
let expected_output = pretty_format_batches(&expected_batches)?.to_string();
let actual_output = pretty_format_batches(&actual_batches)?.to_string();
Expand All @@ -145,26 +228,8 @@ async fn compare_duckdb_datafusion(
Ok(())
}

/// Execute a query in DuckDB and return the results as RecordBatch
fn execute_duckdb_query(sql: &str, parquet_dir: &str) -> Result<Vec<RecordBatch>> {
let conn = Connection::open_in_memory().map_err(|e| {
DataFusionError::Execution(format!("DuckDB connection error: {}", e))
})?;

for table in TPCDS_TABLES {
let path = format!("{}/{}.parquet", parquet_dir, table);
let sql = format!(
"CREATE TABLE {} AS SELECT * FROM read_parquet('{}')",
table, path
);
conn.execute(&sql, []).map_err(|e| {
DataFusionError::Execution(format!(
"Error registering table '{}': {}",
table, e
))
})?;
}

/// Executes a query in DuckDB and returns the results as RecordBatch
fn execute_duckdb_query(sql: &str, conn: &Connection) -> Result<Vec<RecordBatch>> {
let mut stmt = conn.prepare(sql).map_err(|e| {
DataFusionError::Execution(format!("SQL preparation error: {}", e))
})?;
Expand All @@ -176,17 +241,18 @@ fn execute_duckdb_query(sql: &str, parquet_dir: &str) -> Result<Vec<RecordBatch>
Ok(batches)
}

/// Execute a query in DataFusion and return the results as RecordBatch
/// Executes a query in DataFusion and returns the results as RecordBatch
async fn execute_datafusion_query(
sql: &str,
ctx: SessionContext,
ctx: &SessionContext,
) -> Result<Vec<RecordBatch>> {
let df = ctx.sql(sql).await?;
df.collect().await
}

/// Loads the SQL file for a given query number
fn load_query(query_number: usize) -> Result<String> {
let query_path = format!("datafusion/core/tests/tpc-ds/{}.sql", query_number);
let query_path = format!("../datafusion/core/tests/tpc-ds/{}.sql", query_number);
fs::read_to_string(&query_path).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to load query {}: {}",
Expand All @@ -195,11 +261,17 @@ fn load_query(query_number: usize) -> Result<String> {
})
}

/// Main function
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();

// Parse command-line arguments
let opt = TpcdsOpt::from_args();

// Execute based on the selected option
match opt {
TpcdsOpt::Run(opt) => opt.run().await,
TpcdsOpt::QueryAll(opt) => opt.run().await,
}
}
Loading

0 comments on commit bca4a22

Please sign in to comment.