-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathvisualizations_utils.py
383 lines (332 loc) · 13 KB
/
visualizations_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
from typing import Dict
import pandas as pd
import streamlit as st
from bertopic import BERTopic
from pandas import Timestamp
from plotly import graph_objects as go
from bertrend import OUTPUT_PATH, SIGNAL_EVOLUTION_DATA_DIR
from bertrend.demos.demos_utils.icons import WARNING_ICON, SUCCESS_ICON, INFO_ICON
from bertrend.demos.weak_signals.messages import HTML_GENERATION_FAILED_WARNING
from bertrend.demos.demos_utils.state_utils import SessionStateManager
from bertrend.config.parameters import (
MAX_WINDOW_SIZE,
DEFAULT_WINDOW_SIZE,
INDIVIDUAL_MODEL_TOPIC_COUNTS_FILE,
CUMULATIVE_MERGED_TOPIC_COUNTS_FILE,
)
from bertrend.trend_analysis.visualizations import (
create_sankey_diagram_plotly,
plot_newly_emerged_topics,
plot_topics_for_model,
compute_popularity_values_and_thresholds,
create_topic_size_evolution_figure,
plot_topic_size_evolution,
)
from bertrend.trend_analysis.weak_signals import (
classify_signals,
save_signal_evolution_data,
analyze_signal,
)
PLOTLY_BUTTON_SAVE_CONFIG = {
"toImageButtonOptions": {
"format": "svg",
# 'height': 500,
# 'width': 1500,
"scale": 1,
}
}
def display_sankey_diagram(all_merge_histories_df: pd.DataFrame) -> None:
"""
Create a Sankey diagram to visualize the topic merging process.
Args:
all_merge_histories_df (pd.DataFrame): The DataFrame containing all merge histories.
Returns:
go.Figure: The Plotly figure representing the Sankey diagram.
"""
with st.expander("Topic Merging Process", expanded=False):
# Create search box and slider using Streamlit
search_term = st.text_input("Search topics by keyword:")
max_pairs = st.slider(
"Max number of topic pairs to display",
min_value=1,
max_value=1000,
value=20,
)
# Create the Sankey diagram
sankey_diagram = create_sankey_diagram_plotly(
all_merge_histories_df, search_term, max_pairs
)
# Display the diagram using Streamlit in an expander
st.plotly_chart(
sankey_diagram, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True
)
def display_signal_categories_df(
noise_topics_df: pd.DataFrame,
weak_signal_topics_df: pd.DataFrame,
strong_signal_topics_df: pd.DataFrame,
window_end: Timestamp,
):
"""Display the dataframes associated to each signal category: noise, weak signal, strong signal."""
columns = [
"Topic",
"Sources",
"Source_Diversity",
"Representation",
"Latest_Popularity",
"Docs_Count",
"Paragraphs_Count",
"Latest_Timestamp",
"Documents",
]
st.subheader(":grey[Noise]")
if not noise_topics_df.empty:
st.dataframe(
noise_topics_df.astype(str)[columns].sort_values(
by=["Topic", "Latest_Popularity"], ascending=[False, False]
)
)
else:
st.info(
f"No noisy signals were detected at timestamp {window_end}.", icon=INFO_ICON
)
st.subheader(":orange[Weak Signals]")
if not weak_signal_topics_df.empty:
st.dataframe(
weak_signal_topics_df.astype(str)[columns].sort_values(
by=["Latest_Popularity"], ascending=True
)
)
else:
st.info(
f"No weak signals were detected at timestamp {window_end}.", icon=INFO_ICON
)
st.subheader(":green[Strong Signals]")
if not strong_signal_topics_df.empty:
st.dataframe(
strong_signal_topics_df.astype(str)[columns].sort_values(
by=["Topic", "Latest_Popularity"], ascending=[False, False]
)
)
else:
st.info(
f"No strong signals were detected at timestamp {window_end}.",
icon=INFO_ICON,
)
def display_popularity_evolution():
"""Display the popularity evolution diagram."""
window_size = st.number_input(
"Retrospective Period (days)",
min_value=1,
max_value=MAX_WINDOW_SIZE,
value=DEFAULT_WINDOW_SIZE,
key="window_size",
)
bertrend = SessionStateManager.get("bertrend")
all_merge_histories_df = bertrend.all_merge_histories_df
min_datetime = all_merge_histories_df["Timestamp"].min().to_pydatetime()
max_datetime = all_merge_histories_df["Timestamp"].max().to_pydatetime()
# Get granularity
granularity = st.session_state["granularity"]
# Slider to select the date
current_date = st.slider(
"Current date",
min_value=min_datetime,
max_value=max_datetime,
step=pd.Timedelta(days=granularity),
format="YYYY-MM-DD",
help="""The earliest selectable date corresponds to the earliest timestamp when topics were merged
(with the smallest possible value being the earliest timestamp in the provided data).
The latest selectable date corresponds to the most recent topic merges, which is at most equal
to the latest timestamp in the data minus the provided granularity.""",
key="current_date",
)
# Compute threshold values
window_start, window_end, all_popularity_values, q1, q3 = (
compute_popularity_values_and_thresholds(
bertrend.topic_sizes, window_size, granularity, current_date
)
)
# Classify signals
noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = classify_signals(
bertrend.topic_sizes, window_start, window_end, q1, q3
)
# Display threshold values for noise and strong signals
col1, col2 = st.columns(2)
with col1:
st.write(f"### Noise Threshold : {'{:.3f}'.format(q1)}")
with col2:
st.write(f"### Strong Signal Threshold : {'{:.3f}'.format(q3)}")
# Plot popularity evolution with thresholds
fig = plot_topic_size_evolution(
create_topic_size_evolution_figure(bertrend.topic_sizes),
current_date,
window_start,
window_end,
all_popularity_values,
q1,
q3,
)
st.plotly_chart(fig, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True)
# Display DataFrames for each category noise, weak signals, strong signals
display_signal_categories_df(
noise_topics_df, weak_signal_topics_df, strong_signal_topics_df, window_end
)
def save_signal_evolution():
"""Save Signal Evolution Data to investigate later on in a separate notebook"""
bertrend = SessionStateManager.get("bertrend")
granularity = SessionStateManager.get("granularity")
all_merge_histories_df = bertrend.all_merge_histories_df
min_datetime = all_merge_histories_df["Timestamp"].min().to_pydatetime()
max_datetime = all_merge_histories_df["Timestamp"].max().to_pydatetime()
# Save Signal Evolution Data to investigate later on in a separate notebook
start_date, end_date = st.select_slider(
"Select date range for saving signal evolution data:",
options=pd.date_range(
start=min_datetime,
end=max_datetime,
freq=pd.Timedelta(days=granularity),
),
value=(min_datetime, max_datetime),
format_func=lambda x: x.strftime("%Y-%m-%d"),
)
if st.button("Save Signal Evolution Data"):
try:
save_path = save_signal_evolution_data(
all_merge_histories_df=all_merge_histories_df,
topic_sizes=dict(bertrend.topic_sizes),
topic_last_popularity=bertrend.topic_last_popularity,
topic_last_update=bertrend.topic_last_update,
window_size=SessionStateManager.get("window_size"),
granularity=granularity,
start_timestamp=pd.Timestamp(start_date),
end_timestamp=pd.Timestamp(end_date),
)
st.success(
f"Signal evolution data saved successfully at {save_path}",
icon=SUCCESS_ICON,
)
except Exception as e:
st.error(f"Error encountered while saving signal evolution data: {e}")
def display_newly_emerged_topics(all_new_topics_df: pd.DataFrame) -> None:
"""
Display the newly emerged topics over time (dataframe and figure).
Args:
all_new_topics_df (pd.DataFrame): The DataFrame containing information about newly emerged topics.
"""
fig_new_topics = plot_newly_emerged_topics(all_new_topics_df)
with st.expander("Newly Emerged Topics", expanded=False):
st.dataframe(
all_new_topics_df[
[
"Topic",
"Count",
"Document_Count",
"Representation",
"Documents",
"Timestamp",
]
].sort_values(by=["Timestamp", "Document_Count"], ascending=[True, False])
)
st.plotly_chart(
fig_new_topics, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True
)
def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) -> None:
"""
Plot the topics discussed per source for each timestamp.
Args:
topic_models (Dict[pd.Timestamp, BERTopic]): A dictionary of BERTopic models, where the key is the timestamp
and the value is the corresponding model.
"""
with st.expander("Explore topic models"):
model_periods = sorted(topic_models.keys())
selected_model_period = st.select_slider(
"Select Model", options=model_periods, key="model_slider"
)
selected_model = topic_models[selected_model_period]
fig = plot_topics_for_model(selected_model)
st.plotly_chart(fig, config=PLOTLY_BUTTON_SAVE_CONFIG, use_container_width=True)
st.dataframe(
selected_model.doc_info_df[
["Paragraph", "document_id", "Topic", "Representation", "source"]
],
use_container_width=True,
)
st.dataframe(selected_model.topic_info_df, use_container_width=True)
def display_signal_analysis(
topic_number: int, output_file_name: str = "signal_llm.html"
):
"""Display a LLM-based analyis of a specific topic."""
language = SessionStateManager.get("language")
bertrend = SessionStateManager.get("bertrend")
granularity = SessionStateManager.get("granularity")
all_merge_histories_df = bertrend.all_merge_histories_df
st.subheader("Signal Interpretation")
with st.spinner("Analyzing signal..."):
summary, analysis, formatted_html = analyze_signal(
topic_number,
SessionStateManager.get("current_date"),
all_merge_histories_df,
granularity,
language,
)
# Check if the HTML file was created successfully
output_file_path = OUTPUT_PATH / output_file_name
if output_file_path.exists():
# Read the HTML file
with open(output_file_path, "r", encoding="utf-8") as file:
html_content = file.read()
# Display the HTML content
st.html(html_content)
else:
st.warning(HTML_GENERATION_FAILED_WARNING, icon=WARNING_ICON)
# Fallback to displaying markdown if HTML generation fails
col1, col2 = st.columns(spec=[0.5, 0.5], gap="medium")
with col1:
st.markdown(summary)
with col2:
st.markdown(analysis)
def retrieve_topic_counts(topic_models: Dict[pd.Timestamp, BERTopic]) -> None:
individual_model_topic_counts = [
(timestamp, model.topic_info_df["Topic"].max() + 1)
for timestamp, model in topic_models.items()
]
df_individual_models = pd.DataFrame(
individual_model_topic_counts,
columns=["timestamp", "num_topics"],
)
# Number of topics per cumulative merged model
cumulative_merged_topic_counts = SessionStateManager.get(
"merge_df_size_over_time", []
)
df_cumulative_merged = pd.DataFrame(
cumulative_merged_topic_counts,
columns=["timestamp", "num_topics"],
)
# Convert to JSON
json_individual_models = df_individual_models.to_json(
orient="records", date_format="iso", indent=4
)
json_cumulative_merged = df_cumulative_merged.to_json(
orient="records", date_format="iso", indent=4
)
# Save individual model topic counts
json_file_path = (
SIGNAL_EVOLUTION_DATA_DIR
/ f"retrospective_{SessionStateManager.get('window_size')}_days"
)
json_file_path.mkdir(parents=True, exist_ok=True)
(json_file_path / INDIVIDUAL_MODEL_TOPIC_COUNTS_FILE).write_text(
json_individual_models
)
# Save cumulative merged model topic counts
(json_file_path / CUMULATIVE_MERGED_TOPIC_COUNTS_FILE).write_text(
json_cumulative_merged
)
st.success(
f"Topic counts for individual and cumulative merged models saved to {json_file_path}",
icon=SUCCESS_ICON,
)