Skip to content

Commit

Permalink
Add a conversation memory that combines a (optionally persistent) vec…
Browse files Browse the repository at this point in the history
…torstore history with a token buffer (#22155)

**langchain: ConversationVectorStoreTokenBufferMemory**

-**Description:** This PR adds ConversationVectorStoreTokenBufferMemory.
It is similar in concept to ConversationSummaryBufferMemory. It
maintains an in-memory buffer of messages up to a preset token limit.
After the limit is hit timestamped messages are written into a
vectorstore retriever rather than into a summary. The user's prompt is
then used to retrieve relevant fragments of the previous conversation.
By persisting the vectorstore, one can maintain memory from session to
session.
-**Issue:** n/a
-**Dependencies:** none
-**Twitter handle:** Please no!!!
- [X] **Add tests and docs**: I looked to see how the unit tests were
written for the other ConversationMemory modules, but couldn't find
anything other than a test for successful import. I need to know whether
you are using pytest.mock or another fixture to simulate the LLM and
vectorstore. In addition, I would like guidance on where to place the
documentation. Should it be a notebook file in docs/docs?

- [X] **Lint and test**: I am seeing some linting errors from a couple
of modules unrelated to this PR.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Lincoln Stein <[email protected]>
Co-authored-by: isaac hershenson <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 32f8f39 commit c314222
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""__ModuleName__ document loader."""

from typing import Iterator

from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document

Expand Down
4 changes: 4 additions & 0 deletions libs/langchain/langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.memory.vectorstore import VectorStoreRetrieverMemory
from langchain.memory.vectorstore_token_buffer_memory import (
ConversationVectorStoreTokenBufferMemory, # avoid circular import
)

if TYPE_CHECKING:
from langchain_community.chat_message_histories import (
Expand Down Expand Up @@ -122,6 +125,7 @@ def __getattr__(name: str) -> Any:
"ConversationSummaryBufferMemory",
"ConversationSummaryMemory",
"ConversationTokenBufferMemory",
"ConversationVectorStoreTokenBufferMemory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
Expand Down
184 changes: 184 additions & 0 deletions libs/langchain/langchain/memory/vectorstore_token_buffer_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
Class for a conversation memory buffer with older messages stored in a vectorstore .
This implementats a conversation memory in which the messages are stored in a memory
buffer up to a specified token limit. When the limit is exceeded, older messages are
saved to a vectorstore backing database. The vectorstore can be made persistent across
sessions.
"""

import warnings
from datetime import datetime
from typing import Any, Dict, List

from langchain_core.messages import BaseMessage
from langchain_core.prompts.chat import SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import Field, PrivateAttr
from langchain_core.vectorstores import VectorStoreRetriever

from langchain.memory import ConversationTokenBufferMemory, VectorStoreRetrieverMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.text_splitter import RecursiveCharacterTextSplitter

DEFAULT_HISTORY_TEMPLATE = """
Current date and time: {current_time}.
Potentially relevant timestamped excerpts of previous conversations (you
do not need to use these if irrelevant):
{previous_history}
"""

TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S %Z"


class ConversationVectorStoreTokenBufferMemory(ConversationTokenBufferMemory):
"""Conversation chat memory with token limit and vectordb backing.
load_memory_variables() will return a dict with the key "history".
It contains background information retrieved from the vector store
plus recent lines of the current conversation.
To help the LLM understand the part of the conversation stored in the
vectorstore, each interaction is timestamped and the current date and
time is also provided in the history. A side effect of this is that the
LLM will have access to the current date and time.
Initialization arguments:
This class accepts all the initialization arguments of
ConversationTokenBufferMemory, such as `llm`. In addition, it
accepts the following additional arguments
retriever: (required) A VectorStoreRetriever object to use
as the vector backing store
split_chunk_size: (optional, 1000) Token chunk split size
for long messages generated by the AI
previous_history_template: (optional) Template used to format
the contents of the prompt history
Example using ChromaDB:
.. code-block:: python
from langchain.memory.token_buffer_vectorstore_memory import (
ConversationVectorStoreTokenBufferMemory
)
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_openai import OpenAI
embedder = HuggingFaceInstructEmbeddings(
query_instruction="Represent the query for retrieval: "
)
chroma = Chroma(collection_name="demo",
embedding_function=embedder,
collection_metadata={"hnsw:space": "cosine"},
)
retriever = chroma.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
'k': 5,
'score_threshold': 0.75,
},
)
conversation_memory = ConversationVectorStoreTokenBufferMemory(
return_messages=True,
llm=OpenAI(),
retriever=retriever,
max_token_limit = 1000,
)
conversation_memory.save_context({"Human": "Hi there"},
{"AI": "Nice to meet you!"}
)
conversation_memory.save_context({"Human": "Nice day isn't it?"},
{"AI": "I love Wednesdays."}
)
conversation_memory.load_memory_variables({"input": "What time is it?"})
"""

retriever: VectorStoreRetriever = Field(exclude=True)
memory_key: str = "history"
previous_history_template: str = DEFAULT_HISTORY_TEMPLATE
split_chunk_size: int = 1000

_memory_retriever: VectorStoreRetrieverMemory = PrivateAttr(default=None)
_timestamps: List[datetime] = PrivateAttr(default_factory=list)

@property
def memory_retriever(self) -> VectorStoreRetrieverMemory:
"""Return a memory retriever from the passed retriever object."""
if self._memory_retriever is not None:
return self._memory_retriever
self._memory_retriever = VectorStoreRetrieverMemory(retriever=self.retriever)
return self._memory_retriever

def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history and memory buffer."""
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
memory_variables = self.memory_retriever.load_memory_variables(inputs)
previous_history = memory_variables[self.memory_retriever.memory_key]
except AssertionError: # happens when db is empty
previous_history = ""
current_history = super().load_memory_variables(inputs)
template = SystemMessagePromptTemplate.from_template(
self.previous_history_template
)
messages = [
template.format(
previous_history=previous_history,
current_time=datetime.now().astimezone().strftime(TIMESTAMP_FORMAT),
)
]
messages.extend(current_history[self.memory_key])
return {self.memory_key: messages}

def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer. Pruned."""
BaseChatMemory.save_context(self, inputs, outputs)
self._timestamps.append(datetime.now().astimezone())
# Prune buffer if it exceeds max token limit
buffer = self.chat_memory.messages
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)
if curr_buffer_length > self.max_token_limit:
while curr_buffer_length > self.max_token_limit:
self._pop_and_store_interaction(buffer)
curr_buffer_length = self.llm.get_num_tokens_from_messages(buffer)

