Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain[minor], community[minor], core[minor]: Async Cache support and AsyncRedisCache #15817

Merged
merged 25 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# docker-compose to make it easier to spin up integration tests.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzmitry-kankalovich I added a docker compose file under /docker we might move it at some point, but thinking about making it easier for developers to spin up services that integration tests depend on

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that would be super useful!

# Services should use NON standard ports to avoid collision with
version: "3"
name: langchain-tests

services:
redis:
image: redis/redis-stack-server:latest
# We use non standard ports since
# these instances are used for testing
# and users may already have existing
# redis instances set up locally
# for other projects
ports:
- "6020:6379"
volumes:
- ./redis-volume:/data
190 changes: 146 additions & 44 deletions libs/community/langchain_community/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import uuid
import warnings
from abc import ABC
from datetime import timedelta
from functools import lru_cache
from typing import (
Expand Down Expand Up @@ -351,21 +352,73 @@ def clear(self, **kwargs: Any) -> None:
self.redis.flushdb(flush_type=asynchronous)


class RedisCache(BaseCache):
"""Cache that uses Redis as a backend."""
class _RedisCacheBase(BaseCache, ABC):
@staticmethod
def _key(prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)

@staticmethod
def _ensure_generation_type(return_val: RETURN_VAL_TYPE) -> None:
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)

@staticmethod
def _get_generations(
results: dict[str | bytes, str | bytes],
) -> Optional[List[Generation]]:
generations = []
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None

@staticmethod
def _configure_pipeline_for_update(
key: str, pipe: Any, return_val: RETURN_VAL_TYPE, ttl: Optional[int] = None
) -> None:
pipe.hset(
key,
mapping={
str(idx): dumps(generation) for idx, generation in enumerate(return_val)
},
)
if ttl is not None:
pipe.expire(key, ttl)


class RedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use a sync `redis.Redis` client.
"""

def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of RedisCache.

This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class, allowing the object to interact with a Redis
server for caching purposes.
client class (`redis.Redis`), allowing the object
to interact with a Redis server for caching purposes.

Parameters:
redis_ (Any): An instance of a Redis client class
(e.g., redis.Redis) used for caching.
(`redis.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
Expand All @@ -377,61 +430,27 @@ def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
from redis import Redis
except ImportError:
raise ValueError(
"Could not import redis python package. "
"Could not import `redis` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass in Redis object.")
raise ValueError("Please pass a valid `redis.Redis` client.")
self.redis = redis_
self.ttl = ttl

def _key(self, prompt: str, llm_string: str) -> str:
"""Compute key from prompt and llm_string"""
return _hash(prompt + llm_string)

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
generations = []
# Read from a Redis HASH
results = self.redis.hgetall(self._key(prompt, llm_string))
if results:
for _, text in results.items():
try:
generations.append(loads(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)
# In a previous life we stored the raw text directly
# in the table, so assume it's in that format.
generations.append(Generation(text=text))
return generations if generations else None
return self._get_generations(results) # type: ignore[arg-type]

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"RedisCache only supports caching of normal LLM generations, "
f"got {type(gen)}"
)
# Write to a Redis HASH
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

with self.redis.pipeline() as pipe:
pipe.hset(
key,
mapping={
str(idx): dumps(generation)
for idx, generation in enumerate(return_val)
},
)
if self.ttl is not None:
pipe.expire(key, self.ttl)

self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
pipe.execute()

def clear(self, **kwargs: Any) -> None:
Expand All @@ -440,6 +459,89 @@ def clear(self, **kwargs: Any) -> None:
self.redis.flushdb(asynchronous=asynchronous, **kwargs)


class AsyncRedisCache(_RedisCacheBase):
"""
Cache that uses Redis as a backend. Allows to use an
async `redis.asyncio.Redis` client.
"""

def __init__(self, redis_: Any, *, ttl: Optional[int] = None):
"""
Initialize an instance of AsyncRedisCache.

This method initializes an object with Redis caching capabilities.
It takes a `redis_` parameter, which should be an instance of a Redis
client class (`redis.asyncio.Redis`), allowing the object
to interact with a Redis server for caching purposes.

Parameters:
redis_ (Any): An instance of a Redis client class
(`redis.asyncio.Redis`) to be used for caching.
This allows the object to communicate with a
Redis server for caching operations.
ttl (int, optional): Time-to-live (TTL) for cached items in seconds.
If provided, it sets the time duration for how long cached
items will remain valid. If not provided, cached items will not
have an automatic expiration.
"""
try:
from redis.asyncio import Redis
except ImportError:
raise ValueError(
"Could not import `redis.asyncio` python package. "
"Please install it with `pip install redis`."
)
if not isinstance(redis_, Redis):
raise ValueError("Please pass a valid `redis.asyncio.Redis` client.")
self.redis = redis_
self.ttl = ttl

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `lookup()` method. "
"Consider using the async `alookup()` version."
)

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string. Async version."""
results = await self.redis.hgetall(self._key(prompt, llm_string))
return self._get_generations(results) # type: ignore[arg-type]

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
raise NotImplementedError(
"This async Redis cache does not implement `update()` method. "
"Consider using the async `aupdate()` version."
)

