-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
188 lines (159 loc) · 7.37 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os
import re
import json
import torch
import ollama
import PyPDF2
import streamlit as st
from openai import OpenAI
from streamlit_chat import message
from io import StringIO
import subprocess
# Function to convert PDF to text and append to vault.txt
def convert_pdf_to_text(file):
if file:
pdf_reader = PyPDF2.PdfReader(file)
num_pages = len(pdf_reader.pages)
text = ''
for page_num in range(num_pages):
page = pdf_reader.pages[page_num]
if page.extract_text():
text += page.extract_text() + " "
# Normalize whitespace and clean up text
text = re.sub(r'\s+', ' ', text).strip()
# Split text into chunks by sentences, respecting a maximum chunk size
sentences = re.split(r'(?<=[.!?]) +', text) # split on spaces following sentence-ending punctuation
chunks = []
current_chunk = ""
for sentence in sentences:
# Check if the current sentence plus the current chunk exceeds the limit
if len(current_chunk) + len(sentence) + 1 < 1000: # +1 for the space
current_chunk += (sentence + " ").strip()
else:
# When the chunk exceeds 1000 characters, store it and start a new one
chunks.append(current_chunk)
current_chunk = sentence + " "
if current_chunk: # Don't forget the last chunk!
chunks.append(current_chunk)
with open("vault.txt", "a", encoding="utf-8") as vault_file:
for chunk in chunks:
# Write each chunk to its own line
vault_file.write(chunk.strip() + "\n")
return f"PDF content appended to vault.txt with each chunk on a separate line."
return "No file selected."
# Function to get relevant context from the vault based on user input
def get_relevant_context(rewritten_input, vault_embeddings, vault_content, top_k=3):
if vault_embeddings.nelement() == 0: # Check if the tensor has any elements
return []
# Encode the rewritten input
input_embedding = ollama.embeddings(model='mxbai-embed-large', prompt=rewritten_input)["embedding"]
# Compute cosine similarity between the input and vault embeddings
cos_scores = torch.cosine_similarity(torch.tensor(input_embedding).unsqueeze(0), vault_embeddings)
# Adjust top_k if it's greater than the number of available scores
top_k = min(top_k, len(cos_scores))
# Sort the scores and get the top-k indices
top_indices = torch.topk(cos_scores, k=top_k)[1].tolist()
# Get the corresponding context from the vault
relevant_context = [vault_content[idx].strip() for idx in top_indices]
return relevant_context
def rewrite_query(user_input_json, conversation_history, ollama_model):
user_input = json.loads(user_input_json)["Query"]
context = "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history[-2:]])
prompt = f"""Rewrite the following query by incorporating relevant context from the conversation history.
The rewritten query should:
- Preserve the core intent and meaning of the original query
- Expand and clarify the query to make it more specific and informative for retrieving relevant context
- Avoid introducing new topics or queries that deviate from the original query
- DONT EVER ANSWER the Original query, but instead focus on rephrasing and expanding it into a new query
Return ONLY the rewritten query text, without any additional formatting or explanations.
Conversation History:
{context}
Original query: [{user_input}]
Rewritten query:
"""
response = client.chat.completions.create(
model=ollama_model,
messages=[{"role": "system", "content": prompt}],
max_tokens=200,
n=1,
temperature=0.1,
)
rewritten_query = response.choices[0].message.content.strip()
return json.dumps({"Rewritten Query": rewritten_query})
def ollama_chat(user_input, system_message, vault_embeddings, vault_content, ollama_model, conversation_history):
conversation_history.append({"role": "user", "content": user_input})
if len(conversation_history) > 1:
query_json = {
"Query": user_input,
"Rewritten Query": ""
}
rewritten_query_json = rewrite_query(json.dumps(query_json), conversation_history, ollama_model)
rewritten_query_data = json.loads(rewritten_query_json)
rewritten_query = rewritten_query_data["Rewritten Query"]
else:
rewritten_query = user_input
relevant_context = get_relevant_context(rewritten_query, vault_embeddings, vault_content)
if relevant_context:
context_str = "\n".join(relevant_context)
else:
context_str = "No relevant context found."
user_input_with_context = user_input
if relevant_context:
user_input_with_context = user_input + "\n\nRelevant Context:\n" + context_str
conversation_history[-1]["content"] = user_input_with_context
messages = [
{"role": "system", "content": system_message},
*conversation_history
]
response = client.chat.completions.create(
model=ollama_model,
messages=messages,
max_tokens=2000,
)
conversation_history.append({"role": "assistant", "content": response.choices[0].message.content})
return response.choices[0].message.content
# Configuration for the Ollama API client
client = OpenAI(
base_url='http://localhost:11434/v1',
api_key='llama3'
)
# Load the vault content
vault_content = []
if os.path.exists("vault.txt"):
with open("vault.txt", "r", encoding='utf-8') as vault_file:
vault_content = vault_file.readlines()
# Generate embeddings for the vault content using Ollama
vault_embeddings = []
for content in vault_content:
response = ollama.embeddings(model='mxbai-embed-large', prompt=content)
vault_embeddings.append(response["embedding"])
# Convert to tensor
vault_embeddings_tensor = torch.tensor(vault_embeddings)
# Conversation history
conversation_history = []
system_message = "You are a helpful assistant that is an expert at extracting the most useful information from a given text. Also bring in extra relevant information to the user query from outside the given context."
# Streamlit app
st.title("Document Query System")
if st.button("Start Server"):
try:
subprocess.run(["ollama", "pull", "llama3"], check=True)
st.success("Server started successfully.")
except subprocess.CalledProcessError as e:
st.error(f"Error starting server: {e}")
uploaded_file = st.file_uploader("Upload PDF", type="pdf")
if uploaded_file is not None:
pdf_content = convert_pdf_to_text(uploaded_file)
st.info(pdf_content)
user_input = st.text_input("Enter your query:")
if st.button("Ask Question"):
if user_input.strip():
response = ollama_chat(user_input, system_message, vault_embeddings_tensor, vault_content, "llama3", conversation_history)
st.text_area("Response", response, height=300)
st.session_state.conversation_history = conversation_history
else:
st.warning("Please enter a query.")
# Display chat history
if 'conversation_history' in st.session_state:
conversation_history = st.session_state.conversation_history
for chat in conversation_history:
message(chat["content"], is_user=chat["role"] == "user")