Skip to content

Commit

Permalink
Merge pull request #57 from reddit/fix_prune_fn
Browse files Browse the repository at this point in the history
Fix `_prune_extracted_dict()` definition location
  • Loading branch information
mrlevitas authored Jun 30, 2022
2 parents f34b8d7 + 5d18dc3 commit 16fe9c9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 73 deletions.
64 changes: 32 additions & 32 deletions reddit_decider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,12 @@ def _send_expose(
return

experiment = ExperimentConfig(
id=Decider._cast_to_int(exp_id),
id=self._cast_to_int(exp_id),
name=name,
version=version,
bucket_val=bucket_val,
start_ts=Decider._cast_to_int(start_ts),
stop_ts=Decider._cast_to_int(stop_ts),
start_ts=self._cast_to_int(start_ts),
stop_ts=self._cast_to_int(stop_ts),
owner=owner,
)

Expand Down Expand Up @@ -299,12 +299,12 @@ def _send_expose_if_holdout(
# 2: holdout
if event_type == "2":
experiment = ExperimentConfig(
id=Decider._cast_to_int(exp_id),
id=self._cast_to_int(exp_id),
name=name,
version=version,
bucket_val=bucket_val,
start_ts=Decider._cast_to_int(start_ts),
stop_ts=Decider._cast_to_int(stop_ts),
start_ts=self._cast_to_int(start_ts),
stop_ts=self._cast_to_int(stop_ts),
owner=owner,
)

Expand All @@ -321,35 +321,15 @@ def _send_expose_if_holdout(
)
return

@classmethod
def _cast_to_int(cls, input: str) -> int:
@staticmethod
def _cast_to_int(input: str) -> int:
out = 1
try:
out = int(input)
except ValueError as e:
logger.info(f"Encountered error casting to integer: {e}")
return out

@classmethod
def _prune_extracted_dict(cls, extracted_dict: dict) -> dict:
parsed_extracted_fields = deepcopy(extracted_dict)

for k, v in extracted_dict.items():
# remove invalid keys
if k is None or not isinstance(k, str):
logger.info(
f"{k} key in request_field_extractor() dict is not of type str and is removed."
)
del parsed_extracted_fields[k]
continue
# remove invalid values
if not isinstance(v, (int, float, str, bool)) and v is not None:
logger.info(
f"{k}: {v} value in `request_field_extractor()` dict is not one of type: [None, int, float, str, bool] and is removed."
)
del parsed_extracted_fields[k]
return parsed_extracted_fields

def get_variant(
self, experiment_name: str, **exposure_kwargs: Optional[Dict[str, Any]]
) -> Optional[str]:
Expand Down Expand Up @@ -955,14 +935,34 @@ def __init__(
self._event_logger = event_logger
self._request_field_extractor = request_field_extractor

@classmethod
def is_employee(cls, edge_context: Any) -> bool:
@staticmethod
def _is_employee(edge_context: Any) -> bool:
return (
any([edge_context.user.has_role(role) for role in EMPLOYEE_ROLES])
if edge_context.user.is_logged_in
else False
)

@staticmethod
def _prune_extracted_dict(extracted_dict: dict) -> dict:
parsed_extracted_fields = deepcopy(extracted_dict)

for k, v in extracted_dict.items():
# remove invalid keys
if k is None or not isinstance(k, str):
logger.info(
f"{k} key in request_field_extractor() dict is not of type str and is removed."
)
del parsed_extracted_fields[k]
continue
# remove invalid values
if not isinstance(v, (int, float, str, bool)) and v is not None:
logger.info(
f"{k}: {v} value in `request_field_extractor()` dict is not one of type: [None, int, float, str, bool] and is removed."
)
del parsed_extracted_fields[k]
return parsed_extracted_fields

def _minimal_decider(
self, name: str, span: Span, parsed_extracted_fields: Optional[Dict] = None
) -> Decider:
Expand Down Expand Up @@ -997,7 +997,7 @@ def make_object_for_context(self, name: str, span: Span) -> Decider:
if self._request_field_extractor:
extracted_fields = self._request_field_extractor(request)
# prune any invalid keys/values
parsed_extracted_fields = Decider._prune_extracted_dict(
parsed_extracted_fields = self._prune_extracted_dict(
extracted_dict=extracted_fields
)
except Exception as exc:
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def make_object_for_context(self, name: str, span: Span) -> Decider:

is_employee = None
try:
is_employee = DeciderContextFactory.is_employee(ec)
is_employee = self._is_employee(ec)
except Exception as exc:
logger.info(
f"Error in `DeciderContextFactory.is_employee(ec)` in `make_object_for_context()`. details: {exc}"
Expand Down
105 changes: 64 additions & 41 deletions tests/decider_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import json
import logging
import tempfile
import unittest

Expand All @@ -18,6 +19,8 @@
from reddit_decider import EventType
from reddit_decider import init_decider_parser

logger = logging.getLogger()

USER_ID = "t2_1234"
IS_LOGGED_IN = True
AUTH_CLIENT_ID = "token"
Expand Down Expand Up @@ -67,37 +70,39 @@ def setUp(self):
self.mock_span.context = None

def test_make_clients(self, file_watcher_mock):
decider_ctx_factory = decider_client_from_config(
{"experiments.path": "/tmp/test"}, self.event_logger
)
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"experiments.path": f.name}, self.event_logger
)
self.assertIsInstance(decider_ctx_factory, DeciderContextFactory)
file_watcher_mock.assert_called_once_with(
path="/tmp/test", parser=init_decider_parser, timeout=None, backoff=None
path=f.name, parser=init_decider_parser, timeout=None, backoff=None
)

def test_timeout(self, file_watcher_mock):
decider_ctx_factory = decider_client_from_config(
{"experiments.path": "/tmp/test", "experiments.timeout": "60 seconds"},
self.event_logger,
)
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"experiments.path": f.name, "experiments.timeout": "2 seconds"},
self.event_logger,
)
self.assertIsInstance(decider_ctx_factory, DeciderContextFactory)
file_watcher_mock.assert_called_once_with(
path="/tmp/test", parser=init_decider_parser, timeout=60.0, backoff=None
path=f.name, parser=init_decider_parser, timeout=2.0, backoff=None
)

def test_prefix(self, file_watcher_mock):
decider_ctx_factory = decider_client_from_config(
{"r2_experiments.path": "/tmp/test", "r2_experiments.timeout": "60 seconds"},
self.event_logger,
prefix="r2_experiments.",
)
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"r2_experiments.path": f.name, "r2_experiments.timeout": "2 seconds"},
self.event_logger,
prefix="r2_experiments.",
)
self.assertIsInstance(decider_ctx_factory, DeciderContextFactory)
file_watcher_mock.assert_called_once_with(
path="/tmp/test", parser=init_decider_parser, timeout=60.0, backoff=None
path=f.name, parser=init_decider_parser, timeout=2.0, backoff=None
)


