Skip to content

Commit acb8e55

Browse files
committed
update test db url and redis client
1 parent 2e77577 commit acb8e55

File tree

3 files changed

+26
-19
lines changed

3 files changed

+26
-19
lines changed

server/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Type, cast
44

55
import numpy
6+
import redis
67
from apiflask import APIFlask
78
from flask import redirect, render_template
89
from flask_cors import CORS
@@ -59,6 +60,8 @@ class ProperlyTypedSQLAlchemy(SQLAlchemy):
5960
db = SQLAlchemy(model_class=Base)
6061
db = cast(ProperlyTypedSQLAlchemy, db)
6162

63+
redis_client: redis.Redis | None = None
64+
6265

6366
def create_app():
6467
"""Create the Flask app."""
@@ -78,6 +81,11 @@ def create_app():
7881
with app.app_context():
7982
db.init_app(app)
8083

84+
global redis_client
85+
redis_client = redis.Redis(
86+
host=app.config["REDIS_HOST"], port=6379, decode_responses=True
87+
)
88+
8189
allowed_domains = app.config.get("ALLOWED_DOMAINS")
8290

8391
cors.init_app(

server/nlp/embeddings.py

+14-15
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import numpy as np
1010
import openai
11-
import redis
1211
from redis.commands.search.field import (
1312
NumericField,
1413
TextField,
@@ -17,15 +16,15 @@
1716
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
1817
from redis.commands.search.query import Query
1918

20-
from server.config import REDIS_HOST, RedisDocument
19+
from server import redis_client
20+
from server.config import RedisDocument
21+
22+
assert redis_client is not None
2123

2224
cwd = os.path.dirname(__file__)
2325

2426
VECTOR_DIMENSION = 1536
2527

26-
# load redis client
27-
client = redis.Redis(host=REDIS_HOST, port=6379, decode_responses=True)
28-
2928
# load corpus
3029
# with open('corpus.json', 'r') as f:
3130
# corpus = json.load(f)
@@ -46,7 +45,7 @@ def load_corpus(corpus: list[RedisDocument]):
4645
"""
4746
print("loading corpus...")
4847

49-
pipeline = client.pipeline()
48+
pipeline = redis_client.pipeline()
5049
for i, doc in enumerate(corpus, start=1):
5150
redis_key = f"documents:{i:03}"
5251
pipeline.json().set(redis_key, "$", doc)
@@ -81,9 +80,9 @@ def compute_embeddings():
8180
print("computing embeddings...")
8281

8382
# get keys, questions, content
84-
keys = sorted(client.keys("documents:*")) # type: ignore
85-
questions = client.json().mget(keys, "$.question")
86-
content = client.json().mget(keys, "$.content")
83+
keys = sorted(redis_client.keys("documents:*")) # type: ignore
84+
questions = redis_client.json().mget(keys, "$.question")
85+
content = redis_client.json().mget(keys, "$.content")
8786

8887
# compute embeddings
8988
question_and_content = [
@@ -110,7 +109,7 @@ def load_embeddings(embeddings: list[list[float]]):
110109
print("loading embeddings into redis...")
111110

112111
# load embeddings into redis
113-
pipeline = client.pipeline()
112+
pipeline = redis_client.pipeline()
114113
for i, embedding in enumerate(embeddings, start=1):
115114
redis_key = f"documents:{i:03}"
116115
pipeline.json().set(redis_key, "$.question_and_content_embeddings", embedding)
@@ -153,17 +152,17 @@ def create_index(corpus_len: int):
153152
),
154153
)
155154
definition = IndexDefinition(prefix=["documents:"], index_type=IndexType.JSON)
156-
res = client.ft("idx:documents_vss").create_index(
155+
res = redis_client.ft("idx:documents_vss").create_index(
157156
fields=schema, definition=definition
158157
)
159158

160159
if res == "OK":
161160
start = time.time()
162161
while 1:
163-
if str(client.ft("idx:documents_vss").info()["num_docs"]) == str(
162+
if str(redis_client.ft("idx:documents_vss").info()["num_docs"]) == str(
164163
corpus_len
165164
):
166-
info = client.ft("idx:documents_vss").info()
165+
info = redis_client.ft("idx:documents_vss").info()
167166
num_docs = info["num_docs"]
168167
indexing_failures = info["hash_indexing_failures"]
169168
print("num_docs", num_docs, "indexing_failures", indexing_failures)
@@ -209,7 +208,7 @@ def queries(query, queries: list[str]) -> list[dict]:
209208
results_list = []
210209
for i, encoded_query in enumerate(encoded_queries):
211210
result_docs = (
212-
client.ft("idx:documents_vss")
211+
redis_client.ft("idx:documents_vss")
213212
.search(
214213
query,
215214
{"query_vector": np.array(encoded_query, dtype=np.float32).tobytes()},
@@ -259,7 +258,7 @@ def embed_corpus(corpus: list[RedisDocument]):
259258
"""
260259
# flush database
261260
print("cleaning database...")
262-
client.flushdb()
261+
redis_client.flushdb()
263262
print("done cleaning database")
264263

265264
# embed corpus

server_tests/conftest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ def db_url(db_name="pigeondb_test"):
4040

4141
conn.close()
4242

43-
yield "postgresql://postgres:password@database/pigeondb_test"
43+
yield f"postgresql://postgres:password@{host}/{db_name}"
4444

4545

4646
@pytest.fixture(scope="session")
47-
def redis_db_index():
47+
def redis_host():
4848
"""Yields test redis db host.
4949
5050
Flushes test db if it already exists.
@@ -58,9 +58,9 @@ def redis_db_index():
5858

5959

6060
@pytest.fixture(scope="session")
61-
def app(db_url: str, redis_db_index: str):
61+
def app(db_url: str, redis_host: str):
6262
os.environ["DATABASE_URL"] = db_url
63-
os.environ["REDIS_HOST"] = redis_db_index
63+
os.environ["REDIS_HOST"] = redis_host
6464

6565
app = create_app()
6666
app.config.update(

0 commit comments

Comments
 (0)