Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache #15817

Merged
merged 25 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 183 additions & 44 deletions libs/community/langchain_community/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
"""
from __future__ import annotations

import asyncio
import hashlib
import inspect
import json
import logging
import uuid
import warnings
from abc import ABC
from datetime import timedelta
from functools import lru_cache
from typing import (
Expand Down Expand Up @@ -349,21 +351,69 @@ def clear(self, **kwargs: Any) -> None:
self.redis.flushdb(flush_type=asynchronous)


class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
class RedisCacheBase(BaseCache, ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we mark this as private?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can you show me the example? Something like __RedisCacheBase ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked at other examples in codebase and seems like the pattern is to use a single underscore, so I pushed _RedisCacheBase rename.

@staticmethod
def _key(prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)

@staticmethod
def _ensure_generation_type(return_val: RETURN_VAL_TYPE):
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)

@staticmethod
def _get_generations(results: dict[str | bytes, str | bytes]) -> list[Generation]:
generations = []
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None

@staticmethod
def _configure_pipeline_for_update(key, pipe, return_val, ttl=None):
pipe.hset(
key,
mapping={
str(idx): dumps(generation) for idx, generation in enumerate(return_val)
},
)
if ttl is not None:
pipe.expire(key, ttl)


class RedisCache(RedisCacheBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative design is to make the initializer able to accept both async and sync versions of the redis client.

If its initialized with the async redis client, then it uses its async methods. If it gets initializes with the sync client, it delegates all the async calls to the sync ones (using the trick in the abstract class).

What do you think would be better? A single RedisCache for users to know about / user, and have the differentiation in terms of the underlying redis client that's passed to initialize the cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyurtsev it was actually my original implementation, a single Cache class with ability to work with sync and async client. However @cbornet explicitly asked to change - break it out to two separate Redis Caches. I can see pros and cons in either approach.

I think at this point, having two different opinions which both result in a fine working implementation means we can go either way. The maintenance burden of such split seems negligible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even if such split turns out to be a mistake, we should be able to work around it in future by consolidating features in a single class and then using type alias for AsyncRedisCache to provide backwards compatibility with existing codebase.

But I think its just fine as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry didn't notice this comment -- should've responded earlier!

I touched based with @baskaryan, I think we'll try to go with a convention where both sync and async implementations live as close as possible to one another -- the thought is to just use that as the design pattern unless there's a compelling reason to do otherwise.

The reason for the convention itself is:

  1. Make sure that async implementation is always provided and doesn't drift from sync implementation
  2. Reduce the number of new classes that users have to know about
  3. Less objects to document, but we can expand the documentation

The downside to this approach that I see is that the user needs to know that they can pass aredis client, which is more implicit than the approach with ARedisCache.

cc @cbornet

Apologies should've checked in before modifying the PR.

Let me know if you're OK with this change or not. If so, I can merge. If there's still a strong feeling that we should have two implementations please let me know why, and I can revert the changes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I get the arguments for the single class. Sorry for the pull in the wrong direction.

"""
Cache that uses Redis as a backend. Allows to use a sync `redis.Redis` client.
"""

def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of RedisCache.

This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class, allowing the object to interact with a Redis
server for caching purposes.
client class (`redis.Redis`), allowing the object
to interact with a Redis server for caching purposes.

Parameters:
redis_ (Any): An instance of a Redis client class
(e.g., redis.Redis) used for caching.
(`redis.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
Expand All @@ -375,68 +425,157 @@ def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
from redis import Redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Could not import `redis` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
raise ValueError("Please pass a valid `redis.Redis` client.")
self.redis = redis_
self.ttl = ttl

def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
generations = []
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None
return self._get_generations(results)

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("Consider using `AsyncRedisCache` for async cache operations.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure it's a good thing to output these logs. It will flood the logs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree.

return await super().alookup(prompt, llm_string)

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

with self.redis.pipeline() as pipe:
pipe.hset(
key,
mapping={
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)

self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
pipe.execute()

async def aupdate(
eyurtsev marked this conversation as resolved.
Show resolved Hide resolved
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
logger.warning("Consider using `AsyncRedisCache` for async cache operations.")
return await super().aupdate(prompt, llm_string, return_val)

def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
asynchronous = kwargs.get("asynchronous", False)
self.redis.flushdb(asynchronous=asynchronous, **kwargs)

async def aclear(self, **kwargs: Any) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

logger.warning("Consider using `AsyncRedisCache` for async cache operations.")
return await super().aclear(**kwargs)


class AsyncRedisCache(RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use an
async `redis.asyncio.Redis` client.
"""

