Skip to content

Commit

Permalink
Add adetailer QNN Example (#1652)
Browse files Browse the repository at this point in the history
## Describe your changes
Model quantization and evaluation on qnn

[Bingsu/adetailer (face)](https://huggingface.co/Bingsu/adetailer) The
remaining models (e.g. hand, person, etc.) will be added later.
Use
dataset([CUHK-CSE/wider_face](https://huggingface.co/datasets/CUHK-CSE/wider_face))
to evaluate the mAP50 & mAP50-100
Use `custom` metrics type and specific `metric_func`


![image](https://github.com/user-attachments/assets/86f4909a-6a8c-4094-949e-fac4d22e8cfa)


A sample of this model:

![image](https://github.com/user-attachments/assets/e43a3c3b-ebe9-4a16-989d-377e087a4bf8)

## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [x] Update documents if necessary.
- [x] Lint and apply fixes to your code by running `lintrunner -a`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.
- [x] Is this PR including examples changes? If yes, please remember to
update [example
documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md)
in a follow-up PR.

## (Optional) Issue link
  • Loading branch information
fangyangci authored Feb 28, 2025
1 parent ba91e14 commit 34eb17a
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 17 deletions.
19 changes: 19 additions & 0 deletions examples/adetailer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## How to run
### Pip requirements
Install the necessary python packages:
```
python -m pip install -r requirements.txt
```

### Prepare models
```
python prepare_onnx_model.py
```

### Run sample using config
```
olive run --config ./face_yolo_qnn.json
```

**Note**: The special configuration of op_types_to_quantize in the face_yolo_qnn.json file is to exclude the mul operation. This is because after quantizing the mul operation, the latency of this model on the QNN will increase significantly.

126 changes: 126 additions & 0 deletions examples/adetailer/face_yolo_qnn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"input_model": { "type": "ONNXModel", "model_path": "models/face/face_yolov9c.onnx" },
"systems": {
"qnn_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "npu", "execution_providers": [ "QNNExecutionProvider" ] } ]
}
},
"data_configs": [
{
"name": "face_data_config",
"type": "HuggingfaceContainer",
"user_script": "user_script.py",
"load_dataset_config": {
"data_name": "CUHK-CSE/wider_face",
"split": "validation",
"streaming": true,
"trust_remote_code": true
},
"pre_process_data_config": { "type": "face_pre_process", "size": 128, "cache_key": "wider_face" },
"dataloader_config": { "type": "no_auto_batch_dataloader" }
}
],
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "accuracy_qnn",
"type": "custom",
"data_config": "face_data_config",
"sub_types": [
{ "name": "map 50", "priority": 1, "higher_is_better": true },
{ "name": "map 50-95", "priority": 2, "higher_is_better": true }
],
"user_config": {
"user_script": "user_script.py",
"metric_func": "face_metric",
"inference_settings": {
"onnx": {
"session_options": {
"extra_session_config": { "session.disable_cpu_ep_fallback": "1" }
},
"execution_provider": "QNNExecutionProvider",
"provider_options": [ { "backend_path": "QnnHtp.dll" } ]
}
}
}
},
{
"name": "accuracy_cpu",
"type": "custom",
"data_config": "face_data_config",
"sub_types": [
{ "name": "map 50", "priority": 3, "higher_is_better": true },
{ "name": "map 50-95", "priority": 4, "higher_is_better": true }
],
"user_config": {
"user_script": "user_script.py",
"metric_func": "face_metric",
"inference_settings": { "onnx": { "execution_provider": "CPUExecutionProvider" } }
}
},
{
"name": "latency_qnn",
"type": "latency",
"data_config": "face_data_config",
"sub_types": [ { "name": "avg", "priority": 5 } ],
"user_config": {
"inference_settings": {
"onnx": {
"session_options": {
"extra_session_config": { "session.disable_cpu_ep_fallback": "1" }
},
"execution_provider": "QNNExecutionProvider",
"provider_options": [ { "backend_path": "QnnHtp.dll" } ]
}
}
}
},
{
"name": "latency_cpu",
"type": "latency",
"data_config": "face_data_config",
"sub_types": [ { "name": "avg", "priority": 6 } ],
"user_config": {
"inference_settings": { "onnx": { "execution_provider": "CPUExecutionProvider" } }
}
}
]
}
},
"passes": {
"QNNPreprocess": { "type": "QNNPreprocess" },
"OnnxQuantization": {
"type": "OnnxStaticQuantization",
"quant_preprocess": true,
"data_config": "face_data_config",
"activation_type": "QUInt16",
"weight_type": "QUInt8",
"calibrate_method": "MinMax",
"op_types_to_quantize": [
"Reshape",
"Transpose",
"Softmax",
"Add",
"Split",
"AveragePool",
"Div",
"Conv",
"Sigmoid",
"Slice",
"MaxPool",
"Sub",
"Concat",
"Resize"
]
}
},
"host": "qnn_system",
"target": "qnn_system",
"evaluator": "common_evaluator",
"cache_dir": "cache",
"clean_cache": true,
"output_dir": "models/face/output",
"evaluate_input_model": true
}
25 changes: 25 additions & 0 deletions examples/adetailer/prepare_onnx_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path

import torch
from huggingface_hub import hf_hub_download
from ultralytics import YOLO


def download(model_name: str):
models_dir = Path("./models", model_name.split("_")[0])
models_dir.mkdir(parents=True, exist_ok=True)
hf_hub_download("Bingsu/adetailer", f"{model_name}.pt", local_dir=f"./{models_dir}/")
yolo_model = YOLO(f"{models_dir}/{model_name}.pt")
torch_model = yolo_model.model
torch.save(torch_model, f"{models_dir}/{model_name}_pytorch.pt")
yolo_model.export(format="onnx")


