Skip to content

Commit

Permalink
Paginated assignments API endpoint
Browse files Browse the repository at this point in the history
The main goal is to apply the same pagination method we just introduced
for courses to a new entity and generalize the code around to avoid
duplication.

- Move dashboard pagination to its own module
- Handle the limit request parameter directly in get_page
- Replace query parameters in the original URL instead of constructing it from scratch
- Use sqlalchemy.Select instead of sqlalchemy.Query. The latter is
  deprecated and we should move all API to the new SQLAlchemy 2.0

And finally:

- Paginated API endpoint for assignments

This is a barebones endpoint, it doesn't yet accept any query
parameters for for example filter by course.
  • Loading branch information
marcospri committed Jun 27, 2024
1 parent 904a582 commit 5128d35
Show file tree
Hide file tree
Showing 12 changed files with 285 additions and 166 deletions.
4 changes: 3 additions & 1 deletion lms/js_config_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ class APICourses(TypedDict):
class APIAssignment(TypedDict):
id: int
title: str
course: APICourse
course: NotRequired[APICourse]
annotation_metrics: NotRequired[AnnotationMetrics]


class APIAssignments(TypedDict):
assignments: list[APIAssignment]

pagination: NotRequired[Pagination]


class APIStudents(TypedDict):
students: list[APIStudent]
Expand Down
1 change: 1 addition & 0 deletions lms/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,4 @@ def includeme(config): # noqa: PLR0915
"api.dashboard.course.assignments.stats",
"/api/dashboard/courses/{course_id}/assignments/stats",
)
config.add_route("api.dashboard.assignments", "/api/dashboard/assignments")
18 changes: 18 additions & 0 deletions lms/services/assignment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from sqlalchemy import Select, select
from sqlalchemy.orm import Session, joinedload

