Skip to content

Commit

Permalink
fix demo script
Browse files Browse the repository at this point in the history
  • Loading branch information
max-unfinity committed Dec 13, 2024
1 parent 217aa54 commit f658902
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions supervisely_integration/demo/demo_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from PIL import Image, ImageDraw
import torch
import torchvision.transforms as T
Expand All @@ -12,7 +13,7 @@
image_path = "img/coco_sample.jpg"


def draw(images, labels, boxes, scores, thrh = 0.6):
def draw(images, labels, boxes, scores, classes, thrh = 0.5):
for i, im in enumerate(images):
draw = ImageDraw.Draw(im)
scr = scores[i]
Expand All @@ -21,20 +22,22 @@ def draw(images, labels, boxes, scores, thrh = 0.6):
scrs = scores[i][scr > thrh]
for j,b in enumerate(box):
draw.rectangle(list(b), outline='red',)
draw.text((b[0], b[1]), text=f"{lab[j].item()} {round(scrs[j].item(),2)}", fill='blue', )
draw.text((b[0], b[1]), text=f"{classes[lab[j].item()]} {round(scrs[j].item(),2)}", fill='blue', )


if __name__ == "__main__":

# load checkpoint
with open(model_meta_path, "r") as f:
model_meta = json.load(f)
cfg = YAMLConfig(config_path, resume=checkpoint_path)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state = checkpoint["ema"]["module"] if "ema" in checkpoint else checkpoint["model"]
model = cfg.model
model.load_state_dict(state)
model.deploy().to(device)
postprocessor = cfg.postprocessor.deploy().to(device)
classes = ...
classes = [c["title"] for c in model_meta["classes"]]
h, w = 640, 640
transforms = T.Compose([
T.Resize((h, w)),
Expand All @@ -48,9 +51,9 @@ def draw(images, labels, boxes, scores, thrh = 0.6):
im_data = transforms(im_pil)[None].to(device)

# inference
output = model(im_data, orig_size)
labels, boxes, scores = output
output = model(im_data)
labels, boxes, scores = postprocessor(output, orig_size)

# save result
draw([im_pil], labels, boxes, scores)
draw([im_pil], labels, boxes, scores, classes)
im_pil.save("result.jpg")

0 comments on commit f658902

Please sign in to comment.