From 01b9b1614741512db6c2bbb25934fe81823724ce Mon Sep 17 00:00:00 2001 From: jjllee Date: Tue, 30 Jan 2024 12:41:27 -0800 Subject: [PATCH] remote sampling - rules caching and rules matching --- .../sampler/_aws_xray_sampling_client.py | 2 +- .../distro/sampler/_fallback_sampler.py | 39 +++ .../opentelemetry/distro/sampler/_matcher.py | 67 ++++ .../opentelemetry/distro/sampler/_rule.py | 91 ++++++ .../distro/sampler/_rule_cache.py | 94 ++++++ .../distro/sampler/_sampling_rule.py | 26 ++ .../distro/sampler/aws_xray_remote_sampler.py | 45 ++- .../get-sampling-rules-response-sample-2.json | 48 +++ .../sampler/test_aws_xray_sampling_client.py | 2 +- .../distro/sampler/test_matcher.py | 62 ++++ .../opentelemetry/distro/sampler/test_rule.py | 303 ++++++++++++++++++ .../distro/sampler/test_rule_cache.py | 87 +++++ .../distro/sampler/test_sampling_rule.py | 87 +++++ 13 files changed, 937 insertions(+), 16 deletions(-) create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_matcher.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule.py create mode 100644 aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample-2.json create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_matcher.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py create mode 100644 aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule.py diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py index 09c56e386..770d175aa 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py @@ -11,7 +11,7 @@ class _AwsXRaySamplingClient: - def __init__(self, endpoint=None, log_level=None): + def __init__(self, endpoint: str = None, log_level: str = None): # Override default log level if log_level is not None: _logger.setLevel(log_level) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py new file mode 100644 index 000000000..986ee0f16 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Sequence + +from opentelemetry.context import Context +from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + + +class _FallbackSampler(Sampler): + def __init__(self): + # TODO: Add Reservoir sampler + # pylint: disable=unused-private-member + self.__fixed_rate_sampler = TraceIdRatioBased(0.05) + + # pylint: disable=no-self-use + def should_sample( + self, + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> SamplingResult: + # TODO: add reservoir + fixed rate sampling + return ALWAYS_ON.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + # pylint: disable=no-self-use + def get_description(self) -> str: + description = ( + "FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}" + ) + return description diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_matcher.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_matcher.py new file mode 100644 index 000000000..88539a284 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_matcher.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import re + +from opentelemetry.semconv.resource import CloudPlatformValues +from opentelemetry.util.types import Attributes + +cloud_platform_mapping = { + CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function", + CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment", + CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance", + CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container", + CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container", +} + + +class _Matcher: + @staticmethod + def wild_card_match(text: str = None, pattern: str = None) -> bool: + if pattern == "*": + return True + if text is None or pattern is None: + return False + if len(pattern) == 0: + return len(text) == 0 + for char in pattern: + if char in ("*", "?"): + return re.fullmatch(_Matcher.to_regex_pattern(pattern), text) is not None + return pattern == text + + @staticmethod + def to_regex_pattern(rule_pattern: str) -> str: + token_start = -1 + regex_pattern = "" + for index, char in enumerate(rule_pattern): + char = rule_pattern[index] + if char in ("*", "?"): + if token_start != -1: + regex_pattern += re.escape(rule_pattern[token_start:index]) + token_start = -1 + if char == "*": + regex_pattern += ".*" + else: + regex_pattern += "." + else: + if token_start == -1: + token_start = index + if token_start != -1: + regex_pattern += re.escape(rule_pattern[token_start:]) + return regex_pattern + + @staticmethod + def attribute_match(attributes: Attributes = None, rule_attributes: dict = None) -> bool: + if rule_attributes is None or len(rule_attributes) == 0: + return True + if attributes is None or len(attributes) == 0 or len(rule_attributes) > len(attributes): + return False + + matched_count = 0 + for key, val in attributes.items(): + text_to_match = val + pattern = rule_attributes.get(key, None) + if pattern is None: + continue + if _Matcher.wild_card_match(text_to_match, pattern): + matched_count += 1 + return matched_count == len(rule_attributes) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule.py new file mode 100644 index 000000000..fdfff39e1 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Sequence + +from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from opentelemetry.context import Context +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import ALWAYS_ON, SamplingResult +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + + +class _Rule: + def __init__(self, sampling_rule: _SamplingRule): + self.sampling_rule = sampling_rule + # TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object + # TODO add statistics + # TODO change to rate limiter given rate, add fixed rate sampler + self.reservoir_sampler = ALWAYS_ON + # self.fixed_rate_sampler = None + # TODO add clientId + + def should_sample( + self, + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> SamplingResult: + return self.reservoir_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + def matches(self, resource: Resource, attributes: Attributes) -> bool: + http_target = None + http_url = None + http_method = None + http_host = None + service_name = None + + if attributes is not None: + http_target = attributes.get(SpanAttributes.HTTP_TARGET, None) + http_method = attributes.get(SpanAttributes.HTTP_METHOD, None) + http_url = attributes.get(SpanAttributes.HTTP_URL, None) + http_host = attributes.get(SpanAttributes.HTTP_HOST, None) + # NOTE: The above span attribute keys are deprecated in favor of: + # URL_PATH/URL_QUERY, HTTP_REQUEST_METHOD, URL_FULL, SERVER_ADDRESS/SERVER_PORT + # For now, the old attribute keys are kept for consistency with other centralized samplers + + # Resource shouldn't be none as it should default to empty resource + if resource is not None: + service_name = resource.attributes.get(ResourceAttributes.SERVICE_NAME, "") + + # target may be in url + if http_target is None and http_url is not None: + scheme_end_index = http_url.find("://") + # Per spec, http.url is always populated with scheme://host/target. If scheme doesn't + # match, assume it's bad instrumentation and ignore. + if scheme_end_index > -1: + path_index = http_url.find("/", scheme_end_index + len("://")) + if path_index == -1: + http_target = "/" + else: + http_target = http_url[path_index:] + + return ( + _Matcher.attribute_match(attributes, self.sampling_rule.Attributes) + and _Matcher.wild_card_match(http_target, self.sampling_rule.URLPath) + and _Matcher.wild_card_match(http_method, self.sampling_rule.HTTPMethod) + and _Matcher.wild_card_match(http_host, self.sampling_rule.Host) + and _Matcher.wild_card_match(service_name, self.sampling_rule.ServiceName) + and _Matcher.wild_card_match(self.get_service_type(resource), self.sampling_rule.ServiceType) + ) + + # pylint: disable=no-self-use + def get_service_type(self, resource: Resource) -> str: + if resource is None: + return "" + + cloud_platform = resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM, None) + if cloud_platform is None: + return "" + + return cloud_platform_mapping.get(cloud_platform, "") diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py new file mode 100644 index 000000000..f182d9c8f --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py @@ -0,0 +1,94 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import copy +import datetime +from logging import getLogger +from threading import Lock +from typing import Optional, Sequence + +from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler +from amazon.opentelemetry.distro.sampler._rule import _Rule +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from opentelemetry.context import Context +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import SamplingResult +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + +_logger = getLogger(__name__) + +CACHE_TTL_SECONDS = 3600 + + +class _RuleCache: + rules: [_Rule] = [] + + def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock): + self.__cache_lock = lock + self.__resource = resource + self._fallback_sampler = fallback_sampler + self._date_time = date_time + self._last_modified = self._date_time.datetime.now() + + def should_sample( + self, + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> SamplingResult: + for rule in self.rules: + if rule.matches(self.__resource, attributes): + return rule.should_sample( + parent_context, + trace_id, + name, + kind=kind, + attributes=attributes, + links=links, + trace_state=trace_state, + ) + + return self._fallback_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None: + new_sampling_rules.sort() + temp_rules = [] + for sampling_rule in new_sampling_rules: + if sampling_rule.RuleName == "": + _logger.info("sampling rule without rule name is not supported") + continue + if sampling_rule.Version != 1: + _logger.info("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName) + continue + temp_rules.append(_Rule(copy.deepcopy(sampling_rule))) + + self.__cache_lock.acquire() + + # map list of rules by each rule's sampling_rule name + rule_map = {rule.sampling_rule.RuleName: rule for rule in self.rules} + + # If a sampling rule has not changed, keep its respective rule in the cache. + for index, new_rule in enumerate(temp_rules): + rule_name_to_check = new_rule.sampling_rule.RuleName + if rule_name_to_check in rule_map: + previous_rule = rule_map[rule_name_to_check] + if new_rule.sampling_rule == previous_rule.sampling_rule: + temp_rules[index] = previous_rule + self.rules = temp_rules + self._last_modified = datetime.datetime.now() + + self.__cache_lock.release() + + def expired(self) -> bool: + self.__cache_lock.acquire() + try: + return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS) + finally: + self.__cache_lock.release() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py index fe8f4f60b..1fabbdc90 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py @@ -35,3 +35,29 @@ def __init__( self.ServiceType = ServiceType if ServiceType is not None else "" self.URLPath = URLPath if URLPath is not None else "" self.Version = Version if Version is not None else 0 + + def __lt__(self, other) -> bool: + if self.Priority == other.Priority: + # String order priority example: + # "A","Abc","a","ab","abc","abcdef" + return self.RuleName < other.RuleName + return self.Priority < other.Priority + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _SamplingRule): + return False + return ( + self.FixedRate == other.FixedRate + and self.HTTPMethod == other.HTTPMethod + and self.Host == other.Host + and self.Priority == other.Priority + and self.ReservoirSize == other.ReservoirSize + and self.ResourceARN == other.ResourceARN + and self.RuleARN == other.RuleARN + and self.RuleName == other.RuleName + and self.ServiceName == other.ServiceName + and self.ServiceType == other.ServiceType + and self.URLPath == other.URLPath + and self.Version == other.Version + and self.Attributes == other.Attributes + ) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py index 724da64f9..bfc1a20c7 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py @@ -1,16 +1,19 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import json +import datetime +import random from logging import getLogger -from threading import Timer +from threading import Lock, Timer from typing import Optional, Sequence from typing_extensions import override from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient +from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler +from amazon.opentelemetry.distro.sampler._rule_cache import _RuleCache from opentelemetry.context import Context from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Sampler, SamplingResult +from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult from opentelemetry.trace import Link, SpanKind from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes @@ -48,8 +51,13 @@ def __init__( if log_level is not None: _logger.setLevel(log_level) + self.__date_time = datetime self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level) + self.__rule_polling_jitter = random.uniform(0.0, 5.0) self.__polling_interval = polling_interval + self.__fallback_sampler = _FallbackSampler() + + # TODO add client id # pylint: disable=unused-private-member if resource is not None: @@ -58,6 +66,11 @@ def __init__( _logger.warning("OTel Resource provided is `None`. Defaulting to empty resource") self.__resource = Resource.get_empty() + self.__rule_cache_lock = Lock() + self.__rule_cache = _RuleCache( + self.__resource, self.__fallback_sampler, self.__date_time, self.__rule_cache_lock + ) + # Schedule the next rule poll now # Python Timers only run once, so they need to be recreated for every poll self._timer = Timer(0, self.__start_sampling_rule_poller) @@ -68,16 +81,22 @@ def __init__( @override def should_sample( self, - parent_context: Optional["Context"], + parent_context: Optional[Context], trace_id: int, name: str, kind: SpanKind = None, attributes: Attributes = None, - links: Sequence["Link"] = None, - trace_state: "TraceState" = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, ) -> SamplingResult: - # TODO: add sampling functionality - return ALWAYS_OFF.should_sample( + + if self.__rule_cache.expired(): + _logger.info("Rule cache is expired so using fallback sampling strategy") + return self.__fallback_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + return self.__rule_cache.should_sample( parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state ) @@ -87,15 +106,13 @@ def get_description(self) -> str: description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}" return description - def __get_and_update_sampling_rules(self): + def __get_and_update_sampling_rules(self) -> None: sampling_rules = self.__xray_client.get_sampling_rules() + self.__rule_cache.update_sampling_rules(sampling_rules) - # TODO: Update sampling rules cache - _logger.info("Got Sampling Rules: %s", {json.dumps([ob.__dict__ for ob in sampling_rules])}) - - def __start_sampling_rule_poller(self): + def __start_sampling_rule_poller(self) -> None: self.__get_and_update_sampling_rules() # Schedule the next sampling rule poll - self._timer = Timer(self.__polling_interval, self.__start_sampling_rule_poller) + self._timer = Timer(self.__polling_interval + self.__rule_polling_jitter, self.__start_sampling_rule_poller) self._timer.daemon = True self._timer.start() diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample-2.json b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample-2.json new file mode 100644 index 000000000..6bf24ebac --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-rules-response-sample-2.json @@ -0,0 +1,48 @@ +{ + "NextToken": null, + "SamplingRuleRecords": [ + { + "CreatedAt": 1.676038494E9, + "ModifiedAt": 1.676038494E9, + "SamplingRule": { + "Attributes": { + "foo": "bar", + "abc": "1234" + }, + "FixedRate": 0.05, + "HTTPMethod": "*", + "Host": "*", + "Priority": 10000, + "ReservoirSize": 100, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-east-1:999999999999:sampling-rule/Default", + "RuleName": "Default", + "ServiceName": "*", + "ServiceType": "*", + "URLPath": "*", + "Version": 1 + } + }, + { + "CreatedAt": 1.67799933E9, + "ModifiedAt": 1.67799933E9, + "SamplingRule": { + "Attributes": { + "abc": "1234" + }, + "FixedRate": 0.11, + "HTTPMethod": "*", + "Host": "*", + "Priority": 20, + "ReservoirSize": 1, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + "RuleName": "test", + "ServiceName": "*", + "ServiceType": "*", + "URLPath": "*", + "Version": 1 + } + } + ] +} \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py index 637e4545a..9affdc28d 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py @@ -62,7 +62,7 @@ def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock self.assertEqual(sampling_rule.Version, 0) @patch("requests.post") - def test_get_three_sampling_rules(self, mock_post=None): + def test_get_correct_number_of_sampling_rules(self, mock_post=None): sampling_records = [] with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file: sample_response = json.load(file) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_matcher.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_matcher.py new file mode 100644 index 000000000..c6946ef9f --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_matcher.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._matcher import _Matcher +from opentelemetry.util.types import Attributes + + +class TestMatcher(TestCase): + def test_wild_card_match(self): + test_cases = [ + [None, "*"], + ["", "*"], + ["HelloWorld", "*"], + ["HelloWorld", "HelloWorld"], + ["HelloWorld", "Hello*"], + ["HelloWorld", "*World"], + ["HelloWorld", "?ello*"], + ["HelloWorld", "Hell?W*d"], + ["Hello.World", "*.World"], + ["Bye.World", "*.World"], + ] + for test_case in test_cases: + self.assertTrue(_Matcher.wild_card_match(text=test_case[0], pattern=test_case[1])) + + def test_wild_card_not_match(self): + test_cases = [[None, "Hello*"], ["HelloWorld", None]] + for test_case in test_cases: + self.assertFalse(_Matcher.wild_card_match(text=test_case[0], pattern=test_case[1])) + + def test_attribute_matching(self): + attributes: Attributes = { + "dog": "bark", + "cat": "meow", + "cow": "mooo", + } + rule_attributes = { + "dog": "bar?", + "cow": "mooo", + } + + self.assertTrue(_Matcher.attribute_match(attributes, rule_attributes)) + + def test_attribute_matching_without_rule_attributes(self): + attributes = { + "dog": "bark", + "cat": "meow", + "cow": "mooo", + } + rule_attributes = {} + print("LENGTH %s", len(rule_attributes)) + + self.assertTrue(_Matcher.attribute_match(attributes, rule_attributes)) + + def test_attribute_matching_without_span_attributes(self): + attributes = {} + rule_attributes = { + "dog": "bar?", + "cow": "mooo", + } + + self.assertFalse(_Matcher.attribute_match(attributes, rule_attributes)) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule.py new file mode 100644 index 000000000..9bc5dbed1 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule.py @@ -0,0 +1,303 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._rule import _Rule +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from opentelemetry.sdk.resources import Resource +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.util.types import Attributes + +TEST_DIR = os.path.dirname(os.path.realpath(__file__)) +DATA_DIR = os.path.join(TEST_DIR, "data") + + +class TestRule(TestCase): + def test_rule_attribute_matching_from_xray_response(self): + default_rule = None + with open(f"{DATA_DIR}/get-sampling-rules-response-sample-2.json", encoding="UTF-8") as file: + sample_response = json.load(file) + print(sample_response) + all_rules = sample_response["SamplingRuleRecords"] + default_rule = _SamplingRule(**all_rules[0]["SamplingRule"]) + file.close() + + res = Resource.create( + attributes={ + ResourceAttributes.SERVICE_NAME: "test_service_name", + ResourceAttributes.CLOUD_PLATFORM: "test_cloud_platform", + } + ) + attr: Attributes = { + SpanAttributes.HTTP_TARGET: "target", + SpanAttributes.HTTP_METHOD: "method", + SpanAttributes.HTTP_URL: "url", + SpanAttributes.HTTP_HOST: "host", + "foo": "bar", + "abc": "1234", + } + + rule0 = _Rule(default_rule) + self.assertTrue(rule0.matches(res, attr)) + + def test_rule_matches_with_all_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "4?6", "ghi": "*89"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + # ResourceARN can only be "*" + # See: https://docs.aws.amazon.com/xray/latest/devguide/xray-console-sampling.html#xray-console-sampling-options # noqa: E501 + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::EKS::Container", + URLPath="/helloworld", + Version=1, + ) + + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } + + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(resource, attributes)) + + def test_rule_wild_card_attributes_matches_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={ + "attr1": "*", + "attr2": "*", + "attr3": "HelloWorld", + "attr4": "Hello*", + "attr5": "*World", + "attr6": "?ello*", + "attr7": "Hell?W*d", + "attr8": "*.World", + "attr9": "*.World", + }, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "attr1": "", + "attr2": "HelloWorld", + "attr3": "HelloWorld", + "attr4": "HelloWorld", + "attr5": "HelloWorld", + "attr6": "HelloWorld", + "attr7": "HelloWorld", + "attr8": "Hello.World", + "attr9": "Bye.World", + } + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(Resource.get_empty(), attributes)) + + def test_rule_wild_card_attributes_matches_http_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + SpanAttributes.HTTP_HOST: "localhost", + SpanAttributes.HTTP_METHOD: "GET", + SpanAttributes.HTTP_URL: "http://127.0.0.1:5000/helloworld", + } + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(Resource.get_empty(), attributes)) + + def test_rule_wild_card_attributes_matches_with_empty_attributes(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = {} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(resource, attributes)) + self.assertTrue(rule.matches(resource, None)) + self.assertTrue(rule.matches(Resource.get_empty(), attributes)) + self.assertTrue(rule.matches(Resource.get_empty(), None)) + self.assertTrue(rule.matches(None, attributes)) + self.assertTrue(rule.matches(None, None)) + + def test_rule_does_not_match_without_http_target(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="/helloworld", + Version=1, + ) + + attributes: Attributes = {} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertFalse(rule.matches(resource, attributes)) + + def test_rule_matches_with_http_target(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="/hello*", + Version=1, + ) + + attributes: Attributes = {SpanAttributes.HTTP_TARGET: "/helloworld"} + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_ec2", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(resource, attributes)) + + def test_rule_matches_with_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "456", "ghi": "789"}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } + + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertTrue(rule.matches(resource, attributes)) + + def test_rule_does_not_match_with_less_span_attributes(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "456", "ghi": "789"}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + } + + resource_attr: Resource = { + ResourceAttributes.SERVICE_NAME: "myServiceName", + ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + } + resource = Resource.create(attributes=resource_attr) + + rule = _Rule(sampling_rule) + self.assertFalse(rule.matches(resource, attributes)) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py new file mode 100644 index 000000000..acc763889 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from threading import Lock +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._rule_cache import CACHE_TTL_SECONDS, _RuleCache +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from opentelemetry.sdk.resources import Resource + + +class TestRuleCache(TestCase): + def test_cache_update_rules_and_sorts_rules(self): + cache = _RuleCache(None, None, datetime, Lock()) + self.assertTrue(len(cache.rules) == 0) + + rule1 = _SamplingRule(Priority=200, RuleName="only_one_rule", Version=1) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache.rules) == 1) + + rule1 = _SamplingRule(Priority=200, RuleName="abcdef", Version=1) + rule2 = _SamplingRule(Priority=100, RuleName="abc", Version=1) + rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) + rule4 = _SamplingRule(Priority=100, RuleName="ab", Version=1) + rule5 = _SamplingRule(Priority=100, RuleName="A", Version=1) + rule6 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) + rules = [rule1, rule2, rule3, rule4, rule5, rule6] + cache.update_sampling_rules(rules) + + self.assertTrue(len(cache.rules) == 6) + self.assertEqual(cache.rules[0].sampling_rule.RuleName, "abcdef") + self.assertEqual(cache.rules[1].sampling_rule.RuleName, "A") + self.assertEqual(cache.rules[2].sampling_rule.RuleName, "Abc") + self.assertEqual(cache.rules[3].sampling_rule.RuleName, "ab") + self.assertEqual(cache.rules[4].sampling_rule.RuleName, "abc") + self.assertEqual(cache.rules[5].sampling_rule.RuleName, "abcdef") + + def test_rule_cache_expiration_logic(self): + dt = datetime + cache = _RuleCache(None, Resource.get_empty(), dt, Lock()) + self.assertFalse(cache.expired()) + cache._last_modified = dt.datetime.now() - dt.timedelta(seconds=CACHE_TTL_SECONDS - 5) + self.assertFalse(cache.expired()) + cache._last_modified = dt.datetime.now() - dt.timedelta(seconds=CACHE_TTL_SECONDS + 1) + self.assertTrue(cache.expired()) + + def test_update_cache_with_only_one_rule_changed(self): + dt = datetime + cache = _RuleCache(None, Resource.get_empty(), dt, Lock()) + rule1 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) + rule2 = _SamplingRule(Priority=10, RuleName="ab", Version=1) + rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) + rules = [rule1, rule2, rule3] + cache.update_sampling_rules(rules) + + cache_rules_copy = cache.rules + + new_rule3 = _SamplingRule(Priority=5, RuleName="Abc", Version=1) + rules = [rule1, rule2, new_rule3] + cache.update_sampling_rules(rules) + + self.assertTrue(len(cache.rules) == 3) + self.assertEqual(cache.rules[0].sampling_rule.RuleName, "abcdef") + self.assertEqual(cache.rules[1].sampling_rule.RuleName, "Abc") + self.assertEqual(cache.rules[2].sampling_rule.RuleName, "ab") + + # Compare that only rule1 and rule2 objects have not changed due to new_rule3 even after sorting + self.assertTrue(cache_rules_copy[0] is cache.rules[0]) + self.assertTrue(cache_rules_copy[1] is cache.rules[2]) + self.assertTrue(cache_rules_copy[2] is not cache.rules[1]) + + def test_update_rules_removes_older_rule(self): + cache = _RuleCache(None, None, datetime, Lock()) + self.assertTrue(len(cache.rules) == 0) + + rule1 = _SamplingRule(Priority=200, RuleName="first_rule", Version=1) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache.rules) == 1) + self.assertEqual(cache.rules[0].sampling_rule.RuleName, "first_rule") + + rule1 = _SamplingRule(Priority=200, RuleName="second_rule", Version=1) + rules = [rule1] + cache.update_sampling_rules(rules) + self.assertTrue(len(cache.rules) == 1) + self.assertEqual(cache.rules[0].sampling_rule.RuleName, "second_rule") diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule.py new file mode 100644 index 000000000..b9e64a8e7 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule + + +class TestRuleCache(TestCase): + def test_sampling_rule_ordering(self): + rule1 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) + rule2 = _SamplingRule(Priority=100, RuleName="A", Version=1) + rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) + rule4 = _SamplingRule(Priority=100, RuleName="ab", Version=1) + rule5 = _SamplingRule(Priority=100, RuleName="abc", Version=1) + rule6 = _SamplingRule(Priority=200, RuleName="abcdef", Version=1) + + self.assertTrue(rule1 < rule2 < rule3 < rule4 < rule5 < rule6) + + def test_sampling_rule_equality(self): + sampling_rule = _SamplingRule( + Attributes={"abc": "123", "def": "4?6", "ghi": "*89"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::EKS::Container", + URLPath="/helloworld", + Version=1, + ) + + sampling_rule_attr_unordered = _SamplingRule( + Attributes={"ghi": "*89", "abc": "123", "def": "4?6"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::EKS::Container", + URLPath="/helloworld", + Version=1, + ) + + self.assertTrue(sampling_rule == sampling_rule_attr_unordered) + + sampling_rule_updated = _SamplingRule( + Attributes={"ghi": "*89", "abc": "123", "def": "4?6"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::EKS::Container", + URLPath="/helloworld_new", + Version=1, + ) + + sampling_rule_updated_2 = _SamplingRule( + Attributes={"abc": "128", "def": "4?6", "ghi": "*89"}, + FixedRate=0.11, + HTTPMethod="GET", + Host="localhost", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="myServiceName", + ServiceType="AWS::EKS::Container", + URLPath="/helloworld", + Version=1, + ) + + self.assertFalse(sampling_rule == sampling_rule_updated) + self.assertFalse(sampling_rule == sampling_rule_updated_2)