Skip to content

Commit

Permalink
add external memory and base memory
Browse files Browse the repository at this point in the history
  • Loading branch information
pkelaita committed May 15, 2024
1 parent a6a1b38 commit fb267f2
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 57 deletions.
110 changes: 72 additions & 38 deletions l2m2/client/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Set, Dict, Optional
from typing import Any, Set, Dict, Optional, Tuple

import google.generativeai as google
from cohere import Client as CohereClient
Expand All @@ -14,7 +14,14 @@
ModelEntry,
ParamName,
)
from l2m2.memory import ChatMemory, DEFAULT_WINDOW_SIZE
from l2m2.memory import (
ChatMemory,
CHAT_MEMORY_DEFAULT_WINDOW_SIZE,
ExternalMemory,
ExternalMemoryLoadingType,
MemoryType,
)
from l2m2.memory.base_memory import BaseMemory


class LLMClient:
Expand All @@ -23,8 +30,9 @@ class LLMClient:
def __init__(
self,
providers: Optional[Dict[str, str]] = None,
enable_memory: bool = False,
memory_window_size: int = DEFAULT_WINDOW_SIZE,
memory_type: Optional[MemoryType] = None,
memory_window_size: int = CHAT_MEMORY_DEFAULT_WINDOW_SIZE,
memory_loading_type: ExternalMemoryLoadingType = ExternalMemoryLoadingType.SYSTEM_PROMPT_APPEND,
) -> None:
"""Initialize the LLMClient, optionally with active providers.
Expand All @@ -39,9 +47,13 @@ def __init__(
}
Defaults to `None`.
enable_memory (bool, optional): Whether to enable memory. Defaults to `False`.
window_size (int, optional): The size of the memory window. Defaults to
`l2m2.memory.DEFAULT_WINDOW_SIZE`.
memory_type (MemoryType, optional): The type of memory to enable. If `None`, memory is
not enabled. Defaults to `None`.
memory_window_size (int, optional): The size of the memory window. Only applicable if
`memory_type` is `MemoryType.CHAT`, otherwise ignored. Defaults to `40`.
memory_loading_type (ExternalMemoryLoadingType, optional): How the model should load
external memory. Only applicable if `memory_type` is `MemoryType.EXTERNAL`,
otherwise ignored. Defaults to `ExternalMemoryLoadingType.SYSTEM_PROMPT`.
Raises:
ValueError: If an invalid provider is specified in `providers`.
Expand All @@ -51,15 +63,17 @@ def __init__(
self.active_providers: Set[str] = set()
self.active_models: Set[str] = set()
self.preferred_providers: Dict[str, str] = {}

self.memory: Optional[ChatMemory] = None
self.memory: Optional[BaseMemory] = None

if providers is not None:
for provider, api_key in providers.items():
self.add_provider(provider, api_key)

if enable_memory:
self.memory = ChatMemory(memory_window_size)
if memory_type is not None:
if memory_type == MemoryType.CHAT:
self.memory = ChatMemory(window_size=memory_window_size)
elif memory_type == MemoryType.EXTERNAL:
self.memory = ExternalMemory(loading_type=memory_loading_type)

@staticmethod
def get_available_providers() -> Set[str]:
Expand Down Expand Up @@ -179,20 +193,18 @@ def set_preferred_providers(self, preferred_providers: Dict[str, str]) -> None:

self.preferred_providers.update(preferred_providers)

def get_memory(self) -> ChatMemory:
def get_memory(self) -> BaseMemory:
"""Get the memory object, if memory is enabled.
Returns:
ChatMemory: The memory object.
BaseMemory: The memory object.
Raises:
ValueError: If memory is not enabled.
"""
if self.memory is None:
raise ValueError(
"Client memory is not enabled. Instantiate the LLM client with enable_memory=True"
+ " to enable memory."
)
raise ValueError("Memory is not enabled.")

return self.memory

def clear_memory(self) -> None:
Expand All @@ -202,25 +214,19 @@ def clear_memory(self) -> None:
ValueError: If memory is not enabled.
"""
if self.memory is None:
raise ValueError(
"Client memory is not enabled. Instantiate the LLM client with enable_memory=True"
+ " to enable memory."
)
raise ValueError("Memory is not enabled.")

