From 028562a97a796a478fdea7f3fbe2836c1ab36f2f Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Wed, 24 Jul 2024 11:08:59 +0200 Subject: [PATCH] Add filters to the assignments API --- lms/services/assignment.py | 10 +++++----- lms/views/dashboard/api/assignment.py | 15 +++++++++++---- tests/unit/lms/services/assignment_test.py | 16 ++++++++-------- .../lms/views/dashboard/api/assignment_test.py | 8 ++++++-- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/lms/services/assignment.py b/lms/services/assignment.py index 7d2482268d..b0f40b3a5b 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -214,13 +214,13 @@ def is_member(self, assignment: Assignment, h_userid: str) -> bool: def get_assignments( self, instructor_h_userid: str | None = None, - course_id: int | None = None, + course_ids: list[int] | None = None, h_userids: list[str] | None = None, ) -> Select[tuple[Assignment]]: """Get a query to fetch assignments. :param instructor_h_userid: return only assignments where instructor_h_userid is an instructor. - :param course_id: only return assignments that belong to this course. + :param course_ids: only return assignments that belong to this course. :param h_userids: return only assignments where these users are members. """ @@ -247,15 +247,15 @@ def get_assignments( .where(User.h_userid.in_(h_userids)) ) - if course_id: + if course_ids: deduplicated_course_assignments = ( - self._deduplicated_course_assigments_query([course_id]).subquery() + self._deduplicated_course_assigments_query(course_ids).subquery() ) query = query.where( # Get only assignment from the candidates above Assignment.id == deduplicated_course_assignments.c.assignment_id, - deduplicated_course_assignments.c.grouping_id == course_id, + deduplicated_course_assignments.c.grouping_id.in_(course_ids), ) return query.order_by(Assignment.title, Assignment.id).distinct() diff --git a/lms/views/dashboard/api/assignment.py b/lms/views/dashboard/api/assignment.py index 991182d18a..7e3f12f7d5 100644 --- a/lms/views/dashboard/api/assignment.py +++ b/lms/views/dashboard/api/assignment.py @@ -18,8 +18,13 @@ class ListAssignmentsSchema(PaginationParametersMixin): """Query parameters to fetch a list of assignments.""" - course_id = fields.Integer(required=False, validate=validate.Range(min=1)) - """Return assignments that belong to the course with this ID.""" + course_ids = fields.List( + fields.Integer(validate=validate.Range(min=1)), data_key="course_id" + ) + """Return users that belong to these course IDs.""" + + h_userids = fields.List(fields.Str(), data_key="h_userid") + """Return metrics for these users only.""" class AssignmentsMetricsSchema(PyramidRequestSchema): @@ -51,11 +56,13 @@ def __init__(self, request) -> None: schema=ListAssignmentsSchema, ) def assignments(self) -> APIAssignments: + filter_by_h_userids = self.request.parsed_params.get("h_userids") assignments = self.assignment_service.get_assignments( instructor_h_userid=self.request.user.h_userid if self.request.user else None, - course_id=self.request.parsed_params.get("course_id"), + course_ids=self.request.parsed_params.get("course_ids"), + h_userids=filter_by_h_userids, ) assignments, pagination = get_page( self.request, assignments, [Assignment.title, Assignment.id] @@ -104,7 +111,7 @@ def course_assignments_metrics(self) -> APIAssignments: ) ).all() assignments_query = self.assignment_service.get_assignments( - course_id=course.id, + course_ids=[course.id], instructor_h_userid=current_h_userid, h_userids=filter_by_h_userids, ) diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index 5561383c79..f7ddb6d340 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -240,7 +240,7 @@ def test_is_member(self, svc, db_session): assert not svc.is_member(assignment, other_user.h_userid) @pytest.mark.parametrize("instructor_h_userid", [True, False]) - @pytest.mark.parametrize("course_id", [True, False]) + @pytest.mark.parametrize("course_ids", [True, False]) @pytest.mark.parametrize("h_userids", [True, False]) def test_get_assignments( self, @@ -249,7 +249,7 @@ def test_get_assignments( instructor_h_userid, assignment, with_assignment_noise, - course_id, + course_ids, h_userids, ): factories.User() @@ -270,8 +270,8 @@ def test_get_assignments( if instructor_h_userid: query_parameters["instructor_h_userid"] = user.h_userid - if course_id: - query_parameters["course_id"] = course.id + if course_ids: + query_parameters["course_ids"] = [course.id] if h_userids: query_parameters["h_userids"] = [user.h_userid] @@ -301,12 +301,12 @@ def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc): ) db_session.flush() - assert db_session.scalars(svc.get_assignments(course_id=course.id)).all() == [ - assignment - ] + assert db_session.scalars( + svc.get_assignments(course_ids=[course.id]) + ).all() == [assignment] # We don't expect to get the other one at all, now the assignment belongs to the most recent course assert not db_session.scalars( - svc.get_assignments(course_id=other_course.id) + svc.get_assignments(course_ids=[other_course.id]) ).all() def test_get_courses_assignments_count(self, svc, db_session): diff --git a/tests/unit/lms/views/dashboard/api/assignment_test.py b/tests/unit/lms/views/dashboard/api/assignment_test.py index a146fe8704..770f0afca7 100644 --- a/tests/unit/lms/views/dashboard/api/assignment_test.py +++ b/tests/unit/lms/views/dashboard/api/assignment_test.py @@ -16,7 +16,10 @@ class TestAssignmentViews: def test_get_assignments( self, assignment_service, pyramid_request, views, get_page ): - pyramid_request.parsed_params = {"course_id": sentinel.course_id} + pyramid_request.parsed_params = { + "course_ids": sentinel.course_ids, + "h_userids": sentinel.h_userids, + } assignments = factories.Assignment.create_batch(5) get_page.return_value = assignments, sentinel.pagination @@ -24,7 +27,8 @@ def test_get_assignments( assignment_service.get_assignments.assert_called_once_with( instructor_h_userid=pyramid_request.user.h_userid, - course_id=sentinel.course_id, + course_ids=sentinel.course_ids, + h_userids=sentinel.h_userids, ) get_page.assert_called_once_with( pyramid_request,