from lms.models import (
Expand Down Expand Up @@ -222,6 +223,23 @@ def get_members(
.where(LTIRole.scope == role_scope, LTIRole.type == role_type)
]

def get_assignments(self, h_userid: str | None = None) -> Select[tuple[Assignment]]:
"""Get a query to fetch assignments.
:params: h_userid only return assignments the users is a member of.
"""

assignments_query = select(Assignment)

if h_userid:
assignments_query = (
assignments_query.join(AssignmentMembership)
.join(User)
.where(User.h_userid == h_userid)
)

return assignments_query.order_by(Assignment.title, Assignment.id)


def factory(_context, request):
return AssignmentService(db=request.db, misc_plugin=request.product.plugin.misc)
13 changes: 7 additions & 6 deletions lms/services/course.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from copy import deepcopy

from sqlalchemy import Text, column, func, select
from sqlalchemy.orm import Query
from sqlalchemy import Select, Text, column, func, select

from lms.db import full_text_match
from lms.models import (
Expand Down Expand Up @@ -145,7 +144,7 @@ def get_courses(
self,
h_userid: str | None,
organization: Organization | None = None,
) -> Query:
) -> Select[tuple[Course]]:
"""Get a list of unique courses.
:param organization: organization the courses belong to.
Expand All @@ -164,9 +163,11 @@ def get_courses(
).with_entities(Course.id)

return (
self._db.query(Course)
.filter(Course.id.in_(courses_query))
# We can sort these again without affecting deduplication
select(Course)
.where(
Course.id.in_(courses_query)
# We can sort these again without affecting deduplication
)
.order_by(Course.lms_name, Course.id)
)

Expand Down
33 changes: 32 additions & 1 deletion lms/views/dashboard/api/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,55 @@
from lms.js_config_types import (
AnnotationMetrics,
APIAssignment,
APIAssignments,
APICourse,
APIStudent,
APIStudents,
)
from lms.models import RoleScope, RoleType
from lms.models import Assignment, RoleScope, RoleType
from lms.security import Permissions
from lms.services.h_api import HAPI
from lms.views.dashboard.base import get_request_assignment
from lms.views.dashboard.pagination import PaginationParametersMixin, get_page

LOG = logging.getLogger(__name__)


class ListAssignmentsSchema(PaginationParametersMixin):
"""Query parameters to fetch a list of assignments.
Only the pagination related ones from the mixin.
"""


class AssignmentViews:
def __init__(self, request) -> None:
self.request = request
self.h_api = request.find_service(HAPI)
self.assignment_service = request.find_service(name="assignment")

@view_config(
route_name="api.dashboard.assignments",
request_method="GET",
renderer="json",
permission=Permissions.DASHBOARD_VIEW,
schema=ListAssignmentsSchema,
)
def assignments(self) -> APIAssignments:
assignments = self.assignment_service.get_assignments(
h_userid=self.request.user.h_userid if self.request.user else None,
)
assignments, pagination = get_page(
self.request, assignments, [Assignment.title, Assignment.id]
)
return {
"assignments": [
APIAssignment(id=assignment.id, title=assignment.title)
for assignment in assignments
],
"pagination": pagination,
}

@view_config(
route_name="api.dashboard.assignment",
request_method="GET",
Expand Down
64 changes: 8 additions & 56 deletions lms/views/dashboard/api/course.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import json

from marshmallow import ValidationError, fields, post_load
from pyramid.view import view_config
from sqlalchemy.orm import Query

from lms.js_config_types import (
AnnotationMetrics,
Expand All @@ -11,62 +7,23 @@
APICourse,
APICourses,
CourseMetrics,
Pagination,
)
from lms.models import Course, RoleScope, RoleType
from lms.security import Permissions
from lms.services.h_api import HAPI
from lms.services.organization import OrganizationService
from lms.validation._base import PyramidRequestSchema
from lms.views.dashboard.base import get_request_course, get_request_organization
from lms.views.dashboard.pagination import PaginationParametersMixin, get_page

MAX_ITEMS_PER_PAGE = 100
"""Maximum number of items to return in paginated endpoints"""


def get_courses_page(
request, courses_query: Query[Course], limit: int = MAX_ITEMS_PER_PAGE
) -> tuple[list[Course], Pagination]:
"""Return the first page and pagination metadata from a course's query."""
# Over fetch one element to check if need to calculate the next cursor
courses = courses_query.limit(limit + 1).all()
if not courses or len(courses) <= limit:
return courses, Pagination(next=None)

courses = courses[0:limit]
last_element = courses[-1]
cursor_data = json.dumps([last_element.lms_name, last_element.id])
next_url_query = {"cursor": cursor_data}
# Include query parameters in the original request so clients can use the next param verbatim.
if limit := request.params.get("limit"):
next_url_query["limit"] = limit

return courses, Pagination(
next=request.route_url("api.dashboard.courses", _query=next_url_query)
)


class ListCoursesSchema(PyramidRequestSchema):
location = "query"

limit = fields.Integer(required=False, load_default=MAX_ITEMS_PER_PAGE)
"""Maximum number of items to return."""

cursor = fields.Str()
"""Position to return elements from."""
class ListCoursesSchema(PaginationParametersMixin):
"""Query parameters to fetch a list of courses.
@post_load
def decode_cursor(self, in_data, **_kwargs):
cursor = in_data.get("cursor")
if not cursor:
return in_data

try:
in_data["cursor"] = json.loads(cursor)
except ValueError as exc:
raise ValidationError("Invalid value for pagination cursor.") from exc

return in_data
Only the pagination related ones from the mixin.
"""


class CourseViews:
Expand All @@ -87,14 +44,9 @@ def courses(self) -> APICourses:
courses = self.course_service.get_courses(
h_userid=self.request.user.h_userid if self.request.user else None,
)

limit = min(MAX_ITEMS_PER_PAGE, self.request.parsed_params["limit"])
if cursor_values := self.request.parsed_params.get("cursor"):
cursor_course_name, cursor_course_id = cursor_values
courses = courses.filter(
(Course.lms_name, Course.id) > (cursor_course_name, cursor_course_id)
)
courses, pagination = get_courses_page(self.request, courses, limit)
courses, pagination = get_page(
self.request, courses, [Course.lms_name, Course.id]
)
return {
"courses": [
APICourse(id=course.id, title=course.lms_name) for course in courses
Expand Down
79 changes: 79 additions & 0 deletions lms/views/dashboard/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
from typing import TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse

from marshmallow import ValidationError, fields, post_load
from sqlalchemy import Select

from lms.js_config_types import Pagination
from lms.models import Assignment, Course
from lms.validation._base import PyramidRequestSchema

T = TypeVar("T", Course, Assignment)
"""Types for which support pagination."""


MAX_ITEMS_PER_PAGE = 100
"""Maximum number of items to return in paginated endpoints"""


def _get_cursor_value(items: list[T], cursor_columns: list) -> str:
last_element = items[-1]
# Get the relevant values from the last element on the page
values = [getattr(last_element, column.key) for column in cursor_columns]
return json.dumps(values)


def _get_next_url(current_url, cursor_value) -> str:
"""Add or replace the cursor value on top of any other query parameters present in current_url."""
parsed_url = urlparse(current_url)
query_params = parse_qs(parsed_url.query)

query_params["cursor"] = [cursor_value]

new_query_string = urlencode(query_params, doseq=True)

return urlunparse(parsed_url._replace(query=new_query_string))


def get_page(
request, items_query: Select[tuple[T]], cursor_columns: list
) -> tuple[list[T], Pagination]:
"""Return the first page and pagination metadata from a query."""
if cursor_values := request.parsed_params.get("cursor"):
# If we have a cursor only fetch the elements that follow
items_query = items_query.where(tuple(cursor_columns) > tuple(cursor_values)) # type: ignore

limit = min(MAX_ITEMS_PER_PAGE, request.parsed_params["limit"])
# Over fetch one element to check if need to calculate the next cursor
items = request.db.scalars(items_query.limit(limit + 1)).all()
if not items or len(items) <= limit:
# No elements or no next page, no pagination.next
return items, Pagination(next=None)
items = items[0:limit]

cursor_value = _get_cursor_value(items, cursor_columns)
return items, Pagination(next=_get_next_url(request.url, cursor_value))


class PaginationParametersMixin(PyramidRequestSchema):
location = "query"

limit = fields.Integer(required=False, load_default=MAX_ITEMS_PER_PAGE)
"""Maximum number of items to return."""

cursor = fields.Str()
"""Position to return elements from."""

@post_load
def decode_cursor(self, in_data, **_kwargs):
cursor = in_data.get("cursor")
if not cursor:
return in_data

try:
in_data["cursor"] = json.loads(cursor)
except ValueError as exc:
raise ValidationError("Invalid value for pagination cursor.") from exc

return in_data
17 changes: 17 additions & 0 deletions tests/unit/lms/services/assignment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,23 @@ def test_get_members(self, svc, db_session):
assignment, role_scope=lti_role.scope, role_type=lti_role.type
) == [user]

def test_get_assignments(self, svc, db_session):
assert db_session.scalars(svc.get_assignments()).all()

def test_get_assignments_with_h_userid(self, svc, db_session):
factories.User() # User not in assignment
assignment = factories.Assignment()
user = factories.User()
lti_role = factories.LTIRole(scope=RoleScope.COURSE)
factories.AssignmentMembership.create(
assignment=assignment, user=user, lti_role=lti_role
)
db_session.flush()

assert (
db_session.scalars(svc.get_assignments(user.h_userid)).one() == assignment
)

@pytest.fixture
def svc(self, db_session, misc_plugin):
return AssignmentService(db_session, misc_plugin)
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/lms/services/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def test_get_courses_deduplicates(self, db_session, svc):
assert set(svc.search(organization_ids=[org.id])) == {course, older_course}

# But organization deduplicate, We only get the most recent course
assert svc.get_courses(organization=org, h_userid=None).all() == [course]
assert db_session.scalars(
svc.get_courses(organization=org, h_userid=None)
).all() == [course]

def test_get_assignments(self, db_session, svc):
course = factories.Course()
Expand Down
Loading

0 comments on commit 5128d35

Please sign in to comment.