diff --git a/lms/js_config_types.py b/lms/js_config_types.py index 2ec04b3658..5ec82812d0 100644 --- a/lms/js_config_types.py +++ b/lms/js_config_types.py @@ -7,6 +7,10 @@ from typing import NotRequired, TypedDict +class Pagination(TypedDict): + next: str | None + + class AnnotationMetrics(TypedDict): annotations: int replies: int @@ -40,6 +44,8 @@ class APIStudent(TypedDict): class APICourses(TypedDict): courses: list[APICourse] + pagination: NotRequired[Pagination] + class APIAssignment(TypedDict): id: int diff --git a/lms/routes.py b/lms/routes.py index 9af59d407e..f02a6271b1 100644 --- a/lms/routes.py +++ b/lms/routes.py @@ -258,6 +258,8 @@ def includeme(config): # noqa: PLR0915 factory="lms.resources.dashboard.DashboardResource", ) + config.add_route("api.dashboard.courses", "/api/dashboard/courses") + config.add_route( "api.dashboard.organizations.courses", "/api/dashboard/organizations/{organization_public_id}/courses", diff --git a/lms/services/course.py b/lms/services/course.py index 026a07fd10..4224f92b9b 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -140,19 +140,33 @@ def search( # noqa: PLR0913, PLR0917 h_userid=h_userid, ).all() - def get_organization_courses( - self, organization: Organization, h_userid: str | None + def get_courses( + self, + h_userid: str | None, + organization: Organization | None = None, ): - courses_query = self._search_query( - organization_ids=[organization.id], - h_userid=h_userid, - limit=None, - ) - return ( + """Get a list of unique courses. + + :param organization: organization the courses belong to. + :param h_userid: only courses this user has access to. + """ + courses_query = ( + self._search_query( + organization_ids=[organization.id] if organization else None, + h_userid=h_userid, + limit=None, + ) # Deduplicate courses by authority_provided_id, take the last updated one - courses_query.distinct(Course.authority_provided_id) + .distinct(Course.authority_provided_id) .order_by(Course.authority_provided_id, Course.updated.desc()) - .all() + # Only select the ID of the deduplicated courses + ).with_entities(Course.id) + + return ( + self._db.query(Course) + .filter(Course.id.in_(courses_query)) + # We can sort these again without affecting deduplication + .order_by(Course.lms_name, Course.id) ) def _deduplicated_course_assigments_query(self, courses: list[Course]): diff --git a/lms/views/dashboard/api/course.py b/lms/views/dashboard/api/course.py index 6c6dc31bcc..7cbebc4375 100644 --- a/lms/views/dashboard/api/course.py +++ b/lms/views/dashboard/api/course.py @@ -1,3 +1,7 @@ +import base64 +import json + +from marshmallow import ValidationError, fields, post_load from pyramid.view import view_config from lms.js_config_types import ( @@ -7,13 +11,71 @@ APICourse, APICourses, CourseMetrics, + Pagination, ) -from lms.models import RoleScope, RoleType +from lms.models import Course, RoleScope, RoleType from lms.security import Permissions from lms.services.h_api import HAPI from lms.services.organization import OrganizationService +from lms.validation._base import PyramidRequestSchema from lms.views.dashboard.base import get_request_course, get_request_organization +PAGINATION_LIMIT = 100 +"""Maximum number of items to return in paginated endpoints""" + + +def build_courses_pagination( + request, courses, limit: int = PAGINATION_LIMIT +) -> tuple[list[Course], Pagination]: + """Build the pagination details for a list of courses. + + For the next parameter, we'll sort courses by both name and id, expose both as a tuple. + """ + # Over fetch one element to check if need to calculate the next cursor + courses = courses.limit(limit + 1).all() + if not courses or len(courses) <= limit: + return courses, Pagination(next=None) + + courses = courses[0:limit] + last_element = courses[-1] + cursor_data = (last_element.lms_name, last_element.id) + coursor_value = ( + base64.urlsafe_b64encode(json.dumps(cursor_data).encode("utf-8")) + .decode("utf-8") + .replace("=", "") + ) + next_url_query = {"cursor": coursor_value} + # Include query parameters in the original request so clients can use the next param verbatim. + if limit := request.params.get("limit"): + next_url_query["limit"] = limit + + return courses, Pagination( + next=request.route_url("api.dashboard.courses", _query=next_url_query) + ) + + +class ListCoursesSchema(PyramidRequestSchema): + location = "query" + + limit = fields.Integer(required=False, load_default=PAGINATION_LIMIT) + """Maximum number of items to return.""" + + cursor = fields.Str() + """Position to return the elements from. This correponds to the `next` value in the repsonse.""" + + @post_load + def decode_cursor(self, in_data, **_kwargs): + cursor = in_data.get("cursor") + if not cursor: + return in_data + + try: + in_data["cursor"] = json.loads(base64.urlsafe_b64decode(cursor + "===")) + except ValueError as exc: + raise ValidationError("Invalid value for pagination cursor.") from exc + + return in_data + class CourseViews: def __init__(self, request) -> None: @@ -22,6 +84,32 @@ def __init__(self, request) -> None: self.h_api = request.find_service(HAPI) self.organization_service = request.find_service(OrganizationService) + @view_config( + route_name="api.dashboard.courses", + request_method="GET", + renderer="json", + permission=Permissions.DASHBOARD_VIEW, + schema=ListCoursesSchema, + ) + def courses(self) -> APICourses: + courses = self.course_service.get_courses( + h_userid=self.request.user.h_userid if self.request.user else None, + ) + + limit = min(PAGINATION_LIMIT, self.request.parsed_params["limit"]) + if cursor_values := self.request.parsed_params.get("cursor"): + cursor_course_name, cursor_course_id = cursor_values + courses = courses.filter( + (Course.lms_name, Course.id) > (cursor_course_name, cursor_course_id) + ) + courses, pagination = build_courses_pagination(self.request, courses, limit) + return { + "courses": [ + APICourse(id=course.id, title=course.lms_name) for course in courses + ], + "pagination": pagination, + } + @view_config( route_name="api.dashboard.organizations.courses", request_method="GET", @@ -30,8 +118,9 @@ def __init__(self, request) -> None: ) def organization_courses(self) -> APICourses: org = get_request_organization(self.request, self.organization_service) - courses = self.course_service.get_organization_courses( - org, h_userid=self.request.user.h_userid if self.request.user else None + courses = self.course_service.get_courses( + organization=org, + h_userid=self.request.user.h_userid if self.request.user else None, ) courses_assignment_counts = self.course_service.get_courses_assignments_count( courses diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 1f5150499a..729b133fd9 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -366,7 +366,7 @@ def test_get_members(self, svc, db_session): course, role_scope=lti_role.scope, role_type=lti_role.type ) == [user] - def test_get_organization_courses_deduplicates(self, db_session, svc): + def test_get_courses_deduplicates(self, db_session, svc): org = factories.Organization() ai = factories.ApplicationInstance(organization=org) @@ -388,7 +388,7 @@ def test_get_organization_courses_deduplicates(self, db_session, svc): assert set(svc.search(organization_ids=[org.id])) == {course, older_course} # But organization deduplicate, We only get the most recent course - assert svc.get_organization_courses(org, None) == [course] + assert svc.get_courses(organization=org, h_userid=None).all() == [course] def test_get_assignments(self, db_session, svc): course = factories.Course() diff --git a/tests/unit/lms/views/dashboard/api/course_test.py b/tests/unit/lms/views/dashboard/api/course_test.py index 18a27bf235..0e3addd82f 100644 --- a/tests/unit/lms/views/dashboard/api/course_test.py +++ b/tests/unit/lms/views/dashboard/api/course_test.py @@ -1,21 +1,119 @@ +import base64 +import json from unittest.mock import sentinel import pytest - -from lms.views.dashboard.api.course import CourseViews +from h_matchers import Any + +from lms.models import Course +from lms.validation import ValidationError +from lms.views.dashboard.api.course import ( + CourseViews, + ListCoursesSchema, + build_courses_pagination, +) from tests import factories pytestmark = pytest.mark.usefixtures("course_service", "h_api", "organization_service") class TestCourseViews: + def test_get_courses(self, course_service, pyramid_request, views, db_session): + pyramid_request.parsed_params = {"limit": 100} + courses = factories.Course.create_batch(5) + course_service.get_courses.return_value = db_session.query(Course) + db_session.flush() + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [{"id": c.id, "title": c.lms_name} for c in courses], + "pagination": {"next": None}, + } + + def test_get_courses_empty( + self, course_service, pyramid_request, views, db_session + ): + pyramid_request.parsed_params = {"limit": 100} + course_service.get_courses.return_value = db_session.query(Course) + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [], + "pagination": {"next": None}, + } + + def test_get_courses_with_cursor( + self, course_service, pyramid_request, views, db_session + ): + courses = sorted(factories.Course.create_batch(10), key=lambda c: c.lms_name) + db_session.flush() + course_service.get_courses.return_value = db_session.query(Course).order_by( + Course.lms_name, Course.id + ) + + pyramid_request.params = {"limit": 1} + pyramid_request.parsed_params = { + "limit": 1, + "cursor": (courses[4].lms_name, courses[4].id), + } + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [{"id": c.id, "title": c.lms_name} for c in courses[5:6]], + "pagination": { + "next": Any.url.with_path("/api/dashboard/courses").with_query( + {"cursor": Any.string(), "limit": "1"} + ) + }, + } + + def test_get_courses_next_doesnt_include_limit_if_not_in_original_request( + self, course_service, pyramid_request, views, db_session + ): + courses = sorted(factories.Course.create_batch(10), key=lambda c: c.lms_name) + db_session.flush() + course_service.get_courses.return_value = db_session.query(Course).order_by( + Course.lms_name, Course.id + ) + + pyramid_request.parsed_params = { + "limit": 1, + "cursor": (courses[4].lms_name, courses[4].id), + } + + response = views.courses() + + course_service.get_courses.assert_called_once_with( + h_userid=pyramid_request.user.h_userid, + ) + assert response == { + "courses": [{"id": c.id, "title": c.lms_name} for c in courses[5:6]], + "pagination": { + "next": Any.url.with_path("/api/dashboard/courses").with_query( + {"cursor": Any.string()} + ) + }, + } + def test_get_organization_courses( self, course_service, organization_service, pyramid_request, views, db_session ): org = factories.Organization() courses = factories.Course.create_batch(5) organization_service.get_by_public_id.return_value = org - course_service.get_organization_courses.return_value = courses + course_service.get_courses.return_value = courses pyramid_request.matchdict["organization_public_id"] = sentinel.public_id db_session.flush() @@ -24,12 +122,12 @@ def test_get_organization_courses( organization_service.get_by_public_id.assert_called_once_with( sentinel.public_id ) - course_service.get_organization_courses.assert_called_once_with( + course_service.get_courses.assert_called_once_with( organization=org, h_userid=pyramid_request.user.h_userid, ) course_service.get_courses_assignments_count.assert_called_once_with( - course_service.get_organization_courses.return_value + course_service.get_courses.return_value ) assert response == { @@ -132,3 +230,26 @@ def test_course_assignments( @pytest.fixture def views(self, pyramid_request): return CourseViews(pyramid_request) + + +class TestListCoursesSchema: + def test_limit_default(self, pyramid_request): + assert ListCoursesSchema(pyramid_request).parse() == {"limit": 100} + + def test_invalid_cursor(self, pyramid_request): + pyramid_request.GET = {"cursor": "NOPE"} + + with pytest.raises(ValidationError): + ListCoursesSchema(pyramid_request).parse() + + def test_cursor(self, pyramid_request): + pyramid_request.GET = { + "cursor": base64.urlsafe_b64encode( + json.dumps(("VALUE", "OTHER_VALUE")).encode("utf-8") + ).decode("utf-8") + } + + assert ListCoursesSchema(pyramid_request).parse() == { + "limit": 100, + "cursor": ["VALUE", "OTHER_VALUE"], + }