Skip to content

Commit

Permalink
#11 Add files to label ground-truth SemDoms and evaluate SDI
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Janetzki committed May 15, 2023
1 parent 793dfea commit 09777c4
Show file tree
Hide file tree
Showing 14 changed files with 770 additions and 20 deletions.
Empty file added src/sd_labeling/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions src/sd_labeling/eval_simple_sdi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pandas as pd

test_file_num_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

# load test sets
dfs = []
for i in test_file_num_list:
dfs.append(pd.read_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv", usecols=["direct_question", "answer"]))
df_concat = pd.concat(dfs, ignore_index=True)

# count number of 0 and 1 answers
print(df_concat['answer'].value_counts())

# precision = 0.39 (464/1200)
# recall = 1.0
# f1 = 0.56
36 changes: 36 additions & 0 deletions src/sd_labeling/match_direct_questions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pandas as pd
from tqdm import tqdm

# load matched_questions_gospels_raw.csv
raw_df = pd.read_csv('data/2_sd_labeling/matched_questions_gospels_raw.csv')

# keep rows with qids not starting with 9
raw_df = raw_df[~raw_df['qid'].str.startswith('9')]

# load matched_questions.xlsx
mq_df = pd.read_excel('data/2_sd_labeling/matched_questions.xlsx', nrows=156501)

# add column question_without_reference to raw_df and mq_df
raw_df['question_without_reference'] = raw_df['direct_question'].replace(r'\(\w{3} \d+:\d+\) ', '', regex=True)
mq_df['question_without_reference'] = mq_df['direct_question'].replace(r'\(\w{3} \d+:\d+\) ', '', regex=True)

# use question_without_reference as index
raw_df = raw_df.set_index('question_without_reference')
mq_df = mq_df.set_index('question_without_reference')

# for each question_without_reference in mq_df, find all corresponding rows in raw_df
for index, row in tqdm(mq_df.iterrows(), desc='Filling up answers...', total=mq_df.shape[0]):
raw_df.loc[(raw_df.index == index) & (raw_df['qid'] == row['qid']), ['answer', 'gpt3_answer']] =\
pd.DataFrame([row[['answer', 'gpt3_answer']]])

matches = raw_df.loc[(raw_df.index == index) & (raw_df['qid'] == row['qid']), ['answer', 'gpt3_answer']]
if type(matches) == pd.Series:
matches = pd.DataFrame([matches])
num_matches = len(matches)
if row['num_verses'] != num_matches:
print(row['num_verses'], num_matches, row)

# save raw_df to matched_questions_gospels_raw.csv
raw_df.to_csv('data/2_sd_labeling/matched_questions_gospels_raw_out.csv', index=False)

print('Done.')
Empty file.
60 changes: 60 additions & 0 deletions src/sd_labeling/sd_label_tool/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random
from flask import Flask, jsonify, render_template, request
from src.dictionary_creator.tfidf_dictionary_creator import TfidfDictionaryCreator
from src.semantic_domain_identifier import SemanticDomainIdentifier

app = Flask(__name__)

# load all verses from the bid-eng-DBY bible
dc = TfidfDictionaryCreator(['bid-eng-web', 'bid-deu'], score_threshold=0.2, state_files_path='../../../data/0_state')
sdi = SemanticDomainIdentifier(dc)
verses_eng = sdi.dc.wtxts_by_verse_by_bid['bid-eng-web']
verses_deu = sdi.dc.wtxts_by_verse_by_bid['bid-deu']


def generate_sentence():
idx = random.randrange(0, len(verses_eng))
return idx, verses_eng[idx], verses_deu[idx]


def generate_checkboxes():
idx, words_eng, words_deu = generate_sentence()
sentence = f'{idx}<br>' + ' '.join(words_eng) + '<br>'
sentence += ' '.join([wtxt.split('_')[0] for wtxt in words_eng]) + '<br>' + ' '.join(words_deu)

sdi.qid_by_wtxt = sdi.gt_qid_by_wtxt
identified_qids = sdi.identify_semantic_domains([words_eng])
identified_qids = sorted(identified_qids, key=lambda t: (t[0], t[2]))

identified_qids = [(idx, f'"{wtxt.split("_")[0].upper()}"', qid, sd_name, question, words)
for (idx, wtxt, qid, sd_name, question, words) in identified_qids]
checkboxes = [{'checked': '',
'word': f'{question.replace("#", wtxt)}'}
for (idx, wtxt, qid, sd_name, question, words) in identified_qids]

# for each start_token_idx that occurs only once, mark it as checked (because it is the only option)
for outer_idx, (idx, wtxt, qid, sd_name, question, words) in enumerate(identified_qids):
if len([t for t in identified_qids if t[0] == idx]) == 1:
checkboxes[outer_idx]['checked'] = 'checked'

return sentence, checkboxes


@app.route('/', methods=['GET', 'POST'])
def home():
if request.method == 'POST':
selected_words = request.form.getlist('word')
sentence = request.form['sentence']
print(sentence, selected_words)
sentence, checkboxes = generate_checkboxes()
return render_template('index.html', sentence=sentence, checkboxes=checkboxes)


@app.route('/generate_checkboxes')
def fetch_checkboxes():
sentence, checkboxes = generate_checkboxes()
return jsonify({'sentence': sentence, 'checkboxes': checkboxes})


if __name__ == '__main__':
app.run(debug=True)
27 changes: 27 additions & 0 deletions src/sd_labeling/sd_label_tool/static/script.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// function generateCheckboxes() {
// const checkboxesDiv = document.getElementById('checkboxes');
// let sentence = document.getElementById('sentence').textContent;
//
// fetch('/generate_checkboxes')
// .then(res => res.json())
// .then(data => {
// // Update the sentence text
// sentence = data.sentence;
// document.getElementById('sentence').textContent = sentence;
//
// // Update the checkboxes
// checkboxesDiv.innerHTML = '';
// for (const checkbox of data.checkboxes) {
// const label = document.createElement('label');
// const input = document.createElement('input');
// input.type = 'checkbox';
// input.name = 'word';
// input.value = checkbox.word;
// label.appendChild(input);
// label.appendChild(document.createTextNode(' ' + checkbox.word));
// checkboxesDiv.appendChild(label);
// checkboxesDiv.appendChild(document.createElement('br'));
// }
// })
// .catch(error => console.error(error));
// }
19 changes: 19 additions & 0 deletions src/sd_labeling/sd_label_tool/templates/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<!DOCTYPE html>
<html>
<head>
<title>Semantic Domain Labeler</title>
<script src="{{ url_for('static', filename='script.js') }}"></script>
</head>
<body>
<form method="POST">
<h1>{{ sentence | safe }}</h1>
<input type="hidden" name="sentence" id="sentence-input" value="{{ sentence }}">
<button type="submit">Next</button>
<div id="checkboxes">
{% for checkbox in checkboxes %}
<label><input type="checkbox" name="word" value="{{ checkbox.word }}" {{ checkbox.checked }}>{{ checkbox.word }}</label><br>
{% endfor %}
</div>
</form>
</body>
</html>
49 changes: 49 additions & 0 deletions src/sd_labeling/sd_question_bulk_answerer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# bulk answer semantic domain questions using the following approach:

import math
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

df = pd.read_excel("data/2_sd_labeling/matched_questions.xlsx", usecols=["direct_question", "answer", "qid", "tokens"],
nrows=156501)

# assert that no qid starts with '9'
assert not any(df['qid'].str.startswith('9'))

# filter out qids that start with '1'
df = df[~df['qid'].str.startswith('1')]

# group questions by (qid, token)
qid_and_token_to_rows = defaultdict(list)
for idx, row in df.iterrows():
qid_and_token_to_rows[(row['qid'], row['tokens'])].append((idx, row))

# filter (qid, token) pairs that have at least 2 answers (0, 1)
qid_and_token_to_rows = {k: rows for k, rows in qid_and_token_to_rows.items()
if len([row['answer'] for (idx, row) in rows
if row['answer'] in (0, 1)]) >= 2}

# if all answers are the same, use this answer as the answer to all questions of this (qid, token) pair
additional_answers_count = 0
for _, rows in tqdm(qid_and_token_to_rows.items(),
desc='Bulk answering SD questions...',
total=len(qid_and_token_to_rows)):
answers = [row['answer'] for (idx, row) in rows if row['answer'] in (0, 1)]
if len(set(answers)) > 1:
continue
bulk_answer = answers[0]
for (idx, row) in rows:
if math.isnan(row['answer']):
additional_answers_count += 1
target_row = df.loc[idx]
assert target_row['direct_question'] == row['direct_question']
target_row['answer'] = bulk_answer # todo: fix bug that df does not add new answer
print(f'{bulk_answer}: {row["direct_question"]}')
else:
assert (row['answer'] == bulk_answer)
print(f'Added {additional_answers_count} additional answers.')

# save to csv
df.to_csv("data/2_sd_labeling/matched_questions_with_bulk_answers.csv", index=False)
print('Done.')
21 changes: 21 additions & 0 deletions src/sd_labeling/sd_test_set_creator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pandas as pd

test_file_num_start = 10
test_file_num_end = test_file_num_start + 3

# # select 3 random rows from df and save them to a csv file
# df = pd.read_excel("data/2_sd_labeling/matched_questions.xlsx", nrows=156501)
#
# for i in range(test_file_num_start, test_file_num_end):
# df_sample = df.sample(n=100, random_state=i)
# df_sample.to_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv", index=False)

# identify the number of duplicates
dfs = []
for i in range(1, test_file_num_end):
dfs.append(pd.read_csv(f"data/2_sd_labeling/test sets/test_set_{i}.csv"))
df_concat = pd.concat(dfs, ignore_index=True)
print(df_concat['direct_question'].duplicated().sum())

# print duplicates
print(df_concat[df_concat['direct_question'].duplicated()])
Loading

0 comments on commit 09777c4

Please sign in to comment.