@mock.patch("reddit_decider.FileWatcher")
class DeciderContextFactoryTests(unittest.TestCase):
def setUp(self):
super().setUp()
Expand All @@ -115,14 +120,22 @@ def setUp(self):
self.mock_span.context.edge_context.origin_service.name = ORIGIN_SERVICE
self.mock_span.context.edge_context.device.id = DEVICE_ID

def test_make_object_for_context_and_decider_context(self, _filewatcher):
decider_ctx_factory = decider_client_from_config(
{"experiments.path": "/tmp/test", "experiments.timeout": "60 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=decider_field_extractor,
)
decider = decider_ctx_factory.make_object_for_context(name="test", span=self.mock_span)
def test_make_object_for_context_and_decider_context(self):
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"experiments.path": f.name, "experiments.timeout": "2 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=decider_field_extractor,
)
with self.assertLogs(logger, logging.WARN) as captured:
# ensure no warnings are printed except for the dummy one
# https://stackoverflow.com/a/61381576/4260179
logger.warn("Dummy warning")
decider = decider_ctx_factory.make_object_for_context(name="test", span=self.mock_span)
assert len(captured.records) == 1
self.assertEqual(["WARNING:root:Dummy warning"], captured.output)

self.assertIsInstance(decider, Decider)

decider_context = getattr(decider, "_decider_context")
Expand Down Expand Up @@ -183,22 +196,29 @@ def test_make_object_for_context_and_decider_context(self, _filewatcher):
self.assertEqual(decider_event_dict["canonical_url"], CANONICAL_URL)
self.assertEqual(decider_event_dict["request"]["canonical_url"], CANONICAL_URL)

def test_make_object_for_context_and_decider_context_without_span(self, _filewatcher):
decider_ctx_factory = decider_client_from_config(
{"experiments.path": "/tmp/test", "experiments.timeout": "60 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=decider_field_extractor,
)
decider = decider_ctx_factory.make_object_for_context(name="test", span=None)
def test_make_object_for_context_and_decider_context_without_span(self):
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"experiments.path": f.name, "experiments.timeout": "2 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=decider_field_extractor,
)
with self.assertLogs(logger, logging.WARN) as captured:
# ensure no warnings are printed except for the dummy one
# https://stackoverflow.com/a/61381576/4260179
logger.warn("Dummy warning")

decider = decider_ctx_factory.make_object_for_context(name="test", span=None)
assert len(captured.records) == 1
self.assertEqual(["WARNING:root:Dummy warning"], captured.output)

self.assertIsInstance(decider, Decider)

decider_ctx_dict = decider._decider_context.to_dict()
self.assertEqual(decider_ctx_dict["user_id"], None)

def test_make_object_for_context_and_decider_context_with_broken_decider_field_extractor(
self, _filewatcher
):
def test_make_object_for_context_and_decider_context_with_broken_decider_field_extractor(self):
def broken_decider_field_extractor(_request: RequestContext):
return {
"app_name": {},
Expand All @@ -208,16 +228,19 @@ def broken_decider_field_extractor(_request: RequestContext):
None: "xyz",
}

decider_ctx_factory = decider_client_from_config(
{"experiments.path": "/tmp/test", "experiments.timeout": "60 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=broken_decider_field_extractor,
)
with create_temp_config_file({}) as f:
decider_ctx_factory = decider_client_from_config(
{"experiments.path": f.name, "experiments.timeout": "2 seconds"},
self.event_logger,
prefix="experiments.",
request_field_extractor=broken_decider_field_extractor,
)

with self.assertLogs() as captured:
decider_ctx_factory.make_object_for_context(name="test", span=self.mock_span)

assert len(captured.records) == 3

assert any(
"None key in request_field_extractor() dict is not of type str and is removed."
in x.getMessage()
Expand Down

0 comments on commit 16fe9c9

Please sign in to comment.