def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of AsyncRedisCache.

This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class (`redis.asyncio.Redis`), allowing the object
to interact with a Redis server for caching purposes.

Parameters:
redis_ (Any): An instance of a Redis client class
(`redis.asyncio.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis.asyncio import Redis
except ImportError:
raise ValueError(
"Could not import `redis.asyncio` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass a valid `redis.asyncio.Redis` client.")
self.redis = redis_
self.ttl = ttl

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
logger.warning(
"This an async Redis cache. Did you mean to use `alookup()` method?"
)
try:
if asyncio.get_running_loop():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too complex and brittle. I would just throw NotImplementedError and nothing else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree.

# There is no nice way to run async code from sync function if there is
# an already existing event loop. Error out as the only option.
raise NotImplementedError(
"Cannot use sync `lookup()` in async context. "
"Consider using `alookup()`."
)
except RuntimeError:
# At this point, somebody tries to run async redis cache in a
# non-async environment with no event loop.
# Weird, but it's techically possible.
return asyncio.run(self.alookup(prompt, llm_string))

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string. Async version."""
results = await self.redis.hgetall(self._key(prompt, llm_string))
return self._get_generations(results)

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
logger.warning(
"This an async Redis cache. Did you mean to use `aupdate()` method?"
)
try:
if asyncio.get_running_loop():
raise NotImplementedError(
"Cannot use sync `update()` in async context. "
"Consider using `aupdate()`."
)
except RuntimeError:
return asyncio.run(self.aupdate(prompt, llm_string, return_val))

async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string. Async version."""
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

async with self.redis.pipeline() as pipe:
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
await pipe.execute()

def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
logger.warning(
"This an async Redis cache. Did you mean to use `aclear()` method?"
)
try:
if asyncio.get_running_loop():
raise NotImplementedError(
"Cannot use sync `clear()` in async context. "
"Consider using `aclear()`."
)
except RuntimeError:
return asyncio.run(self.aclear(**kwargs))

async def aclear(self, **kwargs: Any) -> None:
"""
Clear cache. If `asynchronous` is True, flush asynchronously.
Async version.
"""
asynchronous = kwargs.get("asynchronous", False)
await self.redis.flushdb(asynchronous=asynchronous, **kwargs)


class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""
Expand Down
15 changes: 15 additions & 0 deletions libs/core/langchain_core/caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Sequence

from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor

RETURN_VAL_TYPE = Sequence[Generation]

Expand All @@ -22,3 +23,17 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to add these methods to the BaseCache class with a default implementation that calls the non-async method with run_in_executor like what is done in BaseLLM and Embeddings.
Otherwise, there is no reason to prefix the methods here with the a prefix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzmitry-kankalovich We generally prefer keeping async version of methods in the same class as the sync versions, and have them provide a default async implementation that uses run_in_executor like @cbornet suggested.

Would you mind merging into existing abstractions? Ideally a single PR that just modifies the core interface first, and then separately we can do a PR for any implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the smaller PRs - yes, I can move slower and split up into several PRs, if you are happy with current direction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally better to split PRs by package to minimize potential dependency conflicts since the packages may have different release schedules

"""Look up based on prompt and llm_string."""
return await run_in_executor(None, self.lookup, prompt, llm_string)

async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)

async def aclear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
return await run_in_executor(None, self.clear, **kwargs)
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ async def _agenerate_with_cache(
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
Expand All @@ -631,7 +631,7 @@ async def _agenerate_with_cache(
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result

@abstractmethod
Expand Down
Loading