Skip to content

Commit

Permalink
add method return typing, sampling_rule defaults, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jj22ee committed Feb 6, 2024
1 parent 1b9e3b5 commit 51d76a0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, endpoint=None, log_level=None):
_logger.error("endpoint must be specified")
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"

def get_sampling_rules(self):
def get_sampling_rules(self) -> [_SamplingRule]:
sampling_rules = []
headers = {"content-type": "application/json"}

Expand All @@ -38,7 +38,10 @@ def get_sampling_rules(self):

sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
for record in sampling_rules_records:
sampling_rules.append(_SamplingRule(**record["SamplingRule"]))
if "SamplingRule" not in record:
_logger.error("SamplingRule is missing in SamplingRuleRecord")
else:
sampling_rules.append(_SamplingRule(**record["SamplingRule"]))

except requests.exceptions.RequestException as req_err:
_logger.error("Request error occurred: %s", req_err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class _SamplingRule:
def __init__(
self,
Attributes=None,
Attributes: dict = None,
FixedRate=None,
HTTPMethod=None,
Host=None,
Expand All @@ -21,16 +21,16 @@ def __init__(
URLPath=None,
Version=None,
):
self.Attributes = Attributes
self.FixedRate = FixedRate
self.HTTPMethod = HTTPMethod
self.Host = Host
self.Priority = Priority
self.ReservoirSize = ReservoirSize
self.ResourceARN = ResourceARN
self.RuleARN = RuleARN
self.RuleName = RuleName
self.ServiceName = ServiceName
self.ServiceType = ServiceType
self.URLPath = URLPath
self.Version = Version
self.Attributes = Attributes if Attributes is not None else {}
self.FixedRate = FixedRate if FixedRate is not None else ""
self.HTTPMethod = HTTPMethod if HTTPMethod is not None else ""
self.Host = Host if Host is not None else ""
self.Priority = Priority if Priority is not None else ""
self.ReservoirSize = ReservoirSize if ReservoirSize is not None else ""
self.ResourceARN = ResourceARN if ResourceARN is not None else ""
self.RuleARN = RuleARN if RuleARN is not None else ""
self.RuleName = RuleName if RuleName is not None else ""
self.ServiceName = ServiceName if ServiceName is not None else ""
self.ServiceType = ServiceType if ServiceType is not None else ""
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else ""
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def __init__(
self.__resource = Resource.get_empty()

# Schedule the next rule poll now
# Python Timers only run once, so they need to be recreated for every poll
self._timer = Timer(0, self.__start_sampling_rule_poller)
self._timer.daemon = True
self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed
self._timer.start()

# pylint: disable=no-self-use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,82 @@ def test_get_no_sampling_rules(self, mock_post=None):
self.assertTrue(len(sampling_rules) == 0)

@patch("requests.post")
def test_get_invalid_response(self, mock_post=None):
def test_get_invalid_responses(self, mock_post=None):
mock_post.return_value.configure_mock(**{"json.return_value": {}})
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
with self.assertLogs(_logger, level="ERROR"):
sampling_rules = client.get_sampling_rules()
self.assertTrue(len(sampling_rules) == 0)
self.assertTrue(len(sampling_rules) == 0)

@patch("requests.post")
def test_get_sampling_rule_missing_in_records(self, mock_post=None):
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{}]}})
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
with self.assertLogs(_logger, level="ERROR"):
sampling_rules = client.get_sampling_rules()
self.assertTrue(len(sampling_rules) == 0)

@patch("requests.post")
def test_get_two_sampling_rules(self, mock_post=None):
def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock_post=None):
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{"SamplingRule": {}}]}})
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
sampling_rules = client.get_sampling_rules()
self.assertTrue(len(sampling_rules) == 1)

sampling_rule = sampling_rules[0]
self.assertEqual(sampling_rule.Attributes, {})
self.assertEqual(sampling_rule.FixedRate, "")
self.assertEqual(sampling_rule.HTTPMethod, "")
self.assertEqual(sampling_rule.Host, "")
self.assertEqual(sampling_rule.Priority, "")
self.assertEqual(sampling_rule.ReservoirSize, "")
self.assertEqual(sampling_rule.ResourceARN, "")
self.assertEqual(sampling_rule.RuleARN, "")
self.assertEqual(sampling_rule.RuleName, "")
self.assertEqual(sampling_rule.ServiceName, "")
self.assertEqual(sampling_rule.ServiceType, "")
self.assertEqual(sampling_rule.URLPath, "")
self.assertEqual(sampling_rule.Version, "")

@patch("requests.post")
def test_get_three_sampling_rules(self, mock_post=None):
sampling_records = []
with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file:
mock_post.return_value.configure_mock(**{"json.return_value": json.load(file)})
sample_response = json.load(file)
sampling_records = sample_response["SamplingRuleRecords"]
mock_post.return_value.configure_mock(**{"json.return_value": sample_response})
file.close()
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
sampling_rules = client.get_sampling_rules()
self.assertTrue(len(sampling_rules) == 3)
self.assertEqual(len(sampling_rules), 3)
self.assertEqual(len(sampling_rules), len(sampling_records))
self.validate_match_sampling_rules_properties_with_records(sampling_rules, sampling_records)

def validate_match_sampling_rules_properties_with_records(self, sampling_rules, sampling_records):
for _, (sampling_rule, sampling_record) in enumerate(zip(sampling_rules, sampling_records)):
self.assertIsNotNone(sampling_rule.Attributes)
self.assertEqual(sampling_rule.Attributes, sampling_record["SamplingRule"]["Attributes"])
self.assertIsNotNone(sampling_rule.FixedRate)
self.assertEqual(sampling_rule.FixedRate, sampling_record["SamplingRule"]["FixedRate"])
self.assertIsNotNone(sampling_rule.HTTPMethod)
self.assertEqual(sampling_rule.HTTPMethod, sampling_record["SamplingRule"]["HTTPMethod"])
self.assertIsNotNone(sampling_rule.Host)
self.assertEqual(sampling_rule.Host, sampling_record["SamplingRule"]["Host"])
self.assertIsNotNone(sampling_rule.Priority)
self.assertEqual(sampling_rule.Priority, sampling_record["SamplingRule"]["Priority"])
self.assertIsNotNone(sampling_rule.ReservoirSize)
self.assertEqual(sampling_rule.ReservoirSize, sampling_record["SamplingRule"]["ReservoirSize"])
self.assertIsNotNone(sampling_rule.ResourceARN)
self.assertEqual(sampling_rule.ResourceARN, sampling_record["SamplingRule"]["ResourceARN"])
self.assertIsNotNone(sampling_rule.RuleARN)
self.assertEqual(sampling_rule.RuleARN, sampling_record["SamplingRule"]["RuleARN"])
self.assertIsNotNone(sampling_rule.RuleName)
self.assertEqual(sampling_rule.RuleName, sampling_record["SamplingRule"]["RuleName"])
self.assertIsNotNone(sampling_rule.ServiceName)
self.assertEqual(sampling_rule.ServiceName, sampling_record["SamplingRule"]["ServiceName"])
self.assertIsNotNone(sampling_rule.ServiceType)
self.assertEqual(sampling_rule.ServiceType, sampling_record["SamplingRule"]["ServiceType"])
self.assertIsNotNone(sampling_rule.URLPath)
self.assertEqual(sampling_rule.URLPath, sampling_record["SamplingRule"]["URLPath"])
self.assertIsNotNone(sampling_rule.Version)
self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"])

0 comments on commit 51d76a0

Please sign in to comment.