diff --git a/kitsune/lib/email.py b/kitsune/lib/email.py index 2f930aef87d..bfdd0e0468a 100644 --- a/kitsune/lib/email.py +++ b/kitsune/lib/email.py @@ -7,6 +7,8 @@ from django.utils.module_loading import import_string from sentry_sdk import capture_exception +from kitsune.sumo.redis_utils import RateLimit + log = logging.getLogger("k.lib.email") @@ -116,6 +118,9 @@ class SMTPEmailBackendWithSentryCapture(smtp.EmailBackend): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fail_silently = False + self.rate_limit = RateLimit( + key="rate-limit-emails", rate="100/sec", wait_period=1, max_wait_period=30 + ) def open(self): try: @@ -125,6 +130,7 @@ def open(self): return None def close(self): + self.rate_limit.close() try: return super().close() except smtplib.SMTPException as err: @@ -132,6 +138,7 @@ def close(self): return None def _send(self, email_message): + self.rate_limit.wait() try: return super()._send(email_message) except smtplib.SMTPException as err: diff --git a/kitsune/sumo/redis_utils.py b/kitsune/sumo/redis_utils.py index 2b70f167393..a7d85ce9836 100644 --- a/kitsune/sumo/redis_utils.py +++ b/kitsune/sumo/redis_utils.py @@ -1,8 +1,12 @@ +from functools import cached_property from urllib.parse import parse_qsl +import random +import time from django.conf import settings from django.core.cache.backends.base import InvalidCacheBackendError from redis import ConnectionError, Redis +from sentry_sdk import capture_exception class RedisError(Exception): @@ -87,3 +91,84 @@ def parse_backend_uri(backend_uri): host = host[:-1] return scheme, host, params + + +class RateLimit: + """ + Simple multi-process rate limiting class that uses Redis. + """ + + ALLOWED_PERIODS = dict(sec=1, min=60, hour=3600, day=86400) + + def __init__( + self, + key: str, + rate: str, + wait_period: int | float, + max_wait_period: int | float, + jitter: float = 0.2, + ): + self.key = key + self.jitter = jitter # percentage + self.wait_period = wait_period # seconds + self.max_wait_period = max_wait_period # seconds + max_calls, period = rate.replace("/", " ").split() + self.max_calls = int(max_calls) + self.period = self.ALLOWED_PERIODS[period] + + @cached_property + def redis(self): + """ + Creates and caches the Redis client on demand. + """ + try: + return redis_client("default") + except RedisError as err: + capture_exception(err) + return None + + def close(self): + """ + Close the Redis client if it exists. + """ + # We only need to do something if we've cached a Redis client. + if "redis" in self.__dict__: + if self.redis: + try: + self.redis.close() + except Exception as err: + capture_exception(err) + # Remove the cached Redis client. + delattr(self, "redis") + + def is_rate_limited(self): + """ + Returns True if the rate limit has been exceeded, False otherwise. + """ + if not self.redis: + # If we can't connect to Redis, don't rate limit. + return False + # The first caller refreshes the "token bucket" with the maximum number of + # calls allowed in a period, while this is a "noop" for all other callers. + # Only the first caller will be able to set the key and its expiration, since + # the key will only be set if it doesn't already exist (nx=True). Note that + # once a key expires, it's considered to no longer exist. + self.redis.set(self.key, self.max_calls, nx=True, ex=self.period) + # If the "token bucket" is empty, start rate limiting until it's refreshed + # again after the period expires. + return self.redis.decrby(self.key, 1) < 0 + + def wait(self): + """ + Wait until we're no longer rate limited. Waits either indefinitely, if + the "max_wait_period" is None or zero, or until the "max_wait_period" + has been reached. Returns the time spent waiting in seconds. + """ + waited = 0 + while self.is_rate_limited(): + jittered_wait = self.wait_period * random.uniform(1 - self.jitter, 1 + self.jitter) + time.sleep(jittered_wait) + waited += jittered_wait + if self.max_wait_period and (waited >= self.max_wait_period): + break + return waited diff --git a/kitsune/sumo/tests/test_redis_utils.py b/kitsune/sumo/tests/test_redis_utils.py new file mode 100644 index 00000000000..0c973d740ea --- /dev/null +++ b/kitsune/sumo/tests/test_redis_utils.py @@ -0,0 +1,109 @@ +import multiprocessing +import time +from unittest import mock + +from kitsune.sumo.redis_utils import RateLimit, RedisError +from kitsune.sumo.tests import TestCase + + +class TestRateLimit(TestCase): + + def setUp(self): + self.key = "test-key" + self.max_calls = 5 + self.wait_period = 0.1 + self.max_wait_period = 2 + self.jitter = 0.2 + self.rate_limit = RateLimit( + key=self.key, + rate=f"{self.max_calls}/sec", + wait_period=self.wait_period, + max_wait_period=self.max_wait_period, + jitter=self.jitter, + ) + self.rate_limit.redis.delete(self.key) + + def tearDown(self): + self.rate_limit.close() + + def test_is_rate_limited(self): + """Ensure basic operation of is_rate_limited().""" + for i in range(self.max_calls): + self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}") + + self.assertTrue(self.rate_limit.is_rate_limited()) + + def test_is_rate_limited_expiration(self): + """Ensure is_rate_limited() resets after the expiration period.""" + for i in range(self.max_calls): + self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}") + + self.assertTrue(self.rate_limit.is_rate_limited()) + time.sleep(1) + self.assertFalse(self.rate_limit.is_rate_limited()) + + def test_wait(self): + """Ensure wait() waits until we're no longer rate limited.""" + for i in range(self.max_calls): + self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}") + + time_waited = self.rate_limit.wait() + + self.assertFalse(self.rate_limit.is_rate_limited()) + self.assertTrue(time_waited >= self.wait_period * (1 - self.jitter)) + self.assertTrue(time_waited < self.max_wait_period) + + def test_wait_respects_max_wait_period(self): + """Ensure wait() respects the "max_wait_period" setting.""" + self.rate_limit = RateLimit( + key=self.key, rate="1/sec", wait_period=0.05, max_wait_period=0.1, jitter=0.0 + ) + self.assertFalse(self.rate_limit.is_rate_limited()) + time_waited = self.rate_limit.wait() + # We stopped waiting only because we hit the maximum waiting period. + self.assertTrue(self.rate_limit.is_rate_limited()) + self.assertTrue(time_waited == 0.1) + + def test_is_rate_limited_multiple_processes(self): + """Test is_rate_limited() across multiple processes.""" + shared_counter = multiprocessing.Value("i", 0) + # Create a lock to ensure safe increments of the shared counter. + shared_counter_lock = multiprocessing.Lock() + + def rate_limited_task(): + """Worker function for multi-process testing.""" + rate_limit = RateLimit( + key="test-key", rate="5/sec", wait_period=0.1, max_wait_period=2 + ) + if not rate_limit.is_rate_limited(): + with shared_counter_lock: + shared_counter.value += 1 + + processes = [multiprocessing.Process(target=rate_limited_task) for _ in range(10)] + + for p in processes: + p.start() + + # Wait until all of the processes have completed. + for p in processes: + p.join() + + self.assertEqual(shared_counter.value, 5) + + @mock.patch("kitsune.sumo.redis_utils.redis_client") + @mock.patch("kitsune.sumo.redis_utils.capture_exception") + def test_redis_client_failure(self, capture_mock, redis_mock): + """Ensure that RateLimit handles Redis failures gracefully.""" + redis_mock.side_effect = RedisError() + + self.rate_limit = RateLimit( + key=self.key, rate="1/min", wait_period=0.05, max_wait_period=0.1 + ) + + # If the creation of the redis client failed, there should be no rate limiting. + self.assertFalse(self.rate_limit.is_rate_limited()) + self.assertFalse(self.rate_limit.is_rate_limited()) + self.assertFalse(self.rate_limit.is_rate_limited()) + + redis_mock.assert_called_once() + capture_mock.assert_called_once_with(redis_mock.side_effect)