Skip to content

Commit

Permalink
Enforce dashboard admin access for assignments and users
Browse files Browse the repository at this point in the history
  • Loading branch information
marcospri committed Jul 29, 2024
1 parent 8d336a0 commit 2ec7601
Show file tree
Hide file tree
Showing 14 changed files with 252 additions and 145 deletions.
44 changes: 33 additions & 11 deletions lms/services/assignment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging
from typing import cast

from sqlalchemy import Select, func, select
from sqlalchemy import BinaryExpression, Select, false, func, or_, select
from sqlalchemy.orm import Session

from lms.models import (
ApplicationInstance,
Assignment,
AssignmentGrouping,
AssignmentMembership,
Grouping,
LTIRole,
Organization,
RoleScope,
RoleType,
User,
Expand Down Expand Up @@ -214,32 +217,51 @@ def is_member(self, assignment: Assignment, h_userid: str) -> bool:
def get_assignments(
self,
instructor_h_userid: str | None = None,
admin_organization_ids: list[int] | None = None,
course_ids: list[int] | None = None,
h_userids: list[str] | None = None,
) -> Select[tuple[Assignment]]:
"""Get a query to fetch assignments.
:param instructor_h_userid: return only assignments where instructor_h_userid is an instructor.
:param admin_organization_ids: organizations where the current user is an admin.
:param course_ids: only return assignments that belong to this course.
:param h_userids: return only assignments where these users are members.
"""

query = select(Assignment)

# Let's crate no op clauses by default to avoid having to check the presence of these filters
instructor_h_userid_clause = cast(BinaryExpression, false())
admin_organization_ids_clause = cast(BinaryExpression, false())

if instructor_h_userid:
query = query.where(
Assignment.id.in_(
select(AssignmentMembership.assignment_id)
.join(User)
.join(LTIRole)
.where(
User.h_userid == instructor_h_userid,
LTIRole.scope == RoleScope.COURSE,
LTIRole.type == RoleType.INSTRUCTOR,
)
instructor_h_userid_clause = Assignment.id.in_(
select(AssignmentMembership.assignment_id)
.join(User)
.join(LTIRole)
.where(
User.h_userid == instructor_h_userid,
LTIRole.scope == RoleScope.COURSE,
LTIRole.type == RoleType.INSTRUCTOR,
)
)

if admin_organization_ids:
admin_organization_ids_clause = Assignment.id.in_(
select(Assignment.id)
.join(AssignmentGrouping)
.join(Grouping)
.join(ApplicationInstance)
.join(Organization)
.where(Organization.id.in_(admin_organization_ids))
)
# instructor_h_userid and admin_organization_ids are about access rather than filtering.
# we apply them both as an or to fetch assignments where the users is either an instructor or an admin
query = query.where(
or_(instructor_h_userid_clause, admin_organization_ids_clause)
)

if h_userids:
query = (
query.join(AssignmentMembership)
Expand Down
27 changes: 22 additions & 5 deletions lms/services/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@


class DashboardService:
def __init__(self, db, assignment_service, course_service, organization_service):
self._db = db
def __init__(
self, request, assignment_service, course_service, organization_service
):
self._db = request.db

self._assignment_service = assignment_service
self._course_service = course_service
self._organization_service = organization_service

def get_request_assignment(self, request):
def get_request_assignment(self, request, admin_organizations: list[Organization]):
"""Get and authorize an assignment for the given request."""
assigment_id = request.matchdict.get(
"assignment_id"
Expand All @@ -28,12 +30,20 @@ def get_request_assignment(self, request):
# STAFF members in our admin pages can access all assignments
return assignment

if (
admin_organizations
and assignment.course.application_instance.organization
in admin_organizations
):
# Organization admins have access to all the assignments in their organizations
return assignment

if not self._assignment_service.is_member(assignment, request.user.h_userid):
raise HTTPUnauthorized()

return assignment

def get_request_course(self, request):
def get_request_course(self, request, admin_organizations: list[Organization]):
"""Get and authorize a course for the given request."""
course = self._course_service.get_by_id(request.matchdict["course_id"])
if not course:
Expand All @@ -43,6 +53,13 @@ def get_request_course(self, request):
# STAFF members in our admin pages can access all courses
return course

if (
admin_organizations
and course.application_instance.organization in admin_organizations
):
# Organization admins have access to all the courses in their organizations
return course

if not self._course_service.is_member(course, request.user.h_userid):
raise HTTPUnauthorized()

Expand Down Expand Up @@ -74,7 +91,7 @@ def delete_dashboard_admin(self, dashboard_admin_id: int) -> None:

def factory(_context, request):
return DashboardService(
db=request.db,
request=request,
assignment_service=request.find_service(name="assignment"),
course_service=request.find_service(name="course"),
organization_service=request.find_service(OrganizationService),
Expand Down
43 changes: 32 additions & 11 deletions lms/services/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from functools import lru_cache
from typing import cast

from sqlalchemy import select
from sqlalchemy import BinaryExpression, false, or_, select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.sql import Select

from lms.models import (
ApplicationInstance,
AssignmentGrouping,
AssignmentMembership,
LTIRole,
LTIUser,
Organization,
RoleScope,
RoleType,
User,
Expand Down Expand Up @@ -107,6 +110,7 @@ def get_users( # noqa: PLR0913, PLR0917
role_scope: RoleScope,
role_type: RoleType,
instructor_h_userid: str | None = None,
admin_organization_ids: list[int] | None = None,
course_ids: list[int] | None = None,
h_userids: list[str] | None = None,
assignment_ids: list[int] | None = None,
Expand All @@ -117,6 +121,7 @@ def get_users( # noqa: PLR0913, PLR0917
:param role_scope: return only users with this LTI role scope.
:param role_type: return only users with this LTI role type.
:param instructor_h_userid: return only users that belongs to courses/assignments where the user instructor_h_userid is an instructor.
:param admin_organization_ids: organizations where the current user is an admin.
:param h_userids: return only users with a h_userid in this list.
:param course_ids: return only users that belong to these courses.
:param assignment_ids: return only users that belong these assignments.
Expand All @@ -128,20 +133,36 @@ def get_users( # noqa: PLR0913, PLR0917
.where(LTIRole.scope == role_scope, LTIRole.type == role_type)
)

# Let's crate no op clauses by default to avoid having to check the presence of these filters
instructor_h_userid_clause = cast(BinaryExpression, false())
admin_organization_ids_clause = cast(BinaryExpression, false())

if instructor_h_userid:
query = query.where(
AssignmentMembership.assignment_id.in_(
select(AssignmentMembership.assignment_id)
.join(User)
.join(LTIRole)
.where(
User.h_userid == instructor_h_userid,
LTIRole.scope == RoleScope.COURSE,
LTIRole.type == RoleType.INSTRUCTOR,
)
instructor_h_userid_clause = AssignmentMembership.assignment_id.in_(
select(AssignmentMembership.assignment_id)
.join(User)
.join(LTIRole)
.where(
User.h_userid == instructor_h_userid,
LTIRole.scope == RoleScope.COURSE,
LTIRole.type == RoleType.INSTRUCTOR,
)
)

if admin_organization_ids:
admin_organization_ids_clause = User.id.in_(
select(User.id)
.join(ApplicationInstance)
.join(Organization)
.where(Organization.id.in_(admin_organization_ids))
)

# instructor_h_userid and admin_organization_ids are about access rather than filtering.
# we apply them both as an or to fetch users where the users is either an instructor or an admin
query = query.where(
or_(instructor_h_userid_clause, admin_organization_ids_clause)
)

if h_userids:
query = query.where(User.h_userid.in_(h_userids))

Expand Down
17 changes: 15 additions & 2 deletions lms/views/dashboard/api/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def __init__(self, request) -> None:
self.course_service = request.find_service(name="course")
self.user_service: UserService = request.find_service(UserService)

self.admin_organizations = (
self.dashboard_service.get_organizations_by_admin_email(
request.lti_user.email if request.lti_user else request.identity.userid
)
)

@view_config(
route_name="api.dashboard.assignments",
request_method="GET",
Expand All @@ -58,6 +64,7 @@ def __init__(self, request) -> None:
def assignments(self) -> APIAssignments:
filter_by_h_userids = self.request.parsed_params.get("h_userids")
assignments = self.assignment_service.get_assignments(
admin_organization_ids=[org.id for org in self.admin_organizations],
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
Expand All @@ -82,7 +89,9 @@ def assignments(self) -> APIAssignments:
permission=Permissions.DASHBOARD_VIEW,
)
def assignment(self) -> APIAssignment:
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.admin_organizations
)
return APIAssignment(
id=assignment.id,
title=assignment.title,
Expand All @@ -100,17 +109,21 @@ def course_assignments_metrics(self) -> APIAssignments:
current_h_userid = self.request.user.h_userid if self.request.user else None
filter_by_h_userids = self.request.parsed_params.get("h_userids")
filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids")
course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, self.admin_organizations
)
course_students = self.request.db.scalars(
self.user_service.get_users(
course_ids=[course.id],
role_scope=RoleScope.COURSE,
role_type=RoleType.LEARNER,
instructor_h_userid=current_h_userid,
admin_organization_ids=[org.id for org in self.admin_organizations],
h_userids=filter_by_h_userids,
)
).all()
assignments_query = self.assignment_service.get_assignments(
admin_organization_ids=[org.id for org in self.admin_organizations],
course_ids=[course.id],
instructor_h_userid=current_h_userid,
h_userids=filter_by_h_userids,
Expand Down
20 changes: 9 additions & 11 deletions lms/views/dashboard/api/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def __init__(self, request) -> None:
self.dashboard_service = request.find_service(name="dashboard")
self.assignment_service = request.find_service(name="assignment")

self.current_user_email = (
self.request.lti_user.email if request.lti_user else request.identity.userid
self.admin_organizations = (
self.dashboard_service.get_organizations_by_admin_email(
request.lti_user.email if request.lti_user else request.identity.userid
)
)

@view_config(
Expand All @@ -58,11 +60,8 @@ def __init__(self, request) -> None:
def courses(self) -> APICourses:
filter_by_h_userids = self.request.parsed_params.get("h_userids")
filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids")
admin_organizations = self.dashboard_service.get_organizations_by_admin_email(
self.current_user_email
)
courses = self.course_service.get_courses(
admin_organization_ids=[org.id for org in admin_organizations],
admin_organization_ids=[org.id for org in self.admin_organizations],
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
Expand Down Expand Up @@ -90,11 +89,8 @@ def courses_metrics(self) -> APICourses:
filter_by_h_userids = self.request.parsed_params.get("h_userids")
filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids")
filter_by_course_ids = self.request.parsed_params.get("course_ids")
admin_organizations = self.dashboard_service.get_organizations_by_admin_email(
self.current_user_email
)
courses_query = self.course_service.get_courses(
admin_organization_ids=[org.id for org in admin_organizations],
admin_organization_ids=[org.id for org in self.admin_organizations],
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
Expand Down Expand Up @@ -132,7 +128,9 @@ def courses_metrics(self) -> APICourses:
permission=Permissions.DASHBOARD_VIEW,
)
def course(self) -> APICourse:
course = self.dashboard_service.get_request_course(self.request)
course = self.dashboard_service.get_request_course(
self.request, self.admin_organizations
)
return {
"id": course.id,
"title": course.lms_name,
Expand Down
12 changes: 11 additions & 1 deletion lms/views/dashboard/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def __init__(self, request) -> None:
self.h_api = request.find_service(HAPI)
self.user_service: UserService = request.find_service(UserService)

self.admin_organizations = (
self.dashboard_service.get_organizations_by_admin_email(
request.lti_user.email if request.lti_user else request.identity.userid
)
)

@view_config(
route_name="api.dashboard.students",
request_method="GET",
Expand All @@ -62,6 +68,7 @@ def students(self) -> APIStudents:
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
admin_organization_ids=[org.id for org in self.admin_organizations],
course_ids=self.request.parsed_params.get("course_ids"),
assignment_ids=self.request.parsed_params.get("assignment_ids"),
)
Expand Down Expand Up @@ -89,7 +96,9 @@ def students(self) -> APIStudents:
def students_metrics(self) -> APIStudents:
"""Fetch the stats for one particular assignment."""
request_h_userids = self.request.parsed_params.get("h_userids")
assignment = self.dashboard_service.get_request_assignment(self.request)
assignment = self.dashboard_service.get_request_assignment(
self.request, self.admin_organizations
)
stats = self.h_api.get_annotation_counts(
[g.authority_provided_id for g in assignment.groupings],
group_by="user",
Expand All @@ -108,6 +117,7 @@ def students_metrics(self) -> APIStudents:
instructor_h_userid=self.request.user.h_userid
if self.request.user
else None,
admin_organization_ids=[org.id for org in self.admin_organizations],
# Users the current user requested
h_userids=request_h_userids,
)
Expand Down
Loading

0 comments on commit 2ec7601

Please sign in to comment.