Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaS92 committed Jan 15, 2025
1 parent c84d8ad commit 77d6e4a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
8 changes: 4 additions & 4 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def llm_config():
)

with st.spinner("Processing initial prompt..."):
llm_integration.chat_completion(initial_prompt, pinned=True)
llm_integration.chat_completion(initial_prompt, pin_message=True)

st.rerun(scope="app")
except AuthenticationError:
Expand All @@ -249,7 +249,7 @@ def llm_config():
def llm_chat(
llm_integration: LLMIntegration,
show_all: bool = False,
show_inidvidual_tokens: bool = False,
show_individual_tokens: bool = False,
):
"""The chat interface for the LLM analysis."""

Expand All @@ -269,11 +269,11 @@ def llm_chat(
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if message[MessageKeys.PINNED] or show_inidvidual_tokens:
if message[MessageKeys.PINNED] or show_individual_tokens:
token_message = ""
if message[MessageKeys.PINNED]:
token_message += ":pushpin: "
if show_inidvidual_tokens:
if show_individual_tokens:
token_message += f"*estimated tokens: {str(tokens)}*"
st.markdown(token_message)
if not message[MessageKeys.IN_CONTEXT]:
Expand Down
33 changes: 19 additions & 14 deletions alphastats/llm/llm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def estimate_tokens(
]
)
except KeyError:
# if the model is not in the tiktoken library a key error is raised by encoding_for_model, we use a rough estimate instead
total_tokens = sum(
[
len(message[MessageKeys.CONTENT]) / average_chars_per_token
Expand All @@ -224,32 +225,36 @@ def _truncate_conversation_history(
# TODO: avoid important messages being removed (e.g. facts about genes)
# TODO: find out how messages can be None type and handle them earlier
total_tokens = self.estimate_tokens(self._messages, average_chars_per_token)
while total_tokens > self._max_tokens and len(self._messages) > 1:
oldest_unpinned = -1
while total_tokens > self._max_tokens:
if len(self._messages) == 1:
raise ValueError(
"Truncating conversation history failed, as the only remaining message exceeds the token limit. Please increase the token limit and reset the LLM analysis."
)
oldest_not_pinned = -1
for message_idx, message in enumerate(self._messages):
if not message[MessageKeys.PINNED]:
oldest_unpinned = message_idx
oldest_not_pinned = message_idx
break
if oldest_unpinned == -1:
if oldest_not_pinned == -1:
raise ValueError(
"Truncating conversation history failed, as all remaining messages are pinned. Please increase the token limit and reset the LLM analysis, or unpin messages."
)
removed_message = self._messages.pop(oldest_unpinned)
removed_message = self._messages.pop(oldest_not_pinned)
warnings.warn(
f"Truncating conversation history to stay within token limits.\nRemoved message:{removed_message[MessageKeys.ROLE]}: {removed_message[MessageKeys.CONTENT][0:min(30, len(removed_message[MessageKeys.CONTENT]))]}..."
f"Truncating conversation history to stay within token limits.\nRemoved message:{removed_message[MessageKeys.ROLE]}: {removed_message[MessageKeys.CONTENT][0:30]}..."
)
while (
removed_message[MessageKeys.ROLE] == Roles.ASSISTANT
and self._messages[oldest_unpinned][MessageKeys.ROLE] == Roles.TOOL
and self._messages[oldest_not_pinned][MessageKeys.ROLE] == Roles.TOOL
):
# This is required as the chat completion fails if there are tool outputs without corresponding tool calls in the message history.
removed_toolmessage = self._messages.pop(oldest_unpinned)
removed_toolmessage = self._messages.pop(oldest_not_pinned)
warnings.warn(
f"Removing corresponsing tool output as well.\nRemoved message:{removed_toolmessage[MessageKeys.ROLE]}: {removed_toolmessage[MessageKeys.CONTENT][0:min(30, len(removed_toolmessage[MessageKeys.CONTENT]))]}..."
f"Removing corresponsing tool output as well.\nRemoved message:{removed_toolmessage[MessageKeys.ROLE]}: {removed_toolmessage[MessageKeys.CONTENT][0:30]}..."
)
if len(self._messages) == 0:
raise ValueError(
"Truncating conversation history failed, as the most recent artifacts exceeded the token limit. Please increase the token limit and reset the LLM analysis."
"Truncating conversation history failed, as the artifact from the last call exceeds the token limit. Please increase the token limit and reset the LLM analysis."
)
total_tokens = self.estimate_tokens(self._messages, average_chars_per_token)

Expand Down Expand Up @@ -407,7 +412,7 @@ def get_chat_log_txt(self) -> str:
return chatlog

def chat_completion(
self, prompt: str, role: str = Roles.USER, *, pinned=False
self, prompt: str, role: str = Roles.USER, *, pin_message=False
) -> None:
"""
Generate a chat completion based on the given prompt and manage any resulting artifacts.
Expand All @@ -418,15 +423,15 @@ def chat_completion(
The user's input prompt
role : str, optional
The role of the message sender, by default "user"
pinned : bool, optional
pin_message : bool, optional
Whether the prompt and assistant reply should be pinned, by default False
Returns
-------
Tuple[str, Dict[str, Any]]
A tuple containing the generated response and a dictionary of new artifacts
"""
self._append_message(role, prompt, pinned=pinned)
self._append_message(role, prompt, pinned=pin_message)

try:
response = self._chat_completion_create()
Expand All @@ -441,7 +446,7 @@ def chat_completion(

content, _ = self._handle_function_calls(tool_calls)

self._append_message(Roles.ASSISTANT, content, pinned=pinned)
self._append_message(Roles.ASSISTANT, content, pinned=pin_message)

except ArithmeticError as e:
error_message = f"Error in chat completion: {str(e)}"
Expand Down

0 comments on commit 77d6e4a

Please sign in to comment.