-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfew_shot.py
42 lines (36 loc) · 1.28 KB
/
few_shot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import pandas as pd
import json
import sys
class FewShotPosts:
def __init__(self,file_path="data/processed_posts.json"):
self.df=None
self.unique_tags=None
self.load_posts(file_path)
def load_posts(self,file_path):
with open(file_path,encoding='utf-8') as f:
posts = json.load(f)
self.df = pd.json_normalize(posts)
self.df['length'] = self.df['line_count'].apply(self.categorize_length)
all_tags = self.df['tags'].apply(lambda x:x).sum()
self.unique_tags = list(set(all_tags))
def categorize_length(self,line_count):
if line_count<5:
return "Short"
elif 5<=line_count<=10:
return "Medium"
else:
return "Long"
def get_tags(self):
return self.unique_tags
def get_filtered_posts(self,length,language,tag):
df_filtered = self.df[
(self.df['tags'].apply(lambda tags: tag in tags)) &
(self.df['language'] == language) &
(self.df['length'] == length)
]
return df_filtered.to_dict(orient='records')
if __name__=="__main__":
sys.stdout.reconfigure(encoding='utf-8')
fs=FewShotPosts()
posts = fs.get_filtered_posts('Short','English','Scams')
print(posts)