From f8afc26e6b6bba736aff553a7592abe36fea9dbd Mon Sep 17 00:00:00 2001 From: Christian Steck Date: Fri, 23 Apr 2021 13:33:04 +0200 Subject: [PATCH] test_round_3 added to loading --- utils/augmenting.py | 35 ++++++++++++++++++++++++++++++++++- utils/loading.py | 12 +++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/utils/augmenting.py b/utils/augmenting.py index beefdd9..e2bf7a9 100644 --- a/utils/augmenting.py +++ b/utils/augmenting.py @@ -112,4 +112,37 @@ def get_augmented_val_X_y(X, y, label): - \ No newline at end of file +def get_augmented_val_id(): + '''get a dataset with augmented texts for the minority positive label + Arguments: X, y - pandas series containing the validation data that needs to be augmented + label - label that needs to be augmented + sampling_strategy - float representing the proportion of positive vs negative labels in the augmented dataframe (range [>0.0; <=1.ß]) + Return: augmented X, y''' + + label_range = ['label_sentimentnegative', 'label_inappropriate', 'label_discriminating', 'label_needsmoderation'] + file_cached = "./cache/df_r3.csv" + try: + df_r3 = pd.read_csv(file_cached) + + except: + df_r3 = loading.load_extended_posts(label=label) + df_r3 = feature_engineering.add_column_ann_round(df_r3) + df_r3 = feature_engineering.add_column_text(df_r3) + df_r3 = df_r3.query('ann_round==3').copy() + df_r3.to_csv(file_cached) + + df_r3 = feature_engineering.add_column_label_needsmoderation(df_r3) + art_list = list(df_r3.id_article.unique()) + + label_range = ['label_sentimentnegative', 'label_inappropriate', 'label_discriminating', 'label_needsmoderation'] + df_ann = pd.DataFrame(columns=df_r3.columns) + + id_list = [] + for label in label_range: + for i in art_list: + df_ann = pd.concat((df_ann, + df_r3.query(f'id_article=={i} and {label}==1').sample(1, + random_state=42))) + id_list.extend(list(df_ann.id_post)) + + return list(set(id_list)) \ No newline at end of file diff --git a/utils/loading.py b/utils/loading.py index 7f124ba..1a59234 100644 --- a/utils/loading.py +++ b/utils/loading.py @@ -1,7 +1,7 @@ # Imports import pandas as pd import sqlite3 - +from utils import augmenting, feature_engineering def get_database_connection(path='./data/corpus.sqlite3'): con = sqlite3.connect(path) @@ -60,7 +60,7 @@ def load_extended_posts(split:str=None, label:str=None): ''' Load post table extended by annotations and staff. Args: - - split: [None, 'test', 'train', 'val']. Reduce dataframe to test/train/validation split only. + - split: [None, 'test', 'test_r3', 'train', 'val']. Reduce dataframe to test/train/validation split only. Returns: - Dataframe ''' @@ -70,9 +70,14 @@ def load_extended_posts(split:str=None, label:str=None): df_staff = load_staff() df_articles = load_articles() - if split: + if split in ['train', 'test', 'val']: filter_frame = pd.read_csv(f'./data/ann2_{split}.csv', header=None, index_col=0, names=['id_post']) df_posts = df_posts.merge(filter_frame, how='inner', on='id_post') + elif split == 'test_r3': + id_list = augmenting.get_augmented_val_id() + df_posts= feature_engineering.add_column_ann_round(df_posts) + df_posts= df_posts.query(f'id_post not in {id_list} & ann_round==3') + print(df_posts.shape) # prepare annotations annotations = df_annotations.pivot(index="id_post", columns="category", values="value") @@ -95,3 +100,4 @@ def load_extended_posts(split:str=None, label:str=None): if label: df = df.dropna(subset=[label]) return df +