-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
137 lines (117 loc) · 4.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import cv2
from tqdm import tqdm
import logging
import os
import argparse
from embedder import CLIPEmbedder, Embedder
from ultralytics import YOLO
from PIL import Image
from embeddings_store import EmbeddingsStore
from recognizer import Recognition
PROJECT_ID = "vertex-ai-playground-402513"
def get_image_paths() -> list[str]:
image_paths = []
# change the data path to gcs and works the same
data_path = "./products"
product_ids = os.listdir(data_path)
for product in tqdm(product_ids):
images = os.listdir(os.path.join(data_path, product))
for image in images:
image_paths.append(os.path.join(data_path, product, image))
return image_paths
def search_example(use_clip: bool = False):
img = cv2.imread("./data/salatka_example_2.jpeg")
e = Embedder(PROJECT_ID) if not use_clip else CLIPEmbedder()
es = EmbeddingsStore()
img_embeddings = e.embed(img)
res = es.search(img_embeddings)
logging.info([(i.product_id, i.similarity) for i in res])
def scene_frame_example(use_clip: bool = False, chroma_dump_path: str = None):
yolo = YOLO("./retail-yolo.pt")
img = cv2.imread("./data/IMG_0504.jpg")
e = Embedder(PROJECT_ID) if not use_clip else CLIPEmbedder()
es = EmbeddingsStore(chroma_dump_path)
res = yolo(img)
boxes = res[0].boxes.xyxy
import matplotlib.pyplot as plt
# plt.imshow(res[0].plot())
# plt.show()
# plt.clf()
for detection_id in range(len(boxes)):
x1, y1, x2, y2 = boxes[detection_id]
product = img[int(y1):int(y2), int(x1):int(x2)]
embeddings = e.embed(product)
results = es.search(embeddings)
if not results:
logging.warning("No results found for %s", detection_id)
continue
logging.info(results)
_, axs = plt.subplots(1, len(results) + 1)
axs[0].imshow(cv2.cvtColor(product, cv2.COLOR_BGR2RGB))
axs[0].set_title("Product")
i = 1
for search_res in results:
# hacky, product ID is also path
predicted_product = cv2.imread(search_res.product_id)
# show the predicted image vs product with mpl
axs[i].imshow(cv2.cvtColor(predicted_product, cv2.COLOR_BGR2RGB))
axs[i].set_title(f"Predicted: {search_res.similarity} sim")
i += 1
plt.show()
def batch_ingest(use_clip: bool = False):
embedder = Embedder(PROJECT_ID) if not use_clip else CLIPEmbedder()
embeddings_store = EmbeddingsStore()
image_paths = get_image_paths()
logging.info(f"embedding and ingesting total of {len(image_paths)} images")
oks = 0
# spawn multiple threads to speed up the process
# i am happy to wait 2 hours, weather is nice
for image_path in tqdm(image_paths):
if embeddings_store.exists(image_path):
logging.info(f"Skipping {image_path}, already exists")
continue
embeddings = embedder.embed_path(image_path)
ok = embeddings_store.ingest(
image_path,
embeddings,
{"product_name": image_path},
)
oks += 1
if not ok:
logging.error(f"Failed to ingest {image_path}")
logging.info(f"Batch ingest complete, got {oks} images ingested")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-ingest",
action="store_true",
help="Run batch ingest",
)
parser.add_argument(
"--use-clip",
action="store_true",
help="Use clip model instead of multimodal",
)
parser.add_argument(
"--search-example",
action="store_true",
help="Run search example",
)
parser.add_argument(
"--scene-frame-example",
action="store_true",
help="Run scene frame example",
)
parser.add_argument(
"--chroma-dump-path",
type=str,
help="Path to chroma dump",
)
args = parser.parse_args()
if args.batch_ingest:
batch_ingest(args.use_clip)
if args.search_example:
search_example(args.use_clip)
if args.scene_frame_example:
scene_frame_example(args.use_clip)