forked from dora-rs/dora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sentence_transformers_op.py
92 lines (76 loc) · 2.72 KB
/
sentence_transformers_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from dora import DoraStatus
import os
import sys
import torch
import pyarrow as pa
SHOULD_BE_INCLUDED = [
"webcam.py",
"object_detection.py",
"plot.py",
]
## Get all python files path in given directory
def get_all_functions(path):
raw = []
paths = []
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".py"):
if file not in SHOULD_BE_INCLUDED:
continue
path = os.path.join(root, file)
with open(path, "r", encoding="utf8") as f:
## add file folder to system path
sys.path.append(root)
## import module from path
raw.append(f.read())
paths.append(path)
return raw, paths
def search(query_embedding, corpus_embeddings, paths, raw, k=5, file_extension=None):
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=min(k, len(cos_scores)), sorted=True)
out = []
for score, idx in zip(top_results[0], top_results[1]):
out.extend([raw[idx], paths[idx], score])
return out
class Operator:
""" """
def __init__(self):
## TODO: Add a initialisation step
self.model = SentenceTransformer("BAAI/bge-large-en-v1.5")
self.encoding = []
# file directory
path = os.path.dirname(os.path.abspath(__file__))
self.raw, self.path = get_all_functions(path)
# Encode all files
self.encoding = self.model.encode(self.raw)
def on_event(
self,
dora_event,
send_output,
) -> DoraStatus:
if dora_event["type"] == "INPUT":
if dora_event["id"] == "query":
values = dora_event["value"].to_pylist()
query_embeddings = self.model.encode(values)
output = search(
query_embeddings,
self.encoding,
self.path,
self.raw,
)
[raw, path, score] = output[0:3]
send_output(
"raw_file",
pa.array([{"raw": raw, "path": path, "user_message": values[0]}]),
dora_event["metadata"],
)
else:
input = dora_event["value"][0].as_py()
index = self.path.index(input["path"])
self.raw[index] = input["raw"]
self.encoding[index] = self.model.encode([input["raw"]])[0]
return DoraStatus.CONTINUE
if __name__ == "__main__":
operator = Operator()