-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #288 from lu-ny/248_Component-Memory-missing-call-…
…method Added call method and dialog turn to memory class, added 3 memory tests [issue #248]
- Loading branch information
Showing
3 changed files
with
114 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,77 @@ | ||
"""Memory for user-assistant conversations. [Not completed] | ||
"""Memory component for user-assistant conversations. | ||
Memory can include data modeling, in-memory data storage, local file data storage, cloud data persistence, data pipeline, data retriever. | ||
It is itself an LLM application and different use cases can do it differently. | ||
This component handles the storage and retrieval of conversation history between users | ||
and assistants. It provides local memory experience with the ability to format and | ||
return conversation history. | ||
This implementation covers the minimal and local memory experience for the user-assistant conversation. | ||
Attributes: | ||
current_conversation (Conversation): Stores the current active conversation. | ||
turn_db (LocalDB): Database for storing all conversation turns. | ||
conver_db (LocalDB): Database for storing complete conversations. | ||
""" | ||
|
||
from uuid import uuid4 | ||
from adalflow.core.component import Component | ||
from adalflow.core.db import LocalDB | ||
from adalflow.core.types import ( | ||
Conversation, | ||
DialogTurn, | ||
UserQuery, | ||
AssistantResponse, | ||
) | ||
|
||
from adalflow.core.db import LocalDB | ||
from adalflow.core.component import Component | ||
|
||
|
||
class Memory(Component): | ||
def __init__(self, turn_db: LocalDB = None): | ||
"""Initialize the Memory component. | ||
Args: | ||
turn_db (LocalDB, optional): Database for storing conversation turns. | ||
Defaults to None, in which case a new LocalDB is created. | ||
""" | ||
super().__init__() | ||
self.current_convesation = Conversation() | ||
self.current_conversation = Conversation() | ||
self.turn_db = turn_db or LocalDB() # all turns | ||
self.conver_db = LocalDB() # a list of conversations | ||
|
||
def call(self) -> str: | ||
"""Returns the current conversation history as a formatted string. | ||
Returns: | ||
str: Formatted conversation history with alternating user and assistant messages. | ||
Returns empty string if no conversation history exists. | ||
""" | ||
if not self.current_conversation.dialog_turns: | ||
return "" | ||
|
||
formatted_history = [] | ||
for turn in self.current_conversation.dialog_turns.values(): | ||
formatted_history.extend( | ||
[ | ||
f"User: {turn.user_query.query_str}", | ||
f"Assistant: {turn.assistant_response.response_str}", | ||
] | ||
) | ||
return "\n".join(formatted_history) | ||
|
||
def add_dialog_turn(self, user_query: str, assistant_response: str): | ||
"""Add a new dialog turn to the current conversation. | ||
Args: | ||
user_query (str): The user's input message. | ||
assistant_response (str): The assistant's response message. | ||
""" | ||
dialog_turn = DialogTurn( | ||
id=str(uuid4()), | ||
user_query=UserQuery(query_str=user_query), | ||
assistant_response=AssistantResponse(response_str=assistant_response), | ||
) | ||
|
||
self.current_conversation.append_dialog_turn(dialog_turn) | ||
|
||
self.turn_db.add( | ||
{"user_query": user_query, "assistant_response": assistant_response} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from adalflow.components.memory.memory import Memory | ||
|
||
|
||
def test_empty_memory(): | ||
memory = Memory() | ||
assert memory() == "" | ||
|
||
|
||
def test_add_dialog_turn(): | ||
memory = Memory() | ||
memory.add_dialog_turn("Hello", "Hi! How can I help you?") | ||
expected = "User: Hello\nAssistant: Hi! How can I help you?" | ||
assert memory() == expected | ||
|
||
|
||
def test_multiple_turns(): | ||
memory = Memory() | ||
memory.add_dialog_turn("Hello", "Hi!") | ||
memory.add_dialog_turn("How are you?", "I'm good!") | ||
expected = ( | ||
"User: Hello\n" "Assistant: Hi!\n" "User: How are you?\n" "Assistant: I'm good!" | ||
) | ||
assert memory() == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters