Skip to content

Commit

Permalink
lint, fix test, debug log
Browse files Browse the repository at this point in the history
  • Loading branch information
jj22ee committed Feb 14, 2024
1 parent bc5617e commit 24fc63d
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def get_sampling_rules(self) -> [_SamplingRule]:
_logger.error("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.error("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.error("Error occurred when attempting to fetch rules: %s", err)

return sampling_rules

def get_sampling_targets_response(self, statistics: [dict]) -> _SamplingTargetResponse:
sampling_targets_response = _SamplingTargetResponse(LastRuleModification=None, SamplingTargetDocuments=None, UnprocessedStatistics=None)
sampling_targets_response = _SamplingTargetResponse(
LastRuleModification=None, SamplingTargetDocuments=None, UnprocessedStatistics=None
)
headers = {"content-type": "application/json"}
try:
xray_response = requests.post(
Expand All @@ -64,22 +67,20 @@ def get_sampling_targets_response(self, statistics: [dict]) -> _SamplingTargetRe
json={"SamplingStatisticsDocuments": statistics},
)
if xray_response is None:
_logger.error("GetSamplingTargets response is None. Unable to update targets.")
_logger.debug("GetSamplingTargets response is None. Unable to update targets.")
return sampling_targets_response
xray_response_json = xray_response.json()
if (
"SamplingTargetDocuments" not in xray_response_json
or "LastRuleModification" not in xray_response_json
):
_logger.error("getSamplingTargets response is invalid. Unable to update targets.")
if "SamplingTargetDocuments" not in xray_response_json or "LastRuleModification" not in xray_response_json:
_logger.debug("getSamplingTargets response is invalid. Unable to update targets.")
return sampling_targets_response

sampling_targets_response = _SamplingTargetResponse(**xray_response_json)
except requests.exceptions.RequestException as req_err:
_logger.error("Request error occurred: %s", req_err)
_logger.debug("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.error("Error in decoding JSON response: %s", json_err)
_logger.debug("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.error("Error occurred when attempting to fetch targets: %s", err)
_logger.debug("Error occurred when attempting to fetch targets: %s", err)

return sampling_targets_response
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def should_sample(
for rule_applier in self.__rule_appliers:
if rule_applier.matches(self.__resource, attributes):
return rule_applier.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
parent_context,
trace_id,
name,
kind=kind,
attributes=attributes,
links=links,
trace_state=trace_state,
)

# Should not ever reach fallback sampler as default rule is able to match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool:
# also check `HTTP_TARGET/HTTP_URL/HTTP_METHOD/HTTP_HOST` respectively as backup
url_path = attributes.get(SpanAttributes.URL_PATH, attributes.get(SpanAttributes.HTTP_TARGET, None))
url_full = attributes.get(SpanAttributes.URL_FULL, attributes.get(SpanAttributes.HTTP_URL, None))
http_request_method = attributes.get(SpanAttributes.HTTP_REQUEST_METHOD, attributes.get(SpanAttributes.HTTP_METHOD, None))
server_address = attributes.get(SpanAttributes.SERVER_ADDRESS, attributes.get(SpanAttributes.HTTP_HOST, None))
http_request_method = attributes.get(
SpanAttributes.HTTP_REQUEST_METHOD, attributes.get(SpanAttributes.HTTP_METHOD, None)
)
server_address = attributes.get(
SpanAttributes.SERVER_ADDRESS, attributes.get(SpanAttributes.HTTP_HOST, None)
)

# Resource shouldn't be none as it should default to empty resource
if resource is not None:
Expand Down Expand Up @@ -163,21 +167,24 @@ def __get_arn(self, resource: Resource, attributes: Attributes) -> str:
arn = resource.attributes.get(ResourceAttributes.AWS_ECS_CONTAINER_ARN, None)
if arn is not None:
return arn
if resource is not None and resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM) == CloudPlatformValues.AWS_LAMBDA.value:
if (
resource is not None
and resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM) == CloudPlatformValues.AWS_LAMBDA.value
):
return self.__get_lambda_arn(resource, attributes)
return ""

def __get_lambda_arn(self, resource: Resource, attributes: Attributes) -> str:
arn = resource.attributes.get(ResourceAttributes.CLOUD_RESOURCE_ID,
resource.attributes.get(ResourceAttributes.FAAS_ID, None))
arn = resource.attributes.get(
ResourceAttributes.CLOUD_RESOURCE_ID, resource.attributes.get(ResourceAttributes.FAAS_ID, None)
)
if arn is not None:
return arn

# Note from `SpanAttributes.CLOUD_RESOURCE_ID`:
# "On some cloud providers, it may not be possible to determine the full ID at startup,
# so it may be necessary to set cloud.resource_id as a span attribute instead."
arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID,
attributes.get("faas.id", None))
arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID, attributes.get("faas.id", None))
if arn is not None:
return arn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

_logger = getLogger(__name__)


