From 51d76a0e7ea0e776ba76cc387ca9948bd28b76fd Mon Sep 17 00:00:00 2001 From: jjllee Date: Mon, 5 Feb 2024 16:18:43 -0800 Subject: [PATCH] add method return typing, sampling_rule defaults, update tests --- .../sampler/_aws_xray_sampling_client.py | 7 +- .../distro/sampler/_sampling_rule.py | 28 +++---- .../distro/sampler/aws_xray_remote_sampler.py | 3 +- .../sampler/test_aws_xray_sampling_client.py | 74 +++++++++++++++++-- 4 files changed, 90 insertions(+), 22 deletions(-) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py index 3ef627154..09c56e386 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py @@ -20,7 +20,7 @@ def __init__(self, endpoint=None, log_level=None): _logger.error("endpoint must be specified") self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules" - def get_sampling_rules(self): + def get_sampling_rules(self) -> [_SamplingRule]: sampling_rules = [] headers = {"content-type": "application/json"} @@ -38,7 +38,10 @@ def get_sampling_rules(self): sampling_rules_records = sampling_rules_response["SamplingRuleRecords"] for record in sampling_rules_records: - sampling_rules.append(_SamplingRule(**record["SamplingRule"])) + if "SamplingRule" not in record: + _logger.error("SamplingRule is missing in SamplingRuleRecord") + else: + sampling_rules.append(_SamplingRule(**record["SamplingRule"])) except requests.exceptions.RequestException as req_err: _logger.error("Request error occurred: %s", req_err) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py index afc0b1368..40b463934 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py @@ -7,7 +7,7 @@ class _SamplingRule: def __init__( self, - Attributes=None, + Attributes: dict = None, FixedRate=None, HTTPMethod=None, Host=None, @@ -21,16 +21,16 @@ def __init__( URLPath=None, Version=None, ): - self.Attributes = Attributes - self.FixedRate = FixedRate - self.HTTPMethod = HTTPMethod - self.Host = Host - self.Priority = Priority - self.ReservoirSize = ReservoirSize - self.ResourceARN = ResourceARN - self.RuleARN = RuleARN - self.RuleName = RuleName - self.ServiceName = ServiceName - self.ServiceType = ServiceType - self.URLPath = URLPath - self.Version = Version + self.Attributes = Attributes if Attributes is not None else {} + self.FixedRate = FixedRate if FixedRate is not None else "" + self.HTTPMethod = HTTPMethod if HTTPMethod is not None else "" + self.Host = Host if Host is not None else "" + self.Priority = Priority if Priority is not None else "" + self.ReservoirSize = ReservoirSize if ReservoirSize is not None else "" + self.ResourceARN = ResourceARN if ResourceARN is not None else "" + self.RuleARN = RuleARN if RuleARN is not None else "" + self.RuleName = RuleName if RuleName is not None else "" + self.ServiceName = ServiceName if ServiceName is not None else "" + self.ServiceType = ServiceType if ServiceType is not None else "" + self.URLPath = URLPath if URLPath is not None else "" + self.Version = Version if Version is not None else "" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py index 6d7714923..724da64f9 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py @@ -59,8 +59,9 @@ def __init__( self.__resource = Resource.get_empty() # Schedule the next rule poll now + # Python Timers only run once, so they need to be recreated for every poll self._timer = Timer(0, self.__start_sampling_rule_poller) - self._timer.daemon = True + self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed self._timer.start() # pylint: disable=no-self-use diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py index 98c9c96ef..bff47599a 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py @@ -24,18 +24,82 @@ def test_get_no_sampling_rules(self, mock_post=None): self.assertTrue(len(sampling_rules) == 0) @patch("requests.post") - def test_get_invalid_response(self, mock_post=None): + def test_get_invalid_responses(self, mock_post=None): mock_post.return_value.configure_mock(**{"json.return_value": {}}) client = _AwsXRaySamplingClient("http://127.0.0.1:2000") with self.assertLogs(_logger, level="ERROR"): sampling_rules = client.get_sampling_rules() - self.assertTrue(len(sampling_rules) == 0) + self.assertTrue(len(sampling_rules) == 0) + + @patch("requests.post") + def test_get_sampling_rule_missing_in_records(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{}]}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + with self.assertLogs(_logger, level="ERROR"): + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 0) @patch("requests.post") - def test_get_two_sampling_rules(self, mock_post=None): + def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{"SamplingRule": {}}]}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 1) + + sampling_rule = sampling_rules[0] + self.assertEqual(sampling_rule.Attributes, {}) + self.assertEqual(sampling_rule.FixedRate, "") + self.assertEqual(sampling_rule.HTTPMethod, "") + self.assertEqual(sampling_rule.Host, "") + self.assertEqual(sampling_rule.Priority, "") + self.assertEqual(sampling_rule.ReservoirSize, "") + self.assertEqual(sampling_rule.ResourceARN, "") + self.assertEqual(sampling_rule.RuleARN, "") + self.assertEqual(sampling_rule.RuleName, "") + self.assertEqual(sampling_rule.ServiceName, "") + self.assertEqual(sampling_rule.ServiceType, "") + self.assertEqual(sampling_rule.URLPath, "") + self.assertEqual(sampling_rule.Version, "") + + @patch("requests.post") + def test_get_three_sampling_rules(self, mock_post=None): + sampling_records = [] with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file: - mock_post.return_value.configure_mock(**{"json.return_value": json.load(file)}) + sample_response = json.load(file) + sampling_records = sample_response["SamplingRuleRecords"] + mock_post.return_value.configure_mock(**{"json.return_value": sample_response}) file.close() client = _AwsXRaySamplingClient("http://127.0.0.1:2000") sampling_rules = client.get_sampling_rules() - self.assertTrue(len(sampling_rules) == 3) + self.assertEqual(len(sampling_rules), 3) + self.assertEqual(len(sampling_rules), len(sampling_records)) + self.validate_match_sampling_rules_properties_with_records(sampling_rules, sampling_records) + + def validate_match_sampling_rules_properties_with_records(self, sampling_rules, sampling_records): + for _, (sampling_rule, sampling_record) in enumerate(zip(sampling_rules, sampling_records)): + self.assertIsNotNone(sampling_rule.Attributes) + self.assertEqual(sampling_rule.Attributes, sampling_record["SamplingRule"]["Attributes"]) + self.assertIsNotNone(sampling_rule.FixedRate) + self.assertEqual(sampling_rule.FixedRate, sampling_record["SamplingRule"]["FixedRate"]) + self.assertIsNotNone(sampling_rule.HTTPMethod) + self.assertEqual(sampling_rule.HTTPMethod, sampling_record["SamplingRule"]["HTTPMethod"]) + self.assertIsNotNone(sampling_rule.Host) + self.assertEqual(sampling_rule.Host, sampling_record["SamplingRule"]["Host"]) + self.assertIsNotNone(sampling_rule.Priority) + self.assertEqual(sampling_rule.Priority, sampling_record["SamplingRule"]["Priority"]) + self.assertIsNotNone(sampling_rule.ReservoirSize) + self.assertEqual(sampling_rule.ReservoirSize, sampling_record["SamplingRule"]["ReservoirSize"]) + self.assertIsNotNone(sampling_rule.ResourceARN) + self.assertEqual(sampling_rule.ResourceARN, sampling_record["SamplingRule"]["ResourceARN"]) + self.assertIsNotNone(sampling_rule.RuleARN) + self.assertEqual(sampling_rule.RuleARN, sampling_record["SamplingRule"]["RuleARN"]) + self.assertIsNotNone(sampling_rule.RuleName) + self.assertEqual(sampling_rule.RuleName, sampling_record["SamplingRule"]["RuleName"]) + self.assertIsNotNone(sampling_rule.ServiceName) + self.assertEqual(sampling_rule.ServiceName, sampling_record["SamplingRule"]["ServiceName"]) + self.assertIsNotNone(sampling_rule.ServiceType) + self.assertEqual(sampling_rule.ServiceType, sampling_record["SamplingRule"]["ServiceType"]) + self.assertIsNotNone(sampling_rule.URLPath) + self.assertEqual(sampling_rule.URLPath, sampling_record["SamplingRule"]["URLPath"]) + self.assertIsNotNone(sampling_rule.Version) + self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"])