Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized preproc function using CuPy #402

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,113 @@
# YOLOX-Bytetrack Algorithm Optimization with CuPy

#### This repository contains an optimized version of the YOLOX-Bytetrack algorithm.

## Abstract
The primary enhancement involves the use of CuPy instead of NumPy for the `preproc` function, resulting in significant performance improvements for the preprocessing stage. The changes also include the utilization of multithreading for parallel processing of multiple images.

## Key Improvements
### 1. CuPy Integration
The original `preproc` function used NumPy for various operations, which are now replaced with CuPy to leverage GPU acceleration. This change drastically reduces the preprocessing time, especially when dealing with large batches of images.

**Original `preproc` Function**
```python
def preproc(image, input_size, mean, std, swap=(2, 0, 1)):
if len(image.shape) == 3:
padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = np.ones(input_size) * 114.0
img = np.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32)
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img

padded_img = padded_img[:, :, ::-1]
padded_img /= 255.0
if mean is not None:
padded_img -= mean
if std is not None:
padded_img /= std
padded_img = padded_img.transpose(swap)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r
```

**Optimized preproc Function with CuPy**
```python
def preproc_with_cupy(image, input_size, mean, std, swap=(2, 0, 1)):
device = cp.cuda.Device(0)
device.use()

if len(image.shape) == 3:
padded_img = cp.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = cp.ones(input_size) * 114.0

img = cp.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])

target_height = int(img.shape[0] * r)
target_width = int(img.shape[1] * r)

if target_height <= 0 or target_width <= 0:
raise ValueError(f"Invalid target size: ({target_width}, {target_height})")

resized_img = cp.array(cv2.resize(
cp.asnumpy(img),
(target_width, target_height),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32))

if len(image.shape) == 3:
padded_img[:target_height, :target_width, :] = resized_img
else:
padded_img[:target_height, :target_width] = resized_img

padded_img = padded_img[:, :, ::-1] / 255.0 # BGR to RGB and normalize

if mean is not None:
mean_array = cp.array(mean).reshape(1, 1, 3)
padded_img -= mean_array

if std is not None:
std_array = cp.array(std).reshape(1, 1, 3)
padded_img /= std_array

padded_img = padded_img.transpose(swap)
padded_img = cp.ascontiguousarray(padded_img, dtype=cp.float32)
return padded_img, r
```
### 2. Multithreading for Image Processing

To further enhance performance, the process_images method now uses multithreading to preprocess multiple images in parallel. This change utilizes Python's ThreadPoolExecutor to handle image preprocessing concurrently.

**Added process_images Method**

```python
def process_images(self, image_list, input_size, mean, std, swap=(2, 0, 1)):
with ThreadPoolExecutor() as executor:
futures= [executor.submit(preproc_with_cupy, img, input_size, mean, std, swap) for img in image_list]
results = [future.result() for future in futures]
if results:
return results
```
## Result
The integration of CuPy and the use of multithreading have significantly improved the preprocessing time for image batches. Below is a comparison of the preprocessing time before and after the optimization:
The tests resulted in an FPS increase of around 1.5X-2X, depending on the graphics card used and the type of model.

**Before optimization**
<p align="center"><img src="assets/without_cupy.png" width="500"/></p>

**After optimization**
<p align="center"><img src="assets/with_cupy.png" width="500"/></p>

------------------


# ByteTrack

[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bytetrack-multi-object-tracking-by-1/multi-object-tracking-on-mot17)](https://paperswithcode.com/sota/multi-object-tracking-on-mot17?p=bytetrack-multi-object-tracking-by-1)
Expand Down
Binary file added assets/with_cupy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/without_cupy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 29 additions & 12 deletions tools/demo_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import time
import cv2
import torch

from loguru import logger
import cupy as cp

import sys
sys.path.append('.')


from yolox.data.data_augment import preproc
from yolox.data.data_augment import preproc, preproc_with_cupy
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer

from concurrent.futures import ThreadPoolExecutor

IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]


