generated from supervisely-ecosystem/appV2-templates
-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
90 lines (72 loc) · 2.56 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from typing import Dict, List
import supervisely as sly
import helpers
my_model = None
def get_classes_and_tags() -> sly.ProjectMeta:
classes = sly.ObjClassCollection([
# Example
sly.ObjClass("person", sly.Rectangle),
sly.ObjClass("car", sly.Rectangle),
sly.ObjClass("bus", sly.Rectangle)
# Put any needed classes here ....
])
tags = sly.TagMetaCollection([
# Example: confidence tag for bounding boxes with number value
sly.TagMeta("confidence", sly.TagValueType.ANY_NUMBER)
])
return sly.ProjectMeta(obj_classes=classes, tag_metas=tags)
def get_session_info() -> Dict:
return {
# Recommended info values
"app": "Serve Custom Detection Model Template",
"model_name": "Put your model name",
"device": "cpu",
"classes_count": 3,
"tags_count": 1,
"sliding_window_support": False
# Put any key-value that you want to ....
}
def inference(image_path: str) -> List[Dict]:
image = sly.image.read(path=image_path) # shape: [H, W, 3], RGB
#########################
# INSERT YOUR CODE HERE #
#########################
# predictions = my_model(image)
# example (remove it when you'll use your own predictions)
predictions = [
{
"bbox": [50, 100, 77, 145], # [top, left, bottom, right]
"class": "person", # class name like in get_classes_and_tags() function
"confidence": 0.88 # optional
}
]
return predictions
def deploy_model(model_weights_path: str) -> None:
global my_model
#########################
# INSERT YOUR CODE HERE #
#########################
# example:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# my_model = init_model(weigths=model_weights_path, device=device)
my_model = None
def main():
if "TASK_ID" not in os.environ:
# Used for local debug
model_weights_path = "/my-folder/my_weights.pth"
input_image_path = "/my-folder/my_image.png"
result_image_path = "/my-folder/result_image.png"
deploy_model(model_weights_path)
predictions = inference(input_image_path)
helpers.draw_demo_result(predictions, input_image_path, result_image_path)
else:
# Used for production
helpers.serve_detection(
get_session_info,
get_classes_and_tags,
inference,
deploy_model
)
if __name__ == "__main__":
sly.main_wrapper("main", main)