diff --git a/lms/js_config_types.py b/lms/js_config_types.py index ed0ba94975..6d72fc685e 100644 --- a/lms/js_config_types.py +++ b/lms/js_config_types.py @@ -51,13 +51,15 @@ class APICourses(TypedDict): class APIAssignment(TypedDict): id: int title: str - course: APICourse + course: NotRequired[APICourse] annotation_metrics: NotRequired[AnnotationMetrics] class APIAssignments(TypedDict): assignments: list[APIAssignment] + pagination: NotRequired[Pagination] + class APIStudents(TypedDict): students: list[APIStudent] diff --git a/lms/routes.py b/lms/routes.py index f02a6271b1..729f188442 100644 --- a/lms/routes.py +++ b/lms/routes.py @@ -277,3 +277,4 @@ def includeme(config): # noqa: PLR0915 "api.dashboard.course.assignments.stats", "/api/dashboard/courses/{course_id}/assignments/stats", ) + config.add_route("api.dashboard.assignments", "/api/dashboard/assignments") diff --git a/lms/services/assignment.py b/lms/services/assignment.py index 296b170234..ff0456d169 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -1,5 +1,6 @@ import logging +from sqlalchemy import Select, select from sqlalchemy.orm import Session, joinedload from lms.models import ( @@ -222,6 +223,23 @@ def get_members( .where(LTIRole.scope == role_scope, LTIRole.type == role_type) ] + def get_assignments(self, h_userid: str | None = None) -> Select[tuple[Assignment]]: + """Get a query to fetch assignments. + + :params: h_userid only return assignments the users is a member of. + """ + + assignments_query = select(Assignment) + + if h_userid: + assignments_query = ( + assignments_query.join(AssignmentMembership) + .join(User) + .where(User.h_userid == h_userid) + ) + + return assignments_query.order_by(Assignment.title, Assignment.id) + def factory(_context, request): return AssignmentService(db=request.db, misc_plugin=request.product.plugin.misc) diff --git a/lms/services/course.py b/lms/services/course.py index 4dcc43d449..9bdfb28e28 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -1,8 +1,7 @@ import json from copy import deepcopy -from sqlalchemy import Text, column, func, select -from sqlalchemy.orm import Query +from sqlalchemy import Select, Text, column, func, select from lms.db import full_text_match from lms.models import ( @@ -145,7 +144,7 @@ def get_courses( self, h_userid: str | None, organization: Organization | None = None, - ) -> Query: + ) -> Select[tuple[Course]]: """Get a list of unique courses. :param organization: organization the courses belong to. @@ -164,9 +163,11 @@ def get_courses( ).with_entities(Course.id) return ( - self._db.query(Course) - .filter(Course.id.in_(courses_query)) - # We can sort these again without affecting deduplication + select(Course) + .where( + Course.id.in_(courses_query) + # We can sort these again without affecting deduplication + ) .order_by(Course.lms_name, Course.id) ) diff --git a/lms/views/dashboard/api/assignment.py b/lms/views/dashboard/api/assignment.py index 1f60195a4d..eeb661321d 100644 --- a/lms/views/dashboard/api/assignment.py +++ b/lms/views/dashboard/api/assignment.py @@ -5,24 +5,55 @@ from lms.js_config_types import ( AnnotationMetrics, APIAssignment, + APIAssignments, APICourse, APIStudent, APIStudents, ) -from lms.models import RoleScope, RoleType +from lms.models import Assignment, RoleScope, RoleType from lms.security import Permissions from lms.services.h_api import HAPI from lms.views.dashboard.base import get_request_assignment +from lms.views.dashboard.pagination import PaginationParametersMixin, get_page LOG = logging.getLogger(__name__) +class ListAssignmentsSchema(PaginationParametersMixin): + """Query parameters to fetch a list of assignments. + + Only the pagination related ones from the mixin. + """ + + class AssignmentViews: def __init__(self, request) -> None: self.request = request self.h_api = request.find_service(HAPI) self.assignment_service = request.find_service(name="assignment") + @view_config( + route_name="api.dashboard.assignments", + request_method="GET", + renderer="json", + permission=Permissions.DASHBOARD_VIEW, + schema=ListAssignmentsSchema, + ) + def assignments(self) -> APIAssignments: + assignments = self.assignment_service.get_assignments( + h_userid=self.request.user.h_userid if self.request.user else None, + ) + assignments, pagination = get_page( + self.request, assignments, [Assignment.title, Assignment.id] + ) + return { + "assignments": [ + APIAssignment(id=assignment.id, title=assignment.title) + for assignment in assignments + ], + "pagination": pagination, + } + @view_config( route_name="api.dashboard.assignment", request_method="GET", diff --git a/lms/views/dashboard/api/course.py b/lms/views/dashboard/api/course.py index 8c2cc81bb5..6d6c23d902 100644 --- a/lms/views/dashboard/api/course.py +++ b/lms/views/dashboard/api/course.py @@ -1,8 +1,4 @@ -import json - -from marshmallow import ValidationError, fields, post_load from pyramid.view import view_config -from sqlalchemy.orm import Query from lms.js_config_types import ( AnnotationMetrics, @@ -11,62 +7,23 @@ APICourse, APICourses, CourseMetrics, - Pagination, ) from lms.models import Course, RoleScope, RoleType from lms.security import Permissions from lms.services.h_api import HAPI from lms.services.organization import OrganizationService -from lms.validation._base import PyramidRequestSchema from lms.views.dashboard.base import get_request_course, get_request_organization +from lms.views.dashboard.pagination import PaginationParametersMixin, get_page MAX_ITEMS_PER_PAGE = 100 """Maximum number of items to return in paginated endpoints""" -def get_courses_page( - request, courses_query: Query[Course], limit: int = MAX_ITEMS_PER_PAGE -) -> tuple[list[Course], Pagination]: - """Return the first page and pagination metadata from a course's query.""" - # Over fetch one element to check if need to calculate the next cursor - courses = courses_query.limit(limit + 1).all() - if not courses or len(courses) <= limit: - return courses, Pagination(next=None) - - courses = courses[0:limit] - last_element = courses[-1] - cursor_data = json.dumps([last_element.lms_name, last_element.id]) - next_url_query = {"cursor": cursor_data} - # Include query parameters in the original request so clients can use the next param verbatim. - if limit := request.params.get("limit"): - next_url_query["limit"] = limit - - return courses, Pagination( - next=request.route_url("api.dashboard.courses", _query=next_url_query) - ) - - -class ListCoursesSchema(PyramidRequestSchema): - location = "query" - - limit = fields.Integer(required=False, load_default=MAX_ITEMS_PER_PAGE) - """Maximum number of items to return.""" - - cursor = fields.Str() - """Position to return elements from.""" +class ListCoursesSchema(PaginationParametersMixin): + """Query parameters to fetch a list of courses. - @post_load - def decode_cursor(self, in_data, **_kwargs): - cursor = in_data.get("cursor") - if not cursor: - return in_data - - try: - in_data["cursor"] = json.loads(cursor) - except ValueError as exc: - raise ValidationError("Invalid value for pagination cursor.") from exc - - return in_data + Only the pagination related ones from the mixin. + """ class CourseViews: @@ -87,14 +44,9 @@ def courses(self) -> APICourses: courses = self.course_service.get_courses( h_userid=self.request.user.h_userid if self.request.user else None, ) - - limit = min(MAX_ITEMS_PER_PAGE, self.request.parsed_params["limit"]) - if cursor_values := self.request.parsed_params.get("cursor"): - cursor_course_name, cursor_course_id = cursor_values - courses = courses.filter( - (Course.lms_name, Course.id) > (cursor_course_name, cursor_course_id) - ) - courses, pagination = get_courses_page(self.request, courses, limit) + courses, pagination = get_page( + self.request, courses, [Course.lms_name, Course.id] + ) return { "courses": [ APICourse(id=course.id, title=course.lms_name) for course in courses diff --git a/lms/views/dashboard/pagination.py b/lms/views/dashboard/pagination.py new file mode 100644 index 0000000000..045fb28d44 --- /dev/null +++ b/lms/views/dashboard/pagination.py @@ -0,0 +1,81 @@ +import json +from typing import TypeVar +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from marshmallow import ValidationError, fields, post_load +from sqlalchemy import Select + +from lms.js_config_types import Pagination +from lms.models import Assignment, Course +from lms.validation._base import PyramidRequestSchema + +T = TypeVar("T", Course, Assignment) +"""Types for which support pagination.""" + + +MAX_ITEMS_PER_PAGE = 100 +"""Maximum number of items to return in paginated endpoints""" + + +def _get_cursor_value(items: list[T], cursor_columns: list) -> str: + last_element = items[-1] + # Get the relevant values from the last element on the page + values = [getattr(last_element, column.key) for column in cursor_columns] + return json.dumps(values) + + +def _get_next_url(current_url, cursor_value) -> str: + parsed_url = urlparse(current_url) + query_params = parse_qs(parsed_url.query) + + # Update the query parameter + query_params["cursor"] = [cursor_value] + + # Encode the query parameters back to a query string + new_query_string = urlencode(query_params, doseq=True) + + # Construct the new URL with the updated query string + return urlunparse(parsed_url._replace(query=new_query_string)) + + +def get_page( + request, items_query: Select[tuple[T]], cursor_columns: list +) -> tuple[list[T], Pagination]: + """Return the first page and pagination metadata from a query.""" + if cursor_values := request.parsed_params.get("cursor"): + # If we have a cursor only fetch the elements that follow + items_query = items_query.where(tuple(cursor_columns) > tuple(cursor_values)) # type: ignore + + limit = min(MAX_ITEMS_PER_PAGE, request.parsed_params["limit"]) + # Over fetch one element to check if need to calculate the next cursor + items = request.db.scalars(items_query.limit(limit + 1)).all() + if not items or len(items) <= limit: + # No elements or no next page, no pagination.next + return items, Pagination(next=None) + items = items[0:limit] + + cursor_value = _get_cursor_value(items, cursor_columns) + return items, Pagination(next=_get_next_url(request.url, cursor_value)) + + +class PaginationParametersMixin(PyramidRequestSchema): + location = "query" + + limit = fields.Integer(required=False, load_default=MAX_ITEMS_PER_PAGE) + """Maximum number of items to return.""" + + cursor = fields.Str() + """Position to return elements from.""" + + @post_load + def decode_cursor(self, in_data, **_kwargs): + cursor = in_data.get("cursor") + if not cursor: + return in_data + + try: + in_data["cursor"] = json.loads(cursor) + except ValueError as exc: + raise ValidationError("Invalid value for pagination cursor.") from exc + + return in_data diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index 5c7e40d044..be093116b7 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -260,6 +260,23 @@ def test_get_members(self, svc, db_session): assignment, role_scope=lti_role.scope, role_type=lti_role.type ) == [user] + def test_get_assignments(self, svc, db_session): + assert db_session.scalars(svc.get_assignments()).all() + + def test_get_assignments_with_h_userid(self, svc, db_session): + factories.User() # User not in assignment + assignment = factories.Assignment() + user = factories.User() + lti_role = factories.LTIRole(scope=RoleScope.COURSE) + factories.AssignmentMembership.create( + assignment=assignment, user=user, lti_role=lti_role + ) + db_session.flush() + + assert ( + db_session.scalars(svc.get_assignments(user.h_userid)).one() == assignment + ) + @pytest.fixture def svc(self, db_session, misc_plugin): return AssignmentService(db_session, misc_plugin) diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 729b133fd9..345fdecd9e 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -388,7 +388,9 @@ def test_get_courses_deduplicates(self, db_session, svc): assert set(svc.search(organization_ids=[org.id])) == {course, older_course} # But organization deduplicate, We only get the most recent course - assert svc.get_courses(organization=org, h_userid=None).all() == [course] + assert db_session.scalars( + svc.get_courses(organization=org, h_userid=None) + ).all() == [course] def test_get_assignments(self, db_session, svc): course = factories.Course() diff --git a/tests/unit/lms/views/dashboard/api/assignment_test.py b/tests/unit/lms/views/dashboard/api/assignment_test.py index 2d1b756416..850f94433d 100644 --- a/tests/unit/lms/views/dashboard/api/assignment_test.py +++ b/tests/unit/lms/views/dashboard/api/assignment_test.py @@ -2,6 +2,7 @@ import pytest +from lms.models import Assignment from lms.views.dashboard.api.assignment import AssignmentViews from tests import factories @@ -9,6 +10,27 @@ class TestAssignmentViews: + def test_get_assignments( + self, assignment_service, pyramid_request, views, get_page + ): + assignments = factories.Assignment.create_batch(5) + get_page.return_value = assignments, sentinel.pagination + + response = views.assignments() + + assignment_service.get_assignments.assert_called_once_with( + pyramid_request.user.h_userid + ) + get_page.assert_called_once_with( + pyramid_request, + assignment_service.get_assignments.return_value, + [Assignment.title, Assignment.id], + ) + assert response == { + "assignments": [{"id": a.id, "title": a.title} for a in assignments], + "pagination": sentinel.pagination, + } + def test_assignment( self, views, pyramid_request, assignment_service, course, assignment, db_session ): @@ -110,6 +132,10 @@ def views(self, pyramid_request): def course(self): return factories.Course() + @pytest.fixture + def get_page(self, patch): + return patch("lms.views.dashboard.api.assignment.get_page") + @pytest.fixture def assignment(self, course): assignment = factories.Assignment() diff --git a/tests/unit/lms/views/dashboard/api/course_test.py b/tests/unit/lms/views/dashboard/api/course_test.py index 5dd60c27be..46259813f3 100644 --- a/tests/unit/lms/views/dashboard/api/course_test.py +++ b/tests/unit/lms/views/dashboard/api/course_test.py @@ -1,105 +1,32 @@ -import json from unittest.mock import sentinel import pytest -from h_matchers import Any from lms.models import Course -from lms.validation import ValidationError -from lms.views.dashboard.api.course import CourseViews, ListCoursesSchema +from lms.views.dashboard.api.course import CourseViews from tests import factories pytestmark = pytest.mark.usefixtures("course_service", "h_api", "organization_service") class TestCourseViews: - def test_get_courses(self, course_service, pyramid_request, views, db_session): - pyramid_request.parsed_params = {"limit": 100} + def test_get_courses(self, course_service, pyramid_request, views, get_page): courses = factories.Course.create_batch(5) - course_service.get_courses.return_value = db_session.query(Course) - db_session.flush() - - response = views.courses() - - course_service.get_courses.assert_called_once_with( - h_userid=pyramid_request.user.h_userid, - ) - assert response == { - "courses": [{"id": c.id, "title": c.lms_name} for c in courses], - "pagination": {"next": None}, - } - - def test_get_courses_empty( - self, course_service, pyramid_request, views, db_session - ): - pyramid_request.parsed_params = {"limit": 100} - course_service.get_courses.return_value = db_session.query(Course) - - response = views.courses() - - course_service.get_courses.assert_called_once_with( - h_userid=pyramid_request.user.h_userid, - ) - assert response == { - "courses": [], - "pagination": {"next": None}, - } - - def test_get_courses_with_cursor( - self, course_service, pyramid_request, views, db_session - ): - courses = sorted(factories.Course.create_batch(10), key=lambda c: c.lms_name) - db_session.flush() - course_service.get_courses.return_value = db_session.query(Course).order_by( - Course.lms_name, Course.id - ) - - pyramid_request.params = {"limit": 1} - pyramid_request.parsed_params = { - "limit": 1, - "cursor": (courses[4].lms_name, courses[4].id), - } + get_page.return_value = courses, sentinel.pagination response = views.courses() course_service.get_courses.assert_called_once_with( - h_userid=pyramid_request.user.h_userid, + pyramid_request.user.h_userid ) - assert response == { - "courses": [{"id": c.id, "title": c.lms_name} for c in courses[5:6]], - "pagination": { - "next": Any.url.with_path("/api/dashboard/courses").with_query( - {"cursor": Any.string(), "limit": "1"} - ) - }, - } - - def test_get_courses_next_doesnt_include_limit_if_not_in_original_request( - self, course_service, pyramid_request, views, db_session - ): - courses = sorted(factories.Course.create_batch(10), key=lambda c: c.lms_name) - db_session.flush() - course_service.get_courses.return_value = db_session.query(Course).order_by( - Course.lms_name, Course.id - ) - - pyramid_request.parsed_params = { - "limit": 1, - "cursor": (courses[4].lms_name, courses[4].id), - } - - response = views.courses() - - course_service.get_courses.assert_called_once_with( - h_userid=pyramid_request.user.h_userid, + get_page.assert_called_once_with( + pyramid_request, + course_service.get_courses.return_value, + [Course.lms_name, Course.id], ) assert response == { - "courses": [{"id": c.id, "title": c.lms_name} for c in courses[5:6]], - "pagination": { - "next": Any.url.with_path("/api/dashboard/courses").with_query( - {"cursor": Any.string()} - ) - }, + "courses": [{"id": c.id, "title": c.lms_name} for c in courses], + "pagination": sentinel.pagination, } def test_get_organization_courses( @@ -226,21 +153,6 @@ def test_course_assignments( def views(self, pyramid_request): return CourseViews(pyramid_request) - -class TestListCoursesSchema: - def test_limit_default(self, pyramid_request): - assert ListCoursesSchema(pyramid_request).parse() == {"limit": 100} - - def test_invalid_cursor(self, pyramid_request): - pyramid_request.GET = {"cursor": "NOPE"} - - with pytest.raises(ValidationError): - ListCoursesSchema(pyramid_request).parse() - - def test_cursor(self, pyramid_request): - pyramid_request.GET = {"cursor": json.dumps(("VALUE", "OTHER_VALUE"))} - - assert ListCoursesSchema(pyramid_request).parse() == { - "limit": 100, - "cursor": ["VALUE", "OTHER_VALUE"], - } + @pytest.fixture + def get_page(self, patch): + return patch("lms.views.dashboard.api.course.get_page") diff --git a/tests/unit/lms/views/dashboard/pagination_test.py b/tests/unit/lms/views/dashboard/pagination_test.py new file mode 100644 index 0000000000..486430f479 --- /dev/null +++ b/tests/unit/lms/views/dashboard/pagination_test.py @@ -0,0 +1,78 @@ +import json + +import pytest +from h_matchers import Any +from sqlalchemy import select + +from lms.js_config_types import Pagination +from lms.models import Course +from lms.validation import ValidationError +from lms.views.dashboard.pagination import PaginationParametersMixin, get_page +from tests import factories + + +class TestGetPage: + def test_it_no_next_page(self, pyramid_request, db_session): + pyramid_request.parsed_params = {"limit": 100} + courses = factories.Course.create_batch(5) + query = select(Course) + db_session.flush() + + items, pagination = get_page(pyramid_request, query, (Course.id,)) + + assert items == courses + assert pagination == Pagination(next=None) + + def test_it_empty(self, pyramid_request): + pyramid_request.parsed_params = {"limit": 100} + query = select(Course).where(False) + + items, pagination = get_page(pyramid_request, query, (Course.id,)) + + assert items == [] + assert pagination == Pagination(next=None) + + def test_it_calculates_next(self, pyramid_request, db_session): + pyramid_request.parsed_params = {"limit": 1} + courses = factories.Course.create_batch(5) + query = select(Course) + db_session.flush() + + items, pagination = get_page(pyramid_request, query, (Course.id,)) + + assert items == courses[0:1] + assert pagination == Pagination( + next=Any.url.with_query({"cursor": json.dumps([courses[0].id])}) + ) + + def test_it_filters_by_cursor(self, pyramid_request, db_session): + courses = factories.Course.create_batch(5) + query = select(Course) + db_session.flush() + pyramid_request.parsed_params = { + "cursor": [courses[0].id, courses[0].lms_name], + "limit": 1, + } + + items, _ = get_page(pyramid_request, query, (Course.id, Course.lms_name)) + + assert items == courses[1:2] + + +class TestPaginationParametersMixin: + def test_limit_default(self, pyramid_request): + assert PaginationParametersMixin(pyramid_request).parse() == {"limit": 100} + + def test_invalid_cursor(self, pyramid_request): + pyramid_request.GET = {"cursor": "NOPE"} + + with pytest.raises(ValidationError): + PaginationParametersMixin(pyramid_request).parse() + + def test_cursor(self, pyramid_request): + pyramid_request.GET = {"cursor": json.dumps(("VALUE", "OTHER_VALUE"))} + + assert PaginationParametersMixin(pyramid_request).parse() == { + "limit": 100, + "cursor": ["VALUE", "OTHER_VALUE"], + }