def make_parser():
parser = argparse.ArgumentParser("ByteTrack Demo!")
parser.add_argument(
Expand Down Expand Up @@ -100,7 +102,6 @@ def get_image_list(path):
image_names.append(apath)
return image_names


def write_results(filename, results):
save_format = '{frame},{id},{x1},{y1},{w},{h},{s},-1,-1,-1\n'
with open(filename, 'w') as f:
Expand All @@ -113,15 +114,14 @@ def write_results(filename, results):
f.write(line)
logger.info('save results to {}'.format(filename))


class Predictor(object):
def __init__(
self,
model,
exp,
trt_file=None,
decoder=None,
device=torch.device("cpu"),
device=None,
fp16=False
):
self.model = model
Expand All @@ -130,7 +130,7 @@ def __init__(
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.device = str(device)
self.fp16 = fp16
if trt_file is not None:
from torch2trt import TRTModule
Expand All @@ -157,9 +157,20 @@ def inference(self, img, timer):
img_info["width"] = width
img_info["raw_img"] = img

img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
if self.device=='cuda':
img=[img]
processed_images = self.process_images(img, self.test_size, self.rgb_means, self.std)

if processed_images:
img = processed_images[0][0]
img_info["ratio"] = processed_images[0][1]

img = torch.from_numpy(cp.asnumpy(img)).unsqueeze(0).float().to(self.device)
else:
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
img_info["ratio"] = ratio
img = torch.from_numpy(img).unsqueeze(0).float().to(self.device)

if self.fp16:
img = img.half() # to FP16

Expand All @@ -174,6 +185,12 @@ def inference(self, img, timer):
#logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info

def process_images(self, image_list, input_size, mean, std, swap=(2, 0, 1)):
with ThreadPoolExecutor() as executor:
futures= [executor.submit(preproc_with_cupy, img, input_size, mean, std, swap) for img in image_list]
results = [future.result() for future in futures]
if results:
return results

def image_demo(predictor, vis_folder, current_time, args):
if osp.isdir(args.path):
Expand Down
50 changes: 45 additions & 5 deletions yolox/data/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@

import cv2
import numpy as np

import torch

from yolox.utils import xyxy2cxcywh

import math
import random

import cupy as cp

def augment_hsv(img, hgain=0.015, sgain=0.7, vgain=0.4):
r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
Expand Down Expand Up @@ -210,6 +206,50 @@ def preproc(image, input_size, mean, std, swap=(2, 0, 1)):
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
return padded_img, r

def preproc_with_cupy(image, input_size, mean, std, swap=(2, 0, 1)):
device = cp.cuda.Device(0)
device.use()

if len(image.shape) == 3:
padded_img = cp.ones((input_size[0], input_size[1], 3)) * 114.0
else:
padded_img = cp.ones(input_size) * 114.0

img = cp.array(image)
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])

# Hedef boyutları hesaplayalım
target_height = int(img.shape[0] * r)
target_width = int(img.shape[1] * r)

# Hedef boyutların sıfır veya negatif olmadığını kontrol edelim
if target_height <= 0 or target_width <= 0:
raise ValueError(f"Invalid target size: ({target_width}, {target_height})")

resized_img = cp.array(cv2.resize(
cp.asnumpy(img),
(target_width, target_height),
interpolation=cv2.INTER_LINEAR,
).astype(np.float32))

if len(image.shape) == 3:
padded_img[:target_height, :target_width, :] = resized_img
else:
padded_img[:target_height, :target_width] = resized_img

padded_img = padded_img[:, :, ::-1] / 255.0 # BGR to RGB and normalize

if mean is not None:
mean_array = cp.array(mean).reshape(1, 1, 3)
padded_img -= mean_array

if std is not None:
std_array = cp.array(std).reshape(1, 1, 3)
padded_img /= std_array

padded_img = padded_img.transpose(swap)
padded_img = cp.ascontiguousarray(padded_img, dtype=cp.float32)
return padded_img, r

class TrainTransform:
def __init__(self, p=0.5, rgb_means=None, std=None, max_labels=100):
Expand Down