Skip to content

Commit

Permalink
Explicitly use h_userid to search and check membership of users to co…
Browse files Browse the repository at this point in the history
…urses

The same person might have multiple "User" rows for example in different
applications instances.

The h_userid value would be the same thought. Explicitly use this value
instead of the full object both on the method to search for courses and
to check user membership.
  • Loading branch information
marcospri committed Jun 6, 2024
1 parent 3612889 commit f39c729
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 14 deletions.
16 changes: 9 additions & 7 deletions lms/services/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def search( # noqa: PLR0913, PLR0917
name: str | None = None,
limit: int = 100,
organization_ids: list[int] | None = None,
user: User | None = None,
h_userid: str | None = None,
) -> list[Course]:
query = self._db.query(Course)

Expand All @@ -106,12 +106,12 @@ def search( # noqa: PLR0913, PLR0917
.filter(Organization.id.in_(organization_ids))
)

if user:
# Only courses `user` belongs to
if h_userid:
# Only courses where the H's h_userid belongs to
query = (
query.join(GroupingMembership)
.join(User)
.filter(User.h_userid == user.h_userid)
.filter(User.h_userid == h_userid)
)

return query.limit(limit).all()
Expand Down Expand Up @@ -210,9 +210,11 @@ def get_by_id(self, id_: int) -> Course | None:

return None

def is_member(self, course: Course, user: User) -> bool:
"""Check if a user is a member of a course."""
return bool(course.memberships.filter_by(user=user).first())
def is_member(self, course: Course, h_userid: str) -> bool:
"""Check if an H user is a member of a course."""
return bool(
course.memberships.join(User).filter(User.h_userid == h_userid).first()
)

def get_assignments(self, course: Course) -> list[Assignment]:
"""
Expand Down
4 changes: 3 additions & 1 deletion lms/views/dashboard/api/course.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def __init__(self, request) -> None:
def get_organization_courses(self) -> APICourses:
org = get_request_organization(self.request, self.organization_service)
courses = self.course_service.search(
limit=None, organization_ids=[org.id], user=self.request.user
limit=None,
organization_ids=[org.id],
h_userid=self.request.user.h_userid if self.request.user else None,
)
return {
"courses": [
Expand Down
2 changes: 1 addition & 1 deletion lms/views/dashboard/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_request_course(request, course_service):
# STAFF members in our admin pages can access all courses
return course

if not course_service.is_member(course, request.user):
if not course_service.is_member(course, request.user.h_userid):
raise HTTPUnauthorized()

return course
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/lms/services/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,15 @@ def test_search_by_organization(self, svc, db_session):

assert result == [course]

def test_search_by_user(self, svc, db_session):
def test_search_by_h_userid(self, svc, db_session):
user = factories.User()
course = factories.Course()
factories.Course.create_batch(10)
factories.GroupingMembership(grouping=course, user=user)
# Ensure ids are written
db_session.flush()

result = svc.search(user=user)
result = svc.search(h_userid=user.h_userid)

assert result == [course]

Expand All @@ -335,8 +335,8 @@ def test_is_member(self, svc, db_session):

db_session.flush()

assert svc.is_member(course, user)
assert not svc.is_member(course, other_user)
assert svc.is_member(course, user.h_userid)
assert not svc.is_member(course, other_user.h_userid)

def test_get_assignments(self, db_session, svc):
course = factories.Course()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/lms/views/dashboard/api/course_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_get_organization_courses(
course_service.search.assert_called_once_with(
limit=None,
organization_ids=[org.id],
user=pyramid_request.user,
h_userid=pyramid_request.user.h_userid,
)

assert response == {
Expand Down

0 comments on commit f39c729

Please sign in to comment.