Skip to content
This repository has been archived by the owner on Jul 11, 2022. It is now read-only.

Commit

Permalink
Add sampler types and py.typed file (#338)
Browse files Browse the repository at this point in the history
Signed-off-by: Kai Mueller <[email protected]>
  • Loading branch information
kasium authored Sep 10, 2021
1 parent ac8e752 commit ba85203
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 38 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ recursive-include config *.json
recursive-include config *.yaml
recursive-include docs *.rst
recursive-include tests *.py
recursive-include jaeger_client *.typed
Empty file added jaeger_client/py.typed
Empty file.
79 changes: 42 additions & 37 deletions jaeger_client/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
SAMPLER_TYPE_RATE_LIMITING,
SAMPLER_TYPE_LOWER_BOUND,
)
from .metrics import Metrics, LegacyMetricsFactory
from .metrics import Metrics, LegacyMetricsFactory, MetricsFactory
from .utils import ErrorReporter
from .rate_limiter import RateLimiter
from typing import Any, Dict, Optional, Tuple

default_logger = logging.getLogger('jaeger_tracing')

Expand All @@ -52,35 +53,38 @@
PROBABILISTIC_SAMPLING_STRATEGY = 'PROBABILISTIC'
RATE_LIMITING_SAMPLING_STRATEGY = 'RATE_LIMITING'

_TagsType = Dict[str, Any]
_IsSampledType = Tuple[bool, _TagsType]


class Sampler(object):
"""
Sampler is responsible for deciding if a particular span should be
"sampled", i.e. recorded in permanent storage.
"""

def __init__(self, tags=None):
self._tags = tags
def __init__(self, tags: Optional[_TagsType] = None) -> None:
self._tags = tags or {}

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
raise NotImplementedError()

def close(self):
def close(self) -> None:
raise NotImplementedError()

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__) and self.__dict__ == other.__dict__
)

def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


class ConstSampler(Sampler):
"""ConstSampler always returns the same decision."""