def save_remainder(self) -> None:
"""
Save the remainder of the conversation buffer to the vector store.
This is useful if you have made the vectorstore persistent, in which
case this can be called before the end of the session to store the
remainder of the conversation.
"""
buffer = self.chat_memory.messages
while len(buffer) > 0:
self._pop_and_store_interaction(buffer)

def _pop_and_store_interaction(self, buffer: List[BaseMessage]) -> None:
input = buffer.pop(0)
output = buffer.pop(0)
timestamp = self._timestamps.pop(0).strftime(TIMESTAMP_FORMAT)
# Split AI output into smaller chunks to avoid creating documents
# that will overflow the context window
ai_chunks = self._split_long_ai_text(str(output.content))
for index, chunk in enumerate(ai_chunks):
self.memory_retriever.save_context(
{"Human": f"<{timestamp}/00> {str(input.content)}"},
{"AI": f"<{timestamp}/{index:02}> {chunk}"},
)

def _split_long_ai_text(self, text: str) -> List[str]:
splitter = RecursiveCharacterTextSplitter(chunk_size=self.split_chunk_size)
return [chunk.page_content for chunk in splitter.create_documents([text])]
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/memory/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"ConversationSummaryBufferMemory",
"ConversationSummaryMemory",
"ConversationTokenBufferMemory",
"ConversationVectorStoreTokenBufferMemory",
"CosmosDBChatMessageHistory",
"DynamoDBChatMessageHistory",
"ElasticsearchChatMessageHistory",
Expand Down

0 comments on commit c314222

Please sign in to comment.