Skip to content

Commit

Permalink
[ENH]: add rate limiting (chroma-core#1728)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - New functionality
- Add rate limiting service. If no rate limit service is provided, it
will not do anything.

## Test plan
*How are these changes tested?*
Unit test on rate limiting service.


---------

Co-authored-by: nicolas <[email protected]>
  • Loading branch information
nicolasgere and nicolas authored Feb 28, 2024
1 parent 44e8ff7 commit 5e9c7a7
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 2 deletions.
4 changes: 2 additions & 2 deletions chromadb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"chromadb.segment.SegmentManager": "chroma_segment_manager_impl",
"chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl",
"chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl",

"chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl"
}

DEFAULT_TENANT = "default_tenant"
Expand All @@ -102,7 +102,7 @@ class Settings(BaseSettings): # type: ignore
"chromadb.segment.impl.manager.local.LocalSegmentManager"
)
chroma_quota_provider_impl:Optional[str] = None

chroma_rate_limiting_provider_impl:Optional[str] = None
# Distributed architecture specific components
chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory"
chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider"
Expand Down
63 changes: 63 additions & 0 deletions chromadb/rate_limiting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import inspect
from abc import abstractmethod
from functools import wraps
from typing import Optional, Any, Dict, Callable, cast

from chromadb.config import Component
from chromadb.quota import QuotaProvider, Resource


class RateLimitError(Exception):
def __init__(self, resource: Resource, quota: int):
super().__init__(f"rate limit error. resource: {resource} quota: {quota}")
self.quota = quota
self.resource = resource

class RateLimitingProvider(Component):
@abstractmethod
def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool:
"""
Determines if a request identified by `key` can proceed given the current rate limit.
:param key: The identifier for the requestor (unused in this simplified implementation).
:param quota: The quota which will be used for bucket size.
:param point: The number of tokens required to fulfill the request.
:return: True if the request can proceed, False otherwise.
"""
pass


def rate_limit(
subject: str,
resource: str
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
args_name = inspect.getfullargspec(f)[0]
if subject not in args_name:
raise Exception(f'rate_limit decorator have unknown subject "{subject}", available {args_name}')
key_index = args_name.index(subject)

@wraps(f)
def wrapper(self, *args: Any, **kwargs: Dict[Any, Any]) -> Any:
# If not rate limiting provider is present, just run and return the function.
if self._system.settings.chroma_rate_limiting_provider_impl is None:
return f(self, *args, **kwargs)

if subject in kwargs:
subject_value = kwargs[subject]
else:
if len(args) < key_index:
return f(self, *args, **kwargs)
subject_value = args[key_index-1]
key_value = resource + "-" + subject_value
self._system.settings.chroma_rate_limiting_provider_impl
quota_provider = self._system.require(QuotaProvider)
rate_limiter = self._system.require(RateLimitingProvider)
quota = quota_provider.get_for_subject(resource=resource,subject=subject)
is_allowed = rate_limiter.is_allowed(key_value, quota)
if is_allowed is False:
raise RateLimitError(resource=resource, quota=quota)
return f(self, *args, **kwargs)
return wrapper

return decorator
15 changes: 15 additions & 0 deletions chromadb/rate_limiting/test_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Optional, Dict

from overrides import overrides

from chromadb.config import System
from chromadb.rate_limiting import RateLimitingProvider


class RateLimitingTestProvider(RateLimitingProvider):
def __init__(self, system: System):
super().__init__(system)

@overrides
def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool:
pass
9 changes: 9 additions & 0 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
InvalidHTTPVersion,
)
from chromadb.quota import QuotaError
from chromadb.rate_limiting import RateLimitError
from chromadb.server.fastapi.types import (
AddEmbedding,
CreateDatabase,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(self, settings: Settings):
allow_methods=["*"],
)
self._app.add_exception_handler(QuotaError, self.quota_exception_handler)
self._app.add_exception_handler(RateLimitError, self.rate_limit_exception_handler)

self._app.on_event("shutdown")(self.shutdown)

Expand Down Expand Up @@ -290,6 +292,13 @@ def shutdown(self) -> None:
def app(self) -> fastapi.FastAPI:
return self._app

async def rate_limit_exception_handler(self, request: Request, exc: RateLimitError):
return JSONResponse(
status_code=429,
content={"message": f"rate limit. resource: {exc.resource} quota: {exc.quota}"},
)


def root(self) -> Dict[str, int]:
return {"nanosecond heartbeat": self._api.heartbeat()}

Expand Down
50 changes: 50 additions & 0 deletions chromadb/test/rate_limiting/test_rate_limiting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional
from unittest.mock import patch

from chromadb.config import System, Settings, Component
from chromadb.quota import QuotaEnforcer, Resource
import pytest

from chromadb.rate_limiting import rate_limit


class RateLimitingGym(Component):
def __init__(self, system: System):
super().__init__(system)
self.system = system

@rate_limit(subject="bar", resource="FAKE_RESOURCE")
def bench(self, foo: str, bar: str) -> str:
return foo

def mock_get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[
int]:
"""Mock function to simulate quota retrieval."""
return 10

@pytest.fixture(scope="module")
def rate_limiting_gym() -> QuotaEnforcer:
settings = Settings(
chroma_quota_provider_impl="chromadb.quota.test_provider.QuotaProviderForTest",
chroma_rate_limiting_provider_impl="chromadb.rate_limiting.test_provider.RateLimitingTestProvider"
)
system = System(settings)
return RateLimitingGym(system)


@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: False)
def test_rate_limiting_should_raise(rate_limiting_gym: RateLimitingGym):
with pytest.raises(Exception) as exc_info:
rate_limiting_gym.bench("foo", "bar")
assert "FAKE_RESOURCE" in str(exc_info.value.resource)

@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True)
def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym):
assert rate_limiting_gym.bench(foo="foo", bar="bar") is "foo"

@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject)
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True)
def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym):
assert rate_limiting_gym.bench("foo", "bar") is "foo"

0 comments on commit 5e9c7a7

Please sign in to comment.