async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string. Async version."""
self._ensure_generation_type(return_val)
key = self._key(prompt, llm_string)

async with self.redis.pipeline() as pipe:
self._configure_pipeline_for_update(key, pipe, return_val, self.ttl)
await pipe.execute() # type: ignore[attr-defined]

def clear(self, **kwargs: Any) -> None:
"""Clear cache. If `asynchronous` is True, flush asynchronously."""
raise NotImplementedError(
"This async Redis cache does not implement `clear()` method. "
"Consider using the async `aclear()` version."
)

async def aclear(self, **kwargs: Any) -> None:
"""
Clear cache. If `asynchronous` is True, flush asynchronously.
Async version.
"""
asynchronous = kwargs.get("asynchronous", False)
await self.redis.flushdb(asynchronous=asynchronous, **kwargs)


class RedisSemanticCache(BaseCache):
"""Cache that uses Redis as a vector-store backend."""

Expand Down
15 changes: 15 additions & 0 deletions libs/core/langchain_core/caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Sequence

from langchain_core.outputs import Generation
from langchain_core.runnables import run_in_executor

RETURN_VAL_TYPE = Sequence[Generation]

Expand All @@ -22,3 +23,17 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
@abstractmethod
def clear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""

async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to add these methods to the BaseCache class with a default implementation that calls the non-async method with run_in_executor like what is done in BaseLLM and Embeddings.
Otherwise, there is no reason to prefix the methods here with the a prefix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzmitry-kankalovich We generally prefer keeping async version of methods in the same class as the sync versions, and have them provide a default async implementation that uses run_in_executor like @cbornet suggested.

Would you mind merging into existing abstractions? Ideally a single PR that just modifies the core interface first, and then separately we can do a PR for any implementations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for the smaller PRs - yes, I can move slower and split up into several PRs, if you are happy with current direction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally better to split PRs by package to minimize potential dependency conflicts since the packages may have different release schedules

"""Look up based on prompt and llm_string."""
return await run_in_executor(None, self.lookup, prompt, llm_string)

async def aupdate(
self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE
) -> None:
"""Update cache based on prompt and llm_string."""
return await run_in_executor(None, self.update, prompt, llm_string, return_val)

async def aclear(self, **kwargs: Any) -> None:
"""Clear cache that can take additional keyword arguments."""
return await run_in_executor(None, self.clear, **kwargs)
4 changes: 2 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ async def _agenerate_with_cache(
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
Expand All @@ -632,7 +632,7 @@ async def _agenerate_with_cache(
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result

@abstractmethod
Expand Down
42 changes: 40 additions & 2 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,26 @@ def get_prompts(
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts


async def aget_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
"""Get prompts that are already cached. Async version."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
llm_cache = get_llm_cache()
for i, prompt in enumerate(prompts):
if llm_cache:
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts


def update_cache(
existing_prompts: Dict[int, List],
llm_string: str,
Expand All @@ -157,6 +177,24 @@ def update_cache(
return llm_output


async def aupdate_cache(
existing_prompts: Dict[int, List],
llm_string: str,
missing_prompt_idxs: List[int],
new_results: LLMResult,
prompts: List[str],
) -> Optional[dict]:
"""Update the cache and get the LLM output. Async version"""
llm_cache = get_llm_cache()
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
if llm_cache:
await llm_cache.aupdate(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output


class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.

Expand Down Expand Up @@ -869,7 +907,7 @@ async def agenerate(
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
) = await aget_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
Expand Down Expand Up @@ -917,7 +955,7 @@ async def agenerate(
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
llm_output = update_cache(
llm_output = await aupdate_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = (
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/cache.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: shouldn't import this in langchain, imports here are only for backwards compatibility

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Noted! Maybe makes sense to drop a comment in that file next time you are working with it - future contributors will know.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langchain_community.cache import (
AstraDBCache,
AstraDBSemanticCache,
AsyncRedisCache,
CassandraCache,
CassandraSemanticCache,
FullLLMCache,
Expand All @@ -22,6 +23,7 @@
"SQLAlchemyCache",
"SQLiteCache",
"UpstashRedisCache",
"AsyncRedisCache",
"RedisCache",
"RedisSemanticCache",
"GPTCache",
Expand Down
Loading
Loading