Skip to content

Commit

Permalink
remote sampling - rules caching and rules matching
Browse files Browse the repository at this point in the history
  • Loading branch information
jj22ee committed Feb 6, 2024
1 parent 8dcaba7 commit 01b9b16
Show file tree
Hide file tree
Showing 13 changed files with 937 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class _AwsXRaySamplingClient:
def __init__(self, endpoint=None, log_level=None):
def __init__(self, endpoint: str = None, log_level: str = None):
# Override default log level
if log_level is not None:
_logger.setLevel(log_level)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _FallbackSampler(Sampler):
def __init__(self):
# TODO: Add Reservoir sampler
# pylint: disable=unused-private-member
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)

# pylint: disable=no-self-use
def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
# TODO: add reservoir + fixed rate sampling
return ALWAYS_ON.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
def get_description(self) -> str:
description = (
"FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}"
)
return description
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import re

from opentelemetry.semconv.resource import CloudPlatformValues
from opentelemetry.util.types import Attributes

cloud_platform_mapping = {
CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function",
CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment",
CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance",
CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container",
CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container",
}


class _Matcher:
@staticmethod
def wild_card_match(text: str = None, pattern: str = None) -> bool:
if pattern == "*":
return True
if text is None or pattern is None:
return False
if len(pattern) == 0:
return len(text) == 0
for char in pattern:
if char in ("*", "?"):
return re.fullmatch(_Matcher.to_regex_pattern(pattern), text) is not None
return pattern == text

@staticmethod
def to_regex_pattern(rule_pattern: str) -> str:
token_start = -1
regex_pattern = ""
for index, char in enumerate(rule_pattern):
char = rule_pattern[index]
if char in ("*", "?"):
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:index])
token_start = -1
if char == "*":
regex_pattern += ".*"
else:
regex_pattern += "."
else:
if token_start == -1:
token_start = index
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:])
return regex_pattern

@staticmethod
def attribute_match(attributes: Attributes = None, rule_attributes: dict = None) -> bool:
if rule_attributes is None or len(rule_attributes) == 0:
return True
if attributes is None or len(attributes) == 0 or len(rule_attributes) > len(attributes):
return False

matched_count = 0
for key, val in attributes.items():
text_to_match = val
pattern = rule_attributes.get(key, None)
if pattern is None:
continue
if _Matcher.wild_card_match(text_to_match, pattern):
matched_count += 1
return matched_count == len(rule_attributes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, SamplingResult
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _Rule:
def __init__(self, sampling_rule: _SamplingRule):
self.sampling_rule = sampling_rule
# TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object
# TODO add statistics
# TODO change to rate limiter given rate, add fixed rate sampler
self.reservoir_sampler = ALWAYS_ON
# self.fixed_rate_sampler = None
# TODO add clientId

def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
return self.reservoir_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

def matches(self, resource: Resource, attributes: Attributes) -> bool:
http_target = None
http_url = None
http_method = None
http_host = None
service_name = None

if attributes is not None:
http_target = attributes.get(SpanAttributes.HTTP_TARGET, None)
http_method = attributes.get(SpanAttributes.HTTP_METHOD, None)
http_url = attributes.get(SpanAttributes.HTTP_URL, None)
http_host = attributes.get(SpanAttributes.HTTP_HOST, None)
# NOTE: The above span attribute keys are deprecated in favor of:
# URL_PATH/URL_QUERY, HTTP_REQUEST_METHOD, URL_FULL, SERVER_ADDRESS/SERVER_PORT
# For now, the old attribute keys are kept for consistency with other centralized samplers

# Resource shouldn't be none as it should default to empty resource
if resource is not None:
service_name = resource.attributes.get(ResourceAttributes.SERVICE_NAME, "")

# target may be in url
if http_target is None and http_url is not None:
scheme_end_index = http_url.find("://")
# Per spec, http.url is always populated with scheme://host/target. If scheme doesn't
# match, assume it's bad instrumentation and ignore.
if scheme_end_index > -1:
path_index = http_url.find("/", scheme_end_index + len("://"))
if path_index == -1:
http_target = "/"
else:
http_target = http_url[path_index:]

return (
_Matcher.attribute_match(attributes, self.sampling_rule.Attributes)
and _Matcher.wild_card_match(http_target, self.sampling_rule.URLPath)
and _Matcher.wild_card_match(http_method, self.sampling_rule.HTTPMethod)
and _Matcher.wild_card_match(http_host, self.sampling_rule.Host)
and _Matcher.wild_card_match(service_name, self.sampling_rule.ServiceName)
and _Matcher.wild_card_match(self.get_service_type(resource), self.sampling_rule.ServiceType)
)

# pylint: disable=no-self-use
def get_service_type(self, resource: Resource) -> str:
if resource is None:
return ""

cloud_platform = resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM, None)
if cloud_platform is None:
return ""

return cloud_platform_mapping.get(cloud_platform, "")
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import datetime
from logging import getLogger
from threading import Lock
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._rule import _Rule
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes

_logger = getLogger(__name__)

CACHE_TTL_SECONDS = 3600


class _RuleCache:
rules: [_Rule] = []

def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock):
self.__cache_lock = lock
self.__resource = resource
self._fallback_sampler = fallback_sampler
self._date_time = date_time
self._last_modified = self._date_time.datetime.now()

def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
for rule in self.rules:
if rule.matches(self.__resource, attributes):
return rule.should_sample(
parent_context,
trace_id,
name,
kind=kind,
attributes=attributes,
links=links,
trace_state=trace_state,
)

return self._fallback_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
new_sampling_rules.sort()
temp_rules = []
for sampling_rule in new_sampling_rules:
if sampling_rule.RuleName == "":
_logger.info("sampling rule without rule name is not supported")
continue
if sampling_rule.Version != 1:
_logger.info("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
continue
temp_rules.append(_Rule(copy.deepcopy(sampling_rule)))

self.__cache_lock.acquire()

# map list of rules by each rule's sampling_rule name
rule_map = {rule.sampling_rule.RuleName: rule for rule in self.rules}

# If a sampling rule has not changed, keep its respective rule in the cache.
for index, new_rule in enumerate(temp_rules):
rule_name_to_check = new_rule.sampling_rule.RuleName
if rule_name_to_check in rule_map:
previous_rule = rule_map[rule_name_to_check]
if new_rule.sampling_rule == previous_rule.sampling_rule:
temp_rules[index] = previous_rule
self.rules = temp_rules
self._last_modified = datetime.datetime.now()

self.__cache_lock.release()

def expired(self) -> bool:
self.__cache_lock.acquire()
try:
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
finally:
self.__cache_lock.release()
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,29 @@ def __init__(
self.ServiceType = ServiceType if ServiceType is not None else ""
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else 0

def __lt__(self, other) -> bool:
if self.Priority == other.Priority:
# String order priority example:
# "A","Abc","a","ab","abc","abcdef"
return self.RuleName < other.RuleName
return self.Priority < other.Priority

def __eq__(self, other: object) -> bool:
if not isinstance(other, _SamplingRule):
return False
return (
self.FixedRate == other.FixedRate
and self.HTTPMethod == other.HTTPMethod
and self.Host == other.Host
and self.Priority == other.Priority
and self.ReservoirSize == other.ReservoirSize
and self.ResourceARN == other.ResourceARN
and self.RuleARN == other.RuleARN
and self.RuleName == other.RuleName
and self.ServiceName == other.ServiceName
and self.ServiceType == other.ServiceType
and self.URLPath == other.URLPath
and self.Version == other.Version
and self.Attributes == other.Attributes
)
Loading

0 comments on commit 01b9b16

Please sign in to comment.