diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py new file mode 100644 index 000000000..09c56e386 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +from logging import getLogger + +import requests + +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule + +_logger = getLogger(__name__) + + +class _AwsXRaySamplingClient: + def __init__(self, endpoint=None, log_level=None): + # Override default log level + if log_level is not None: + _logger.setLevel(log_level) + + if endpoint is None: + _logger.error("endpoint must be specified") + self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules" + + def get_sampling_rules(self) -> [_SamplingRule]: + sampling_rules = [] + headers = {"content-type": "application/json"} + + try: + xray_response = requests.post(url=self.__get_sampling_rules_endpoint, headers=headers, timeout=20) + if xray_response is None: + _logger.error("GetSamplingRules response is None") + return [] + sampling_rules_response = xray_response.json() + if "SamplingRuleRecords" not in sampling_rules_response: + _logger.error( + "SamplingRuleRecords is missing in getSamplingRules response: %s", sampling_rules_response + ) + return [] + + sampling_rules_records = sampling_rules_response["SamplingRuleRecords"] + for record in sampling_rules_records: + if "SamplingRule" not in record: + _logger.error("SamplingRule is missing in SamplingRuleRecord") + else: + sampling_rules.append(_SamplingRule(**record["SamplingRule"])) + + except requests.exceptions.RequestException as req_err: + _logger.error("Request error occurred: %s", req_err) + except json.JSONDecodeError as json_err: + _logger.error("Error in decoding JSON response: %s", json_err) + + return sampling_rules diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py new file mode 100644 index 000000000..fe8f4f60b --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +# Disable snake_case naming style so this class can match the sampling rules response from X-Ray +# pylint: disable=invalid-name +class _SamplingRule: + def __init__( + self, + Attributes: dict = None, + FixedRate=None, + HTTPMethod=None, + Host=None, + Priority=None, + ReservoirSize=None, + ResourceARN=None, + RuleARN=None, + RuleName=None, + ServiceName=None, + ServiceType=None, + URLPath=None, + Version=None, + ): + self.Attributes = Attributes if Attributes is not None else {} + self.FixedRate = FixedRate if FixedRate is not None else 0.0 + self.HTTPMethod = HTTPMethod if HTTPMethod is not None else "" + self.Host = Host if Host is not None else "" + # Default to value with lower priority than default rule + self.Priority = Priority if Priority is not None else 10001 + self.ReservoirSize = ReservoirSize if ReservoirSize is not None else 0 + self.ResourceARN = ResourceARN if ResourceARN is not None else "" + self.RuleARN = RuleARN if RuleARN is not None else "" + self.RuleName = RuleName if RuleName is not None else "" + self.ServiceName = ServiceName if ServiceName is not None else "" + self.ServiceType = ServiceType if ServiceType is not None else "" + self.URLPath = URLPath if URLPath is not None else "" + self.Version = Version if Version is not None else 0 diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py new file mode 100644 index 000000000..724da64f9 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +from logging import getLogger +from threading import Timer +from typing import Optional, Sequence + +from typing_extensions import override + +from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient +from opentelemetry.context import Context +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Sampler, SamplingResult +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + +_logger = getLogger(__name__) + +DEFAULT_RULES_POLLING_INTERVAL_SECONDS = 300 +DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10 +DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000" + + +class AwsXRayRemoteSampler(Sampler): + """ + Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray + + Args: + resource: OpenTelemetry Resource (Required) + endpoint: proxy endpoint for AWS X-Ray Sampling (Optional) + polling_interval: Polling interval for getSamplingRules call (Optional) + log_level: custom log level configuration for remote sampler (Optional) + """ + + __resource: Resource + __polling_interval: int + __xray_client: _AwsXRaySamplingClient + + def __init__( + self, + resource: Resource, + endpoint=DEFAULT_SAMPLING_PROXY_ENDPOINT, + polling_interval=DEFAULT_RULES_POLLING_INTERVAL_SECONDS, + log_level=None, + ): + # Override default log level + if log_level is not None: + _logger.setLevel(log_level) + + self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level) + self.__polling_interval = polling_interval + + # pylint: disable=unused-private-member + if resource is not None: + self.__resource = resource + else: + _logger.warning("OTel Resource provided is `None`. Defaulting to empty resource") + self.__resource = Resource.get_empty() + + # Schedule the next rule poll now + # Python Timers only run once, so they need to be recreated for every poll + self._timer = Timer(0, self.__start_sampling_rule_poller) + self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed + self._timer.start() + + # pylint: disable=no-self-use + @override + def should_sample( + self, + parent_context: Optional["Context"], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence["Link"] = None, + trace_state: "TraceState" = None, + ) -> SamplingResult: + # TODO: add sampling functionality + return ALWAYS_OFF.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + # pylint: disable=no-self-use + @override + def get_description(self) -> str: + description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}" + return description + + def __get_and_update_sampling_rules(self): + sampling_rules = self.__xray_client.get_sampling_rules() + + # TODO: Update sampling rules cache + _logger.info("Got Sampling Rules: %s", {json.dumps([ob.__dict__ for ob in sampling_rules])}) + + def __start_sampling_rule_poller(self): + self.__get_and_update_sampling_rules() + # Schedule the next sampling rule poll + self._timer = Timer(self.__polling_interval, self.__start_sampling_rule_poller) + self._timer.daemon = True + self._timer.start() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample.json b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample.json new file mode 100644 index 000000000..a0d3c5ba2 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample.json @@ -0,0 +1,65 @@ +{ + "NextToken": null, + "SamplingRuleRecords": [ + { + "CreatedAt": 1.67799933E9, + "ModifiedAt": 1.67799933E9, + "SamplingRule": { + "Attributes": { + "foo": "bar", + "doo": "baz" + }, + "FixedRate": 0.05, + "HTTPMethod": "*", + "Host": "*", + "Priority": 1000, + "ReservoirSize": 10, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule1", + "RuleName": "Rule1", + "ServiceName": "*", + "ServiceType": "AWS::Foo::Bar", + "URLPath": "*", + "Version": 1 + } + }, + { + "CreatedAt": 0.0, + "ModifiedAt": 1.611564245E9, + "SamplingRule": { + "Attributes": {}, + "FixedRate": 0.05, + "HTTPMethod": "*", + "Host": "*", + "Priority": 10000, + "ReservoirSize": 1, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Default", + "RuleName": "Default", + "ServiceName": "*", + "ServiceType": "*", + "URLPath": "*", + "Version": 1 + } + }, + { + "CreatedAt": 1.676038494E9, + "ModifiedAt": 1.676038494E9, + "SamplingRule": { + "Attributes": {}, + "FixedRate": 0.2, + "HTTPMethod": "GET", + "Host": "*", + "Priority": 1, + "ReservoirSize": 10, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule2", + "RuleName": "Rule2", + "ServiceName": "FooBar", + "ServiceType": "*", + "URLPath": "/foo/bar", + "Version": 1 + } + } + ] +} \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py new file mode 100644 index 000000000..17d0d5f97 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from logging import DEBUG +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler +from opentelemetry.sdk.resources import Resource + + +class TestAwsXRayRemoteSampler(TestCase): + def test_create_remote_sampler_with_empty_resource(self): + rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) + self.assertIsNotNone(rs._timer) + self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) + + def test_create_remote_sampler_with_populated_resource(self): + rs = AwsXRayRemoteSampler( + resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}) + ) + self.assertIsNotNone(rs._timer) + self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) + self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name") + self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform") + + def test_create_remote_sampler_with_all_fields_populated(self): + rs = AwsXRayRemoteSampler( + resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}), + endpoint="http://abc.com", + polling_interval=120, + log_level=DEBUG, + ) + self.assertIsNotNone(rs._timer) + self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 120) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) + self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) + self.assertEqual( + rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint, + "http://abc.com/GetSamplingRules", + ) + self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name") + self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py new file mode 100644 index 000000000..637e4545a --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from logging import getLogger +from unittest import TestCase +from unittest.mock import patch + +from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient + +SAMPLING_CLIENT_LOGGER_NAME = "amazon.opentelemetry.distro.sampler._aws_xray_sampling_client" +_logger = getLogger(SAMPLING_CLIENT_LOGGER_NAME) + +TEST_DIR = os.path.dirname(os.path.realpath(__file__)) +DATA_DIR = os.path.join(TEST_DIR, "data") + + +class TestAwsXRaySamplingClient(TestCase): + @patch("requests.post") + def test_get_no_sampling_rules(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": []}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 0) + + @patch("requests.post") + def test_get_invalid_responses(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + with self.assertLogs(_logger, level="ERROR"): + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 0) + + @patch("requests.post") + def test_get_sampling_rule_missing_in_records(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{}]}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + with self.assertLogs(_logger, level="ERROR"): + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 0) + + @patch("requests.post") + def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock_post=None): + mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{"SamplingRule": {}}]}}) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_rules = client.get_sampling_rules() + self.assertTrue(len(sampling_rules) == 1) + + sampling_rule = sampling_rules[0] + self.assertEqual(sampling_rule.Attributes, {}) + self.assertEqual(sampling_rule.FixedRate, 0.0) + self.assertEqual(sampling_rule.HTTPMethod, "") + self.assertEqual(sampling_rule.Host, "") + self.assertEqual(sampling_rule.Priority, 10001) + self.assertEqual(sampling_rule.ReservoirSize, 0) + self.assertEqual(sampling_rule.ResourceARN, "") + self.assertEqual(sampling_rule.RuleARN, "") + self.assertEqual(sampling_rule.RuleName, "") + self.assertEqual(sampling_rule.ServiceName, "") + self.assertEqual(sampling_rule.ServiceType, "") + self.assertEqual(sampling_rule.URLPath, "") + self.assertEqual(sampling_rule.Version, 0) + + @patch("requests.post") + def test_get_three_sampling_rules(self, mock_post=None): + sampling_records = [] + with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file: + sample_response = json.load(file) + sampling_records = sample_response["SamplingRuleRecords"] + mock_post.return_value.configure_mock(**{"json.return_value": sample_response}) + file.close() + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_rules = client.get_sampling_rules() + self.assertEqual(len(sampling_rules), 3) + self.assertEqual(len(sampling_rules), len(sampling_records)) + self.validate_match_sampling_rules_properties_with_records(sampling_rules, sampling_records) + + def validate_match_sampling_rules_properties_with_records(self, sampling_rules, sampling_records): + for _, (sampling_rule, sampling_record) in enumerate(zip(sampling_rules, sampling_records)): + self.assertIsNotNone(sampling_rule.Attributes) + self.assertEqual(sampling_rule.Attributes, sampling_record["SamplingRule"]["Attributes"]) + self.assertIsNotNone(sampling_rule.FixedRate) + self.assertEqual(sampling_rule.FixedRate, sampling_record["SamplingRule"]["FixedRate"]) + self.assertIsNotNone(sampling_rule.HTTPMethod) + self.assertEqual(sampling_rule.HTTPMethod, sampling_record["SamplingRule"]["HTTPMethod"]) + self.assertIsNotNone(sampling_rule.Host) + self.assertEqual(sampling_rule.Host, sampling_record["SamplingRule"]["Host"]) + self.assertIsNotNone(sampling_rule.Priority) + self.assertEqual(sampling_rule.Priority, sampling_record["SamplingRule"]["Priority"]) + self.assertIsNotNone(sampling_rule.ReservoirSize) + self.assertEqual(sampling_rule.ReservoirSize, sampling_record["SamplingRule"]["ReservoirSize"]) + self.assertIsNotNone(sampling_rule.ResourceARN) + self.assertEqual(sampling_rule.ResourceARN, sampling_record["SamplingRule"]["ResourceARN"]) + self.assertIsNotNone(sampling_rule.RuleARN) + self.assertEqual(sampling_rule.RuleARN, sampling_record["SamplingRule"]["RuleARN"]) + self.assertIsNotNone(sampling_rule.RuleName) + self.assertEqual(sampling_rule.RuleName, sampling_record["SamplingRule"]["RuleName"]) + self.assertIsNotNone(sampling_rule.ServiceName) + self.assertEqual(sampling_rule.ServiceName, sampling_record["SamplingRule"]["ServiceName"]) + self.assertIsNotNone(sampling_rule.ServiceType) + self.assertEqual(sampling_rule.ServiceType, sampling_record["SamplingRule"]["ServiceType"]) + self.assertIsNotNone(sampling_rule.URLPath) + self.assertEqual(sampling_rule.URLPath, sampling_record["SamplingRule"]["URLPath"]) + self.assertIsNotNone(sampling_rule.Version) + self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"])