This folder contains the benchmark code to compare:
spark_resnet.py
: Uses predict_batch_udf to perform in-process prediction on the GPU.spark_resnet_triton.py
: Uses predict_batch_udf to send inference requests to Triton, which performs inference on the GPU.
Spark cannot change the task parallelism within a stage based on the resources required (i.e., multiple CPUs for preprocessing vs. single GPU for inference). Therefore, implementation (1) will limit to 1 task per GPU to enable one instance of the model on the GPU. In contrast, implementation (2) allows as many tasks to run in parallel as cores on the executor, since Triton handles inference on the GPU.
The workload consists of the following 4-step pipeline:
- Read binary JPEG image data from parquet
- Preprocess on CPU (decompress, resize, crop, normalize)
- Inference on GPU
- Write results to parquet
We used the ImageNet 2012 validation dataset containing 50,000 images, and a pre-trained PyTorch ResNet-50 model for classification. We used the prepare_dataset.py
script to load and prepare the ImageNet data into a binary parquet format to be read with Spark.
We used the spark-dl-torch
conda environment, setup following the README.
We tested on a local standalone cluster with 1 executor: 1 A6000 GPU, 16 cores, and 32GB of memory. The cluster can be started like so:
conda activate spark-dl-torch
export SPARK_HOME=</path/to/spark>
export MASTER=spark://$(hostname):7077
export SPARK_WORKER_INSTANCES=1
export CORES_PER_WORKER=16
export SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=1 \
-Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh"
${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 32G ${MASTER}
The Spark configurations we used for the two implementations can be found under bench_spark_resnet.sh
and bench_spark_resnet_triton.sh
respectively. The only differences are in the task parallelism, i.e. spark.task.resource.gpu.amount
and spark.task.cpus
.
End-to-end throughput of the two implementations (higher is better):