def __init__(self, decision):
def __init__(self, decision: bool) -> None:
super(ConstSampler, self).__init__(
tags={
SAMPLER_TYPE_TAG_KEY: SAMPLER_TYPE_CONST,
Expand All @@ -89,13 +93,13 @@ def __init__(self, decision):
)
self.decision = decision

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
return self.decision, self._tags

def close(self):
pass

def __str__(self):
def __str__(self) -> str:
return 'ConstSampler(%s)' % self.decision


Expand All @@ -110,7 +114,7 @@ class ProbabilisticSampler(Sampler):
Note that we actually ignore (zero out) the most significant bit.
"""

def __init__(self, rate):
def __init__(self, rate: float) -> None:
super(ProbabilisticSampler, self).__init__(
tags={
SAMPLER_TYPE_TAG_KEY: SAMPLER_TYPE_PROBABILISTIC,
Expand All @@ -122,14 +126,14 @@ def __init__(self, rate):
self.max_number = 1 << _max_id_bits
self.boundary = rate * self.max_number

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
trace_id = trace_id & (self.max_number - 1)
return trace_id < self.boundary, self._tags

def close(self):
def close(self) -> None:
pass

def __str__(self):
def __str__(self) -> str:
return 'ProbabilisticSampler(%s)' % self.rate


Expand All @@ -142,9 +146,9 @@ class RateLimitingSampler(Sampler):
sequential requests can be sampled each second.
"""

def __init__(self, max_traces_per_second=10):
def __init__(self, max_traces_per_second: float = 10) -> None:
super(RateLimitingSampler, self).__init__()
self.rate_limiter = None
self.rate_limiter: RateLimiter = None # type:ignore # value is set below
self._init(max_traces_per_second)

def _init(self, max_traces_per_second):
Expand All @@ -164,13 +168,13 @@ def _init(self, max_traces_per_second):
else:
self.rate_limiter.update(max_traces_per_second, max_balance)

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
return self.rate_limiter.check_credit(1.0), self._tags

def close(self):
def close(self) -> None:
pass

def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""The last_tick and balance fields can be different"""
if not isinstance(other, self.__class__):
return False
Expand All @@ -180,13 +184,13 @@ def __eq__(self, other):
d1['last_tick'] = d2['last_tick']
return d1 == d2

def update(self, max_traces_per_second):
def update(self, max_traces_per_second: float) -> bool:
if self.traces_per_second == max_traces_per_second:
return False
self._init(max_traces_per_second)
return True

def __str__(self):
def __str__(self) -> str:
return 'RateLimitingSampler(%s)' % self.traces_per_second


Expand All @@ -202,7 +206,7 @@ class GuaranteedThroughputProbabilisticSampler(Sampler):
ie. if is_sampled() for both samplers return true, the tags for
ProbabilisticSampler will be used.
"""
def __init__(self, operation, lower_bound, rate):
def __init__(self, operation: str, lower_bound: float, rate: float) -> None:
super(GuaranteedThroughputProbabilisticSampler, self).__init__(
tags={
SAMPLER_TYPE_TAG_KEY: SAMPLER_TYPE_LOWER_BOUND,
Expand All @@ -215,7 +219,7 @@ def __init__(self, operation, lower_bound, rate):
self.rate = rate
self.lower_bound = lower_bound

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
sampled, tags = \
self.probabilistic_sampler.is_sampled(trace_id, operation)
if sampled:
Expand All @@ -224,11 +228,11 @@ def is_sampled(self, trace_id, operation=''):
sampled, _ = self.lower_bound_sampler.is_sampled(trace_id, operation)
return sampled, self._tags

def close(self):
def close(self) -> None:
self.probabilistic_sampler.close()
self.lower_bound_sampler.close()

def update(self, lower_bound, rate):
def update(self, lower_bound: int, rate: float) -> None:
# (NB) This function should only be called while holding a Write lock.
if self.rate != rate:
self.probabilistic_sampler = ProbabilisticSampler(rate)
Expand All @@ -241,7 +245,7 @@ def update(self, lower_bound, rate):
self.lower_bound_sampler.update(lower_bound)
self.lower_bound = lower_bound

def __str__(self):
def __str__(self) -> str:
return 'GuaranteedThroughputProbabilisticSampler(%s, %f, %f)' \
% (self.operation, self.rate, self.lower_bound)

Expand All @@ -253,7 +257,7 @@ class AdaptiveSampler(Sampler):
of all operations and delegates calls the the respective
GuaranteedThroughputProbabilisticSampler.
"""
def __init__(self, strategies, max_operations):
def __init__(self, strategies: Dict[str, Any], max_operations: int) -> None:
super(AdaptiveSampler, self).__init__()

samplers = {}
Expand All @@ -275,7 +279,7 @@ def __init__(self, strategies, max_operations):
self.lower_bound = strategies.get(DEFAULT_LOWER_BOUND_STR, DEFAULT_LOWER_BOUND)
self.max_operations = max_operations

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
sampler = self.samplers.get(operation)
if not sampler:
if len(self.samplers) >= self.max_operations:
Expand All @@ -289,7 +293,7 @@ def is_sampled(self, trace_id, operation=''):
return sampler.is_sampled(trace_id, operation)
return sampler.is_sampled(trace_id, operation)

def update(self, strategies):
def update(self, strategies: Dict[str, Any]) -> None:
# (NB) This function should only be called while holding a Write lock.
for strategy in strategies.get(STRATEGIES_STR, []):
operation = strategy.get(OPERATION_STR)
Expand All @@ -313,19 +317,19 @@ def update(self, strategies):
self.default_sampler = \
ProbabilisticSampler(self.default_sampling_probability)

def close(self):
def close(self) -> None:
for _, sampler in self.samplers.items():
sampler.close()

def __str__(self):
def __str__(self) -> str:
return 'AdaptiveSampler(%f, %f, %d)' \
% (self.default_sampling_probability, self.lower_bound,
self.max_operations)


class RemoteControlledSampler(Sampler):
"""Periodically loads the sampling strategy from a remote server."""
def __init__(self, channel, service_name, **kwargs):
def __init__(self, channel: Any, service_name: str, **kwargs: Any) -> None:
"""
:param channel: channel for communicating with jaeger-agent
:param service_name: name of this application
Expand Down Expand Up @@ -378,8 +382,9 @@ def __init__(self, channel, service_name, **kwargs):
# unless already running in the loop, so we use `add_callback`
self.io_loop.add_callback(self._init_polling)

def is_sampled(self, trace_id, operation=''):
def is_sampled(self, trace_id: int, operation: str = '') -> _IsSampledType:
with self.lock:
assert self.sampler # needed for mypy
return self.sampler.is_sampled(trace_id, operation)

def _init_polling(self):
Expand Down Expand Up @@ -496,14 +501,14 @@ def _poll_sampling_manager(self):
fut = self._channel.request_sampling_strategy(self.service_name)
fut.add_done_callback(self._sampling_request_callback)

def close(self):
def close(self) -> None:
with self.lock:
self.running = False
if self.periodic:
self.periodic.stop()


def get_sampling_probability(strategy=None):
def get_sampling_probability(strategy: Optional[Dict[str, Any]] = None) -> float:
if not strategy:
return DEFAULT_SAMPLING_PROBABILITY
probability_strategy = strategy.get(PROBABILISTIC_SAMPLING_STR)
Expand All @@ -512,7 +517,7 @@ def get_sampling_probability(strategy=None):
return probability_strategy.get(SAMPLING_RATE_STR, DEFAULT_SAMPLING_PROBABILITY)


def get_rate_limit(strategy=None):
def get_rate_limit(strategy: Optional[Dict[str, Any]] = None) -> float:
if not strategy:
return DEFAULT_LOWER_BOUND
rate_limit_strategy = strategy.get(RATE_LIMITING_SAMPLING_STR)
Expand All @@ -524,7 +529,7 @@ def get_rate_limit(strategy=None):
class SamplerMetrics(object):
"""Sampler specific metrics."""

def __init__(self, metrics_factory):
def __init__(self, metrics_factory: MetricsFactory) -> None:
self.sampler_retrieved = \
metrics_factory.create_counter(name='jaeger:sampler_queries', tags={'result': 'ok'})
self.sampler_query_failure = \
Expand Down
2 changes: 1 addition & 1 deletion jaeger_client/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def start_span(self,
baggage = None
if parent is None:
sampled, sampler_tags = \
self.sampler.is_sampled(trace_id, operation_name)
self.sampler.is_sampled(trace_id, operation_name or '')
if sampled:
flags = SAMPLED_FLAG
tags = tags or {}
Expand Down

0 comments on commit ba85203

Please sign in to comment.