Skip to content

Commit

Permalink
Add filters to the assignments API
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Jul 24, 2024
1 parent e0e3c0c commit 028562a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
10 changes: 5 additions & 5 deletions lms/services/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ def is_member(self, assignment: Assignment, h_userid: str) -> bool:
def get_assignments(
self,
instructor_h_userid: str | None = None,
course_id: int | None = None,
course_ids: list[int] | None = None,
h_userids: list[str] | None = None,
) -> Select[tuple[Assignment]]:
"""Get a query to fetch assignments.
:param instructor_h_userid: return only assignments where instructor_h_userid is an instructor.
:param course_id: only return assignments that belong to this course.
:param course_ids: only return assignments that belong to this course.
:param h_userids: return only assignments where these users are members.
"""

Expand All @@ -247,15 +247,15 @@ def get_assignments(
.where(User.h_userid.in_(h_userids))
)

if course_id:
if course_ids:
deduplicated_course_assignments = (
self._deduplicated_course_assigments_query([course_id]).subquery()
self._deduplicated_course_assigments_query(course_ids).subquery()
)

query = query.where(
# Get only assignment from the candidates above
Assignment.id == deduplicated_course_assignments.c.assignment_id,
deduplicated_course_assignments.c.grouping_id == course_id,
deduplicated_course_assignments.c.grouping_id.in_(course_ids),
)

return query.order_by(Assignment.title, Assignment.id).distinct()
Expand Down
15 changes: 11 additions & 4 deletions lms/views/dashboard/api/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
class ListAssignmentsSchema(PaginationParametersMixin):
"""Query parameters to fetch a list of assignments."""

course_id = fields.Integer(required=False, validate=validate.Range(min=1))
"""Return assignments that belong to the course with this ID."""
course_ids = fields.List(
fields.Integer(validate=validate.Range(min=1)), data_key="course_id"
)
"""Return users that belong to these course IDs."""

h_userids = fields.List(fields.Str(), data_key="h_userid")
"""Return metrics for these users only."""


class AssignmentsMetricsSchema(PyramidRequestSchema):
Expand Down Expand Up @@ -51,11 +56,13 @@ def __init__(self, request) -> None:
schema=ListAssignmentsSchema,
)
def assignments(self) -> APIAssignments:
filter_by_h_userids = self.request.parsed_params.get("h_userids")
assignments = self.assignment_service.get_assignments(
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
course_id=self.request.parsed_params.get("course_id"),
course_ids=self.request.parsed_params.get("course_ids"),
h_userids=filter_by_h_userids,
)
assignments, pagination = get_page(
self.request, assignments, [Assignment.title, Assignment.id]
Expand Down Expand Up @@ -104,7 +111,7 @@ def course_assignments_metrics(self) -> APIAssignments:
)
).all()
assignments_query = self.assignment_service.get_assignments(
course_id=course.id,
course_ids=[course.id],
instructor_h_userid=current_h_userid,
h_userids=filter_by_h_userids,
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/lms/services/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_is_member(self, svc, db_session):
assert not svc.is_member(assignment, other_user.h_userid)

@pytest.mark.parametrize("instructor_h_userid", [True, False])
@pytest.mark.parametrize("course_id", [True, False])
@pytest.mark.parametrize("course_ids", [True, False])
@pytest.mark.parametrize("h_userids", [True, False])
def test_get_assignments(
self,
Expand All @@ -249,7 +249,7 @@ def test_get_assignments(
instructor_h_userid,
assignment,
with_assignment_noise,
course_id,
course_ids,
h_userids,
):
factories.User()
Expand All @@ -270,8 +270,8 @@ def test_get_assignments(
if instructor_h_userid:
query_parameters["instructor_h_userid"] = user.h_userid

if course_id:
query_parameters["course_id"] = course.id
if course_ids:
query_parameters["course_ids"] = [course.id]

if h_userids:
query_parameters["h_userids"] = [user.h_userid]
Expand Down Expand Up @@ -301,12 +301,12 @@ def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc):
)
db_session.flush()

assert db_session.scalars(svc.get_assignments(course_id=course.id)).all() == [
assignment
]
assert db_session.scalars(
svc.get_assignments(course_ids=[course.id])
).all() == [assignment]
# We don't expect to get the other one at all, now the assignment belongs to the most recent course
assert not db_session.scalars(
svc.get_assignments(course_id=other_course.id)
svc.get_assignments(course_ids=[other_course.id])
).all()

def test_get_courses_assignments_count(self, svc, db_session):
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/lms/views/dashboard/api/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ class TestAssignmentViews:
def test_get_assignments(
self, assignment_service, pyramid_request, views, get_page
):
pyramid_request.parsed_params = {"course_id": sentinel.course_id}
pyramid_request.parsed_params = {
"course_ids": sentinel.course_ids,
"h_userids": sentinel.h_userids,
}
assignments = factories.Assignment.create_batch(5)
get_page.return_value = assignments, sentinel.pagination

response = views.assignments()

assignment_service.get_assignments.assert_called_once_with(
instructor_h_userid=pyramid_request.user.h_userid,
course_id=sentinel.course_id,
course_ids=sentinel.course_ids,
h_userids=sentinel.h_userids,
)
get_page.assert_called_once_with(
pyramid_request,
Expand Down

0 comments on commit 028562a

Please sign in to comment.