From ddaf9de169e629ab3c56a76b2228d7f67054ef04 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Sat, 9 Mar 2024 13:42:22 -0800 Subject: [PATCH] community: Fix bug with StreamlitChatMessageHistory (#18834) - **Description:** Fix Streamlit bug which was introduced by https://github.com/langchain-ai/langchain/pull/18250, update integration test - **Issue:** https://github.com/langchain-ai/langchain/issues/18684 - **Dependencies:** None --- .../chat_message_histories/streamlit.py | 6 +-- .../chat_message_histories/test_streamlit.py | 42 ++++++++++++------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/libs/community/langchain_community/chat_message_histories/streamlit.py b/libs/community/langchain_community/chat_message_histories/streamlit.py index 3952bc47ca383..c04e865bf9a99 100644 --- a/libs/community/langchain_community/chat_message_histories/streamlit.py +++ b/libs/community/langchain_community/chat_message_histories/streamlit.py @@ -22,14 +22,13 @@ def __init__(self, key: str = "langchain_messages"): if key not in st.session_state: st.session_state[key] = [] + self._messages = st.session_state[key] self._key = key @property def messages(self) -> List[BaseMessage]: """Retrieve the current list of messages""" - import streamlit as st - - return st.session_state[self._key] + return self._messages @messages.setter def messages(self, value: List[BaseMessage]) -> None: @@ -37,6 +36,7 @@ def messages(self, value: List[BaseMessage]) -> None: import streamlit as st st.session_state[self._key] = value + self._messages = st.session_state[self._key] def add_message(self, message: BaseMessage) -> None: """Add a message to the session memory""" diff --git a/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py b/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py index 16a304bf3818b..8b670feb6e3eb 100644 --- a/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py +++ b/libs/community/tests/integration_tests/chat_message_histories/test_streamlit.py @@ -6,7 +6,7 @@ import streamlit as st from langchain.memory import ConversationBufferMemory from langchain_community.chat_message_histories import StreamlitChatMessageHistory - from langchain_core.messages import message_to_dict + from langchain_core.messages import message_to_dict, BaseMessage message_history = StreamlitChatMessageHistory() memory = ConversationBufferMemory(chat_memory=message_history, return_messages=True) @@ -23,6 +23,15 @@ st.markdown("Cleared!") memory.chat_memory.clear() + # Use message setter + if st.checkbox("Override messages"): + memory.chat_memory.messages = [ + BaseMessage(content="A basic message", type="basic") + ] + st.session_state["langchain_messages"].append( + BaseMessage(content="extra cool message", type="basic") + ) + # Write the output to st.code as a json blob for inspection messages = memory.chat_memory.messages messages_json = json.dumps([message_to_dict(msg) for msg in messages]) @@ -33,32 +42,33 @@ @pytest.mark.requires("streamlit") def test_memory_with_message_store() -> None: try: - from streamlit.testing.script_interactions import InteractiveScriptTests + from streamlit.testing.v1 import AppTest except ModuleNotFoundError: pytest.skip("Incorrect version of Streamlit installed") - test_handler = InteractiveScriptTests() - test_handler.setUp() - try: - sr = test_handler.script_from_string(test_script).run() - except TypeError: - # Earlier version expected 2 arguments - sr = test_handler.script_from_string("memory_test.py", test_script).run() + at = AppTest.from_string(test_script).run(timeout=10) # Initial run should write two messages - messages_json = sr.get("text")[-1].value + messages_json = at.get("text")[-1].value assert "This is me, the AI" in messages_json assert "This is me, the human" in messages_json # Uncheck the initial write, they should persist in session_state - sr = sr.get("checkbox")[0].uncheck().run() - assert sr.get("markdown")[0].value == "Skipped add" - messages_json = sr.get("text")[-1].value + at.get("checkbox")[0].uncheck().run() + assert at.get("markdown")[0].value == "Skipped add" + messages_json = at.get("text")[-1].value assert "This is me, the AI" in messages_json assert "This is me, the human" in messages_json # Clear the message history - sr = sr.get("checkbox")[1].check().run() - assert sr.get("markdown")[1].value == "Cleared!" - messages_json = sr.get("text")[-1].value + at.get("checkbox")[1].check().run() + assert at.get("markdown")[1].value == "Cleared!" + messages_json = at.get("text")[-1].value assert messages_json == "[]" + + # Use message setter + at.get("checkbox")[1].uncheck() + at.get("checkbox")[2].check().run() + messages_json = at.get("text")[-1].value + assert "A basic message" in messages_json + assert "extra cool message" in messages_json