Skip to content

retrieve extra medatada from csv (#323) #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cdqa/pipeline/cdqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
import numpy as np
import torch
import inspect

from typing import List

from sklearn.base import BaseEstimator

Expand Down Expand Up @@ -135,6 +138,7 @@ def predict(
n_predictions: int = None,
retriever_score_weight: float = 0.35,
return_all_preds: bool = False,
extra_metadata: List[str] = [],
):
""" Compute prediction of an answer to a question

Expand Down Expand Up @@ -165,7 +169,6 @@ def predict(
given the question.

"""

if not isinstance(query, str):
raise TypeError(
"The input is not a string. Please provide a string as input."
Expand All @@ -180,13 +183,16 @@ def predict(
best_idx_scores=best_idx_scores,
metadata=self.metadata,
retrieve_by_doc=self.retrieve_by_doc,
extra_metadata=extra_metadata,
)
self.processor_predict.set_extra_metadata(extra_metadata)
examples, features = self.processor_predict.fit_transform(X=squad_examples)
prediction = self.reader.predict(
X=(examples, features),
n_predictions=n_predictions,
retriever_score_weight=retriever_score_weight,
return_all_preds=return_all_preds,
extra_metadata=extra_metadata,
)
return prediction

Expand Down
120 changes: 75 additions & 45 deletions cdqa/reader/bertqa_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,29 @@ class SquadExample(object):

def __init__(
self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
start_position=None,
end_position=None,
is_impossible=None,
paragraph=None,
title=None,
retriever_score=None,
#qas_id,
#question_text,
#doc_tokens,
#orig_answer_text=None,
#start_position=None,
#end_position=None,
#is_impossible=None,
#paragraph=None,
#title=None,
#retriever_score=None,
**kwargs
):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.paragraph = paragraph
self.title = title
self.retriever_score = retriever_score
#self.qas_id = qas_id
#self.question_text = question_text
#self.doc_tokens = doc_tokens
#self.orig_answer_text = orig_answer_text
#self.start_position = start_position
#self.end_position = end_position
#self.is_impossible = is_impossible
#self.paragraph = paragraph
#self.title = title
#self.retriever_score = retriever_score
self.__dict__.update(kwargs)

def __str__(self):
return self.__repr__()
Expand Down Expand Up @@ -135,7 +137,7 @@ def __init__(self,
self.is_impossible = is_impossible


def read_squad_examples(input_file, is_training, version_2_with_negative):
def read_squad_examples(input_file, is_training, version_2_with_negative, extra_metadata):
"""Read a SQuAD json file into a list of SquadExample."""

if isinstance(input_file, str):
Expand Down Expand Up @@ -213,18 +215,34 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
end_position = -1
orig_answer_text = ""

se_args = dict(
qas_id= qas_id,
question_text= question_text,
doc_tokens= doc_tokens,
orig_answer_text= orig_answer_text,
start_position= start_position,
end_position= end_position,
is_impossible= is_impossible,
paragraph= paragraph_text,
title= entry["title"],
retriever_score= retriever_score,
)
for data in extra_metadata:
se_args[data] = entry[data]

examples.append(
SquadExample(
qas_id=qas_id,
question_text=question_text,
doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text,
start_position=start_position,
end_position=end_position,
is_impossible=is_impossible,
paragraph=paragraph_text,
title=entry["title"],
retriever_score=retriever_score,
#qas_id=qas_id,
#question_text=question_text,
#doc_tokens=doc_tokens,
#orig_answer_text=orig_answer_text,
#start_position=start_position,
#end_position=end_position,
#is_impossible=is_impossible,
#paragraph=paragraph_text,
#title=entry["title"],
#retriever_score=retriever_score,
**se_args
)
)
return examples
Expand Down Expand Up @@ -548,6 +566,7 @@ def write_predictions(
null_score_diff_threshold,
retriever_score_weight,
n_predictions=None,
extra_metadata=[],
):
"""
Write final predictions to the json file and log-odds of null if needed.
Expand Down Expand Up @@ -755,18 +774,22 @@ def write_predictions(
best_dict["final_score"] = (1 - retriever_score_weight) * (
best_dict["start_logit"] + best_dict["end_logit"]
) + retriever_score_weight * best_dict["retriever_score"]

for data in extra_metadata:
best_dict.update({data: getattr(example, data)})
final_predictions.append(best_dict)

final_predictions_sorted = sorted(
final_predictions, key=lambda d: d["final_score"], reverse=True
)

best_prediction = (
final_predictions_sorted[0]["text"],
final_predictions_sorted[0]["title"],
final_predictions_sorted[0]["paragraph"],
final_predictions_sorted[0]["final_score"],
)
#best_prediction = (
# final_predictions_sorted[0]["text"],
# final_predictions_sorted[0]["title"],
# final_predictions_sorted[0]["paragraph"],
# final_predictions_sorted[0]["final_score"],
#)
best_prediction = final_predictions_sorted[0]

return_list = [best_prediction, final_predictions_sorted]

Expand Down Expand Up @@ -926,13 +949,13 @@ def _n_best_predictions(final_predictions_sorted, n):
n = min(n, len(final_predictions_sorted))
final_prediction_list = []
for i in range(n):
curr_pred = (
final_predictions_sorted[i]["text"],
final_predictions_sorted[i]["title"],
final_predictions_sorted[i]["paragraph"],
final_predictions_sorted[i]["final_score"],
)
final_prediction_list.append(curr_pred)
#curr_pred = (
# final_predictions_sorted[i]["text"],
# final_predictions_sorted[i]["title"],
# final_predictions_sorted[i]["paragraph"],
# final_predictions_sorted[i]["final_score"],
#)
final_prediction_list.append(final_predictions_sorted[i])
return final_prediction_list

def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
Expand Down Expand Up @@ -1000,6 +1023,7 @@ def __init__(
max_query_length=64,
verbose=False,
tokenizer=None,
extra_metadata=[],
):

self.bert_model = bert_model
Expand All @@ -1010,6 +1034,7 @@ def __init__(
self.doc_stride = doc_stride
self.max_query_length = max_query_length
self.verbose = verbose
self.extra_metadata = extra_metadata

if tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained(
Expand All @@ -1027,6 +1052,7 @@ def transform(self, X):
input_file=X,
is_training=self.is_training,
version_2_with_negative=self.version_2_with_negative,
extra_metadata=self.extra_metadata,
)

features = convert_examples_to_features(
Expand All @@ -1041,6 +1067,9 @@ def transform(self, X):

return examples, features

def set_extra_metadata(self, extra_metadata=[]):
self.extra_metadata = extra_metadata


class BertQA(BaseEstimator):
"""
Expand Down Expand Up @@ -1428,7 +1457,7 @@ def fit(self, X, y=None):
return self

def predict(
self, X, n_predictions=None, retriever_score_weight=0.35, return_all_preds=False
self, X, n_predictions=None, retriever_score_weight=0.35, return_all_preds=False, extra_metadata=[]
):

eval_examples, eval_features = X
Expand Down Expand Up @@ -1507,6 +1536,7 @@ def predict(
self.null_score_diff_threshold,
retriever_score_weight,
n_predictions,
extra_metadata,
)

if n_predictions is not None:
Expand Down
4 changes: 3 additions & 1 deletion cdqa/utils/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def df2squad(df, squad_version="v1.1", output_dir=None, filename=None):
return json_data


def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc):
def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc, extra_metadata):
"""
Creates a SQuAD examples json object for a given for a given question using outputs of retriever and document database.

Expand Down Expand Up @@ -94,6 +94,8 @@ def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc

for idx, row in metadata_sliced.iterrows():
temp = {"title": row["title"], "paragraphs": []}
for data in extra_metadata:
temp[data] = row[data]

if retrieve_by_doc:
for paragraph in row["paragraphs"]:
Expand Down