diff --git a/cdqa/pipeline/cdqa_sklearn.py b/cdqa/pipeline/cdqa_sklearn.py index a15f01b..1e6eee4 100644 --- a/cdqa/pipeline/cdqa_sklearn.py +++ b/cdqa/pipeline/cdqa_sklearn.py @@ -4,6 +4,9 @@ import pandas as pd import numpy as np import torch +import inspect + +from typing import List from sklearn.base import BaseEstimator @@ -135,6 +138,7 @@ def predict( n_predictions: int = None, retriever_score_weight: float = 0.35, return_all_preds: bool = False, + extra_metadata: List[str] = [], ): """ Compute prediction of an answer to a question @@ -165,7 +169,6 @@ def predict( given the question. """ - if not isinstance(query, str): raise TypeError( "The input is not a string. Please provide a string as input." @@ -180,13 +183,16 @@ def predict( best_idx_scores=best_idx_scores, metadata=self.metadata, retrieve_by_doc=self.retrieve_by_doc, + extra_metadata=extra_metadata, ) + self.processor_predict.set_extra_metadata(extra_metadata) examples, features = self.processor_predict.fit_transform(X=squad_examples) prediction = self.reader.predict( X=(examples, features), n_predictions=n_predictions, retriever_score_weight=retriever_score_weight, return_all_preds=return_all_preds, + extra_metadata=extra_metadata, ) return prediction diff --git a/cdqa/reader/bertqa_sklearn.py b/cdqa/reader/bertqa_sklearn.py index 04987f7..b0e4c89 100644 --- a/cdqa/reader/bertqa_sklearn.py +++ b/cdqa/reader/bertqa_sklearn.py @@ -60,27 +60,29 @@ class SquadExample(object): def __init__( self, - qas_id, - question_text, - doc_tokens, - orig_answer_text=None, - start_position=None, - end_position=None, - is_impossible=None, - paragraph=None, - title=None, - retriever_score=None, + #qas_id, + #question_text, + #doc_tokens, + #orig_answer_text=None, + #start_position=None, + #end_position=None, + #is_impossible=None, + #paragraph=None, + #title=None, + #retriever_score=None, + **kwargs ): - self.qas_id = qas_id - self.question_text = question_text - self.doc_tokens = doc_tokens - self.orig_answer_text = orig_answer_text - self.start_position = start_position - self.end_position = end_position - self.is_impossible = is_impossible - self.paragraph = paragraph - self.title = title - self.retriever_score = retriever_score + #self.qas_id = qas_id + #self.question_text = question_text + #self.doc_tokens = doc_tokens + #self.orig_answer_text = orig_answer_text + #self.start_position = start_position + #self.end_position = end_position + #self.is_impossible = is_impossible + #self.paragraph = paragraph + #self.title = title + #self.retriever_score = retriever_score + self.__dict__.update(kwargs) def __str__(self): return self.__repr__() @@ -135,7 +137,7 @@ def __init__(self, self.is_impossible = is_impossible -def read_squad_examples(input_file, is_training, version_2_with_negative): +def read_squad_examples(input_file, is_training, version_2_with_negative, extra_metadata): """Read a SQuAD json file into a list of SquadExample.""" if isinstance(input_file, str): @@ -213,18 +215,34 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): end_position = -1 orig_answer_text = "" + se_args = dict( + qas_id= qas_id, + question_text= question_text, + doc_tokens= doc_tokens, + orig_answer_text= orig_answer_text, + start_position= start_position, + end_position= end_position, + is_impossible= is_impossible, + paragraph= paragraph_text, + title= entry["title"], + retriever_score= retriever_score, + ) + for data in extra_metadata: + se_args[data] = entry[data] + examples.append( SquadExample( - qas_id=qas_id, - question_text=question_text, - doc_tokens=doc_tokens, - orig_answer_text=orig_answer_text, - start_position=start_position, - end_position=end_position, - is_impossible=is_impossible, - paragraph=paragraph_text, - title=entry["title"], - retriever_score=retriever_score, + #qas_id=qas_id, + #question_text=question_text, + #doc_tokens=doc_tokens, + #orig_answer_text=orig_answer_text, + #start_position=start_position, + #end_position=end_position, + #is_impossible=is_impossible, + #paragraph=paragraph_text, + #title=entry["title"], + #retriever_score=retriever_score, + **se_args ) ) return examples @@ -548,6 +566,7 @@ def write_predictions( null_score_diff_threshold, retriever_score_weight, n_predictions=None, + extra_metadata=[], ): """ Write final predictions to the json file and log-odds of null if needed. @@ -755,18 +774,22 @@ def write_predictions( best_dict["final_score"] = (1 - retriever_score_weight) * ( best_dict["start_logit"] + best_dict["end_logit"] ) + retriever_score_weight * best_dict["retriever_score"] + + for data in extra_metadata: + best_dict.update({data: getattr(example, data)}) final_predictions.append(best_dict) final_predictions_sorted = sorted( final_predictions, key=lambda d: d["final_score"], reverse=True ) - best_prediction = ( - final_predictions_sorted[0]["text"], - final_predictions_sorted[0]["title"], - final_predictions_sorted[0]["paragraph"], - final_predictions_sorted[0]["final_score"], - ) + #best_prediction = ( + # final_predictions_sorted[0]["text"], + # final_predictions_sorted[0]["title"], + # final_predictions_sorted[0]["paragraph"], + # final_predictions_sorted[0]["final_score"], + #) + best_prediction = final_predictions_sorted[0] return_list = [best_prediction, final_predictions_sorted] @@ -926,13 +949,13 @@ def _n_best_predictions(final_predictions_sorted, n): n = min(n, len(final_predictions_sorted)) final_prediction_list = [] for i in range(n): - curr_pred = ( - final_predictions_sorted[i]["text"], - final_predictions_sorted[i]["title"], - final_predictions_sorted[i]["paragraph"], - final_predictions_sorted[i]["final_score"], - ) - final_prediction_list.append(curr_pred) + #curr_pred = ( + # final_predictions_sorted[i]["text"], + # final_predictions_sorted[i]["title"], + # final_predictions_sorted[i]["paragraph"], + # final_predictions_sorted[i]["final_score"], + #) + final_prediction_list.append(final_predictions_sorted[i]) return final_prediction_list def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): @@ -1000,6 +1023,7 @@ def __init__( max_query_length=64, verbose=False, tokenizer=None, + extra_metadata=[], ): self.bert_model = bert_model @@ -1010,6 +1034,7 @@ def __init__( self.doc_stride = doc_stride self.max_query_length = max_query_length self.verbose = verbose + self.extra_metadata = extra_metadata if tokenizer is None: self.tokenizer = BertTokenizer.from_pretrained( @@ -1027,6 +1052,7 @@ def transform(self, X): input_file=X, is_training=self.is_training, version_2_with_negative=self.version_2_with_negative, + extra_metadata=self.extra_metadata, ) features = convert_examples_to_features( @@ -1041,6 +1067,9 @@ def transform(self, X): return examples, features + def set_extra_metadata(self, extra_metadata=[]): + self.extra_metadata = extra_metadata + class BertQA(BaseEstimator): """ @@ -1428,7 +1457,7 @@ def fit(self, X, y=None): return self def predict( - self, X, n_predictions=None, retriever_score_weight=0.35, return_all_preds=False + self, X, n_predictions=None, retriever_score_weight=0.35, return_all_preds=False, extra_metadata=[] ): eval_examples, eval_features = X @@ -1507,6 +1536,7 @@ def predict( self.null_score_diff_threshold, retriever_score_weight, n_predictions, + extra_metadata, ) if n_predictions is not None: diff --git a/cdqa/utils/converters.py b/cdqa/utils/converters.py index 146c9b5..d832c88 100644 --- a/cdqa/utils/converters.py +++ b/cdqa/utils/converters.py @@ -61,7 +61,7 @@ def df2squad(df, squad_version="v1.1", output_dir=None, filename=None): return json_data -def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc): +def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc, extra_metadata): """ Creates a SQuAD examples json object for a given for a given question using outputs of retriever and document database. @@ -94,6 +94,8 @@ def generate_squad_examples(question, best_idx_scores, metadata, retrieve_by_doc for idx, row in metadata_sliced.iterrows(): temp = {"title": row["title"], "paragraphs": []} + for data in extra_metadata: + temp[data] = row[data] if retrieve_by_doc: for paragraph in row["paragraphs"]: