Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache #15817

Merged
merged 25 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 82 additions & 28 deletions libs/community/langchain_community/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,20 +350,24 @@ def clear(self, **kwargs: Any) -> None:


class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
"""
Cache that uses Redis as a backend. Allows to use either sync or
async Redis client. Depending on the client passed, you are expected
to use either sync or async methods.
"""

def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for this you could have a RedisCache that uses a redis.Redis client and an AsyncRedisCache that uses a redis.asyncio.Redis ? (I know this is the opposite move as for BaseCache 😄 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually did it this way in the very first variant, but I didn't like it was literally 99% of the copy from the original RedisCache. The remaining 1% were async/await keywords and client check in constructor. Feels like a potential future maintenance problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I am not sure what I suppose to do for AsyncRedisCache sync methods lookup/update/clear. Like I can raise error on those, or do the opposite trick for async in sync execution, but doesn't it feel like too much of duplication of effort?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could put this common behavior in a parent abstract class ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I can do it. But before that, can you tell me whats your rationale here? Just so I know for future :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RedisBaseCache is abstract and has methods _key, __ensure_generation_type __get_generations __configure_pipeline_for_update
RedisCache extends RedisBaseCache and implements methods lookup, update, clear using redis.Redis
AsyncRedisCache extends RedisBaseCache and implements methods alookup, aupdate, aclear using redis.asyncio.Redis
AsyncRedisCache raises NotImplementedError() for lookup, update, clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbornet I mean I was asking whats the motivation for such split.

I've pushed the implementation which is basically what you said, however I went a bit further with sync-in-async and actually made it so you can run it in certain context (weird case though, I log warning for that). I can rework to NotImplementedError if you think its cleaner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or do the opposite trick for async in sync execution

At the moment, we generally try to avoid that. Only one event loop is allowed in a given thread, so there's some additional logic that's required to determine if there's an event loop running already and if yes, kick off another thread etc

"""
Initialize an instance of RedisCache.

This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class, allowing the object to interact with a Redis
server for caching purposes.
client class (`redis.Redis` or `redis.asyncio.Redis`), allowing the object
to interact with a Redis server for caching purposes.

Parameters:
redis_ (Any): An instance of a Redis client class
(e.g., redis.Redis) used for caching.
(`redis.Redis` or `redis.asyncio.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
Expand All @@ -372,26 +376,42 @@ def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
have an automatic expiration.
"""
try:
from redis import Redis
from redis import Redis as SyncRedis
from redis.asyncio import Redis as AsyncRedis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
if isinstance(redis_, SyncRedis):
self._async = False
elif isinstance(redis_, AsyncRedis):
self._async = True
else:
raise ValueError("Please pass a valid Redis client.")
self.redis = redis_
self.ttl = ttl

def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
def __ensure_sync(self, function_name: str = "function"):
if self._async:
raise ValueError(f"Cannot use sync {function_name} with async Redis client")

@staticmethod
def __ensure_generation_type(return_val: RETURN_VAL_TYPE):
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)

@staticmethod
def __get_generations(results: dict[str | bytes, str | bytes]) -> list[Generation]:
generations = []
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
try:
Expand All @@ -408,35 +428,69 @@ def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
generations.append(Generation(text=text))
return generations if generations else None

def __configure_pipeline_for_update(self, key, pipe, return_val):
pipe.hset(
key,
mapping={
str(idx): dumps(generation) for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
self.__ensure_sync("lookup")
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
return self.__get_generations(results)

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
"""Look up based on prompt and llm_string. Async version."""
if not self._async:
return await super().alookup(prompt, llm_string)
results = await self.redis.hgetall(self._key(prompt, llm_string))
return self.__get_generations(results)

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
self.__ensure_sync("update")
self.__ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

with self.redis.pipeline() as pipe:
pipe.hset(
key,
mapping={
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)

self.__configure_pipeline_for_update(key, pipe, return_val)
pipe.execute()

async def aupdate(
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string. Async version."""
if not self._async:
return await super().aupdate(prompt, llm_string, return_val)
self.__ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

async with self.redis.pipeline() as pipe:
self.__configure_pipeline_for_update(key, pipe, return_val)
await pipe.execute()

def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
self.__ensure_sync("clear")
asynchronous = kwargs.get("asynchronous", False)
self.redis.flushdb(asynchronous=asynchronous, **kwargs)

async def aclear(self, **kwargs: Any) -> None:
"""
Clear cache. If `asynchronous` is True, flush asynchronously.
Async version.
"""
if not self._async:
return await super().aclear(**kwargs)
asynchronous = kwargs.get("asynchronous", False)
await self.redis.flushdb(asynchronous=asynchronous, **kwargs)


class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""
Expand Down
15 changes: 15 additions & 0 deletions libs/core/langchain_core/caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Sequence

from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor

RETURN_VAL_TYPE = Sequence[Generation]

Expand All @@ -22,3 +23,17 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to add these methods to the BaseCache class with a default implementation that calls the non-async method with run_in_executor like what is done in BaseLLM and Embeddings.
Otherwise, there is no reason to prefix the methods here with the a prefix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzmitry-kankalovich We generally prefer keeping async version of methods in the same class as the sync versions, and have them provide a default async implementation that uses run_in_executor like @cbornet suggested.

Would you mind merging into existing abstractions? Ideally a single PR that just modifies the core interface first, and then separately we can do a PR for any implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the smaller PRs - yes, I can move slower and split up into several PRs, if you are happy with current direction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally better to split PRs by package to minimize potential dependency conflicts since the packages may have different release schedules

"""Look up based on prompt and llm_string."""
return await run_in_executor(None, self.lookup, prompt, llm_string)

async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)

async def aclear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
return await run_in_executor(None, self.clear, **kwargs)
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ async def _agenerate_with_cache(
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
Expand All @@ -631,7 +631,7 @@ async def _agenerate_with_cache(
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result

@abstractmethod
Expand Down
42 changes: 40 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,26 @@ def get_prompts(
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts


async def aget_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached. Async version."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts


def update_cache(
existing_prompts: Dict[int, List],
llm_string: str,
Expand All @@ -152,6 +172,24 @@ def update_cache(
return llm_output


async def aupdate_cache(
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version"""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output


class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.

Expand Down Expand Up @@ -864,7 +902,7 @@ async def agenerate(
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
) = await aget_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
Expand Down Expand Up @@ -912,7 +950,7 @@ async def agenerate(
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
llm_output = update_cache(
llm_output = await aupdate_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = (
Expand Down
Loading