Skip to content

Commit

Permalink
Issue #131: Add extended data splits from annotation round 3.
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikmn committed Apr 25, 2021
1 parent a9fd69d commit 81c452d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
7 changes: 5 additions & 2 deletions utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Posts:
"""
AVAILABLE_LABELS = ['label_argumentsused', 'label_discriminating', 'label_inappropriate',
'label_offtopic', 'label_personalstories', 'label_possiblyfeedback',
'label_sentimentnegative', 'label_sentimentpositive', 'label_needsmoderation', 'label_negative']
'label_sentimentnegative', 'label_sentimentpositive', 'label_needsmoderation']

def __init__(self):
df = loading.load_extended_posts()
Expand All @@ -70,9 +70,12 @@ def get_X_y(self, split:str=None, label:str=None, balance_method:str=None, sampl
X: The feature (text column) of the posts
y: The target annotations
"""
if split:
if split in ['train', 'test', 'val']:
filter_frame = pd.read_csv(f'./data/ann2_{split}.csv', header=None, index_col=0, names=['id_post'])
df = self.df.merge(filter_frame, how='inner', on='id_post')
elif split in ['ann3_all', 'ann3_israelpalestine', 'ann3_refugees', 'ann3_women']:
filter_frame = pd.read_csv(f'./data/{split}.csv', header=None, index_col=0, names=['id_post'])
df = self.df.merge(filter_frame, how='inner', on='id_post')
else:
df = self.df.copy()

Expand Down
26 changes: 21 additions & 5 deletions utils/train_test_val_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,30 @@ def create_splits():

df_train = pd.concat([ann2_train, df_ann3_feedback_stories], axis=0)

print(f"Number of posts in train-set: {df_train.shape[0]}")
print(f"Number of posts in val-set: {ann2_val.shape[0]}")
print(f"Number of posts in test-set: {ann2_test.shape[0]}")
label_subset = ['label_sentimentnegative', 'label_inappropriate', 'label_discriminating']
articles_israelpalestine = {9767, 10820, 11105}
articles_refugees = {1860, 11004, 10425, 10707}
articles_women = {1172, 1704, 1831}
ann3_all = df_posts.query("ann_round == 3").dropna(subset=label_subset)
ann3_israelpalestine = df_posts.query("ann_round == 3 and id_article in @articles_israelpalestine").dropna(subset=label_subset)
ann3_refugees = df_posts.query("ann_round == 3 and id_article in @articles_refugees").dropna(subset=label_subset)
ann3_women = df_posts.query("ann_round == 3 and id_article in @articles_women").dropna(subset=label_subset)

print(f"Number of posts in train set: {df_train.shape[0]}")
print(f"Number of posts in val set: {ann2_val.shape[0]}")
print(f"Number of posts in test set: {ann2_test.shape[0]}")
print(f"Number of posts in ann3_all set: {ann3_all.shape[0]}")
print(f"Number of posts in ann3_israelpalestine set: {ann3_israelpalestine.shape[0]}")
print(f"Number of posts in ann3_refugees set: {ann3_refugees.shape[0]}")
print(f"Number of posts in ann3_women set: {ann3_women.shape[0]}")
df_train.id_post.to_csv('./data/ann2_train.csv', header=False)
ann2_test.id_post.to_csv('./data/ann2_test.csv', header=False)
ann2_val.id_post.to_csv('./data/ann2_val.csv', header=False)
ann2_test.id_post.to_csv('./data/ann2_test.csv', header=False)
ann3_all.id_post.to_csv('./data/ann3_all.csv', header=False)
ann3_israelpalestine.id_post.to_csv('./data/ann3_israelpalestine.csv', header=False)
ann3_refugees.id_post.to_csv('./data/ann3_refugees.csv', header=False)
ann3_women.id_post.to_csv('./data/ann3_women.csv', header=False)
print('Splits created.')


if __name__ == '__main__':
create_splits()

0 comments on commit 81c452d

Please sign in to comment.