From 38f155e8453263c17ee23d94f475bbded59f131a Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Mon, 24 Jun 2024 11:29:08 +0200 Subject: [PATCH] Courses endpoint with pagination API endpoint to drive the course filter/drop-down in the instructor dashboard. It exposes a list of courses the current user has access to with pagination. Responses include a pagination object with a next value. --- lms/js_config_types.py | 6 ++ lms/routes.py | 2 + lms/services/course.py | 36 +++++--- lms/views/dashboard/api/course.py | 89 +++++++++++++++++- tests/unit/lms/services/course_test.py | 4 +- .../lms/views/dashboard/api/course_test.py | 92 ++++++++++++++++++- 6 files changed, 209 insertions(+), 20 deletions(-) diff --git a/lms/js_config_types.py b/lms/js_config_types.py index 2ec04b3658..fc9118d095 100644 --- a/lms/js_config_types.py +++ b/lms/js_config_types.py @@ -7,6 +7,10 @@ from typing import NotRequired, TypedDict +class Pagination(TypedDict): + next: NotRequired[str] + + class AnnotationMetrics(TypedDict): annotations: int replies: int @@ -40,6 +44,8 @@ class APIStudent(TypedDict): class APICourses(TypedDict): courses: list[APICourse] + pagination: NotRequired[Pagination] + class APIAssignment(TypedDict): id: int diff --git a/lms/routes.py b/lms/routes.py index 9af59d407e..f02a6271b1 100644 --- a/lms/routes.py +++ b/lms/routes.py @@ -258,6 +258,8 @@ def includeme(config): # noqa: PLR0915 factory="lms.resources.dashboard.DashboardResource", ) + config.add_route("api.dashboard.courses", "/api/dashboard/courses") + config.add_route( "api.dashboard.organizations.courses", "/api/dashboard/organizations/{organization_public_id}/courses", diff --git a/lms/services/course.py b/lms/services/course.py index 026a07fd10..a9c0927c39 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -140,19 +140,33 @@ def search( # noqa: PLR0913, PLR0917 h_userid=h_userid, ).all() - def get_organization_courses( - self, organization: Organization, h_userid: str | None + def get_courses( + self, + h_userid: str | None, + organization: Organization | None = None, ): - courses_query = self._search_query( - organization_ids=[organization.id], - h_userid=h_userid, - limit=None, - ) - return ( + """Get a list of unique courses. + + :param organization: organization the courses belong to. + :param h_userid: only courses this user has access to. + """ + courses_query = ( + self._search_query( + organization_ids=[organization.id] if organization else None, + h_userid=h_userid, + limit=None, + ) # Deduplicate courses by authority_provided_id, take the last updated one - courses_query.distinct(Course.authority_provided_id) - .order_by(Course.authority_provided_id, Course.updated.desc()) - .all() + .distinct(Course.authority_provided_id).order_by( + Course.authority_provided_id, Course.updated.desc() + ) + # Only select the ID of the deduplicated courses + ).with_entities(Course.id) + + return ( + self._db.query(Course).filter(Course.id.in_(courses_query)) + # We can sort these again without affecting deduplication + .order_by(Course.lms_name.desc(), Course.id.desc()) ) def _deduplicated_course_assigments_query(self, courses: list[Course]): diff --git a/lms/views/dashboard/api/course.py b/lms/views/dashboard/api/course.py index 6c6dc31bcc..1414b6d0ed 100644 --- a/lms/views/dashboard/api/course.py +++ b/lms/views/dashboard/api/course.py @@ -8,11 +8,66 @@ APICourses, CourseMetrics, ) -from lms.models import RoleScope, RoleType +import json +import base64 +from lms.models import Course, RoleScope, RoleType from lms.security import Permissions from lms.services.h_api import HAPI from lms.services.organization import OrganizationService +from lms.validation._base import PyramidRequestSchema from lms.views.dashboard.base import get_request_course, get_request_organization +from marshmallow import fields, ValidationError, post_load + +from lms.js_config_types import Pagination + +PAGINATION_LIMIT = 100 +"""Maximum number of items to return in paginated endpoints""" + + +def build_courses_pagination(request, courses: list[Course]) -> Pagination | None: + """Build the pagination details for a list of courses. + + For the next parameter, we'll sort courses by both name and id, expose both as a tuple. + """ + if not courses: + return None + + last_element = courses[-1] + + cursor_data = (last_element.lms_name, last_element.id) + coursor_value = ( + base64.urlsafe_b64encode(json.dumps(cursor_data).encode("utf-8")) + .decode("utf-8") + .replace("=", "") + ) + return Pagination( + next=request.route_url( + "api.dashboard.courses", _query={"cursor": coursor_value} + ) + ) + + +class ListCoursesSchema(PyramidRequestSchema): + location = "query" + + limit = fields.Integer(required=False, load_default=PAGINATION_LIMIT) + """Maximum number of items to return.""" + + cursor = fields.Str() + """Position to return the elements from. This correponds to the `next` value in the repsonse.""" + + @post_load + def decode_cursor(self, in_data, **kwargs): + cursor = in_data.get("cursor") + if not cursor: + return in_data + + try: + in_data["cursor"] = json.loads(base64.urlsafe_b64decode(cursor + "===")) + except ValueError: + raise ValidationError("Invalid value for pagination cursor.") + + return in_data class CourseViews: @@ -22,6 +77,33 @@ def __init__(self, request) -> None: self.h_api = request.find_service(HAPI) self.organization_service = request.find_service(OrganizationService) + @view_config( + route_name="api.dashboard.courses", + request_method="GET", + renderer="json", + permission=Permissions.DASHBOARD_VIEW, + schema=ListCoursesSchema, + ) + def courses(self) -> APICourses: + courses = self.course_service.get_courses( + h_userid=self.request.user.h_userid if self.request.user else None, + ) + + limit = min(PAGINATION_LIMIT, self.request.parsed_params["limit"]) + if cursor_values := self.request.parsed_params.get("cursor"): + cursor_course_name, cursor_course_id = cursor_values + courses = courses.filter( + (Course.lms_name, Course.id) > (cursor_course_name, cursor_course_id) + ) + + courses = courses.limit(limit).all() + return { + "courses": [ + APICourse(id=course.id, title=course.lms_name) for course in courses + ], + "pagination": build_courses_pagination(self.request, courses), + } + @view_config( route_name="api.dashboard.organizations.courses", request_method="GET", @@ -30,8 +112,9 @@ def __init__(self, request) -> None: ) def organization_courses(self) -> APICourses: org = get_request_organization(self.request, self.organization_service) - courses = self.course_service.get_organization_courses( - org, h_userid=self.request.user.h_userid if self.request.user else None + courses = self.course_service.get_courses( + organization=org, + h_userid=self.request.user.h_userid if self.request.user else None, ) courses_assignment_counts = self.course_service.get_courses_assignments_count( courses diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 1f5150499a..729b133fd9 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -366,7 +366,7 @@ def test_get_members(self, svc, db_session): course, role_scope=lti_role.scope, role_type=lti_role.type ) == [user] - def test_get_organization_courses_deduplicates(self, db_session, svc): + def test_get_courses_deduplicates(self, db_session, svc): org = factories.Organization() ai = factories.ApplicationInstance(organization=org) @@ -388,7 +388,7 @@ def test_get_organization_courses_deduplicates(self, db_session, svc): assert set(svc.search(organization_ids=[org.id])) == {course, older_course} # But organization deduplicate, We only get the most recent course - assert svc.get_organization_courses(org, None) == [course] + assert svc.get_courses(organization=org, h_userid=None).all() == [course] def test_get_assignments(self, db_session, svc): course = factories.Course() diff --git a/tests/unit/lms/views/dashboard/api/course_test.py b/tests/unit/lms/views/dashboard/api/course_test.py index 18a27bf235..29076f7327 100644 --- a/tests/unit/lms/views/dashboard/api/course_test.py +++ b/tests/unit/lms/views/dashboard/api/course_test.py @@ -1,21 +1,82 @@ +import base64 + +from lms.validation import ValidationError +import json from unittest.mock import sentinel +from h_matchers import Any import pytest -from lms.views.dashboard.api.course import CourseViews +from lms.views.dashboard.api.course import ( + CourseViews, + build_courses_pagination, + ListCoursesSchema, +) from tests import factories +from lms.models import Course pytestmark = pytest.mark.usefixtures("course_service", "h_api", "organization_service") class TestCourseViews: + def test_build_courses_pagination_empty_result(self, pyramid_request): + assert not build_courses_pagination(pyramid_request, []) + + def test_get_courses(self, course_service, pyramid_request, views, db_session): + pyramid_request.parsed_params = {"limit": 100} + courses = factories.Course.create_batch(5) + course_service.get_courses.return_value = db_session.query(Course) + db_session.flush() + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [{"id": c.id, "title": c.lms_name} for c in courses], + "pagination": { + "next": Any.url.with_path("/api/dashboard/courses").with_query( + {"cursor": Any.string()} + ) + }, + } + + def test_get_courses_with_cursor( + self, course_service, pyramid_request, views, db_session + ): + courses = sorted(factories.Course.create_batch(10), key=lambda c: c.lms_name) + db_session.flush() + course_service.get_courses.return_value = db_session.query(Course).order_by( + Course.lms_name, Course.id + ) + + pyramid_request.parsed_params = { + "limit": 1, + "cursor": (courses[4].lms_name, courses[4].id), + } + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [{"id": c.id, "title": c.lms_name} for c in courses[5:6]], + "pagination": { + "next": Any.url.with_path("/api/dashboard/courses").with_query( + {"cursor": Any.string()} + ) + }, + } + def test_get_organization_courses( self, course_service, organization_service, pyramid_request, views, db_session ): org = factories.Organization() courses = factories.Course.create_batch(5) organization_service.get_by_public_id.return_value = org - course_service.get_organization_courses.return_value = courses + course_service.get_courses.return_value = courses pyramid_request.matchdict["organization_public_id"] = sentinel.public_id db_session.flush() @@ -24,12 +85,12 @@ def test_get_organization_courses( organization_service.get_by_public_id.assert_called_once_with( sentinel.public_id ) - course_service.get_organization_courses.assert_called_once_with( + course_service.get_courses.assert_called_once_with( organization=org, h_userid=pyramid_request.user.h_userid, ) course_service.get_courses_assignments_count.assert_called_once_with( - course_service.get_organization_courses.return_value + course_service.get_courses.return_value ) assert response == { @@ -132,3 +193,26 @@ def test_course_assignments( @pytest.fixture def views(self, pyramid_request): return CourseViews(pyramid_request) + + +class TestListCoursesSchema: + def test_limit_default(self, pyramid_request): + assert ListCoursesSchema(pyramid_request).parse() == {"limit": 100} + + def test_invalid_cursor(self, pyramid_request): + pyramid_request.GET = {"cursor": "NOPE"} + + with pytest.raises(ValidationError): + ListCoursesSchema(pyramid_request).parse() + + def test_cursor(self, pyramid_request): + pyramid_request.GET = { + "cursor": base64.urlsafe_b64encode( + json.dumps(("VALUE", "OTHER_VALUE")).encode("utf-8") + ).decode("utf-8") + } + + assert ListCoursesSchema(pyramid_request).parse() == { + "limit": 100, + "cursor": ["VALUE", "OTHER_VALUE"], + }