Skip to content

Commit

Permalink
refactor(datasets): BI-5663 simplify user processing in the RLS module (
Browse files Browse the repository at this point in the history
  • Loading branch information
MCPN authored Jul 10, 2024
1 parent a573532 commit 9c9fc75
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -377,25 +377,25 @@ def validate_select_field(self, block_spec: BlockSpec, field: BIField) -> None:
assert avatar_id is not None
if not self._avatar_exists(avatar_id=avatar_id):
raise dl_core.exc.UnknownReferencedAvatar(
f"Field {field.title!r} ({field_id}) references unknown source avatar " f"{field.avatar_id}."
f"Field {field.title!r} ({field_id}) references unknown source avatar {field.avatar_id}."
)

if not field.valid:
# FIXME: BI-2714 Investigate if/why this error is happening and return the raise
# raise exc.InvalidFieldError(
LOGGER.error(f"Field {field.title!r} ({field_id}) is invalid " f"and cannot be selected. Error ignored.")
LOGGER.error(f"Field {field.title!r} ({field_id}) is invalid and cannot be selected. Error ignored.")

self._ensure_not_unsupported_type(field)
if not block_spec.allow_measure_fields:
self._ensure_not_measure(field)

def _make_rls_filter_specs(self, subject_type: RLSSubjectType) -> List[FilterFieldSpec]:
subject_id = self._rci.user_id
if not subject_type or not subject_id:
raise Exception("No subject to use in RLS")
def _make_rls_filter_specs(self) -> List[FilterFieldSpec]:
user_id = self._rci.user_id
if not user_id:
raise RuntimeError("No subject to use in RLS")

result: List[FilterFieldSpec] = []
restrictions = self._dataset.rls.get_subject_restrictions(subject_type=subject_type, subject_id=subject_id)
restrictions = self._dataset.rls.get_user_restrictions(user_id=user_id)
for field_guid, values in restrictions.items():
result.append(
FilterFieldSpec(
Expand All @@ -405,7 +405,7 @@ def _make_rls_filter_specs(self, subject_type: RLSSubjectType) -> List[FilterFie
anonymous=True,
)
)
self._log_info("RLS filters for %s %s: %s", subject_type.name, subject_id, result)
self._log_info("RLS filters for user %s: %s", user_id, result)
return result

def make_phantom_select_ids(
Expand All @@ -429,7 +429,7 @@ def make_filter_specs(self, block_spec: BlockSpec) -> List[FilterFieldSpec]:
)

if not block_spec.disable_rls:
filter_specs += self._make_rls_filter_specs(subject_type=RLSSubjectType.user)
filter_specs += self._make_rls_filter_specs()

return filter_specs

Expand Down
33 changes: 11 additions & 22 deletions lib/dl_rls/dl_rls/rls.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,30 @@ def has_restrictions(self) -> bool:
def fields_with_rls(self) -> list[str]:
return list(set(item.field_guid for item in self.items))

def get_entries(
self, field_guid: str, subject_type: RLSSubjectType, subject_id: str, add_userid_entry: bool = True
) -> list[RLSEntry]:
def get_entries(self, field_guid: str, user_id: str) -> list[RLSEntry]:
return [
item
for item in self.items
if item.field_guid == field_guid
and (
# Same subject
(item.subject.subject_type == subject_type and item.subject.subject_id == subject_id)
# same user
(item.subject.subject_type == RLSSubjectType.user and item.subject.subject_id == user_id)
# user is in the group
or (
item.subject.subject_type == RLSSubjectType.group and item.subject.subject_id in self.allowed_groups
)
# 'all subjects' matches any subject.
or item.subject.subject_type == RLSSubjectType.all
# `userid: userid`
or (add_userid_entry and item.pattern_type == RLSPatternType.userid)
or item.pattern_type == RLSPatternType.userid
)
]

def get_field_restriction_for_subject(
self,
field_guid: str,
subject_type: RLSSubjectType,
subject_id: str,
) -> FieldRestrictions:
def get_field_restriction_for_user(self, field_guid: str, user_id: str) -> FieldRestrictions:
"""
For subject and field, return `allow_all_values, allowed_values`.
For user and field, return `allow_all_values, allowed_values`.
"""
rls_entries = self.get_entries(field_guid=field_guid, subject_type=subject_type, subject_id=subject_id)
rls_entries = self.get_entries(field_guid=field_guid, user_id=user_id)

# There's a `*: {current_user}` entry, no need to filter.
if any(rls_entry.pattern_type == RLSPatternType.all for rls_entry in rls_entries):
Expand Down Expand Up @@ -105,23 +98,19 @@ def get_field_restriction_for_subject(

return FieldRestrictions(allow_all_values=False, allow_userid=allow_userid, allowed_values=allowed_values)

def get_subject_restrictions(
self,
subject_type: RLSSubjectType,
subject_id: str,
) -> dict[str, list[str]]:
def get_user_restrictions(self, user_id: str) -> dict[str, list[str]]:
result = {}
for field_guid in self.fields_with_rls:
allow_all_values, allow_userid, allowed_values = self.get_field_restriction_for_subject(
field_guid=field_guid, subject_type=subject_type, subject_id=subject_id
allow_all_values, allow_userid, allowed_values = self.get_field_restriction_for_user(
field_guid=field_guid, user_id=user_id
)
if allow_all_values:
# `*: subject` => 'not restricted'
continue

# For `userid: userid`, add the subject id to the values.
if allow_userid:
allowed_values = list(allowed_values) + [subject_id]
allowed_values = list(allowed_values) + [user_id]

result[field_guid] = allowed_values

Expand Down
10 changes: 3 additions & 7 deletions lib/dl_rls/dl_rls_tests/unit/test_rls.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ def test_rls_simple():
_add_rls_restrictions(rls, field_guid=field_guid, restrictions=[dict(allowed_value="QQQ", subject_id="qwerty")])

assert rls.has_restrictions
allow_all_values, allow_userid, allowed_values = rls.get_field_restriction_for_subject(
allow_all_values, allow_userid, allowed_values = rls.get_field_restriction_for_user(
field_guid=field_guid,
subject_id="qwerty",
subject_type=RLSSubjectType.user,
user_id="qwerty",
)
assert not allow_all_values
assert not allow_userid
Expand Down Expand Up @@ -256,10 +255,7 @@ def test_rls(entrysets: dict, expected_restrictions: dict):
_add_rls_restrictions(rls, field_guid, entries)

assert rls.has_restrictions
restrictions = {
user_id: rls.get_subject_restrictions(subject_type=RLSSubjectType.user, subject_id=user_id)
for user_id in expected_restrictions.keys()
}
restrictions = {user_id: rls.get_user_restrictions(user_id=user_id) for user_id in expected_restrictions.keys()}
# Map back fields to aliases for readability
restrictions = {
user_id: {field_guid: field_restrictions for field_guid, field_restrictions in user_restrictions.items()}
Expand Down

0 comments on commit 9c9fc75

Please sign in to comment.