Skip to content

Commit

Permalink
wip 2
Browse files Browse the repository at this point in the history
  • Loading branch information
jj22ee committed Feb 1, 2024
1 parent 6e3c11c commit 3e9f5dd
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AwsXRayRemoteSampler(Sampler):
log_level: custom log level configuration for remote sampler (Optional)
"""

__resource: Resource
# __resource: Resource
__polling_interval: int
__xray_client: AwsXRaySamplingClient

Expand All @@ -49,14 +49,14 @@ def __init__(
_logger.setLevel(log_level)

self.__xray_client = AwsXRaySamplingClient(endpoint, log_level=log_level)
self.__rule_polling_jitter = random.uniform(0.0, 5.0)
self.__rule_polling_jitter = 0#random.uniform(0.0, 5.0)
self.__polling_interval = polling_interval

self.__rule_cache_lock = Lock()
self.__rule_cache = RuleCache(self.__rule_cache_lock)
self.__rule_cache = RuleCache(resource, self.__rule_cache_lock)

# pylint: disable=unused-private-member
self.__resource = resource
# self.__resource = resource

self.__sampling_rules = []

Expand All @@ -74,7 +74,9 @@ def should_sample(
trace_state: "TraceState" = None,
) -> "SamplingResult":
# TODO: add sampling functionality
return ALWAYS_OFF.should_sample(self, parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state)
return ALWAYS_OFF.should_sample(
self, parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
def get_description(self) -> str:
Expand All @@ -85,8 +87,11 @@ def __get_and_update_sampling_rules(self):
sampling_rules = self.__xray_client.get_sampling_rules()

# TODO: Update sampling rules cache
_logger.info("Got Sampling Rules: %s", {json.dumps([ob.__dict__ for ob in sampling_rules])})
self.__rule_cache.update_sampling_rules(sampling_rules)
# _logger.info("Got Sampling Rules: %s", {json.dumps([ob.__dict__ for ob in sampling_rules])})
_logger.info("Rules Cache ID: %s", [f"{rule.id} -> {rule.sampling_rule.Attributes}" for rule in self.__rule_cache.rules])

# rule_map = {rule.sampling_rule.RuleName: rule for rule in self.rules}
def __start_sampling_rule_poller(self):
self.__get_and_update_sampling_rules()
# Schedule the next sampling rule poll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ def should_sample(
) -> "SamplingResult":
# TODO: add sampling functionality
current_time = time.process_time()
if (current_time - self.__last_take >= 1.0):
res = SamplingResult(
Decision.RECORD_AND_SAMPLE,
trace_state=_get_parent_trace_state(parent_context))
if current_time - self.__last_take >= 1.0:
res = SamplingResult(Decision.RECORD_AND_SAMPLE, trace_state=_get_parent_trace_state(parent_context))
self.__last_take = time.process_time()
return res
return self.__default_sampler.should_sample(self, parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state)
return self.__default_sampler.should_sample(
self, parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
def get_description(self) -> str:
description = "FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}"
description = (
"FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}"
)
return description
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import re

from opentelemetry.semconv.resource import CloudPlatformValues
from opentelemetry.util.types import Attributes

cloud_platform_mapping = {
CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function",
CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment",
CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance",
CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container",
CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container",
}

class Matcher:
@staticmethod
def wild_card_match(pattern:str = None, text:str = None) -> bool:
if pattern == "*":
return True
if text == None or pattern == None:
return False
if len(pattern) == 0:
return len(text) == 0
for c in pattern:
if c == "*" or c == "?":
return re.fullmatch(Matcher.to_regex_pattern(pattern), text) != None
return pattern == text

@staticmethod
def to_regex_pattern(rule_pattern: str) -> str:
token_start = -1
regex_pattern = ""
for i in range(0, len(rule_pattern)):
c = rule_pattern[i]
if c == '*' or c == '?':
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:i])
token_start = -1
if c == "*":
regex_pattern += ".*"
else:
regex_pattern += "."
else:
if token_start == -1:
token_start = i
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:])
return regex_pattern

def attribute_match(attributes: Attributes = None, rule_attributes: dict = None):
if rule_attributes == None or len(rule_attributes) == 0:
return True
if attributes == None or len(attributes) == 0 or len(rule_attributes) > len(attributes):
return False
matched_count = 0

for key, val in rule_attributes:
pattern = val
text_to_match = rule_attributes.get(key, None)
if text_to_match == None or Matcher.wild_card_match(pattern=pattern, text=text_to_match) == False:
return False
return True

#previous logic

# for key, val in attributes.items():
# text_to_match = val
# pattern = rule_attributes.get(key, None)
# if pattern == None:
# continue
# else:
# if Matcher.wild_card_match(pattern, text_to_match):
# matched_count += 1
# return matched_count == len(attributes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import random
from logging import getLogger
from threading import Lock, Timer
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler.aws_xray_sampling_client import AwsXRaySamplingClient
from amazon.opentelemetry.distro.sampler.rule_cache import RuleCache
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Sampler, SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes

_logger = getLogger(__name__)

DEFAULT_RULES_POLLING_INTERVAL = 300
DEFAULT_TARGET_POLLING_INTERVAL = 10
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"


class ReservoirSampler(Sampler):
def __init__(
self,
quota=None,
):
self.quota = quota
self.quota_balance = 0
pass

# pylint: disable=no-self-use
def should_sample(
self,
parent_context: Optional["Context"],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence["Link"] = None,
trace_state: "TraceState" = None,
) -> "SamplingResult":
# TODO: add sampling functionality
pass


Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import time
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler.sampling_statistics_document import SamplingStatisticsDocument
from opentelemetry.context import Context
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class Rule:
def __init__(self, sampling_rule, id):
self.sampling_rule = sampling_rule
self.rervoir_sampler = None
self.fixed_rate_sampler = None
self.statistics = SamplingStatisticsDocument()
self.next_target_fetch_time = time.process_time()
self.id = id

def should_sample(
self,
parent_context: Optional["Context"],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence["Link"] = None,
trace_state: "TraceState" = None,
):
pass

def matches(self, *args, **kwargs):
attributes = {}
http_target = None
http_url = None
http_method = None
http_host = None

http_target = attributes.get(Attributes.HTTP_TARGET, None)
http_method = attributes.get(Attributes.HTTP_METHOD, None)
http_url = attributes.get(Attributes.HTTP_URL, None)
http_host = attributes.get(Attributes.HTTP_HOST, None)

# target may be in url
if http_target is None and http_url is not None:
scheme_end_index = http_url.find("://")
# Per spec, http.url is always populated with scheme://host/target. If scheme doesn't
# match, assume it's bad instrumentation and ignore.
if scheme_end_index > -1:
path_index = http_url.find("/", scheme_end_index + len("://"))
if path_index == -1:
http_target = "/"
else:
http_target = http_url[path_index:]

service_name = attributes.get("service.name", "")

#rule_match_here
"""
// URL path may be in either http.target or http.url
if (httpTarget == null && httpUrl != null)
{
int schemeEndIndex = httpUrl.IndexOf("://", StringComparison.Ordinal);
// Per spec, http.url is always populated with scheme://host/target. If scheme doesn't
// match, assume it's bad instrumentation and ignore.
if (schemeEndIndex > 0)
{
int pathIndex = httpUrl.IndexOf('/', schemeEndIndex + "://".Length);
if (pathIndex < 0)
{
httpTarget = "/";
}
else
{
httpTarget = httpUrl.Substring(pathIndex);
}
}
}
SpanAttributes.HTTP_SCHEME: self.scope["scheme"],
SpanAttributes.NET_HOST_PORT: self.scope["server"][1],
SpanAttributes.HTTP_HOST: self.scope["server"][0],
SpanAttributes.HTTP_FLAVOR: self.scope["http_version"],
SpanAttributes.HTTP_TARGET: self.scope["path"],
SpanAttributes.HTTP_URL: f'{self.scope["scheme"]}://{self.scope["server"][0]}{self.scope["path"]}',
SpanAttributes.NET_PEER_IP: self.scope["client"][0],
SpanAttributes.NET_PEER_PORT: self.scope["client"][1],
SpanAttributes.HTTP_STATUS_CODE: 200,
if (samplingParameters.Tags is not null)
{
foreach (var tag in samplingParameters.Tags)
{
if (tag.Key.Equals(SemanticConventions.AttributeHttpTarget, StringComparison.Ordinal))
{
httpTarget = (string?)tag.Value;
}
else if (tag.Key.Equals(SemanticConventions.AttributeHttpUrl, StringComparison.Ordinal))
{
httpUrl = (string?)tag.Value;
}
else if (tag.Key.Equals(SemanticConventions.AttributeHttpMethod, StringComparison.Ordinal))
{
httpMethod = (string?)tag.Value;
}
else if (tag.Key.Equals(SemanticConventions.AttributeHttpHost, StringComparison.Ordinal))
{
httpHost = (string?)tag.Value;
}
}
}
"""


# pass
Loading

0 comments on commit 3e9f5dd

Please sign in to comment.