From 3a16dde693764e71c216a45efc30870ad3b34c0b Mon Sep 17 00:00:00 2001 From: Vidminas Mikucionis <5411598+Vidminas@users.noreply.github.com> Date: Wed, 24 Jan 2024 19:53:01 +0000 Subject: [PATCH] Implement basic Solid message history --- chatdocs/memory/__init__.py | 0 chatdocs/memory/solid_message_history.py | 241 +++++++++++++++++++++++ chatdocs/ui.py | 44 +++-- setup.py | 3 + 4 files changed, 276 insertions(+), 12 deletions(-) create mode 100644 chatdocs/memory/__init__.py create mode 100644 chatdocs/memory/solid_message_history.py diff --git a/chatdocs/memory/__init__.py b/chatdocs/memory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chatdocs/memory/solid_message_history.py b/chatdocs/memory/solid_message_history.py new file mode 100644 index 0000000..8c5cd64 --- /dev/null +++ b/chatdocs/memory/solid_message_history.py @@ -0,0 +1,241 @@ +import requests + +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage + + +class CssAccount: + def __init__( + self, + css_base_url: str, + name: str, + email: str, + password: str, + web_id: str, + pod_base_url: str, + ) -> None: + self.css_base_url = css_base_url + self.name = name + self.email = email + self.password = password + self.web_id = web_id + self.pod_base_url = pod_base_url + + +class ClientCredentials: + def __init__(self, client_id: str, client_secret: str) -> None: + self.client_id = client_id + self.client_secret = client_secret + + +def create_css_account( + css_base_url: str, name: str, email: str, password: str +) -> CssAccount: + register_endpoint = f"{css_base_url}/idp/register/" + + res = requests.post( + register_endpoint, + json={ + "createWebId": "on", + "webId": "", + "register": "on", + "createPod": "on", + "podName": name, + "email": email, + "password": password, + "confirmPassword": password, + }, + timeout=5000, + ) + + if not res.ok: + raise Exception(f"Could not create account: {res.status_code} {res.text}") + + data = res.json() + account = CssAccount( + css_base_url=css_base_url, + name=name, + email=email, + password=password, + web_id=data["webId"], + pod_base_url=data["podBaseUrl"], + ) + return account + + +def get_client_credentials(account: CssAccount) -> ClientCredentials: + credentials_endpoint = f"{account.css_base_url}/idp/credentials/" + + res = requests.post( + credentials_endpoint, + json={ + "name": "chatdocs-client-credentials", + "email": account.email, + "password": account.password, + }, + timeout=5000, + ) + + if not res.ok: + raise Exception( + f"Could not create client credentials: {res.status_code} {res.text}" + ) + + data = res.json() + return ClientCredentials(client_id=data["id"], client_secret=data["secret"]) + +def get_item_name(url) -> str: + if url[-1] == '/': + url = url[:-1] + + if url.count('/') == 2: # is base url, no item name + return '' + + i = url.rindex('/') + return url[i + 1:] + +class SolidChatMessageHistory(BaseChatMessageHistory): + """ + Chat message history that stores messages in a Solid pod. + + Args: + solid_server_url: A Community Solid Server base url. + """ + + def __init__(self, solid_server_url, account): + try: + from solid_client_credentials import SolidClientCredentialsAuth, DpopTokenProvider + except ImportError as e: + raise ImportError( + "Unable to import solid_client_credentials, please run `pip install SolidClientCredentials`." + ) from e + try: + from rdflib import Graph + except ImportError as e: + raise ImportError( + "Unable to import rdflib, please run `pip install rdflib`." + ) from e + + self.account = account + client_credentials = get_client_credentials(account) + token_provider = DpopTokenProvider( + issuer_url=solid_server_url, + client_id=client_credentials.client_id, + client_secret=client_credentials.client_secret + ) + self.session = requests.Session() + self.session.auth = SolidClientCredentialsAuth(token_provider) + self.graph = Graph() + + def is_item_available(self, url) -> bool: + try: + res = self.session.head(url, allow_redirects=True) + return res.ok + except requests.exceptions.ConnectionError: + return False + + def create_item(self, url: str) -> bool: + res = self.session.put(url, + data=None, + headers={ + "Accept": 'text/turtle', + "If-None-Match": "*", + 'Link': '; rel="type"' + if url.endswith("/") else + '; rel="type"', + 'Slug': get_item_name(url), + 'Content-Type': 'text/turtle', + }) + return res.ok + + @property + def messages(self) -> list[BaseMessage]: + """Retrieve the current list of messages""" + if not self.is_item_available(f"{self.account.pod_base_url}private/"): + self.create_item(f"{self.account.pod_base_url}private/") + if not self.is_item_available(f"{self.account.pod_base_url}private/chatdocs.ttl"): + self.create_item(f"{self.account.pod_base_url}private/chatdocs.ttl") + + res = self.session.get(f"{self.account.pod_base_url}private/chatdocs.ttl") + if not res.ok: + print("getting messages failed", res.text) + msgs = [] + else: + from rdflib.namespace import PROF, RDF + from rdflib.collection import Collection + + self.graph.parse(data=res.text, publicID=f"{self.account.pod_base_url}private/chatdocs.ttl") + list_node = self.graph.value(predicate=RDF.type, object=RDF.List) + if list_node is None: + return [] + + rdf_list = Collection(self.graph, list_node) + msgs = [BaseMessage( + content=self.graph.value(subject=msg, predicate=PROF.hasResource).toPython(), + type=self.graph.value(subject=msg, predicate=PROF.hasRole).toPython() + ) for msg in rdf_list] + return msgs + + def add_message(self, message: BaseMessage) -> None: + """Add a message to the session memory""" + # https://solidproject.org/TR/protocol#n3-patch seems to be broken with Community Solid Server + # https://www.w3.org/TR/sparql11-update/ works + from rdflib import Graph + from rdflib.term import Node, BNode, URIRef, Literal + from rdflib.namespace import RDF, PROF, XSD + from rdflib.collection import Collection + + update_graph = Graph() + + msg = BNode() + update_graph.add((msg, RDF.type, PROF.ResourceDescriptor)) + update_graph.add((msg, PROF.hasResource, Literal(message.content, datatype=XSD.string))) + update_graph.add((msg, PROF.hasRole, Literal(message.type, datatype=XSD.string))) + + list_node = self.graph.value(predicate=RDF.type, object=RDF.List) + if list_node is None: + msgs_node = URIRef(f"{self.account.pod_base_url}private/chatdocs.ttl#messages") + update_graph.add((msgs_node, RDF.type, RDF.List)) + + msgs = Collection(update_graph, msgs_node) + msgs.append(msg) + + triples = "\n".join([ + f"{subject.n3()} {predicate.n3()} {object.n3()} ." + for subject, predicate, object in update_graph + ]) + sparql = f"INSERT DATA {{{triples}}}" + else: + new_item = BNode() + update_graph.add((new_item, RDF.first, msg)) + update_graph.add((new_item, RDF.rest, RDF.nil)) + + triples = "\n".join([ + f"{subject.n3()} {predicate.n3()} {object.n3()} ." + for subject, predicate, object in update_graph + ]) + sparql = f""" + PREFIX rdf: + DELETE {{ ?end rdf:rest rdf:nil }} + INSERT {{ ?end rdf:rest {new_item.n3()} .\n + {triples} }} + WHERE {{ ?end rdf:rest rdf:nil }} + """ + + # Update remote copy + self.session.patch( + url=f"{self.account.pod_base_url}private/chatdocs.ttl", + data=sparql.encode("utf-8"), + headers={ + "Content-Type": "application/sparql-update", + } + ) + # Update local copy + self.graph.update(sparql) + + def clear(self) -> None: + """Clear session memory""" + from rdflib import Graph + + self.session.delete(f"{self.account.pod_base_url}private/chatdocs.ttl") + self.graph = Graph() diff --git a/chatdocs/ui.py b/chatdocs/ui.py index 39e5222..82143d5 100644 --- a/chatdocs/ui.py +++ b/chatdocs/ui.py @@ -1,8 +1,9 @@ import argparse from typing import Any -from uuid import UUID +from uuid import UUID, uuid4 import langchain # unused but needed to avoid circular import errors +from langchain_core.chat_history import BaseChatMessageHistory from langchain.memory.chat_message_histories import StreamlitChatMessageHistory from langchain.callbacks import StreamingStdOutCallbackHandler from langchain.callbacks.base import BaseCallbackHandler @@ -21,6 +22,7 @@ from .chains import make_conversation_chain from .st_utils import load_config +from memory.solid_message_history import SolidChatMessageHistory, CssAccount, create_css_account class StreamHandler(BaseCallbackHandler): @@ -81,19 +83,19 @@ def on_retriever_end(self, documents, **kwargs): self.status.update(state="complete") -def init_messages(msgs: StreamlitChatMessageHistory) -> None: +def init_messages(history: BaseChatMessageHistory) -> None: clear_button = st.sidebar.button("Clear Conversation", key="clear") - if clear_button or len(msgs.messages) == 0: - msgs.clear() + if clear_button or len(history.messages) == 0: + history.clear() -def print_state_messages(msgs: StreamlitChatMessageHistory): +def print_state_messages(history: BaseChatMessageHistory): roles = { "human": "user", "ai": "assistant", } - for message in msgs.messages: + for message in history.messages: with st.chat_message(roles[message.type]): st.markdown(message.content) @@ -103,6 +105,17 @@ def load_llm(config, selected_llm): return make_conversation_chain(config, selected_llm_index=selected_llm) +@st.cache_data +def create_random_solid_account(css_base_url: str) -> CssAccount: + name = f"test-{uuid4()}" + email = f"{name}@example.org" + password = "12345" + + return create_css_account( + css_base_url=css_base_url, name=name, email=email, password=password + ) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -119,9 +132,16 @@ def main(): st.title("ChatDocs") st.sidebar.title("Options") - msgs = StreamlitChatMessageHistory(key="messages") - init_messages(msgs) - print_state_messages(msgs) + solid_server_url = st.sidebar.text_input("Solid Server URL", "https://localhost:1234/") + css_account = create_random_solid_account(solid_server_url) + + history_type = st.sidebar.radio("Message history", ("Local", "Solid")) + if history_type == "Local": + history = StreamlitChatMessageHistory(key="messages") + else: + history = SolidChatMessageHistory(solid_server_url, css_account) + init_messages(history) + print_state_messages(history) config = load_config() selected_llm = st.sidebar.radio("LLM", range(len(config["llms"])), format_func=lambda idx: config["llms"][idx]["model"]) @@ -130,16 +150,16 @@ def main(): if prompt := st.chat_input("Enter a query"): with st.chat_message("user"): st.markdown(prompt) - msgs.add_user_message(prompt) + history.add_user_message(prompt) retrieve_callback = PrintRetrievalHandler(st.container()) print_callback = StreamHandler(st.empty()) stdout_callback = StreamingStdOutCallbackHandler() response = llm( - { "question": prompt, "chat_history": msgs.messages }, + { "question": prompt, "chat_history": history.messages }, callbacks=[retrieve_callback, print_callback, stdout_callback], ) - msgs.add_ai_message(response["answer"]) + history.add_ai_message(response["answer"]) if __name__ == "__main__": diff --git a/setup.py b/setup.py index b6cc67f..5f73ade 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,9 @@ "transformers>=4.35.0", "typer>=0.9.0", "typing-extensions>=4.4.0,<5.0.0", + # Solid + "SolidClientCredentials>=1.0.0", + "rdflib>=7.0.0", # UI "streamlit>=1.29.0", "plotly>=5.17.0",