Skip to content

Commit

Permalink
rate limit emails (#6482)
Browse files Browse the repository at this point in the history
* rate limit emails

* add jitter
  • Loading branch information
escattone authored Feb 13, 2025
1 parent 6e798e3 commit 0356cc0
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 0 deletions.
7 changes: 7 additions & 0 deletions kitsune/lib/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from django.utils.module_loading import import_string
from sentry_sdk import capture_exception

from kitsune.sumo.redis_utils import RateLimit


log = logging.getLogger("k.lib.email")

Expand Down Expand Up @@ -116,6 +118,9 @@ class SMTPEmailBackendWithSentryCapture(smtp.EmailBackend):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fail_silently = False
self.rate_limit = RateLimit(
key="rate-limit-emails", rate="100/sec", wait_period=1, max_wait_period=30
)

def open(self):
try:
Expand All @@ -125,13 +130,15 @@ def open(self):
return None

def close(self):
self.rate_limit.close()
try:
return super().close()
except smtplib.SMTPException as err:
capture_exception(err)
return None

def _send(self, email_message):
self.rate_limit.wait()
try:
return super()._send(email_message)
except smtplib.SMTPException as err:
Expand Down
85 changes: 85 additions & 0 deletions kitsune/sumo/redis_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from functools import cached_property
from urllib.parse import parse_qsl
import random
import time

from django.conf import settings
from django.core.cache.backends.base import InvalidCacheBackendError
from redis import ConnectionError, Redis
from sentry_sdk import capture_exception


class RedisError(Exception):
Expand Down Expand Up @@ -87,3 +91,84 @@ def parse_backend_uri(backend_uri):
host = host[:-1]

return scheme, host, params


class RateLimit:
"""
Simple multi-process rate limiting class that uses Redis.
"""

ALLOWED_PERIODS = dict(sec=1, min=60, hour=3600, day=86400)

def __init__(
self,
key: str,
rate: str,
wait_period: int | float,
max_wait_period: int | float,
jitter: float = 0.2,
):
self.key = key
self.jitter = jitter # percentage
self.wait_period = wait_period # seconds
self.max_wait_period = max_wait_period # seconds
max_calls, period = rate.replace("/", " ").split()
self.max_calls = int(max_calls)
self.period = self.ALLOWED_PERIODS[period]

@cached_property
def redis(self):
"""
Creates and caches the Redis client on demand.
"""
try:
return redis_client("default")
except RedisError as err:
capture_exception(err)
return None

def close(self):
"""
Close the Redis client if it exists.
"""
# We only need to do something if we've cached a Redis client.
if "redis" in self.__dict__:
if self.redis:
try:
self.redis.close()
except Exception as err:
capture_exception(err)
# Remove the cached Redis client.
delattr(self, "redis")

def is_rate_limited(self):
"""
Returns True if the rate limit has been exceeded, False otherwise.
"""
if not self.redis:
# If we can't connect to Redis, don't rate limit.
return False
# The first caller refreshes the "token bucket" with the maximum number of
# calls allowed in a period, while this is a "noop" for all other callers.
# Only the first caller will be able to set the key and its expiration, since
# the key will only be set if it doesn't already exist (nx=True). Note that
# once a key expires, it's considered to no longer exist.
self.redis.set(self.key, self.max_calls, nx=True, ex=self.period)
# If the "token bucket" is empty, start rate limiting until it's refreshed
# again after the period expires.
return self.redis.decrby(self.key, 1) < 0

def wait(self):
"""
Wait until we're no longer rate limited. Waits either indefinitely, if
the "max_wait_period" is None or zero, or until the "max_wait_period"
has been reached. Returns the time spent waiting in seconds.
"""
waited = 0
while self.is_rate_limited():
jittered_wait = self.wait_period * random.uniform(1 - self.jitter, 1 + self.jitter)
time.sleep(jittered_wait)
waited += jittered_wait
if self.max_wait_period and (waited >= self.max_wait_period):
break
return waited
109 changes: 109 additions & 0 deletions kitsune/sumo/tests/test_redis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import multiprocessing
import time
from unittest import mock

from kitsune.sumo.redis_utils import RateLimit, RedisError
from kitsune.sumo.tests import TestCase


class TestRateLimit(TestCase):

def setUp(self):
self.key = "test-key"
self.max_calls = 5
self.wait_period = 0.1
self.max_wait_period = 2
self.jitter = 0.2
self.rate_limit = RateLimit(
key=self.key,
rate=f"{self.max_calls}/sec",
wait_period=self.wait_period,
max_wait_period=self.max_wait_period,
jitter=self.jitter,
)
self.rate_limit.redis.delete(self.key)

def tearDown(self):
self.rate_limit.close()

def test_is_rate_limited(self):
"""Ensure basic operation of is_rate_limited()."""
for i in range(self.max_calls):
self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}")

