Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat optimization around tokens and keeping important messages in history #393

Open
wants to merge 13 commits into
base: development
Choose a base branch
from
83 changes: 71 additions & 12 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
init_session_state,
sidebar_info,
)
from alphastats.llm.llm_integration import LLMIntegration, Models
from alphastats.llm.llm_integration import LLMIntegration, MessageKeys, Models, Roles
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
from alphastats.llm.prompts import get_initial_prompt, get_system_message
from alphastats.plots.plot_utils import PlotlyObject

Expand Down Expand Up @@ -82,6 +82,14 @@ def llm_config():
else:
st.error(f"Connection to {model_name} failed: {str(error)}")

st.number_input(
"Maximal number of tokens",
value=st.session_state[StateKeys.MAX_TOKENS],
min_value=2000,
max_value=128000, # TODO: set this automatically based on the selected model
key=StateKeys.MAX_TOKENS,
)

if current_model != st.session_state[StateKeys.MODEL_NAME]:
st.rerun(scope="app")

Expand Down Expand Up @@ -216,6 +224,7 @@ def llm_config():
base_url=OLLAMA_BASE_URL,
dataset=st.session_state[StateKeys.DATASET],
genes_of_interest=list(regulated_genes_dict.keys()),
max_tokens=st.session_state[StateKeys.MAX_TOKENS],
)

st.session_state[StateKeys.LLM_INTEGRATION][model_name] = llm_integration
Expand All @@ -226,7 +235,7 @@ def llm_config():
)

with st.spinner("Processing initial prompt..."):
llm_integration.chat_completion(initial_prompt)
llm_integration.chat_completion(initial_prompt, pin_message=True)

st.rerun(scope="app")
except AuthenticationError:
Expand All @@ -237,7 +246,11 @@ def llm_config():


@st.fragment
def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
def llm_chat(
llm_integration: LLMIntegration,
show_all: bool = False,
show_individual_tokens: bool = False,
):
"""The chat interface for the LLM analysis."""

# TODO dump to file -> static file name, plus button to do so
Expand All @@ -246,10 +259,30 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
# Alternatively write it all in one pdf report using e.g. pdfrw and reportlab (I have code for that combo).

# no. tokens spent
total_tokens = 0
pinned_tokens = 0
for message in llm_integration.get_print_view(show_all=show_all):
with st.chat_message(message["role"]):
st.markdown(message["content"])
for artifact in message["artifacts"]:
with st.chat_message(message[MessageKeys.ROLE]):
st.markdown(message[MessageKeys.CONTENT])
tokens = llm_integration.estimate_tokens([message])
if message[MessageKeys.IN_CONTEXT]:
total_tokens += tokens
if message[MessageKeys.PINNED]:
pinned_tokens += tokens
if (
message[MessageKeys.PINNED]
or not message[MessageKeys.IN_CONTEXT]
or show_individual_tokens
):
token_message = ""
if message[MessageKeys.PINNED]:
token_message += ":pushpin: "
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
if not message[MessageKeys.IN_CONTEXT]:
token_message += ":x: "
if show_individual_tokens:
token_message += f"*tokens: {str(tokens)}*"
st.markdown(token_message)
for artifact in message[MessageKeys.ARTIFACTS]:
if isinstance(artifact, pd.DataFrame):
st.dataframe(artifact)
elif isinstance(
Expand All @@ -260,7 +293,17 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
st.warning("Don't know how to display artifact:")
st.write(artifact)

st.markdown(
f"*total tokens used: {str(total_tokens)}, tokens used for pinned messages: {str(pinned_tokens)}*"
)

if prompt := st.chat_input("Say something"):
with st.chat_message(Roles.USER):
st.markdown(prompt)
if show_individual_tokens:
st.markdown(
f"*tokens: {str(llm_integration.estimate_tokens([{MessageKeys.CONTENT:prompt}]))}*"
)
with st.spinner("Processing prompt..."):
llm_integration.chat_completion(prompt)
st.rerun(scope="fragment")
Expand All @@ -272,11 +315,27 @@ def llm_chat(llm_integration: LLMIntegration, show_all: bool = False):
"text/plain",
)

st.markdown(
"*icons: :pushpin: pinned message, :x: message no longer in context due to token limitations*"
)

show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)

llm_chat(st.session_state[StateKeys.LLM_INTEGRATION][model_name], show_all)
c1, c2 = st.columns((1, 2))
with c1:
show_all = st.checkbox(
"Show system messages",
key="show_system_messages",
help="Show all messages in the chat interface.",
)
with c2:
show_inidvidual_tokens = st.checkbox(
"Show individual token estimates",
key="show_individual_tokens",
help="Show individual token estimates for each message.",
)

llm_chat(
st.session_state[StateKeys.LLM_INTEGRATION][model_name],
show_all,
show_inidvidual_tokens,
)
4 changes: 4 additions & 0 deletions alphastats/gui/utils/ui_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def init_session_state() -> None:
DefaultStates.SELECTED_UNIPROT_FIELDS.copy()
)

if StateKeys.MAX_TOKENS not in st.session_state:
st.session_state[StateKeys.MAX_TOKENS] = 10000
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved


class StateKeys(metaclass=ConstantsClass):
USER_SESSION_ID = "user_session_id"
Expand All @@ -152,5 +155,6 @@ class StateKeys(metaclass=ConstantsClass):
LLM_INTEGRATION = "llm_integration"
ANNOTATION_STORE = "annotation_store"
SELECTED_UNIPROT_FIELDS = "selected_uniprot_fields"
MAX_TOKENS = "max_tokens"

ORGANISM = "organism" # TODO this is essentially a constant
Loading
Loading