-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcomponent_5_topic_normalization.py
210 lines (184 loc) · 7.85 KB
/
component_5_topic_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import spacy
import pandas as pd
import torch
import json
import os
from collections import Counter
from spacy import displacy
from pathlib import Path
from tqdm import tqdm
import logging
import numpy as np
import math
import re
import openai
from openai import OpenAI
import asyncio
from sklearn.metrics import classification_report
from pipeline.pipeline_component import PipelineComponent
from utils.constants import CULTURAL_TOPICS
from utils.prompt_utils import TOPIC_SYSTEM_MESSAGE, TOPIC_USER_MESSAGE_TEMPLATE
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sentence_transformers import SentenceTransformer
import nltk
from nltk.stem import WordNetLemmatizer
from utils.clustering import (
hac_clustering,
hac_clustering_retain_index,
secondary_clustering,
)
openai.api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI()
nltk.download("wordnet")
logger = logging.getLogger(__name__)
class TopicNormalizer(PipelineComponent):
description = "normalizing the topics and cultural groups"
config_layer = "5_topic_normalizer"
def __init__(self, config: dict):
super().__init__(config)
# get local config
self._local_config = config[self.config_layer]
self._override_config()
self._condition = f"group={self._local_config['cultural_group_threshold']}"
self._create_new_output_dir()
if "output_file" in self._local_config:
self.check_if_output_exists(self._local_config["output_file"])
self.scores = {"cluster_silhouette_score": []}
# setup models
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logger.info(f"using {self.device}")
self.sbert = SentenceTransformer(self._local_config["sbert"]["model"]).to(
self.device
)
def _override_config(self):
for key in ["cultural_group_threshold"]:
if key in self._config and self._config[key] is not None:
self._local_config[key] = self._config[key]
def _create_new_output_dir(self):
import pathlib
new_output_dir = "/".join(
self._local_config["output_file"].split("/")[:-1] + [self._condition]
)
pathlib.Path(new_output_dir).mkdir(parents=True, exist_ok=True)
for key in ["output_file", "output_score_file"]:
self._local_config[key] = "/".join(
self._local_config[key].split("/")[:-1]
+ [self._condition]
+ self._local_config[key].split("/")[-1:],
)
def read_input(self):
df = pd.read_csv(self._local_config["input_file"])
if self._config["dry_run"] is not None:
df = df.head(self._config["dry_run"])
return df
def run(self):
df = self.read_input()
logger.info(f"total number of samples: {len(df)}")
group_clusters = self.cultural_group_normalization(df)
df = self.select_representative_summarization(
df, "cultural group", group_clusters
)
df = self.topic_normalization(df)
self.save_output(df)
logger.info("Normalization Done!")
def save_output(self, df):
logger.info(f"save to {self._local_config['output_file']}")
df.to_csv(
self._local_config["output_file"],
index=False,
)
with open(self._local_config["output_score_file"], "w") as fh:
json.dump(self.scores, fh)
def cultural_group_normalization(self, df):
sents = [f"{df.iloc[idx]['cultural group']}" for idx, row in df.iterrows()]
logger.info(f"this many culture groups: {len(sents)}")
embeddings = self.sbert.encode(sents, show_progress_bar=True)
raw_clusters, score = hac_clustering_retain_index(
sents, embeddings, self._local_config["cultural_group_threshold"]
)
logger.info(f"there are a total of {len(raw_clusters)} cultural groups")
logger.info(
f"the size of the largest cultural group is: {max([len(cluster) for cluster in raw_clusters])}"
)
logger.info(
f"the silhouette_score for the cultural group clustering is {score}"
)
self.scores["cultural_group_silhouette_score"] = score
return raw_clusters
def topic_normalization(self, df):
df["representative_topic"] = ""
model = self._local_config["openai"]["model"]
temperature = self._local_config["openai"]["temperature"]
max_tokens = self._local_config["openai"]["max_tokens"]
top_p = self._local_config["openai"]["top_p"]
seed = self._local_config["openai"]["seed"]
for idx, _ in tqdm(df.iterrows(), total=len(df)):
for _ in range(10):
try:
df_line = df.iloc[idx]
system_message = TOPIC_SYSTEM_MESSAGE.format(CULTURAL_TOPICS)
user_message = TOPIC_USER_MESSAGE_TEMPLATE.format(df_line["topic"])
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
response = client.chat.completions.create(
model=model,
messages=messages,
seed=seed,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
response_content = response.choices[0].message.content
prompt_tokens = response.usage.prompt_tokens
summarized_topic = response_content.strip()
summarized_topic = re.sub(r'[\'"]', "", summarized_topic)
summarized_topic = re.sub(r"\.$", "", summarized_topic)
if summarized_topic not in CULTURAL_TOPICS:
print(
f"row {idx}: the summarized topic {summarized_topic} does not fit into any of the predefined themes, retrying..."
)
continue
df.at[idx, "representative_topic"] = summarized_topic
break
except Exception as e:
print(f"encountered error at row {idx}: {e}")
print("retrying...")
continue
return df
@staticmethod
def select_representative_summarization(
df, cluster_target, raw_clusters, strategy="majority"
):
final_values = [None] * df.shape[0]
final_values_count = [None] * df.shape[0]
final_cluster_id = [None] * df.shape[0]
for i, cluster in enumerate(tqdm(raw_clusters)):
actual_values = []
for idx, _ in cluster:
row = df.iloc[idx]
if cluster_target == "topic":
actual_values.append(row["representative_topic"])
else:
actual_values.append(row[cluster_target])
if strategy == "majority":
# Count the occurrences of each element
vote_counts = Counter(actual_values)
# Find the majority vote
majority_vote, majority_count = vote_counts.most_common(1)[0]
rep_topic = majority_vote
else:
raise NotImplementedError
for idx, _ in cluster:
final_values[idx] = rep_topic
if strategy == "majority":
final_values_count[idx] = majority_count
final_cluster_id[idx] = i
df[f"representative_{cluster_target}"] = final_values
if strategy == "majority":
df[f"representative_{cluster_target}_count"] = final_values_count
df[f"representative_{cluster_target}_cluster_id"] = final_cluster_id
return df