Skip to content

Commit

Permalink
reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
doodledood committed Nov 12, 2023
1 parent 924d941 commit 10dc8d4
Show file tree
Hide file tree
Showing 50 changed files with 1,553 additions and 1,408 deletions.
70 changes: 36 additions & 34 deletions chatflock/ai_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Sequence, Type, TypeVar

import json
from json import JSONDecodeError
from typing import Optional, Dict, Any, TypeVar, Type, Sequence

from halo import Halo
from langchain.chat_models.base import BaseChatModel
Expand All @@ -14,38 +15,41 @@


def execute_chat_model_messages(
chat_model: BaseChatModel,
messages: Sequence[BaseMessage],
chat_model_args: Optional[Dict[str, Any]] = None,
tools: Optional[Sequence[BaseTool]] = None,
spinner: Optional[Halo] = None) -> str:
chat_model: BaseChatModel,
messages: Sequence[BaseMessage],
chat_model_args: Optional[Dict[str, Any]] = None,
tools: Optional[Sequence[BaseTool]] = None,
spinner: Optional[Halo] = None,
) -> str:
chat_model_args = chat_model_args or {}

assert 'functions' not in chat_model_args, ('The `functions` argument is reserved for the '
'`execute_chat_model_messages` function. If you want to add more '
'functions use the `functions` argument to this method.')
assert "functions" not in chat_model_args, (
"The `functions` argument is reserved for the "
"`execute_chat_model_messages` function. If you want to add more "
"functions use the `functions` argument to this method."
)

if tools is not None and len(tools) > 0:
chat_model_args['functions'] = [format_tool_to_openai_function(tool) for tool in tools]
chat_model_args["functions"] = [format_tool_to_openai_function(tool) for tool in tools]

function_map = {tool.name: tool for tool in tools or []}

all_messages = list(messages).copy()

last_message = chat_model.predict_messages(all_messages, **chat_model_args)
function_call = last_message.additional_kwargs.get('function_call')
function_call = last_message.additional_kwargs.get("function_call")

while function_call is not None:
function_name = function_call['name']
function_name = function_call["name"]
if function_name in function_map:
tool = function_map[function_name]
args = function_call['arguments']
args = function_call["arguments"]

if spinner is not None:
if hasattr(tool, 'progress_text'):
if hasattr(tool, "progress_text"):
progress_text = tool.progress_text
else:
progress_text = f'Executing function `{function_name}`...'
progress_text = f"Executing function `{function_name}`..."

spinner.start(progress_text)