download("face_yolov9c")
download("hand_yolov9c")
download("person_yolov8m-seg")
download("deepfashion2_yolov8s-seg")
2 changes: 2 additions & 0 deletions examples/adetailer/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pycocotools
ultralytics
156 changes: 156 additions & 0 deletions examples/adetailer/user_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from pathlib import Path

import numpy as np
import torch
import torchvision.ops as ops
from torch.utils.data import Dataset
from torchmetrics.detection.mean_ap import MeanAveragePrecision

from olive.data.registry import Registry

logger = getLogger(__name__)

# The number of boxes in the labels is not fixed.
# If they are directly used as the return value of the FaceDataset,
# an error will occur when performing torch.cat(targets, dim=0) later.
# So, this cache is used as a workaround.
# pylint: disable=global-statement
_curlabels_np = None


class FaceDataset(Dataset):
def __init__(self, data):
global _curlabels_np
_curlabels_np = data["labels"]
self.images_np = data["images"]

def __len__(self):
return min(len(self.images_np), len(_curlabels_np))

def __getitem__(self, idx):
input_img = self.images_np[idx]
input_img = np.transpose(input_img, (2, 0, 1))
input_img = np.expand_dims(input_img, axis=0).astype(np.float32) / 255.0
return {"images": input_img}, torch.tensor([idx], dtype=torch.int32)


def face_get_boxes(output):
confidence_threshold = 0.1
boxes = []
scores = []

for i in range(output.shape[1]):
confidence = output[4, i]
if confidence > confidence_threshold:
x_center = output[0, i]
y_center = output[1, i]
width = output[2, i]
height = output[3, i]

x1 = int(x_center - width / 2)
y1 = int(y_center - height / 2)
x2 = int(x_center + width / 2)
y2 = int(y_center + height / 2)
boxes.append([x1, y1, x2, y2])
scores.append(confidence.item())

if len(boxes) == 0:
return boxes, scores

boxes = torch.tensor(boxes, dtype=torch.float32)
scores = torch.tensor(scores, dtype=torch.float32)

nms_threshold = 0.4
keep_indices = ops.nms(boxes, scores, nms_threshold)
keep_indices = keep_indices.tolist()

keep_boxes = []
keep_scores = []
for i in keep_indices:
x1, y1, x2, y2 = boxes[i].int().tolist()
keep_boxes.append([x1, y1, x2, y2])
keep_scores.append(scores[i])

return keep_boxes, keep_scores


@Registry.register_pre_process()
def face_pre_process(validation_dataset, **kwargs):
cache_key = kwargs.get("cache_key")
size = kwargs.get("size", 256)
cache_file = None
if cache_key:
cache_file = Path(f"./cache/data/{cache_key}_{size}.npz")
if cache_file.exists():
with np.load(Path(cache_file), allow_pickle=True) as data:
return FaceDataset(data)

images = []
labels = []

target_size = (640, 640)

for i, sample in enumerate(validation_dataset):
if i == size:
break
saved_img = sample["image"]
original_width, original_height = saved_img.size
saved_img = saved_img.resize(target_size)
img_array = np.array(saved_img)
images.append(img_array)

bbox_list = sample["faces"]["bbox"]
scaled_bbox_list = []
width_scale = target_size[0] / original_width
height_scale = target_size[1] / original_height
for bbox in bbox_list:
x, y, w, h = bbox
scaled_x = x * width_scale
scaled_y = y * height_scale
scaled_w = w * width_scale
scaled_h = h * height_scale
scaled_bbox_list.append([scaled_x, scaled_y, scaled_x + scaled_w, scaled_y + scaled_h])
labels.append(scaled_bbox_list)

images_np = np.array(images)
labels_np = np.array(labels, dtype=object)
result_data = {"images": images_np, "labels": labels_np}

if cache_file:
cache_file.parent.resolve().mkdir(parents=True, exist_ok=True)
np.savez(cache_file, **result_data)

return FaceDataset(result_data)


def face_metric(model_output, targets):
prediction_data = []
target_data = []

for i, target in enumerate(targets):
keep_boxes, keep_scores = face_get_boxes(model_output[0][i])
target_boxes = _curlabels_np[target]
prediction_data.append(
{
"boxes": torch.tensor(keep_boxes, dtype=torch.float32),
"scores": torch.tensor(keep_scores, dtype=torch.float32),
"labels": torch.zeros(len(keep_boxes), dtype=torch.int64),
}
)
target_data.append(
{
"boxes": torch.tensor(target_boxes, dtype=torch.float32),
"labels": torch.zeros(len(target_boxes), dtype=torch.int64),
}
)

iou_thresholds = torch.arange(0.5, 1, 0.05).tolist()
metric = MeanAveragePrecision(iou_thresholds=iou_thresholds)
metric.update(prediction_data, target_data)
result = metric.compute()
return {"map 50-95": result["map"], "map 50": result["map_50"]}
4 changes: 2 additions & 2 deletions examples/resnet/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def imagenet_post_fun(output):
@Registry.register_pre_process()
def dataset_pre_process(output_data, **kwargs):
cache_key = kwargs.get("cache_key")
size = kwargs.get("size", 256)
cache_file = None
if cache_key:
cache_file = Path(f"./cache/data/{cache_key}.npz")
cache_file = Path(f"./cache/data/{cache_key}_{size}.npz")
if cache_file.exists():
with np.load(Path(cache_file)) as data:
return ImagenetDataset(data)

size = kwargs.get("size", 256)
labels = []
images = []
for i, sample in enumerate(output_data):
Expand Down
Loading

0 comments on commit 34eb17a

Please sign in to comment.