From d8e6125dd311fa73f660cc8897634144b5d8c1c9 Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Fri, 26 Jul 2024 16:45:13 +0200 Subject: [PATCH] Enforce dashboard admin access for assignments and users --- lms/services/assignment.py | 44 ++++++++--- lms/services/dashboard.py | 27 +++++-- lms/services/user.py | 43 ++++++++--- lms/views/dashboard/api/assignment.py | 17 +++- lms/views/dashboard/api/course.py | 20 +++-- lms/views/dashboard/api/user.py | 12 ++- lms/views/dashboard/views.py | 14 +++- tests/unit/lms/services/assignment_test.py | 32 ++++---- tests/unit/lms/services/dashboard_test.py | 54 ++++++++++--- tests/unit/lms/services/user_test.py | 77 ++++++++----------- .../views/dashboard/api/assignment_test.py | 5 +- .../lms/views/dashboard/api/course_test.py | 37 +++------ .../unit/lms/views/dashboard/api/user_test.py | 4 +- tests/unit/lms/views/dashboard/views_test.py | 11 ++- 14 files changed, 252 insertions(+), 145 deletions(-) diff --git a/lms/services/assignment.py b/lms/services/assignment.py index b0f40b3a5b..3ba0986f11 100644 --- a/lms/services/assignment.py +++ b/lms/services/assignment.py @@ -1,14 +1,17 @@ import logging +from typing import cast -from sqlalchemy import Select, func, select +from sqlalchemy import BinaryExpression, Select, false, func, or_, select from sqlalchemy.orm import Session from lms.models import ( + ApplicationInstance, Assignment, AssignmentGrouping, AssignmentMembership, Grouping, LTIRole, + Organization, RoleScope, RoleType, User, @@ -214,32 +217,51 @@ def is_member(self, assignment: Assignment, h_userid: str) -> bool: def get_assignments( self, instructor_h_userid: str | None = None, + admin_organization_ids: list[int] | None = None, course_ids: list[int] | None = None, h_userids: list[str] | None = None, ) -> Select[tuple[Assignment]]: """Get a query to fetch assignments. :param instructor_h_userid: return only assignments where instructor_h_userid is an instructor. + :param admin_organization_ids: organizations where the current user is an admin. :param course_ids: only return assignments that belong to this course. :param h_userids: return only assignments where these users are members. """ query = select(Assignment) + # Let's crate no op clauses by default to avoid having to check the presence of these filters + instructor_h_userid_clause = cast(BinaryExpression, false()) + admin_organization_ids_clause = cast(BinaryExpression, false()) + if instructor_h_userid: - query = query.where( - Assignment.id.in_( - select(AssignmentMembership.assignment_id) - .join(User) - .join(LTIRole) - .where( - User.h_userid == instructor_h_userid, - LTIRole.scope == RoleScope.COURSE, - LTIRole.type == RoleType.INSTRUCTOR, - ) + instructor_h_userid_clause = Assignment.id.in_( + select(AssignmentMembership.assignment_id) + .join(User) + .join(LTIRole) + .where( + User.h_userid == instructor_h_userid, + LTIRole.scope == RoleScope.COURSE, + LTIRole.type == RoleType.INSTRUCTOR, ) ) + if admin_organization_ids: + admin_organization_ids_clause = Assignment.id.in_( + select(Assignment.id) + .join(AssignmentGrouping) + .join(Grouping) + .join(ApplicationInstance) + .join(Organization) + .where(Organization.id.in_(admin_organization_ids)) + ) + # instructor_h_userid and admin_organization_ids are about access rather than filtering. + # we apply them both as an or to fetch assignments where the users is either an instructor or an admin + query = query.where( + or_(instructor_h_userid_clause, admin_organization_ids_clause) + ) + if h_userids: query = ( query.join(AssignmentMembership) diff --git a/lms/services/dashboard.py b/lms/services/dashboard.py index 1bde515f52..7df980007d 100644 --- a/lms/services/dashboard.py +++ b/lms/services/dashboard.py @@ -8,14 +8,16 @@ class DashboardService: - def __init__(self, db, assignment_service, course_service, organization_service): - self._db = db + def __init__( + self, request, assignment_service, course_service, organization_service + ): + self._db = request.db self._assignment_service = assignment_service self._course_service = course_service self._organization_service = organization_service - def get_request_assignment(self, request): + def get_request_assignment(self, request, admin_organizations: list[Organization]): """Get and authorize an assignment for the given request.""" assigment_id = request.matchdict.get( "assignment_id" @@ -28,12 +30,20 @@ def get_request_assignment(self, request): # STAFF members in our admin pages can access all assignments return assignment + if ( + admin_organizations + and assignment.course.application_instance.organization + in admin_organizations + ): + # Organization admins have access to all the assignments in their organizations + return assignment + if not self._assignment_service.is_member(assignment, request.user.h_userid): raise HTTPUnauthorized() return assignment - def get_request_course(self, request): + def get_request_course(self, request, admin_organizations: list[Organization]): """Get and authorize a course for the given request.""" course = self._course_service.get_by_id(request.matchdict["course_id"]) if not course: @@ -43,6 +53,13 @@ def get_request_course(self, request): # STAFF members in our admin pages can access all courses return course + if ( + admin_organizations + and course.application_instance.organization in admin_organizations + ): + # Organization admins have access to all the courses in their organizations + return course + if not self._course_service.is_member(course, request.user.h_userid): raise HTTPUnauthorized() @@ -74,7 +91,7 @@ def delete_dashboard_admin(self, dashboard_admin_id: int) -> None: def factory(_context, request): return DashboardService( - db=request.db, + request=request, assignment_service=request.find_service(name="assignment"), course_service=request.find_service(name="course"), organization_service=request.find_service(OrganizationService), diff --git a/lms/services/user.py b/lms/services/user.py index fb66bc2ba2..46239cbcee 100644 --- a/lms/services/user.py +++ b/lms/services/user.py @@ -1,14 +1,17 @@ from functools import lru_cache +from typing import cast -from sqlalchemy import select +from sqlalchemy import BinaryExpression, false, or_, select from sqlalchemy.exc import NoResultFound from sqlalchemy.sql import Select from lms.models import ( + ApplicationInstance, AssignmentGrouping, AssignmentMembership, LTIRole, LTIUser, + Organization, RoleScope, RoleType, User, @@ -107,6 +110,7 @@ def get_users( # noqa: PLR0913, PLR0917 role_scope: RoleScope, role_type: RoleType, instructor_h_userid: str | None = None, + admin_organization_ids: list[int] | None = None, course_ids: list[int] | None = None, h_userids: list[str] | None = None, assignment_ids: list[int] | None = None, @@ -117,6 +121,7 @@ def get_users( # noqa: PLR0913, PLR0917 :param role_scope: return only users with this LTI role scope. :param role_type: return only users with this LTI role type. :param instructor_h_userid: return only users that belongs to courses/assignments where the user instructor_h_userid is an instructor. + :param admin_organization_ids: organizations where the current user is an admin. :param h_userids: return only users with a h_userid in this list. :param course_ids: return only users that belong to these courses. :param assignment_ids: return only users that belong these assignments. @@ -128,20 +133,36 @@ def get_users( # noqa: PLR0913, PLR0917 .where(LTIRole.scope == role_scope, LTIRole.type == role_type) ) + # Let's crate no op clauses by default to avoid having to check the presence of these filters + instructor_h_userid_clause = cast(BinaryExpression, false()) + admin_organization_ids_clause = cast(BinaryExpression, false()) + if instructor_h_userid: - query = query.where( - AssignmentMembership.assignment_id.in_( - select(AssignmentMembership.assignment_id) - .join(User) - .join(LTIRole) - .where( - User.h_userid == instructor_h_userid, - LTIRole.scope == RoleScope.COURSE, - LTIRole.type == RoleType.INSTRUCTOR, - ) + instructor_h_userid_clause = AssignmentMembership.assignment_id.in_( + select(AssignmentMembership.assignment_id) + .join(User) + .join(LTIRole) + .where( + User.h_userid == instructor_h_userid, + LTIRole.scope == RoleScope.COURSE, + LTIRole.type == RoleType.INSTRUCTOR, ) ) + if admin_organization_ids: + admin_organization_ids_clause = User.id.in_( + select(User.id) + .join(ApplicationInstance) + .join(Organization) + .where(Organization.id.in_(admin_organization_ids)) + ) + + # instructor_h_userid and admin_organization_ids are about access rather than filtering. + # we apply them both as an or to fetch users where the users is either an instructor or an admin + query = query.where( + or_(instructor_h_userid_clause, admin_organization_ids_clause) + ) + if h_userids: query = query.where(User.h_userid.in_(h_userids)) diff --git a/lms/views/dashboard/api/assignment.py b/lms/views/dashboard/api/assignment.py index 7e3f12f7d5..28a44b1075 100644 --- a/lms/views/dashboard/api/assignment.py +++ b/lms/views/dashboard/api/assignment.py @@ -48,6 +48,12 @@ def __init__(self, request) -> None: self.course_service = request.find_service(name="course") self.user_service: UserService = request.find_service(UserService) + self.admin_organizations = ( + self.dashboard_service.get_organizations_by_admin_email( + request.lti_user.email if request.lti_user else request.identity.userid + ) + ) + @view_config( route_name="api.dashboard.assignments", request_method="GET", @@ -58,6 +64,7 @@ def __init__(self, request) -> None: def assignments(self) -> APIAssignments: filter_by_h_userids = self.request.parsed_params.get("h_userids") assignments = self.assignment_service.get_assignments( + admin_organization_ids=[org.id for org in self.admin_organizations], instructor_h_userid=self.request.user.h_userid if self.request.user else None, @@ -82,7 +89,9 @@ def assignments(self) -> APIAssignments: permission=Permissions.DASHBOARD_VIEW, ) def assignment(self) -> APIAssignment: - assignment = self.dashboard_service.get_request_assignment(self.request) + assignment = self.dashboard_service.get_request_assignment( + self.request, self.admin_organizations + ) return APIAssignment( id=assignment.id, title=assignment.title, @@ -100,17 +109,21 @@ def course_assignments_metrics(self) -> APIAssignments: current_h_userid = self.request.user.h_userid if self.request.user else None filter_by_h_userids = self.request.parsed_params.get("h_userids") filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids") - course = self.dashboard_service.get_request_course(self.request) + course = self.dashboard_service.get_request_course( + self.request, self.admin_organizations + ) course_students = self.request.db.scalars( self.user_service.get_users( course_ids=[course.id], role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, instructor_h_userid=current_h_userid, + admin_organization_ids=[org.id for org in self.admin_organizations], h_userids=filter_by_h_userids, ) ).all() assignments_query = self.assignment_service.get_assignments( + admin_organization_ids=[org.id for org in self.admin_organizations], course_ids=[course.id], instructor_h_userid=current_h_userid, h_userids=filter_by_h_userids, diff --git a/lms/views/dashboard/api/course.py b/lms/views/dashboard/api/course.py index 602302ca0e..e78eb7abad 100644 --- a/lms/views/dashboard/api/course.py +++ b/lms/views/dashboard/api/course.py @@ -44,8 +44,10 @@ def __init__(self, request) -> None: self.dashboard_service = request.find_service(name="dashboard") self.assignment_service = request.find_service(name="assignment") - self.current_user_email = ( - self.request.lti_user.email if request.lti_user else request.identity.userid + self.admin_organizations = ( + self.dashboard_service.get_organizations_by_admin_email( + request.lti_user.email if request.lti_user else request.identity.userid + ) ) @view_config( @@ -58,11 +60,8 @@ def __init__(self, request) -> None: def courses(self) -> APICourses: filter_by_h_userids = self.request.parsed_params.get("h_userids") filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids") - admin_organizations = self.dashboard_service.get_organizations_by_admin_email( - self.current_user_email - ) courses = self.course_service.get_courses( - admin_organization_ids=[org.id for org in admin_organizations], + admin_organization_ids=[org.id for org in self.admin_organizations], instructor_h_userid=self.request.user.h_userid if self.request.user else None, @@ -90,11 +89,8 @@ def courses_metrics(self) -> APICourses: filter_by_h_userids = self.request.parsed_params.get("h_userids") filter_by_assignment_ids = self.request.parsed_params.get("assignment_ids") filter_by_course_ids = self.request.parsed_params.get("course_ids") - admin_organizations = self.dashboard_service.get_organizations_by_admin_email( - self.current_user_email - ) courses_query = self.course_service.get_courses( - admin_organization_ids=[org.id for org in admin_organizations], + admin_organization_ids=[org.id for org in self.admin_organizations], instructor_h_userid=self.request.user.h_userid if self.request.user else None, @@ -132,7 +128,9 @@ def courses_metrics(self) -> APICourses: permission=Permissions.DASHBOARD_VIEW, ) def course(self) -> APICourse: - course = self.dashboard_service.get_request_course(self.request) + course = self.dashboard_service.get_request_course( + self.request, self.admin_organizations + ) return { "id": course.id, "title": course.lms_name, diff --git a/lms/views/dashboard/api/user.py b/lms/views/dashboard/api/user.py index efb012c74b..4ab40f0692 100644 --- a/lms/views/dashboard/api/user.py +++ b/lms/views/dashboard/api/user.py @@ -48,6 +48,12 @@ def __init__(self, request) -> None: self.h_api = request.find_service(HAPI) self.user_service: UserService = request.find_service(UserService) + self.admin_organizations = ( + self.dashboard_service.get_organizations_by_admin_email( + request.lti_user.email if request.lti_user else request.identity.userid + ) + ) + @view_config( route_name="api.dashboard.students", request_method="GET", @@ -62,6 +68,7 @@ def students(self) -> APIStudents: instructor_h_userid=self.request.user.h_userid if self.request.user else None, + admin_organization_ids=[org.id for org in self.admin_organizations], course_ids=self.request.parsed_params.get("course_ids"), assignment_ids=self.request.parsed_params.get("assignment_ids"), ) @@ -89,7 +96,9 @@ def students(self) -> APIStudents: def students_metrics(self) -> APIStudents: """Fetch the stats for one particular assignment.""" request_h_userids = self.request.parsed_params.get("h_userids") - assignment = self.dashboard_service.get_request_assignment(self.request) + assignment = self.dashboard_service.get_request_assignment( + self.request, self.admin_organizations + ) stats = self.h_api.get_annotation_counts( [g.authority_provided_id for g in assignment.groupings], group_by="user", @@ -108,6 +117,7 @@ def students_metrics(self) -> APIStudents: instructor_h_userid=self.request.user.h_userid if self.request.user else None, + admin_organization_ids=[org.id for org in self.admin_organizations], # Users the current user requested h_userids=request_h_userids, ) diff --git a/lms/views/dashboard/views.py b/lms/views/dashboard/views.py index 9a7d433b7a..07dbb8ee1a 100644 --- a/lms/views/dashboard/views.py +++ b/lms/views/dashboard/views.py @@ -40,6 +40,12 @@ def __init__(self, request) -> None: ) self.dashboard_service = request.find_service(name="dashboard") + self.admin_organizations = ( + self.dashboard_service.get_organizations_by_admin_email( + request.lti_user.email if request.lti_user else request.identity.userid + ) + ) + @view_config( route_name="dashboard.launch.assignment", permission=Permissions.DASHBOARD_VIEW, @@ -71,7 +77,9 @@ def assignment_show(self): Authenticated via the LTIUser present in a cookie making this endpoint accessible directly in the browser. """ - assignment = self.dashboard_service.get_request_assignment(self.request) + assignment = self.dashboard_service.get_request_assignment( + self.request, self.admin_organizations + ) self.request.context.js_config.enable_dashboard_mode() self._set_lti_user_cookie(self.request.response) return {"title": assignment.title} @@ -87,7 +95,9 @@ def course_show(self): Authenticated via the LTIUser present in a cookie making this endpoint accessible directly in the browser. """ - course = self.dashboard_service.get_request_course(self.request) + course = self.dashboard_service.get_request_course( + self.request, self.admin_organizations + ) self.request.context.js_config.enable_dashboard_mode() self._set_lti_user_cookie(self.request.response) return {"title": course.lms_name} diff --git a/tests/unit/lms/services/assignment_test.py b/tests/unit/lms/services/assignment_test.py index f7ddb6d340..857a7c4c79 100644 --- a/tests/unit/lms/services/assignment_test.py +++ b/tests/unit/lms/services/assignment_test.py @@ -248,12 +248,13 @@ def test_get_assignments( db_session, instructor_h_userid, assignment, - with_assignment_noise, course_ids, h_userids, + organization, + application_instance, ): factories.User() - course = factories.Course() + course = factories.Course(application_instance=application_instance) user = factories.User() lti_role = factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR) factories.AssignmentMembership.create( @@ -269,26 +270,23 @@ def test_get_assignments( if instructor_h_userid: query_parameters["instructor_h_userid"] = user.h_userid + else: + query_parameters["admin_organization_ids"] = [organization.id] if course_ids: query_parameters["course_ids"] = [course.id] if h_userids: query_parameters["h_userids"] = [user.h_userid] - query = svc.get_assignments(**query_parameters) - if not query_parameters: - assert set(db_session.scalars(query).all()) == set( - [assignment] + with_assignment_noise - ) - - else: - assert db_session.scalars(query).all() == [assignment] + assert db_session.scalars(query).all() == [assignment] - def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc): - course = factories.Course() - other_course = factories.Course() + def test_get_assignments_by_course_id_with_duplicate( + self, db_session, svc, application_instance, organization + ): + course = factories.Course(application_instance=application_instance) + other_course = factories.Course(application_instance=application_instance) assignment = factories.Assignment() @@ -302,11 +300,15 @@ def test_get_assignments_by_course_id_with_duplicate(self, db_session, svc): db_session.flush() assert db_session.scalars( - svc.get_assignments(course_ids=[course.id]) + svc.get_assignments( + course_ids=[course.id], admin_organization_ids=[organization.id] + ) ).all() == [assignment] # We don't expect to get the other one at all, now the assignment belongs to the most recent course assert not db_session.scalars( - svc.get_assignments(course_ids=[other_course.id]) + svc.get_assignments( + course_ids=[other_course.id], admin_organization_ids=[organization.id] + ) ).all() def test_get_courses_assignments_count(self, svc, db_session): diff --git a/tests/unit/lms/services/dashboard_test.py b/tests/unit/lms/services/dashboard_test.py index 98f9fb5c54..f213c1a2f6 100644 --- a/tests/unit/lms/services/dashboard_test.py +++ b/tests/unit/lms/services/dashboard_test.py @@ -21,7 +21,7 @@ def test_get_request_assignment_404( assignment_service.get_by_id.return_value = None with pytest.raises(HTTPNotFound): - svc.get_request_assignment(pyramid_request) + svc.get_request_assignment(pyramid_request, []) def test_get_request_assignment_403( self, @@ -33,7 +33,7 @@ def test_get_request_assignment_403( assignment_service.is_member.return_value = False with pytest.raises(HTTPUnauthorized): - svc.get_request_assignment(pyramid_request) + svc.get_request_assignment(pyramid_request, []) def test_get_request_assignment_for_staff( self, pyramid_request, assignment_service, pyramid_config, svc @@ -42,18 +42,37 @@ def test_get_request_assignment_for_staff( pyramid_request.matchdict["assignment_id"] = sentinel.id assignment_service.is_member.return_value = False - assert svc.get_request_assignment(pyramid_request) + assert svc.get_request_assignment(pyramid_request, []) def test_get_request_assignment(self, pyramid_request, assignment_service, svc): pyramid_request.matchdict["assignment_id"] = sentinel.id assignment_service.is_member.return_value = True - assert svc.get_request_assignment(pyramid_request) + assert svc.get_request_assignment(pyramid_request, []) assignment_service.is_member.assert_called_once_with( assignment_service.get_by_id.return_value, pyramid_request.user.h_userid ) + def test_get_request_assignment_for_admin( + self, + pyramid_request, + assignment_service, + svc, + organization, + db_session, + application_instance, + ): + assignment = factories.Assignment() + course = factories.Course(application_instance=application_instance) + factories.AssignmentGrouping(assignment=assignment, grouping=course) + assignment_service.get_by_id.return_value = assignment + db_session.flush() + + pyramid_request.matchdict["assignment_id"] = sentinel.id + + assert svc.get_request_assignment(pyramid_request, [organization]) + def test_get_request_course_404( self, pyramid_request, @@ -64,14 +83,14 @@ def test_get_request_course_404( course_service.get_by_id.return_value = None with pytest.raises(HTTPNotFound): - svc.get_request_course(pyramid_request) + svc.get_request_course(pyramid_request, []) def test_get_request_course_403(self, pyramid_request, course_service, svc): pyramid_request.matchdict["course_id"] = sentinel.id course_service.is_member.return_value = False with pytest.raises(HTTPUnauthorized): - svc.get_request_course(pyramid_request) + svc.get_request_course(pyramid_request, []) def test_get_request_course_for_staff( self, pyramid_request, course_service, pyramid_config, svc @@ -80,13 +99,23 @@ def test_get_request_course_for_staff( pyramid_request.matchdict["course_id"] = sentinel.id course_service.is_member.return_value = False - assert svc.get_request_course(pyramid_request) + assert svc.get_request_course(pyramid_request, []) + + def test_get_request_course_for_admin( + self, pyramid_request, course_service, svc, application_instance, organization + ): + course_service.get_by_id.return_value = factories.Course( + application_instance=application_instance + ) + pyramid_request.matchdict["course_id"] = sentinel.id + + assert svc.get_request_course(pyramid_request, [organization]) def test_get_request_course(self, pyramid_request, course_service, svc): pyramid_request.matchdict["course_id"] = sentinel.id course_service.is_member.return_value = True - assert svc.get_request_course(pyramid_request) + assert svc.get_request_course(pyramid_request, []) def test_add_dashboard_admin(self, svc, db_session): admin = svc.add_dashboard_admin( @@ -114,9 +143,11 @@ def test_get_organizations_by_admin_email(self, svc, db_session, organization): assert svc.get_organizations_by_admin_email(admin.email) == [organization] @pytest.fixture() - def svc(self, assignment_service, course_service, organization_service, db_session): + def svc( + self, assignment_service, course_service, organization_service, pyramid_request + ): return DashboardService( - db_session, assignment_service, course_service, organization_service + pyramid_request, assignment_service, course_service, organization_service ) @pytest.fixture(autouse=True) @@ -133,12 +164,11 @@ def test_it( DashboardService, course_service, organization_service, - db_session, ): service = factory(sentinel.context, pyramid_request) DashboardService.assert_called_once_with( - db=db_session, + request=pyramid_request, assignment_service=assignment_service, course_service=course_service, organization_service=organization_service, diff --git a/tests/unit/lms/services/user_test.py b/tests/unit/lms/services/user_test.py index 72d050b47e..cfbf43fff2 100644 --- a/tests/unit/lms/services/user_test.py +++ b/tests/unit/lms/services/user_test.py @@ -71,54 +71,55 @@ def test_get_not_found(self, user, service): with pytest.raises(UserNotFound): service.get(user.application_instance, "some-other-id") - def test_get_users(self, service, db_session): - assignment = factories.Assignment() - student = factories.User() - factories.User(h_userid=student.h_userid) # Duplicated student - teacher = factories.User() - factories.AssignmentMembership.create( - assignment=assignment, - user=student, - lti_role=factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.LEARNER), - ) - factories.AssignmentMembership.create( - assignment=assignment, - user=teacher, - lti_role=factories.LTIRole( - scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR - ), - ) + def test_get_users( + self, + service, + db_session, + organization, + student_in_assigment, + ): + factories.User(h_userid=student_in_assigment.h_userid) # Duplicated student query = service.get_users( - role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER + role_scope=RoleScope.COURSE, + role_type=RoleType.LEARNER, + admin_organization_ids=[organization.id], ) - assert db_session.scalars(query).all() == [student] + assert db_session.scalars(query).all() == [student_in_assigment] def test_get_users_by_h_userids( - self, service, db_session, student_in_assigment, assignment + self, service, db_session, student_in_assigment, assignment, organization ): - other_student = factories.User() + other_student = factories.User( + application_instance=organization.application_instances[0] + ) factories.AssignmentMembership.create( assignment=assignment, user=other_student, lti_role=factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.LEARNER), ) + db_session.flush() # Make sure we have in fact two users assert db_session.scalars( - service.get_users(role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER) + service.get_users( + role_scope=RoleScope.COURSE, + role_type=RoleType.LEARNER, + admin_organization_ids=[organization.id], + ) ).all() == [student_in_assigment, other_student] query = service.get_users( role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, h_userids=[other_student.h_userid], + admin_organization_ids=[organization.id], ) assert db_session.scalars(query).all() == [other_student] def test_get_users_by_course_id( - self, service, db_session, student_in_assigment, assignment + self, service, db_session, student_in_assigment, assignment, organization ): course = factories.Course() factories.AssignmentGrouping(assignment=assignment, grouping=course) @@ -128,36 +129,26 @@ def test_get_users_by_course_id( role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, course_ids=[course.id], + admin_organization_ids=[organization.id], ) assert db_session.scalars(query).all() == [student_in_assigment] - def test_get_users_by_assigment_id(self, service, db_session): - assignment = factories.Assignment() - student = factories.User() - factories.User(h_userid=student.h_userid) # Duplicated student - teacher = factories.User() - factories.AssignmentMembership.create( - assignment=assignment, - user=student, - lti_role=factories.LTIRole(scope=RoleScope.COURSE, type=RoleType.LEARNER), - ) - factories.AssignmentMembership.create( - assignment=assignment, - user=teacher, - lti_role=factories.LTIRole( - scope=RoleScope.COURSE, type=RoleType.INSTRUCTOR - ), - ) + @pytest.mark.usefixtures("teacher_in_assigment") + def test_get_users_by_assigment_id( + self, service, db_session, student_in_assigment, assignment, organization + ): + factories.User(h_userid=student_in_assigment.h_userid) # Duplicated student db_session.flush() query = service.get_users( role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, assignment_ids=[assignment.id], + admin_organization_ids=[organization.id], ) - assert db_session.scalars(query).all() == [student] + assert db_session.scalars(query).all() == [student_in_assigment] def test_get_users_by_instructor_h_userid( self, @@ -188,8 +179,8 @@ def assignment(self): return factories.Assignment() @pytest.fixture - def student_in_assigment(self, assignment): - student = factories.User() + def student_in_assigment(self, assignment, application_instance): + student = factories.User(application_instance=application_instance) factories.AssignmentMembership.create( assignment=assignment, user=student, diff --git a/tests/unit/lms/views/dashboard/api/assignment_test.py b/tests/unit/lms/views/dashboard/api/assignment_test.py index 770f0afca7..1fa9eeecd5 100644 --- a/tests/unit/lms/views/dashboard/api/assignment_test.py +++ b/tests/unit/lms/views/dashboard/api/assignment_test.py @@ -29,6 +29,7 @@ def test_get_assignments( instructor_h_userid=pyramid_request.user.h_userid, course_ids=sentinel.course_ids, h_userids=sentinel.h_userids, + admin_organization_ids=[], ) get_page.assert_called_once_with( pyramid_request, @@ -50,7 +51,8 @@ def test_assignment( response = views.assignment() dashboard_service.get_request_assignment.assert_called_once_with( - pyramid_request + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, ) assert response == { @@ -97,6 +99,7 @@ def test_course_assignments( role_type=RoleType.LEARNER, instructor_h_userid=pyramid_request.user.h_userid, h_userids=sentinel.h_userids, + admin_organization_ids=[], ) h_api.get_annotation_counts.assert_called_once_with( [course.authority_provided_id, section.authority_provided_id], diff --git a/tests/unit/lms/views/dashboard/api/course_test.py b/tests/unit/lms/views/dashboard/api/course_test.py index 116e8f4ec2..3a79d8ecb9 100644 --- a/tests/unit/lms/views/dashboard/api/course_test.py +++ b/tests/unit/lms/views/dashboard/api/course_test.py @@ -17,22 +17,18 @@ class TestCourseViews: - def test_get_courses( - self, course_service, pyramid_request, views, get_page, dashboard_service - ): - org = factories.Organization() + def test_get_courses(self, course_service, pyramid_request, views, get_page): courses = factories.Course.create_batch(5) get_page.return_value = courses, sentinel.pagination pyramid_request.parsed_params = { "h_userids": sentinel.h_userids, "assignment_ids": sentinel.assignment_ids, } - dashboard_service.get_organizations_by_admin_email.return_value = [org] response = views.courses() course_service.get_courses.assert_called_once_with( - admin_organization_ids=[org.id], + admin_organization_ids=[], instructor_h_userid=pyramid_request.user.h_userid, h_userids=sentinel.h_userids, assignment_ids=sentinel.assignment_ids, @@ -48,15 +44,8 @@ def test_get_courses( } def test_course_metrics( - self, - course_service, - pyramid_request, - views, - db_session, - dashboard_service, - assignment_service, + self, course_service, pyramid_request, views, db_session, assignment_service ): - org = factories.Organization() courses = factories.Course.create_batch(5) course_service.get_courses.return_value = select(Course).order_by(Course.id) pyramid_request.matchdict["organization_public_id"] = sentinel.public_id @@ -65,12 +54,11 @@ def test_course_metrics( "assignment_ids": sentinel.assignment_ids, } db_session.flush() - dashboard_service.get_organizations_by_admin_email.return_value = [org] response = views.courses_metrics() course_service.get_courses.assert_called_once_with( - admin_organization_ids=[org.id], + admin_organization_ids=[], instructor_h_userid=pyramid_request.user.h_userid, h_userids=sentinel.h_userids, assignment_ids=sentinel.assignment_ids, @@ -94,26 +82,18 @@ def test_course_metrics( } def test_courses_metrics_by_courses( - self, - course_service, - pyramid_request, - views, - db_session, - dashboard_service, - assignment_service, + self, course_service, pyramid_request, views, db_session, assignment_service ): - org = factories.Organization() courses = factories.Course.create_batch(5) course_service.get_courses.return_value = select(Course).order_by(Course.id) pyramid_request.matchdict["organization_public_id"] = sentinel.public_id - dashboard_service.get_organizations_by_admin_email.return_value = [org] db_session.flush() pyramid_request.parsed_params = {"course_ids": [courses[0].id]} response = views.courses_metrics() course_service.get_courses.assert_called_once_with( - admin_organization_ids=[org.id], + admin_organization_ids=[], instructor_h_userid=pyramid_request.user.h_userid, h_userids=None, assignment_ids=None, @@ -143,7 +123,10 @@ def test_course(self, views, pyramid_request, dashboard_service): response = views.course() - dashboard_service.get_request_course.assert_called_once_with(pyramid_request) + dashboard_service.get_request_course.assert_called_once_with( + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, + ) assert response == { "id": course.id, diff --git a/tests/unit/lms/views/dashboard/api/user_test.py b/tests/unit/lms/views/dashboard/api/user_test.py index aac019c12d..93d497c057 100644 --- a/tests/unit/lms/views/dashboard/api/user_test.py +++ b/tests/unit/lms/views/dashboard/api/user_test.py @@ -28,6 +28,7 @@ def test_get_students(self, user_service, pyramid_request, views, get_page): role_scope=RoleScope.COURSE, role_type=RoleType.LEARNER, instructor_h_userid=pyramid_request.user.h_userid, + admin_organization_ids=[], course_ids=sentinel.course_ids, assignment_ids=sentinel.assignment_ids, ) @@ -94,7 +95,8 @@ def test_students_metrics( response = views.students_metrics() dashboard_service.get_request_assignment.assert_called_once_with( - pyramid_request + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, ) h_api.get_annotation_counts.assert_called_once_with( [g.authority_provided_id for g in assignment.groupings], diff --git a/tests/unit/lms/views/dashboard/views_test.py b/tests/unit/lms/views/dashboard/views_test.py index 6b62da1978..6640c6e593 100644 --- a/tests/unit/lms/views/dashboard/views_test.py +++ b/tests/unit/lms/views/dashboard/views_test.py @@ -49,7 +49,8 @@ def test_assignment_show(self, views, pyramid_request, dashboard_service): views.assignment_show() dashboard_service.get_request_assignment.assert_called_once_with( - pyramid_request + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, ) pyramid_request.context.js_config.enable_dashboard_mode.assert_called_once() assert ( @@ -67,7 +68,10 @@ def test_course_show(self, views, pyramid_request, dashboard_service): views.course_show() - dashboard_service.get_request_course.assert_called_once_with(pyramid_request) + dashboard_service.get_request_course.assert_called_once_with( + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, + ) pyramid_request.context.js_config.enable_dashboard_mode.assert_called_once() assert ( pyramid_request.response.headers["Set-Cookie"] @@ -101,7 +105,8 @@ def test_assignment_show_with_no_lti_user( views.assignment_show() dashboard_service.get_request_assignment.assert_called_once_with( - pyramid_request + pyramid_request, + dashboard_service.get_organizations_by_admin_email.return_value, ) pyramid_request.context.js_config.enable_dashboard_mode.assert_called_once()