self.memory.clear()

def enable_memory(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
"""Enable memory, with a specified window size.
def load_memory(self, memory_object: BaseMemory) -> None:
"""Loads memory into the LLM client. If the client already has memory enabled, the existing
memory is replaced with the new memory.
Args:
window_size (int, optional): The size of the memory window. Defaults to
`l2m2.memory.DEFAULT_WINDOW_SIZE`.
memory_object (BaseMemory): The memory to load.
Raises:
ValueError: If memory is already enabled.
"""
if self.memory is not None:
raise ValueError("Memory is already enabled.")
self.memory = ChatMemory(window_size)
self.memory = memory_object

def call(
self,
Expand Down Expand Up @@ -390,12 +396,19 @@ def add_param(name: ParamName, value: Any) -> None:
add_param("temperature", temperature)
add_param("max_tokens", max_tokens)

if isinstance(self.memory, ExternalMemory):
system_prompt, prompt = self._get_external_memory_prompts(
system_prompt, prompt
)

result = getattr(self, f"_call_{provider}")(
model_info["model_id"], prompt, system_prompt, params
)
if self.memory is not None:

if isinstance(self.memory, ChatMemory):
self.memory.add_user_message(prompt)
self.memory.add_agent_message(result)

return str(result)

def _call_openai(
Expand All @@ -409,7 +422,7 @@ def _call_openai(
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if self.memory is not None:
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = oai.chat.completions.create(
Expand All @@ -430,7 +443,7 @@ def _call_anthropic(
if system_prompt is not None:
params["system"] = system_prompt
messages = []
if self.memory is not None:
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = anthr.messages.create(
Expand All @@ -450,7 +463,7 @@ def _call_cohere(
cohere = CohereClient(api_key=self.api_keys["cohere"])
if system_prompt is not None:
params["preamble"] = system_prompt
if self.memory is not None:
if isinstance(self.memory, ChatMemory):
params["chat_history"] = self.memory.unpack(
"role", "message", "USER", "CHATBOT"
)
Expand All @@ -472,7 +485,7 @@ def _call_groq(
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if self.memory is not None:
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "content", "user", "assistant"))
messages.append({"role": "user", "content": prompt})
result = groq.chat.completions.create(
Expand Down Expand Up @@ -501,7 +514,7 @@ def _call_google(
model = google.GenerativeModel(**model_params)

messages = []
if self.memory is not None:
if isinstance(self.memory, ChatMemory):
messages.extend(self.memory.unpack("role", "parts", "user", "model"))
messages.append({"role": "user", "parts": prompt})

Expand All @@ -521,8 +534,9 @@ def _call_replicate(
system_prompt: Optional[str],
params: Dict[str, Any],
) -> str:
if self.memory is not None:
raise ValueError("Memory is not supported with Replicate models.")
if isinstance(self.memory, ChatMemory):
raise ValueError("Chat memory is not supported with Replicate models.")

client = replicate.Client(api_token=self.api_keys["replicate"])
if system_prompt is not None:
params["system_prompt"] = system_prompt
Expand All @@ -534,3 +548,23 @@ def _call_replicate(
},
)
return "".join(result)

def _get_external_memory_prompts(
self, system_prompt: Optional[str], prompt: str
) -> Tuple[str, str]:
if not isinstance(self.memory, ExternalMemory):
raise ValueError("Memory is not enabled or is not of type ExternalMemory.")

if system_prompt is None:
system_prompt = ""

if self.memory.loading_type == ExternalMemoryLoadingType.SYSTEM_PROMPT_APPEND:
system_prompt += "\n" + self.memory.get_contents()
elif self.memory.loading_type == ExternalMemoryLoadingType.USER_PROMPT_APPEND:
prompt += "\n" + self.memory.get_contents()
else:
raise NotImplementedError(
f"Loading type {self.memory.loading_type} is not yet supported."
)

return system_prompt, prompt
18 changes: 16 additions & 2 deletions l2m2/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from .chat_memory import ChatMemory, ChatMemoryEntry, DEFAULT_WINDOW_SIZE
from .chat_memory import (
ChatMemory,
ChatMemoryEntry,
CHAT_MEMORY_DEFAULT_WINDOW_SIZE,
)
from .external_memory import ExternalMemory, ExternalMemoryLoadingType

__all__ = ["ChatMemory", "ChatMemoryEntry", "DEFAULT_WINDOW_SIZE"]
from .base_memory import MemoryType

__all__ = [
"ChatMemory",
"ChatMemoryEntry",
"CHAT_MEMORY_DEFAULT_WINDOW_SIZE",
"ExternalMemory",
"ExternalMemoryLoadingType",
"MemoryType",
]
27 changes: 27 additions & 0 deletions l2m2/memory/base_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from abc import ABC, abstractmethod


class MemoryType(Enum):
"""The type of memory used by the model."""

CHAT = "chat"
EXTERNAL = "external"


class BaseMemory(ABC):
"""Abstract representation of a model's memory."""

def __init__(self, memory_type: MemoryType) -> None:
"""Create a new BaseMemory object.
Args:
memory_type (MemoryType): The type of memory managed by the model.
"""
self.memory_type = memory_type

@abstractmethod
def clear(self) -> None:
"""Clears the model's memory."""

pass
10 changes: 7 additions & 3 deletions l2m2/memory/chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Deque, Iterator, List, Dict
from enum import Enum

DEFAULT_WINDOW_SIZE = 40
from l2m2.memory.base_memory import BaseMemory, MemoryType

CHAT_MEMORY_DEFAULT_WINDOW_SIZE = 40


class ChatMemoryEntry:
Expand All @@ -17,13 +19,13 @@ def __init__(self, text: str, role: Role):
self.role = role


class ChatMemory:
class ChatMemory(BaseMemory):
"""Represents a sliding-window conversation memory between a user and an agent. `ChatMemory` is
the most basic type of memory and is designed to be passed directly to chat-based models such
as `llama3-70b-instruct`.
"""

def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
def __init__(self, window_size: int = CHAT_MEMORY_DEFAULT_WINDOW_SIZE) -> None:
"""Create a new ChatMemory object.
Args:
Expand All @@ -33,6 +35,8 @@ def __init__(self, window_size: int = DEFAULT_WINDOW_SIZE) -> None:
Raises:
ValueError: If `window_size` is less than or equal to 0.
"""
super().__init__(MemoryType.CHAT)

if not window_size > 0:
raise ValueError("window_size must be a positive integer.")

Expand Down
61 changes: 61 additions & 0 deletions l2m2/memory/external_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from enum import Enum

from l2m2.memory.base_memory import BaseMemory, MemoryType


class ExternalMemoryLoadingType(Enum):
"""Represents how the model should load external memory."""

SYSTEM_PROMPT_APPEND = "system_prompt_append"
USER_PROMPT_APPEND = "user_prompt_append"


class ExternalMemory(BaseMemory):
"""Represents custom memory that is managed completely externally to the model."""

def __init__(
self,
contents: str = "",
loading_type: ExternalMemoryLoadingType = ExternalMemoryLoadingType.SYSTEM_PROMPT_APPEND,
) -> None:
"""Create a new ExternalMemory object.
Args:
contents (str, optional): The memory to pre-load. Defaults to "".
loading_type (LoadingType, optional): How the model should load the memory –
either in the system prompt, inserted as a user prompt, or appended to the
most recent user prompt. Defaults to LoadingType.SYSTEM_PROMPT.
"""

super().__init__(MemoryType.EXTERNAL)
self.contents: str = contents
self.loading_type: ExternalMemoryLoadingType = loading_type

def get_contents(self) -> str:
"""Get the contents of the memory object.
Returns:
str: The entire memory contents
"""

return self.contents

def set_contents(self, new_contents: str) -> None:
"""Load the contents into the memory object, replacing the existing contents.
Args:
new_contents (str): The new contents to load.
"""
self.contents = new_contents

def append_contents(self, new_contents: str) -> None:
"""Append new contents to the memory object.
Args:
new_contents (str): The new contents to append.
"""
self.contents += new_contents

def clear(self) -> None:
"""Clear the memory."""
self.contents = ""
Loading

0 comments on commit fb267f2

Please sign in to comment.