Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat optimization around tokens and keeping important messages in history #393

Merged
merged 15 commits into from
Jan 22, 2025
Merged
76 changes: 64 additions & 12 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
init_session_state,
sidebar_info,
)
from alphastats.llm.llm_integration import LLMIntegration, Models
from alphastats.llm.llm_integration import LLMIntegration, MessageKeys, Models, Roles
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
from alphastats.llm.prompts import get_initial_prompt, get_system_message
from alphastats.plots.plot_utils import PlotlyObject

Expand Down Expand Up @@ -82,6 +82,14 @@ def llm_config():
else:
st.error(f"Connection to {model_name} failed: {str(error)}")

st.number_input(
"Maximal number of tokens",
value=st.session_state[StateKeys.MAX_TOKENS],
min_value=1,
max_value=100000,
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
key=StateKeys.MAX_TOKENS,
)

if current_model != st.session_state[StateKeys.MODEL_NAME]:
st.rerun(scope="app")

Expand Down Expand Up @@ -216,6 +224,7 @@ def llm_config():
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
genes_of_interest=list(regulated_genes_dict.keys()),
max_tokens=st.session_state[StateKeys.MAX_TOKENS],
)

st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration
Expand All @@ -226,7 +235,7 @@ def llm_config():
)

with st.spinner("Processing initial prompt..."):
llm_integration.chat_completion(initial_prompt)
llm_integration.chat_completion(initial_prompt, pinned=True)
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved

st.rerun(scope="app")
except AuthenticationError:
Expand All @@ -237,7 +246,11 @@ def llm_config():


@st.fragment
def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
def llm_chat(
llm_integration: LLMIntegration,
show_all: bool = False,
show_inidvidual_tokens: bool = False,
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
):
"""The chat interface for the LLM analysis."""

# TODO dump to file -> static file name, plus button to do so
Expand All @@ -246,10 +259,28 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).

# no. tokens spent
total_tokens = 0
pinned_tokens = 0
for message in llm_integration.get_print_view(show_all=show_all):
with st.chat_message(message["role"]):
st.markdown(message["content"])
for artifact in message["artifacts"]:
with st.chat_message(message[MessageKeys.ROLE]):
st.markdown(message[MessageKeys.CONTENT])
tokens = llm_integration.estimate_tokens([message])
if message[MessageKeys.IN_CONTEXT]:
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if message[MessageKeys.PINNED] or show_inidvidual_tokens:
token_message = ""
if message[MessageKeys.PINNED]:
token_message += ":pushpin: "
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
if show_inidvidual_tokens:
token_message += f"*estimated tokens: {str(tokens)}*"
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
st.markdown(token_message)
if not message[MessageKeys.IN_CONTEXT]:
st.markdown(
"**This message is no longer in context due to token limitations.**"
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
)
for artifact in message[MessageKeys.ARTIFACTS]:
if isinstance(artifact, pd.DataFrame):
st.dataframe(artifact)
elif isinstance(
Expand All @@ -260,7 +291,16 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
st.warning("Don't know how to display artifact:")
st.write(artifact)

st.markdown(
f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*"
)

if prompt := st.chat_input("Say something"):
with st.chat_message(Roles.USER):
st.markdown(prompt)
st.markdown(
f"*estimated tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*"
)
with st.spinner("Processing prompt..."):
llm_integration.chat_completion(prompt)
st.rerun(scope="fragment")
Expand All @@ -273,10 +313,22 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
)


show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)
c1, c2 = st.columns((1, 2))
with c1:
show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)
with c2:
show_inidvidual_tokens = st.checkbox(
"Show individual token estimates",
key="show_individual_tokens",
help="Show individual token estimates for each message.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
llm_chat(
st.session_state[StateKeys.LLM_INTEGRATION][model_name],
show_all,
show_inidvidual_tokens,
)
4 changes: 4 additions & 0 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def init_session_state() -> None:
DefaultStates.SELECTED_UNIPROT_FIELDS.copy()
)

if StateKeys.MAX_TOKENS not in st.session_state:
st.session_state[StateKeys.MAX_TOKENS] = 10000
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved


class StateKeys(metaclass=ConstantsClass):
USER_SESSION_ID = "user_session_id"
Expand All @@ -152,5 +155,6 @@ class StateKeys(metaclass=ConstantsClass):
LLM_INTEGRATION = "llm_integration"
ANNOTATION_STORE = "annotation_store"
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"
MAX_TOKENS = "max_tokens"

ORGANISM = "organism" # TODO this is essentially a constant
Loading
Loading