8
8
9
9
import numpy as np
10
10
import openai
11
- import redis
12
11
from redis .commands .search .field import (
13
12
NumericField ,
14
13
TextField ,
17
16
from redis .commands .search .indexDefinition import IndexDefinition , IndexType
18
17
from redis .commands .search .query import Query
19
18
20
- from server .config import REDIS_HOST , RedisDocument
19
+ from server import redis_client
20
+ from server .config import RedisDocument
21
+
22
+ assert redis_client is not None
21
23
22
24
cwd = os .path .dirname (__file__ )
23
25
24
26
VECTOR_DIMENSION = 1536
25
27
26
- # load redis client
27
- client = redis .Redis (host = REDIS_HOST , port = 6379 , decode_responses = True )
28
-
29
28
# load corpus
30
29
# with open('corpus.json', 'r') as f:
31
30
# corpus = json.load(f)
@@ -46,7 +45,7 @@ def load_corpus(corpus: list[RedisDocument]):
46
45
"""
47
46
print ("loading corpus..." )
48
47
49
- pipeline = client .pipeline ()
48
+ pipeline = redis_client .pipeline ()
50
49
for i , doc in enumerate (corpus , start = 1 ):
51
50
redis_key = f"documents:{ i :03} "
52
51
pipeline .json ().set (redis_key , "$" , doc )
@@ -81,9 +80,9 @@ def compute_embeddings():
81
80
print ("computing embeddings..." )
82
81
83
82
# get keys, questions, content
84
- keys = sorted (client .keys ("documents:*" )) # type: ignore
85
- questions = client .json ().mget (keys , "$.question" )
86
- content = client .json ().mget (keys , "$.content" )
83
+ keys = sorted (redis_client .keys ("documents:*" )) # type: ignore
84
+ questions = redis_client .json ().mget (keys , "$.question" )
85
+ content = redis_client .json ().mget (keys , "$.content" )
87
86
88
87
# compute embeddings
89
88
question_and_content = [
@@ -110,7 +109,7 @@ def load_embeddings(embeddings: list[list[float]]):
110
109
print ("loading embeddings into redis..." )
111
110
112
111
# load embeddings into redis
113
- pipeline = client .pipeline ()
112
+ pipeline = redis_client .pipeline ()
114
113
for i , embedding in enumerate (embeddings , start = 1 ):
115
114
redis_key = f"documents:{ i :03} "
116
115
pipeline .json ().set (redis_key , "$.question_and_content_embeddings" , embedding )
@@ -153,17 +152,17 @@ def create_index(corpus_len: int):
153
152
),
154
153
)
155
154
definition = IndexDefinition (prefix = ["documents:" ], index_type = IndexType .JSON )
156
- res = client .ft ("idx:documents_vss" ).create_index (
155
+ res = redis_client .ft ("idx:documents_vss" ).create_index (
157
156
fields = schema , definition = definition
158
157
)
159
158
160
159
if res == "OK" :
161
160
start = time .time ()
162
161
while 1 :
163
- if str (client .ft ("idx:documents_vss" ).info ()["num_docs" ]) == str (
162
+ if str (redis_client .ft ("idx:documents_vss" ).info ()["num_docs" ]) == str (
164
163
corpus_len
165
164
):
166
- info = client .ft ("idx:documents_vss" ).info ()
165
+ info = redis_client .ft ("idx:documents_vss" ).info ()
167
166
num_docs = info ["num_docs" ]
168
167
indexing_failures = info ["hash_indexing_failures" ]
169
168
print ("num_docs" , num_docs , "indexing_failures" , indexing_failures )
@@ -209,7 +208,7 @@ def queries(query, queries: list[str]) -> list[dict]:
209
208
results_list = []
210
209
for i , encoded_query in enumerate (encoded_queries ):
211
210
result_docs = (
212
- client .ft ("idx:documents_vss" )
211
+ redis_client .ft ("idx:documents_vss" )
213
212
.search (
214
213
query ,
215
214
{"query_vector" : np .array (encoded_query , dtype = np .float32 ).tobytes ()},
@@ -259,7 +258,7 @@ def embed_corpus(corpus: list[RedisDocument]):
259
258
"""
260
259
# flush database
261
260
print ("cleaning database..." )
262
- client .flushdb ()
261
+ redis_client .flushdb ()
263
262
print ("done cleaning database" )
264
263
265
264
# embed corpus
0 commit comments