self.assertTrue(self.rate_limit.is_rate_limited())

def test_is_rate_limited_expiration(self):
"""Ensure is_rate_limited() resets after the expiration period."""
for i in range(self.max_calls):
self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}")

self.assertTrue(self.rate_limit.is_rate_limited())
time.sleep(1)
self.assertFalse(self.rate_limit.is_rate_limited())

def test_wait(self):
"""Ensure wait() waits until we're no longer rate limited."""
for i in range(self.max_calls):
self.assertFalse(self.rate_limit.is_rate_limited(), f"is_rate_limited() call: {i+1}")

time_waited = self.rate_limit.wait()

self.assertFalse(self.rate_limit.is_rate_limited())
self.assertTrue(time_waited >= self.wait_period * (1 - self.jitter))
self.assertTrue(time_waited < self.max_wait_period)

def test_wait_respects_max_wait_period(self):
"""Ensure wait() respects the "max_wait_period" setting."""
self.rate_limit = RateLimit(
key=self.key, rate="1/sec", wait_period=0.05, max_wait_period=0.1, jitter=0.0
)
self.assertFalse(self.rate_limit.is_rate_limited())
time_waited = self.rate_limit.wait()
# We stopped waiting only because we hit the maximum waiting period.
self.assertTrue(self.rate_limit.is_rate_limited())
self.assertTrue(time_waited == 0.1)

def test_is_rate_limited_multiple_processes(self):
"""Test is_rate_limited() across multiple processes."""
shared_counter = multiprocessing.Value("i", 0)
# Create a lock to ensure safe increments of the shared counter.
shared_counter_lock = multiprocessing.Lock()

def rate_limited_task():
"""Worker function for multi-process testing."""
rate_limit = RateLimit(
key="test-key", rate="5/sec", wait_period=0.1, max_wait_period=2
)
if not rate_limit.is_rate_limited():
with shared_counter_lock:
shared_counter.value += 1

processes = [multiprocessing.Process(target=rate_limited_task) for _ in range(10)]

for p in processes:
p.start()

# Wait until all of the processes have completed.
for p in processes:
p.join()

self.assertEqual(shared_counter.value, 5)

@mock.patch("kitsune.sumo.redis_utils.redis_client")
@mock.patch("kitsune.sumo.redis_utils.capture_exception")
def test_redis_client_failure(self, capture_mock, redis_mock):
"""Ensure that RateLimit handles Redis failures gracefully."""
redis_mock.side_effect = RedisError()

self.rate_limit = RateLimit(
key=self.key, rate="1/min", wait_period=0.05, max_wait_period=0.1
)

# If the creation of the redis client failed, there should be no rate limiting.
self.assertFalse(self.rate_limit.is_rate_limited())
self.assertFalse(self.rate_limit.is_rate_limited())
self.assertFalse(self.rate_limit.is_rate_limited())

redis_mock.assert_called_once()
capture_mock.assert_called_once_with(redis_mock.side_effect)

0 comments on commit 0356cc0

Please sign in to comment.