Skip to content

Commit

Permalink
feat: Reinstate disk-based shuffle (#47)
Browse files Browse the repository at this point in the history
* old old shuffle reader/writer

* old old shuffle reader/writer

* remove ray shuffle

* revert more changes

* save progress

* update expected plans

* remove unused code

* fix regression
  • Loading branch information
andygrove authored Nov 19, 2024
1 parent a86218c commit 31f8833
Show file tree
Hide file tree
Showing 36 changed files with 872 additions and 864 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ build = "build.rs"
datafusion = { version = "42.0.0", features = ["pyarrow", "avro"] }
datafusion-proto = "42.0.0"
futures = "0.3"
glob = "0.3.1"
log = "0.4"
prost = "0.13"
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
uuid = "1.11.0"

[build-dependencies]
prost-types = "0.13"
Expand Down
99 changes: 46 additions & 53 deletions datafusion_ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
from datafusion import SessionContext


def schedule_execution(
graph: ExecutionGraph,
stage_id: int,
is_final_stage: bool,
) -> list[ray.ObjectRef]:
stage = graph.get_query_stage(stage_id)
@ray.remote(num_cpus=0)
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int
) -> tuple[int, list[ray.ObjectRef]]:
"""
Execute a query stage on the workers.
Returns the stage ID, and a list of futures for the output partitions of the query stage.
"""
stage = QueryStage(stage_id, query_stages[stage_id])

# execute child stages first
# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = []
child_futures = []
for child_id in stage.get_child_stage_ids():
child_outputs.append((child_id, schedule_execution(graph, child_id, False)))
# child_outputs.append((child_id, schedule_execution(graph, child_id)))
child_futures.append(
execute_query_stage.remote(query_stages, child_id)
)

# if the query stage has a single output partition then we need to execute for the output
# partition, otherwise we need to execute in parallel for each input partition
concurrency = stage.get_input_partition_count()
output_partitions_count = stage.get_output_partition_count()
if is_final_stage:
if output_partitions_count == 1:
# reduce stage
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
concurrency = 1

Expand All @@ -55,50 +63,33 @@ def schedule_execution(
)
)

def _get_worker_inputs(
part: int,
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
ids = []
futures = []
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ids.append((child_stage_id, i, j))
futures.append(f)
elif concurrency == 1 or part == 0:
ids.append((child_stage_id, i, 0))
futures.append(lst)
return ids, futures
# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = ray.get(child_futures)

# if we are using disk-based shuffle, wait until the child stages to finish
# writing the shuffle files to disk first.
ray.get([f for _, lst in child_outputs for f in lst])

# schedule the actual execution workers
plan_bytes = stage.get_execution_plan_bytes()
futures = []
opt = {}
# TODO not sure why we had this but my Ray cluster could not find suitable resource
# until I commented this out
# opt["resources"] = {"worker": 1e-3}
opt["num_returns"] = output_partitions_count
for part in range(concurrency):
ids, inputs = _get_worker_inputs(part)
futures.append(
execute_query_partition.options(**opt).remote(
stage_id, plan_bytes, part, ids, *inputs
stage_id, plan_bytes, part
)
)
return futures

return stage_id, futures


@ray.remote
def execute_query_partition(
stage_id: int,
plan_bytes: bytes,
part: int,
input_partition_ids: list[tuple[int, int, int]],
*input_partitions: list[pa.RecordBatch],
part: int
) -> Iterable[pa.RecordBatch]:
start_time = time.time()
# plan = datafusion_ray.deserialize_execution_plan(plan_bytes)
Expand All @@ -109,13 +100,10 @@ def execute_query_partition(
# input_partition_ids,
# )
# )
partitions = [
(s, j, p) for (s, _, j), p in zip(input_partition_ids, input_partitions)
]
# This is delegating to DataFusion for execution, but this would be a good place
# to plug in other execution engines by translating the plan into another engine's plan
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
ret = datafusion_ray.execute_partition(plan_bytes, part, partitions)
ret = datafusion_ray.execute_partition(plan_bytes, part)
duration = time.time() - start_time
event = {
"cat": f"{stage_id}-{part}",
Expand Down Expand Up @@ -153,19 +141,24 @@ def sql(self, sql: str) -> pa.RecordBatch:
return []

df = self.df_ctx.sql(sql)
execution_plan = df.execution_plan()
return self.plan(df.execution_plan())

def plan(self, execution_plan: Any) -> pa.RecordBatch:

graph = self.ctx.plan(execution_plan)
final_stage_id = graph.get_final_query_stage().id()
partitions = schedule_execution(graph, final_stage_id, True)
# serialize the query stages and store in Ray object store
query_stages = [
graph.get_query_stage(i).get_execution_plan_bytes()
for i in range(final_stage_id + 1)
]
# schedule execution
future = execute_query_stage.remote(
query_stages,
final_stage_id
)
_, partitions = ray.get(future)
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set

def plan(self, physical_plan: Any) -> pa.RecordBatch:
graph = self.ctx.plan(physical_plan)
final_stage_id = graph.get_final_query_stage().id()
partitions = schedule_execution(graph, final_stage_id, True)
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set
69 changes: 5 additions & 64 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
// under the License.

use crate::planner::{make_execution_graph, PyExecutionGraph};
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec};
use datafusion::arrow::pyarrow::FromPyArrow;
use crate::shuffle::ShuffleCodec;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
Expand All @@ -31,7 +30,7 @@ use futures::StreamExt;
use prost::Message;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList, PyLong, PyTuple};
use pyo3::types::{PyBytes, PyTuple};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -117,22 +116,20 @@ impl PyContext {
&self,
plan: &Bound<'_, PyBytes>,
part: usize,
inputs: PyObject,
py: Python,
) -> PyResult<PyResultSet> {
execute_partition(plan, part, inputs, py)
execute_partition(plan, part, py)
}
}

