Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Triton benchmarks for blog #509

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/ML+DL-Examples/Spark-DL/dl_inference/benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Batch Inference Benchmark

This folder contains the benchmark code to compare:
1. [`spark_resnet.py`](spark_resnet.py): Uses predict_batch_udf to perform in-process prediction on the GPU.
2. [`spark_resnet_triton.py`](spark_resnet_triton.py): Uses predict_batch_udf to send inference requests to Triton, which performs inference on the GPU.

Spark cannot change the task parallelism within a stage based on the resources required (i.e., multiple CPUs for preprocessing vs. single GPU for inference). Therefore, implementation (1) will limit to 1 task per GPU to enable one instance of the model on the GPU. In contrast, implementation (2) allows as many tasks to run in parallel as cores on the executor, since Triton handles inference on the GPU.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For resnet-50 could multiple model instances fit in the GPU? If so, might be good to benchmark that case, where multiple spark tasks run per GPU with each having its own model instance. Due to multiple processes, GPU compute will be time sliced so perf could be hit, but still interesting to compare.


<img src="../images/benchmark_comparison.png" alt="drawing" width="1000"/>

### Setup

The workload consists of the following 4-step pipeline:
1. Read binary JPEG image data from parquet
2. Preprocess on CPU (decompress, resize, crop, normalize)
3. Inference on GPU
4. Write results to parquet

<img src="../images/benchmark_pipeline.png" alt="drawing" width="800"/>