Expand All @@ -59,37 +63,35 @@ def execute_chat_model_messages(
args = json.loads(args)
result = tool.run(args)
except JSONDecodeError as e:
result = f'Error decoding args for function: {e}'
result = f"Error decoding args for function: {e}"
except Exception as e:
result = f'Error executing function: {e}'
result = f"Error executing function: {e}"

all_messages.append(FunctionMessage(
name=function_name,
content=f'The function execution returned:\n```{str(result).strip()}```' or 'None'
))
all_messages.append(
FunctionMessage(
name=function_name,
content=f"The function execution returned:\n```{str(result).strip()}```" or "None",
)
)

last_message = chat_model.predict_messages(all_messages, **chat_model_args)
function_call = last_message.additional_kwargs.get('function_call')
function_call = last_message.additional_kwargs.get("function_call")
else:
raise FunctionNotFoundError(function_name)

return str(last_message.content)


PydanticType = TypeVar('PydanticType', bound=Type[BaseModel])
PydanticType = TypeVar("PydanticType", bound=Type[BaseModel])


def pydantic_to_openai_function(pydantic_type: PydanticType,
function_name: Optional[str] = None,
function_description: Optional[str] = None) -> Dict[str, Any]:
def pydantic_to_openai_function(
pydantic_type: PydanticType, function_name: Optional[str] = None, function_description: Optional[str] = None
) -> Dict[str, Any]:
base_schema = pydantic_type.model_json_schema()
del base_schema['title']
del base_schema['description']
del base_schema["title"]
del base_schema["description"]

description = function_description if function_description is not None else (pydantic_type.__doc__ or '')
description = function_description if function_description is not None else (pydantic_type.__doc__ or "")

return {
'name': function_name or pydantic_type.__name__,
'description': description,
'parameters': base_schema
}
return {"name": function_name or pydantic_type.__name__, "description": description, "parameters": base_schema}
5 changes: 1 addition & 4 deletions chatflock/backing_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .in_memory import InMemoryChatDataBackingStore
from .langchain import LangChainMemoryBasedChatDataBackingStore

__all__ = [
'InMemoryChatDataBackingStore',
'LangChainMemoryBasedChatDataBackingStore'
]
__all__ = ["InMemoryChatDataBackingStore", "LangChainMemoryBasedChatDataBackingStore"]
22 changes: 13 additions & 9 deletions chatflock/backing_stores/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Dict, List, Optional

import datetime
from typing import List, Dict, Optional

from chatflock.base import ChatDataBackingStore, ChatMessage, ChatParticipant, ActiveChatParticipant
from chatflock.base import ActiveChatParticipant, ChatDataBackingStore, ChatMessage, ChatParticipant
from chatflock.errors import ChatParticipantAlreadyJoinedToChatError, ChatParticipantNotJoinedToChatError


Expand All @@ -10,8 +11,9 @@ class InMemoryChatDataBackingStore(ChatDataBackingStore):
participants: Dict[str, ChatParticipant]
last_message_id: Optional[int] = None

def __init__(self, messages: Optional[List[ChatMessage]] = None,
participants: Optional[List[ChatParticipant]] = None):
def __init__(
self, messages: Optional[List[ChatMessage]] = None, participants: Optional[List[ChatParticipant]] = None
):
self.messages = messages or []
self.participants = {participant.name: participant for participant in (participants or [])}
self.last_message_id = None if len(self.messages) == 0 else self.messages[-1].id
Expand All @@ -26,7 +28,7 @@ def add_message(self, sender_name: str, content: str, timestamp: Optional[dateti
id=self.last_message_id,
sender_name=sender_name,
content=content,
timestamp=timestamp or datetime.datetime.now()
timestamp=timestamp or datetime.datetime.now(),
)

self.messages.append(message)
Expand All @@ -39,15 +41,17 @@ def clear_messages(self):

def get_active_participants(self) -> List[ActiveChatParticipant]:
participants = list(self.participants.values())
active_participants = [participant for participant in participants if
isinstance(participant, ActiveChatParticipant)]
active_participants = [
participant for participant in participants if isinstance(participant, ActiveChatParticipant)
]

return active_participants

def get_non_active_participants(self) -> List[ChatParticipant]:
participants = list(self.participants.values())
participants = [participant for participant in participants if
not isinstance(participant, ActiveChatParticipant)]
participants = [
participant for participant in participants if not isinstance(participant, ActiveChatParticipant)
]

return participants

Expand Down
64 changes: 31 additions & 33 deletions chatflock/backing_stores/langchain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable, List, Optional

import datetime
import re
from typing import List, Optional, Callable

from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage
Expand All @@ -12,47 +13,42 @@
def base_message_to_chat_message(base_message: BaseMessage) -> ChatMessage:
content = str(base_message.content)

pattern = re.compile(r'(\d+)\.\s*(.+?):\s*(.*)', re.DOTALL)
pattern = re.compile(r"(\d+)\.\s*(.+?):\s*(.*)", re.DOTALL)
match = pattern.match(content)

if not match:
return ChatMessage(
id=-1,
sender_name='SYSTEM',
content=content
)
return ChatMessage(id=-1, sender_name="SYSTEM", content=content)

id_number = int(match.group(1))
sender_name = match.group(2)
message_content = match.group(3)

return ChatMessage(
id=id_number,
sender_name=sender_name,
content=message_content
)
return ChatMessage(id=id_number, sender_name=sender_name, content=message_content)


class LangChainMemoryBasedChatDataBackingStore(InMemoryChatDataBackingStore):
no_output_message: str = '##NO_OUTPUT##'

def __init__(self,
memory: BaseChatMemory,
memory_key_getter: Optional[Callable[[BaseChatMemory], str]] = None,
messages: Optional[List[ChatMessage]] = None,
include_timestamp_in_messages: bool = False,
participants: Optional[List[ChatParticipant]] = None):
no_output_message: str = "##NO_OUTPUT##"

def __init__(
self,
memory: BaseChatMemory,
memory_key_getter: Optional[Callable[[BaseChatMemory], str]] = None,
messages: Optional[List[ChatMessage]] = None,
include_timestamp_in_messages: bool = False,
participants: Optional[List[ChatParticipant]] = None,
):
super().__init__(participants=participants)

self.memory = memory
self.include_timestamp_in_messages = include_timestamp_in_messages

if memory_key_getter is None:

def default_memory_key_getter(memory: BaseChatMemory) -> str:
if hasattr(memory, 'memory_key'):
if hasattr(memory, "memory_key"):
return str(memory.memory_key)

return self.memory.output_key or 'history'
return self.memory.output_key or "history"

self.memory_key_getter: Callable[[BaseChatMemory], str] = default_memory_key_getter
else:
Expand All @@ -65,8 +61,11 @@ def get_messages(self) -> List[ChatMessage]:

memory_key = self.memory_key_getter(self.memory)
base_messages = self.memory.load_memory_variables({})[memory_key]
chat_messages = [base_message_to_chat_message(base_message) for base_message in base_messages if
base_message.content != self.no_output_message]
chat_messages = [
base_message_to_chat_message(base_message)
for base_message in base_messages
if base_message.content != self.no_output_message
]

self.memory.return_messages = prev_return_messages

Expand All @@ -75,16 +74,15 @@ def get_messages(self) -> List[ChatMessage]:
def add_message(self, sender_name: str, content: str, timestamp: Optional[datetime.datetime] = None) -> ChatMessage:
message = super().add_message(sender_name=sender_name, content=content)

prefix = ''
prefix = ""
if self.include_timestamp_in_messages:
pretty_datetime = message.timestamp.strftime('%m-%d-%Y %H:%M:%S')
prefix = f'[{pretty_datetime}] '

self.memory.save_context({
"input": f'{prefix}{message.id}. {message.sender_name}: {message.content}'
}, {
'output': self.no_output_message
})
pretty_datetime = message.timestamp.strftime("%m-%d-%Y %H:%M:%S")
prefix = f"[{pretty_datetime}] "

self.memory.save_context(
{"input": f"{prefix}{message.id}. {message.sender_name}: {message.content}"},
{"output": self.no_output_message},
)

return message

Expand Down
Loading

0 comments on commit 10dc8d4

Please sign in to comment.