-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
334 lines (282 loc) · 10.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""
Main application
"""
from typing import List
import streamlit as st
from scraper.app import execute_scraping
from scraper.config import tasks
from scraper.logging import setup_logging
def main():
"""Main function to set up the Streamlit interface and session state."""
st.set_page_config(layout="wide")
# Load CSS styles
load_css()
# Initialize session state variables
initialize_session_state()
# Set up three columns, using st.columns
left_column, _, right_column = st.columns([1, 0.1, 2.6])
with left_column:
display_title()
display_scraping_ui()
with st.expander("Configuration", expanded=False):
display_config_ui()
st.markdown(" ")
# Add styled GitHub link at the bottom
st.markdown(
"<a class='github-link' href='https://github.com/arkeodev/scraper' target='_blank'>"
"<span class='github-icon'></span>View on GitHub</a>",
unsafe_allow_html=True,
)
with right_column:
display_qa_ui()
setup_logging()
def load_css():
"""Load CSS styles from a css file."""
with open(".css/app_styles.css", "r") as f:
css = f.read()
st.markdown(f"<style>{css}</style>", unsafe_allow_html=True)
def display_title() -> None:
"""Display title and image."""
# Use a container to wrap the image for specific styling
st.markdown("<h1>m o l e</h1>", unsafe_allow_html=True)
st.image("images/mole.png")
st.markdown("<h2>AI powered web scraping</h2>", unsafe_allow_html=True)
def display_scraping_ui() -> None:
"""Display scraping interface."""
# Create a dictionary to map source definitions to task instances
source_options = {task.source_def: task for task in tasks}
source_selection = st.selectbox(
"Select Source",
options=list(source_options.keys()),
placeholder="Select source...",
index=0,
key="source_key",
disabled=st.session_state.scraping_done,
)
selected_source = source_options[source_selection]
st.session_state.selected_source = selected_source
# Execute the corresponding function based on task selection
if selected_source.is_url:
display_url_input()
else:
display_file_uploader(selected_source.allowed_extensions)
# Display task selection and update session state with the selected task.
selected_task = st.selectbox(
"Select Task",
options=selected_source.task_def,
placeholder="Select task...",
index=0,
key="task_key",
disabled=st.session_state.scraping_done,
)
st.session_state.selected_task_index = selected_source.task_def.index(selected_task)
# Display error messages if any
if st.session_state.error_mes:
st.error(f"{st.session_state.error_mes}")
# Layout for start and refresh buttons
start_col, refresh_col = st.columns([1, 1], gap="small")
with start_col:
st.button(
"Start",
on_click=lambda: execute_scraping(st.session_state),
key="start_button",
disabled=st.session_state.scraping_done,
)
with refresh_col:
st.button("Refresh", key="refresh_button", on_click=trigger_refresh)
def display_file_uploader(allowed_extensions: List[str]):
"""Display file uploader for parsing files."""
st.session_state.source = st.file_uploader(
"Choose a file to parse",
type=allowed_extensions,
disabled=st.session_state.scraping_done,
)
def display_url_input():
"""Display URL input field."""
st.session_state.source = st.text_input(
"Enter the URL of the website to scrape:",
key="url_input",
placeholder="http://example.com",
disabled=st.session_state.scraping_done,
)
def display_config_ui() -> None:
"""Display configuration options for the scraper."""
st.session_state.model_company = st.selectbox(
"Select the Model Company:",
options=("OpenAI", "Hugging Face"),
placeholder="Select model company...",
index=0,
key="model_company_key",
disabled=st.session_state.scraping_done,
)
# Load dynamic configuration options based on the selected company
load_model_specific_ui(st.session_state.model_company)
def load_model_specific_ui(company_name: str):
"""Load UI components based on model company."""
if company_name == "OpenAI":
st.session_state.model_name = st.selectbox(
"Select the Model:",
options=("gpt-4o-mini", "gpt-4o", "gpt-4", "gpt-4-turbo"),
placeholder="Select model...",
index=0,
key="model_name_key",
disabled=st.session_state.scraping_done,
)
st.session_state.api_key = st.text_input(
"OpenAI API Key",
type="password",
key="chatbot_api_key",
disabled=st.session_state.scraping_done,
)
elif company_name == "Hugging Face":
st.session_state.model_name = st.selectbox(
"Select the Model:",
options=(
"mistralai/Mistral-7B-Instruct-v0.3",
"meta-llama/Meta-Llama-3-8B-Instruct",
),
placeholder="Select model...",
index=0,
key="model_name_key",
disabled=st.session_state.scraping_done,
)
st.session_state.api_key = st.text_input(
"Hugging Face API Key",
type="password",
key="chatbot_api_key",
disabled=st.session_state.scraping_done,
)
st.session_state.temperature = st.slider(
"Temperature",
0.0,
1.0,
0.7,
key="temperature_key",
disabled=st.session_state.scraping_done,
)
st.session_state.max_tokens = st.number_input(
"Max Output Tokens",
min_value=1,
value=1000,
key="max_tokens_key",
disabled=st.session_state.scraping_done,
)
def display_qa_ui() -> None:
"""Displays the QA interface for user interaction."""
if st.session_state.scraping_done:
if st.session_state.selected_task_index == 0:
user_input = st.chat_input(
"Please ask your questions", key="question_input"
)
if user_input:
handle_submit(user_input)
elif st.session_state.selected_task_index == 1:
handle_summary_submit()
elif st.session_state.selected_task_index == 2:
handle_keypoints_submit()
def handle_submit(user_input: str):
"""Handle the submission of the chat input."""
with st.spinner("Fetching answer..."):
graph = st.session_state.graph
answer = graph.execute(user_input)
if not answer:
answer = "I'm sorry, I don't answer this question."
st.session_state.chat_history.append(("assistant", answer))
st.session_state.chat_history.append(("user", user_input))
# Reverse the list to display the latest message first
reversed_chat_history = reversed(st.session_state.chat_history)
for role, content in reversed_chat_history:
st.markdown(
f"<div class='chat-message-{role}'>{content}</div>",
unsafe_allow_html=True,
)
st.markdown(" ")
def handle_summary_submit():
"""Handle the submission for document summarization."""
with st.spinner("Fetching summary..."):
graph = st.session_state.graph
summary = graph.execute()
if not summary:
summary = "I'm sorry, I couldn't generate a summary for this document."
st.session_state.summary_result = summary
st.markdown("<h2>Summary Result</h2>", unsafe_allow_html=True)
st.text_area(
"",
value=st.session_state.summary_result,
height=500,
key="summary_result",
disabled=True,
)
def handle_keypoints_submit():
"""Handle the submission for extracting key points."""
with st.spinner("Fetching key points..."):
graph = st.session_state.graph
key_points = graph.execute()
if not key_points:
key_points = "I'm sorry, I couldn't extract key points from this document."
st.session_state.key_points_result = key_points
st.markdown("<h2>Key Points Result</h2>", unsafe_allow_html=True)
st.text_area(
"",
value=st.session_state.key_points_result,
height=500,
key="key_points_result",
disabled=True,
)
def trigger_refresh() -> None:
"""Triggers a refresh by setting the flag."""
st.session_state.refresh_triggered = True
st.rerun()
def initialize_session_state() -> None:
"""Initialize session state variables if not already set."""
session_defaults = {
"url": "",
"model_company_key": "OpenAI",
"model_name_key": "gpt-4o-mini",
"chatbot_api_key": "",
"source_key": "URL",
"task_key": "Chat",
"temperature_key": 0.7,
"max_tokens_key": 1000,
"status": [],
"graph": None,
"chat_history": [],
"scraping_done": False,
"question_input": "",
"summary_result": "",
"key_points_result": "",
"refresh_triggered": False,
"selected_task_index": 0,
"error_mes": "",
}
for key, value in session_defaults.items():
if key not in st.session_state:
st.session_state[key] = value
def clear_state() -> None:
"""Clears the session state."""
for key in list(st.session_state.keys()):
del st.session_state[key]
st.session_state.status = []
st.session_state.url_input = ""
st.session_state.question_input = ""
st.session_state.chat_history = []
st.session_state.model_company_key = "OpenAI"
st.session_state.model_name_key = "gpt-4o-mini"
st.session_state.chatbot_api_key = ""
st.session_state.source_key = "URL"
st.session_state.task_key = "Chat"
st.session_state.temperature_key = 0.7
st.session_state.max_tokens_key = 1000
st.session_state.selected_task_index = 0
st.session_state.key_points_result = ""
st.session_state.summary_result = ""
st.session_state.scraping_done = False
st.session_state.qa = (None,)
st.cache_data.clear()
st.session_state.refresh_triggered = False
st.session_state.error_mes = ""
st.rerun()
if __name__ == "__main__":
if st.session_state.get("refresh_triggered"):
clear_state()
main()