-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#11 Add files to label ground-truth SemDoms and evaluate SDI
- Loading branch information
Jonathan Janetzki
committed
May 15, 2023
1 parent
793dfea
commit 09777c4
Showing
14 changed files
with
770 additions
and
20 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import pandas as pd | ||
|
||
test_file_num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] | ||
|
||
# load test sets | ||
dfs = [] | ||
for i in test_file_num_list: | ||
dfs.append(pd.read_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv", usecols=["direct_question", "answer"])) | ||
df_concat = pd.concat(dfs, ignore_index=True) | ||
|
||
# count number of 0 and 1 answers | ||
print(df_concat['answer'].value_counts()) | ||
|
||
# precision = 0.39 (464/1200) | ||
# recall = 1.0 | ||
# f1 = 0.56 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
# load matched_questions_gospels_raw.csv | ||
raw_df = pd.read_csv('data/2_sd_labeling/matched_questions_gospels_raw.csv') | ||
|
||
# keep rows with qids not starting with 9 | ||
raw_df = raw_df[~raw_df['qid'].str.startswith('9')] | ||
|
||
# load matched_questions.xlsx | ||
mq_df = pd.read_excel('data/2_sd_labeling/matched_questions.xlsx', nrows=156501) | ||
|
||
# add column question_without_reference to raw_df and mq_df | ||
raw_df['question_without_reference'] = raw_df['direct_question'].replace(r'\(\w{3} \d+:\d+\) ', '', regex=True) | ||
mq_df['question_without_reference'] = mq_df['direct_question'].replace(r'\(\w{3} \d+:\d+\) ', '', regex=True) | ||
|
||
# use question_without_reference as index | ||
raw_df = raw_df.set_index('question_without_reference') | ||
mq_df = mq_df.set_index('question_without_reference') | ||
|
||
# for each question_without_reference in mq_df, find all corresponding rows in raw_df | ||
for index, row in tqdm(mq_df.iterrows(), desc='Filling up answers...', total=mq_df.shape[0]): | ||
raw_df.loc[(raw_df.index == index) & (raw_df['qid'] == row['qid']), ['answer', 'gpt3_answer']] =\ | ||
pd.DataFrame([row[['answer', 'gpt3_answer']]]) | ||
|
||
matches = raw_df.loc[(raw_df.index == index) & (raw_df['qid'] == row['qid']), ['answer', 'gpt3_answer']] | ||
if type(matches) == pd.Series: | ||
matches = pd.DataFrame([matches]) | ||
num_matches = len(matches) | ||
if row['num_verses'] != num_matches: | ||
print(row['num_verses'], num_matches, row) | ||
|
||
# save raw_df to matched_questions_gospels_raw.csv | ||
raw_df.to_csv('data/2_sd_labeling/matched_questions_gospels_raw_out.csv', index=False) | ||
|
||
print('Done.') |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import random | ||
from flask import Flask, jsonify, render_template, request | ||
from src.dictionary_creator.tfidf_dictionary_creator import TfidfDictionaryCreator | ||
from src.semantic_domain_identifier import SemanticDomainIdentifier | ||
|
||
app = Flask(__name__) | ||
|
||
# load all verses from the bid-eng-DBY bible | ||
dc = TfidfDictionaryCreator(['bid-eng-web', 'bid-deu'], score_threshold=0.2, state_files_path='../../../data/0_state') | ||
sdi = SemanticDomainIdentifier(dc) | ||
verses_eng = sdi.dc.wtxts_by_verse_by_bid['bid-eng-web'] | ||
verses_deu = sdi.dc.wtxts_by_verse_by_bid['bid-deu'] | ||
|
||
|
||
def generate_sentence(): | ||
idx = random.randrange(0, len(verses_eng)) | ||
return idx, verses_eng[idx], verses_deu[idx] | ||
|
||
|
||
def generate_checkboxes(): | ||
idx, words_eng, words_deu = generate_sentence() | ||
sentence = f'{idx}<br>' + ' '.join(words_eng) + '<br>' | ||
sentence += ' '.join([wtxt.split('_')[0] for wtxt in words_eng]) + '<br>' + ' '.join(words_deu) | ||
|
||
sdi.qid_by_wtxt = sdi.gt_qid_by_wtxt | ||
identified_qids = sdi.identify_semantic_domains([words_eng]) | ||
identified_qids = sorted(identified_qids, key=lambda t: (t[0], t[2])) | ||
|
||
identified_qids = [(idx, f'"{wtxt.split("_")[0].upper()}"', qid, sd_name, question, words) | ||
for (idx, wtxt, qid, sd_name, question, words) in identified_qids] | ||
checkboxes = [{'checked': '', | ||
'word': f'{question.replace("#", wtxt)}'} | ||
for (idx, wtxt, qid, sd_name, question, words) in identified_qids] | ||
|
||
# for each start_token_idx that occurs only once, mark it as checked (because it is the only option) | ||
for outer_idx, (idx, wtxt, qid, sd_name, question, words) in enumerate(identified_qids): | ||
if len([t for t in identified_qids if t[0] == idx]) == 1: | ||
checkboxes[outer_idx]['checked'] = 'checked' | ||
|
||
return sentence, checkboxes | ||
|
||
|
||
@app.route('/', methods=['GET', 'POST']) | ||
def home(): | ||
if request.method == 'POST': | ||
selected_words = request.form.getlist('word') | ||
sentence = request.form['sentence'] | ||
print(sentence, selected_words) | ||
sentence, checkboxes = generate_checkboxes() | ||
return render_template('index.html', sentence=sentence, checkboxes=checkboxes) | ||
|
||
|
||
@app.route('/generate_checkboxes') | ||
def fetch_checkboxes(): | ||
sentence, checkboxes = generate_checkboxes() | ||
return jsonify({'sentence': sentence, 'checkboxes': checkboxes}) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(debug=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// function generateCheckboxes() { | ||
// const checkboxesDiv = document.getElementById('checkboxes'); | ||
// let sentence = document.getElementById('sentence').textContent; | ||
// | ||
// fetch('/generate_checkboxes') | ||
// .then(res => res.json()) | ||
// .then(data => { | ||
// // Update the sentence text | ||
// sentence = data.sentence; | ||
// document.getElementById('sentence').textContent = sentence; | ||
// | ||
// // Update the checkboxes | ||
// checkboxesDiv.innerHTML = ''; | ||
// for (const checkbox of data.checkboxes) { | ||
// const label = document.createElement('label'); | ||
// const input = document.createElement('input'); | ||
// input.type = 'checkbox'; | ||
// input.name = 'word'; | ||
// input.value = checkbox.word; | ||
// label.appendChild(input); | ||
// label.appendChild(document.createTextNode(' ' + checkbox.word)); | ||
// checkboxesDiv.appendChild(label); | ||
// checkboxesDiv.appendChild(document.createElement('br')); | ||
// } | ||
// }) | ||
// .catch(error => console.error(error)); | ||
// } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
<!DOCTYPE html> | ||
<html> | ||
<head> | ||
<title>Semantic Domain Labeler</title> | ||
<script src="{{ url_for('static', filename='script.js') }}"></script> | ||
</head> | ||
<body> | ||
<form method="POST"> | ||
<h1>{{ sentence | safe }}</h1> | ||
<input type="hidden" name="sentence" id="sentence-input" value="{{ sentence }}"> | ||
<button type="submit">Next</button> | ||
<div id="checkboxes"> | ||
{% for checkbox in checkboxes %} | ||
<label><input type="checkbox" name="word" value="{{ checkbox.word }}" {{ checkbox.checked }}>{{ checkbox.word }}</label><br> | ||
{% endfor %} | ||
</div> | ||
</form> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# bulk answer semantic domain questions using the following approach: | ||
|
||
import math | ||
import pandas as pd | ||
from collections import defaultdict | ||
from tqdm import tqdm | ||
|
||
df = pd.read_excel("data/2_sd_labeling/matched_questions.xlsx", usecols=["direct_question", "answer", "qid", "tokens"], | ||
nrows=156501) | ||
|
||
# assert that no qid starts with '9' | ||
assert not any(df['qid'].str.startswith('9')) | ||
|
||
# filter out qids that start with '1' | ||
df = df[~df['qid'].str.startswith('1')] | ||
|
||
# group questions by (qid, token) | ||
qid_and_token_to_rows = defaultdict(list) | ||
for idx, row in df.iterrows(): | ||
qid_and_token_to_rows[(row['qid'], row['tokens'])].append((idx, row)) | ||
|
||
# filter (qid, token) pairs that have at least 2 answers (0, 1) | ||
qid_and_token_to_rows = {k: rows for k, rows in qid_and_token_to_rows.items() | ||
if len([row['answer'] for (idx, row) in rows | ||
if row['answer'] in (0, 1)]) >= 2} | ||
|
||
# if all answers are the same, use this answer as the answer to all questions of this (qid, token) pair | ||
additional_answers_count = 0 | ||
for _, rows in tqdm(qid_and_token_to_rows.items(), | ||
desc='Bulk answering SD questions...', | ||
total=len(qid_and_token_to_rows)): | ||
answers = [row['answer'] for (idx, row) in rows if row['answer'] in (0, 1)] | ||
if len(set(answers)) > 1: | ||
continue | ||
bulk_answer = answers[0] | ||
for (idx, row) in rows: | ||
if math.isnan(row['answer']): | ||
additional_answers_count += 1 | ||
target_row = df.loc[idx] | ||
assert target_row['direct_question'] == row['direct_question'] | ||
target_row['answer'] = bulk_answer # todo: fix bug that df does not add new answer | ||
print(f'{bulk_answer}: {row["direct_question"]}') | ||
else: | ||
assert (row['answer'] == bulk_answer) | ||
print(f'Added {additional_answers_count} additional answers.') | ||
|
||
# save to csv | ||
df.to_csv("data/2_sd_labeling/matched_questions_with_bulk_answers.csv", index=False) | ||
print('Done.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pandas as pd | ||
|
||
test_file_num_start = 10 | ||
test_file_num_end = test_file_num_start + 3 | ||
|
||
# # select 3 random rows from df and save them to a csv file | ||
# df = pd.read_excel("data/2_sd_labeling/matched_questions.xlsx", nrows=156501) | ||
# | ||
# for i in range(test_file_num_start, test_file_num_end): | ||
# df_sample = df.sample(n=100, random_state=i) | ||
# df_sample.to_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv", index=False) | ||
|
||
# identify the number of duplicates | ||
dfs = [] | ||
for i in range(1, test_file_num_end): | ||
dfs.append(pd.read_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv")) | ||
df_concat = pd.concat(dfs, ignore_index=True) | ||
print(df_concat['direct_question'].duplicated().sum()) | ||
|
||
# print duplicates | ||
print(df_concat[df_concat['direct_question'].duplicated()]) |
Oops, something went wrong.