We used the [ImageNet 2012](https://image-net.org/challenges/LSVRC/2012/2012-downloads.php#Images) validation dataset containing 50,000 images, and a pre-trained [PyTorch ResNet-50](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html) model for classification. We used the [`prepare_dataset.py`](prepare_dataset.py) script to load and prepare the ImageNet data into a binary parquet format to be read with Spark.

### Environment

We used the `spark-dl-torch` conda environment, setup following the [README](../README.md#create-environment).
We tested on a local standalone cluster with 1 executor: 1 A6000 GPU, 16 cores, and 32GB of memory. The cluster can be started like so:
```shell
conda activate spark-dl-torch
export SPARK_HOME=</path/to/spark>
export MASTER=spark://$(hostname):7077
export SPARK_WORKER_INSTANCES=1
export CORES_PER_WORKER=16
export SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=1 \
-Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh"
${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 32G ${MASTER}
```

The Spark configurations we used for the two implementations can be found under [`bench_spark_resnet.sh`](bench_spark_resnet.sh) and [`bench_spark_resnet_triton.sh`](bench_spark_resnet_triton.sh) respectively. The only differences are in the task parallelism, i.e. `spark.task.resource.gpu.amount` and `spark.task.cpus`.

### Results

End-to-end throughput of the two implementations (higher is better):

<img src="../images/benchmark_results.png" alt="drawing" width="800"/>
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#! /bin/bash

spark-submit \
--master spark://$(hostname):7077 \
--num-executors 1 \
--executor-cores 16 \
--executor-memory 32g \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=1 \
--conf spark.task.cpus=16 \
--conf spark.task.maxFailures=1 \
--conf spark.sql.execution.arrow.pyspark.enabled=true \
--conf spark.python.worker.reuse=true \
--conf spark.pyspark.python=${CONDA_PREFIX}/bin/python \
--conf spark.pyspark.driver.python=${CONDA_PREFIX}/bin/python \
--conf spark.locality.wait=0s \
--conf spark.sql.adaptive.enabled=false \
--conf spark.sql.execution.sortBeforeRepartition=false \
--conf spark.sql.files.minPartitionNum=16 \
spark_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#! /bin/bash

spark-submit \
--master spark://$(hostname):7077 \
--num-executors 1 \
--executor-cores 16 \
--executor-memory 32g \
--conf spark.executor.resource.gpu.amount=1 \
--conf spark.task.resource.gpu.amount=0.0625 \
--conf spark.task.maxFailures=1 \
--conf spark.sql.execution.arrow.pyspark.enabled=true \
--conf spark.python.worker.reuse=true \
--conf spark.pyspark.python=${CONDA_PREFIX}/bin/python \
--conf spark.pyspark.driver.python=${CONDA_PREFIX}/bin/python \
--conf spark.locality.wait=0s \
--conf spark.sql.adaptive.enabled=false \
--conf spark.sql.execution.sortBeforeRepartition=false \
--conf spark.sql.files.minPartitionNum=16 \
spark_resnet_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import subprocess
from datetime import datetime
import os
class GPUMonitor:
def __init__(self, gpu_ids=[0], interval=1):
self.gpu_ids = gpu_ids
self.interval = interval
self.log_file = f"results/gpu_metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
self.process = None

def start(self):
if not os.path.exists("results"):
os.makedirs("results")
with open(self.log_file, 'w') as f:
f.write("timestamp,gpu_id,utilization\n")

cmd = f"""
while true; do
nvidia-smi --query-gpu=timestamp,index,utilization.gpu \
--format=csv,noheader,nounits \
-i {','.join(map(str, self.gpu_ids))} >> {self.log_file}
sleep {self.interval}
done
"""

self.process = subprocess.Popen(cmd, shell=True)
print(f"Started GPU monitoring, logging to {self.log_file}")

def stop(self):
if self.process:
self.process.terminate()
self.process.wait()
print("Stopped GPU monitoring")

try:
with open(self.log_file, 'r') as f:
next(f)

gpu_utils = {}
for line in f:
_, gpu_id, util = line.strip().split(',')
if gpu_id not in gpu_utils:
gpu_utils[gpu_id] = []
gpu_utils[gpu_id].append(float(util))

print("\nGPU Utilization Summary:")
for gpu_id, utils in gpu_utils.items():
avg_util = sum(utils) / len(utils)
max_util = max(utils)
median_util = np.median(utils)
print(f"GPU {gpu_id}:")
print(f" Average: {avg_util:.1f}%")
print(f" Median: {median_util:.1f}%")
print(f" Maximum: {max_util:.1f}%")
except Exception as e:
print(f"Error generating summary: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os
import tarfile
import pandas as pd
from pyspark.sql import SparkSession

def prepare_imagenet_parquet(size='50k', data_dir="spark-dl-datasets/imagenet-val"):
"""Prepare ImageNet validation set as parquet file with raw bytes."""

size_map = {
'1k': 1000,
'5k': 5000,
'10k': 10000,
'50k': 50000
}
num_images = size_map.get(size, 50000)

valdata_path = os.path.join(data_dir, 'ILSVRC2012_img_val.tar')
if not os.path.exists(valdata_path):
raise RuntimeError(
"ImageNet validation data not found. Please download:\n"
"ILSVRC2012_img_val.tar\n"
f"And place it in {data_dir}"
)

images = []
count = 0

# Write raw compressed JPEG bytes to parquet
with tarfile.open(valdata_path, 'r:') as tar:
members = tar.getmembers()
for _, member in enumerate(members):
if count >= num_images:
break

if member.isfile() and member.name.endswith(('.JPEG', '.jpg', '.jpeg')):
f = tar.extractfile(member)
if f is not None:
raw_bytes = f.read()
images.append(raw_bytes)
count += 1

if count % 100 == 0:
print(f"Processed {count} images")

pdf = pd.DataFrame({
'value': images
})
return pdf

def main():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--size', type=str, default='50k', help='Dataset size (e.g., 1k, 5k, 10k, 50k)')
args = parser.parse_args()

pdf = prepare_imagenet_parquet(size=args.size)
if not os.path.exists("spark-dl-datasets"):
os.makedirs("spark-dl-datasets")

pdf.to_parquet(f"spark-dl-datasets/imagenet_{args.size}.parquet")

# Repartition and write to parquet
spark = SparkSession.builder.appName("prepare-imagenet-parquet").getOrCreate()

spark.conf.set("spark.sql.execution.arrow.useLargeVarTypes", "true")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")

df = spark.read.parquet(f"spark-dl-datasets/imagenet_{args.size}.parquet")
df = df.repartition(16)
df.write.mode("overwrite").parquet(f"spark-dl-datasets/imagenet_{args.size}.parquet")

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os
import pandas as pd
import numpy as np
import time
import argparse
from typing import Iterator
from pyspark.sql.types import ArrayType, FloatType
from pyspark import TaskContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col
from pyspark.ml.functions import predict_batch_udf
from gpu_monitor import GPUMonitor

def predict_batch_fn():
"""Classify batch of images"""
import torch
import torchvision.models as models

start_load = time.perf_counter()
model = models.resnet50(pretrained=True).to("cuda")
model.eval()
end_load = time.perf_counter()
print(f"Model loaded in {end_load - start_load:.4f} seconds")

def predict(inputs):
print(f"PARTITION {TaskContext.get().partitionId()}: Inferring batch of size {len(inputs)}")
batch_tensor = torch.from_numpy(inputs).to("cuda")

with torch.no_grad():
outputs = model(batch_tensor)

_, predicted_ids = torch.max(outputs, 1)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
confidences = torch.max(probabilities, dim=1)[0]
indices = predicted_ids.cpu().numpy()
scores = confidences.cpu().numpy()
results = np.stack([indices, scores], axis=1).astype(np.float32)
return results

return predict

@pandas_udf(ArrayType(FloatType()))
def preprocess(image_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
"""Preprocess images (raw JPEG bytes) into a batch of tensors"""
import io
from PIL import Image
from torchvision import transforms
import torch
from pyspark import TaskContext

preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

part_id = TaskContext.get().partitionId()

for image_batch in image_iter:
batch_size = len(image_batch)
print(f"PARTITION {part_id}: number of images: {batch_size}")

# Pre-allocate tensor for batch
batch_tensor = torch.empty(batch_size, 3, 224, 224, dtype=torch.float32)

# Decompress and transform images
for idx, raw_bytes in enumerate(image_batch):
img = Image.open(io.BytesIO(raw_bytes))
if img.mode != 'RGB':
img = img.convert('RGB')
batch_tensor[idx] = preprocess(img)

numpy_batch = batch_tensor.numpy()
flattened_batch = numpy_batch.reshape(batch_size, -1)

yield pd.Series(list(flattened_batch))

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=str, default='50k', help='Dataset size (e.g., 1k, 5k, 10k, 50k)')
args = parser.parse_args()
spark = SparkSession.builder.appName("bench-spark-resnet").getOrCreate()

# Avoid OOM for image loading from raw byte arrays
spark.conf.set("spark.sql.execution.arrow.useLargeVarTypes", "true")
spark.conf.set("spark.sql.parquet.columnarReaderBatchSize", "1024")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "1024")

file_path = os.path.abspath(f"spark-dl-datasets/imagenet_{args.size}.parquet")
classify = predict_batch_udf(predict_batch_fn,
return_type=ArrayType(FloatType()),
input_tensor_shapes=[[3, 224, 224]],
batch_size=1024)

# Start GPU utilization monitoring
monitor = GPUMonitor()
monitor.start()

try:
start_read = time.perf_counter()

df = spark.read.parquet(file_path)
preprocessed_df = df.withColumn("images", preprocess(col("value"))).drop("value")
preds = preprocessed_df.withColumn("preds", classify(col("images")))
preds.write.mode("overwrite").parquet(f"spark-dl-datasets/imagenet_{args.size}_preds.parquet")

end_write = time.perf_counter()

print(f"E2E read -> inference -> write time: {end_write - start_read:.4f} seconds")
finally:
monitor.stop()

if __name__ == "__main__":
main()
Loading