Skip to content

Commit 082dffc

Browse files
committed
begin setting up more tests
1 parent 367562f commit 082dffc

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

server/controllers/emails.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def get_response():
442442
response = db.session.execute(
443443
select(Response).where(Response.email_id == data["id"])
444444
).scalar()
445+
445446
if not response:
446447
return {"message": "Response not found"}, 400
447448
return response.map()
@@ -524,23 +525,26 @@ def get_threads():
524525
525526
Get a list of all threads.
526527
"""
527-
thread_list = db.session.execute(select(Thread)).all()
528-
print("thread list", thread_list)
529-
# thread_list = db.session.execute(
530-
# select(Thread).order_by(Thread.resolved, Thread.last_email.desc())
531-
# ).all()
532-
# email_list = [
533-
# {
534-
# "id": thread.id,
535-
# "resolved": thread.resolved,
536-
# "emailList": [
537-
# thread_email.map()
538-
# for thread_email in db.session.execute(
539-
# select(Email).where(Email.thread_id == thread.id)
540-
# ).all()
541-
# ],
542-
# }
543-
# for thread in thread_list
544-
# ]
545-
email_list = []
528+
thread_list = (
529+
db.session.execute(
530+
select(Thread).order_by(Thread.resolved, Thread.last_email.desc())
531+
)
532+
.scalars()
533+
.all()
534+
)
535+
email_list = [
536+
{
537+
"id": thread.id,
538+
"resolved": thread.resolved,
539+
"emailList": [
540+
thread_email.map()
541+
for thread_email in db.session.execute(
542+
select(Email).where(Email.thread_id == thread.id)
543+
)
544+
.scalars()
545+
.all()
546+
],
547+
}
548+
for thread in thread_list
549+
]
546550
return email_list

server/models/response.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,25 @@ class Response(db.Model):
5454
)
5555

5656
def map(self):
57-
"""Map the response to a dictionary."""
57+
"""Map the response to a dictionary.
58+
59+
Groups documents and document_confidences into a list of lists clustered by
60+
question.
61+
"""
62+
doc_confs = []
63+
docs = []
64+
cur_idx = 0
65+
for num_docs in self.docs_per_question:
66+
doc_confs.append(self.document_confidences[cur_idx : cur_idx + num_docs])
67+
docs.append(self.documents[cur_idx : cur_idx + num_docs])
68+
cur_idx += num_docs
69+
docs = [[doc.map() for doc in doc_list] for doc_list in docs]
5870
return {
5971
"id": self.id,
6072
"content": self.response,
6173
"questions": self.questions,
62-
"document_confidences": self.document_confidences,
74+
"documents": docs,
75+
"document_confidences": doc_confs,
6376
"confidence": self.confidence,
6477
"emailId": self.email_id,
6578
}

server_tests/test_email.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from apiflask import APIFlask
2+
from flask.testing import FlaskClient
3+
4+
from server_tests.utils import assert_status
5+
6+
7+
def test_get_threads(app: APIFlask, client: FlaskClient):
8+
"""Test fetching threads."""
9+
response = client.get("/api/emails/get_threads")
10+
assert_status(response, 200)

server_tests/utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Utils for testing."""
2+
3+
import logging
4+
5+
from werkzeug.test import TestResponse
6+
7+
8+
def assert_status(response: TestResponse, status: int):
9+
"""Asserts a response's status code, logging the response if it fails."""
10+
try:
11+
assert response.status_code == status
12+
except AssertionError:
13+
logging.error(
14+
f"Assertion failed: expected {status}, got {response.status_code}. "
15+
f"Response body: {response.data.decode()}"
16+
)
17+
raise

0 commit comments

Comments
 (0)