# Disable snake_case naming style so this class can match the sampling rules response from X-Ray
# pylint: disable=invalid-name
class _SamplingTarget:
Expand Down Expand Up @@ -49,12 +50,12 @@ def __init__(
try:
self.SamplingTargetDocuments.append(_SamplingTarget(**document))
except TypeError as e:
_logger.debug("TypeError occurred: ", e)
_logger.debug("TypeError occurred: %s", e)

self.UnprocessedStatistics: [_UnprocessedStatistics] = []
if UnprocessedStatistics is not None:
for unprocessed in UnprocessedStatistics:
try:
self.UnprocessedStatistics.append(_UnprocessedStatistics(**unprocessed))
except TypeError as e:
_logger.debug("TypeError occurred: ", e)
_logger.debug("TypeError occurred: %s", e)
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@
from unittest import TestCase
from unittest.mock import patch

from mock_clock import MockClock

from amazon.opentelemetry.distro.sampler import _clock
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import Decision

TEST_DIR = os.path.dirname(os.path.realpath(__file__))
DATA_DIR = os.path.join(TEST_DIR, "data")


class TestAwsXRayRemoteSampler(TestCase):
def test_create_remote_sampler_with_empty_resource(self):
rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
Expand Down Expand Up @@ -63,31 +59,35 @@ class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

if kwargs["url"] == 'http://127.0.0.1:2000/GetSamplingRules':
if kwargs["url"] == "http://127.0.0.1:2000/GetSamplingRules":
with open(f"{DATA_DIR}/test-remote-sampler_sampling-rules-response-sample.json", encoding="UTF-8") as file:
sample_response = json.load(file)
file.close()
return MockResponse(sample_response, 200)
elif kwargs["url"] == 'http://127.0.0.1:2000/SamplingTargets':
with open(f"{DATA_DIR}/test-remote-sampler_sampling-targets-response-sample.json", encoding="UTF-8") as file:
if kwargs["url"] == "http://127.0.0.1:2000/SamplingTargets":
with open(
f"{DATA_DIR}/test-remote-sampler_sampling-targets-response-sample.json", encoding="UTF-8"
) as file:
sample_response = json.load(file)
file.close()
return MockResponse(sample_response, 200)
return MockResponse(None, 404)


@patch("requests.post", side_effect=mocked_requests_get)
@patch('amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS', new=2)
@patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", new=2)
def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None):
rs = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
)

time.sleep(1.0)
self.assertEqual(rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "test")
self.assertEqual(
rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "test"
)
self.assertEqual(rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.DROP)

# wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def validate_match_sampling_rules_properties_with_records(self, sampling_rules,

@patch("requests.post")
def test_get_sampling_targets(self, mock_post=None):
sampling_targets = []
with open(f"{DATA_DIR}/get-sampling-targets-response-sample.json", encoding="UTF-8") as file:
sample_response = json.load(file)
mock_post.return_value.configure_mock(**{"json.return_value": sample_response})
Expand All @@ -119,11 +118,15 @@ def test_get_sampling_targets(self, mock_post=None):

@patch("requests.post")
def test_get_invalid_sampling_targets(self, mock_post=None):
mock_post.return_value.configure_mock(**{"json.return_value": {
"LastRuleModification": None,
"SamplingTargetDocuments": None,
"UnprocessedStatistics": None
}})
mock_post.return_value.configure_mock(
**{
"json.return_value": {
"LastRuleModification": None,
"SamplingTargetDocuments": None,
"UnprocessedStatistics": None,
}
}
)
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
sampling_targets_response = client.get_sampling_targets_response(statistics=[])
self.assertEqual(sampling_targets_response.SamplingTargetDocuments, [])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
class TestClock(TestCase):
def test_from_timestamp(self):
pass

def test_time_delta(self):
clock = _Clock()
dt = clock.from_timestamp(1707551387.0)
delta = clock.time_delta(3600)
new_dt = dt + delta
self.assertTrue(new_dt.timestamp() - dt.timestamp() == 3600)
self.assertTrue(new_dt.timestamp() - dt.timestamp() == 3600)
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# # SPDX-License-Identifier: Apache-2.0
# from decimal import Decimal
# from threading import Lock
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from unittest import TestCase

from mock_clock import MockClock

from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._rate_limiter import _RateLimiter
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Decision


class TestRateLimitingSampler(TestCase):
# pylint: disable=too-many-branches
def test_should_sample(self):
time_now = datetime.datetime.fromtimestamp(1707551387.0)
clock = MockClock(time_now)
Expand All @@ -26,64 +24,63 @@ def test_should_sample(self):

# 0 seconds passed, 0 quota available
sampled = 0
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 0)

# 0.4 seconds passed, 0.4 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 0)

# 0.8 seconds passed, 0.8 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 0)

# 1.2 seconds passed, 1 quota consumed, 0 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 1)

# 1.6 seconds passed, 0.4 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 0)


# 2.0 seconds passed, 0.8 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 0)

# 2.4 seconds passed, one more quota consumed, 0 quota available
sampled = 0
clock.add_time(0.4)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 1)

# 100 seconds passed, only one quota can be consumed
# 30 seconds passed, only one quota can be consumed
sampled = 0
clock.add_time(100)
for _ in range(0,100):
for _ in range(0, 30):
if sampler.should_sample(None, 1234, "name").decision != Decision.DROP:
sampled += 1
self.assertEqual(sampled, 1)
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# # SPDX-License-Identifier: Apache-2.0
# from decimal import Decimal
# from threading import Lock
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from unittest import TestCase

Expand All @@ -17,28 +15,28 @@ def test_try_spend(self):
rate_limiter = _RateLimiter(1, 30, clock)

spent = 0
for _ in range(0,100):
for _ in range(0, 100):
if rate_limiter.try_spend(1, False):
spent += 1
self.assertEqual(spent, 0)

spent = 0
clock.add_time(0.5)
for _ in range(0,100):
for _ in range(0, 100):
if rate_limiter.try_spend(1, False):
spent += 1
self.assertEqual(spent, 15)

spent = 0
clock.add_time(1)
for _ in range(0,100):
for _ in range(0, 100):
if rate_limiter.try_spend(1, True):
spent += 1
self.assertEqual(spent, 1)

spent = 0
clock.add_time(1000)
for _ in range(0,100):
for _ in range(0, 100):
if rate_limiter.try_spend(1, False):
spent += 1
self.assertEqual(spent, 30)
Loading

0 comments on commit 24fc63d

Please sign in to comment.