#[pyfunction]
pub fn execute_partition(
plan_bytes: &Bound<'_, PyBytes>,
part: usize,
inputs: PyObject,
py: Python,
) -> PyResult<PyResultSet> {
let plan = deserialize_execution_plan(plan_bytes)?;
_execute_partition(plan, part, inputs)
_execute_partition(plan, part)
.unwrap()
.into_iter()
.map(|batch| batch.to_pyarrow(py))
Expand Down Expand Up @@ -170,59 +167,10 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
Ok(plan)
}

/// Iterate down an ExecutionPlan and set the input objects for RayShuffleReaderExec.
fn _set_inputs_for_ray_shuffle_reader(
plan: Arc<dyn ExecutionPlan>,
input_partitions: &Bound<'_, PyList>,
) -> Result<()> {
if let Some(reader_exec) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
let exec_stage_id = reader_exec.stage_id;
// iterate over inputs, wrap in PyBytes and set as input objects
for item in input_partitions.iter() {
let pytuple = item
.downcast::<PyTuple>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let stage_id = pytuple
.get_item(0)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyLong>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.extract::<usize>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
if stage_id != exec_stage_id {
continue;
}
let part = pytuple
.get_item(1)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyLong>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.extract::<usize>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let batch = RecordBatch::from_pyarrow_bound(
&pytuple
.get_item(2)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?,
)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
reader_exec.add_input_partition(part, batch)?;
}
} else {
for child in plan.children() {
_set_inputs_for_ray_shuffle_reader(child.to_owned(), input_partitions)?;
}
}
Ok(())
}

/// Execute a partition of a query plan. This will typically be executing a shuffle write and
/// write the results to disk, except for the final query stage, which will return the data.
/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
fn _execute_partition(
plan: Arc<dyn ExecutionPlan>,
part: usize,
inputs: PyObject,
) -> Result<Vec<RecordBatch>> {
fn _execute_partition(plan: Arc<dyn ExecutionPlan>, part: usize) -> Result<Vec<RecordBatch>> {
let ctx = Arc::new(TaskContext::new(
Some("task_id".to_string()),
"session_id".to_string(),
Expand All @@ -233,13 +181,6 @@ fn _execute_partition(
Arc::new(RuntimeEnv::default()),
));

Python::with_gil(|py| {
let input_partitions = inputs
.downcast_bound::<PyList>(py)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
_set_inputs_for_ray_shuffle_reader(plan.clone(), input_partitions)
})?;

// create a Tokio runtime to run the async code
let rt = Runtime::new().unwrap();

Expand Down
20 changes: 17 additions & 3 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::query_stage::PyQueryStage;
use crate::query_stage::QueryStage;
use crate::shuffle::{RayShuffleReaderExec, RayShuffleWriterExec};
use crate::shuffle::{ShuffleReaderExec, ShuffleWriterExec};
use datafusion::error::Result;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
Expand All @@ -29,6 +29,7 @@ use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use uuid::Uuid;

#[pyclass(name = "ExecutionGraph", module = "datafusion_ray", subclass)]
pub struct PyExecutionGraph {
Expand Down Expand Up @@ -200,11 +201,15 @@ fn create_shuffle_exchange(
// introduce shuffle to produce one output partition
let stage_id = graph.next_id();

// create temp dir for stage shuffle files
let temp_dir = create_temp_dir(stage_id)?;

let shuffle_writer_input = plan.clone();
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(RayShuffleWriterExec::new(
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(ShuffleWriterExec::new(
stage_id,
shuffle_writer_input,
partitioning_scheme.clone(),
&temp_dir,
));

debug!(
Expand All @@ -214,13 +219,22 @@ fn create_shuffle_exchange(

let stage_id = graph.add_query_stage(stage_id, shuffle_writer);
// replace the plan with a shuffle reader
Ok(Arc::new(RayShuffleReaderExec::new(
Ok(Arc::new(ShuffleReaderExec::new(
stage_id,
plan.schema(),
partitioning_scheme,
&temp_dir,
)))
}

fn create_temp_dir(stage_id: usize) -> Result<String> {
let uuid = Uuid::new_v4();
let temp_dir = format!("/tmp/ray-sql-{uuid}-stage-{stage_id}");
debug!("Creating temp shuffle dir: {temp_dir}");
std::fs::create_dir(&temp_dir)?;
Ok(temp_dir)
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
12 changes: 8 additions & 4 deletions src/proto/datafusion_ray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,29 @@ import "datafusion.proto";

message RaySqlExecNode {
oneof PlanType {
RayShuffleReaderExecNode ray_shuffle_reader = 3;
RayShuffleWriterExecNode ray_shuffle_writer = 4;
ShuffleReaderExecNode shuffle_reader = 1;
ShuffleWriterExecNode shuffle_writer = 2;
}
}

message RayShuffleReaderExecNode {
message ShuffleReaderExecNode {
// stage to read from
uint32 stage_id = 1;
// schema of the shuffle stage
datafusion_common.Schema schema = 2;
// this must match the output partitioning of the writer we are reading from
datafusion.PhysicalHashRepartition partitioning = 3;
// directory for shuffle files
string shuffle_dir = 4;
}

message RayShuffleWriterExecNode {
message ShuffleWriterExecNode {
// stage that is writing the shuffle files
uint32 stage_id = 1;
// plan to execute
datafusion.PhysicalPlanNode plan = 2;
// output partitioning schema
datafusion.PhysicalHashRepartition partitioning = 3;
// directory for shuffle files
string shuffle_dir = 4;
}
Loading

0 comments on commit 31f8833

Please sign in to comment.