-
Notifications
You must be signed in to change notification settings - Fork 59
/
Copy pathopen_scholar.py
765 lines (672 loc) · 37.9 KB
/
open_scholar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
from tqdm import tqdm
import os
import re
import spacy
from src.use_search_apis import search_paper_via_query, retrieve_pes2o_passages
import numpy as np
import os
from nltk import sent_tokenize
import vllm
import src.instructions as instructions
from FlagEmbedding import FlagReranker
nlp = spacy.load('en_core_web_sm')
# To compute API costs based on October 2023 pricing available at https://openai.com/ja-JP/api/pricing/
price_per_million = {"gpt-4o": 2.50, "gpt-4o-2024-08-06": 2.50, "gpt-4o-2024-05-13": 5.00, "gpt-4o-mini": 0.15, "gpt-4o-mini-2024-07-18": 0.15, "gpt-4-turbo": 10.0, "gpt-3.5-turbo-0125": 0.50}
price_per_million_output = {"gpt-4o": 10.00, "gpt-4o-2024-08-06": 10.00, "gpt-4o-2024-05-13": 15.00, "gpt-4o-mini": 0.600, "gpt-4o-mini-2024-07-18": 0.600, "gpt-4-turbo": 30.0, "gpt-3.5-turbo-0125": 1.50}
def calculate_openai_api_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
"""
Calculate OpenAI API cost based on the number of input and output tokens.
Args:
- input_tokens (int): Number of tokens in the input.
- output_tokens (int): Estimated number of tokens in the output.
- price_per_million_tokens (float): Cost per 1 million tokens (e.g., 0.02 for GPT-4).
Returns:
- float: The total API cost.
"""
total_cost_input = (input_tokens / 1000000) * price_per_million[model_name]
total_cost_output = (output_tokens / 1000000) * price_per_million_output[model_name]
total_cost = total_cost_input + total_cost_output
return round(total_cost, 6)
def remove_citations(sent):
return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")
def rerank_paragraphs_bge(query, paragraphs, reranker, norm_cite=False, start_index=0, use_abstract=False):
paragraphs = [p for p in paragraphs if p["text"] is not None]
if use_abstract is True:
paragraph_texts = [p["title"] + "\n" + p["abstract"] + "\n" + p["text"] if "title" in p and "abstract" in p else p["text"] for p in paragraphs]
else:
paragraph_texts = [p["title"] + " " + p["text"] if "title" in p and p["title"] is not None else p["text"] for p in paragraphs]
print(paragraph_texts[0])
scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100)
if type(scores) is float:
result_dic = {0: scores}
else:
result_dic = {p_id: score for p_id, score in enumerate(scores)}
if norm_cite is True and len([item["citation_counts"] for item in paragraphs if "citation_counts" in item and item["citation_counts"] is not None]) > 0:
# add normalized scores
max_citations = max([item["citation_counts"] for item in paragraphs if "citation_counts" in item and item["citation_counts"] is not None])
for p_id in result_dic:
if "citation_counts" in paragraphs[p_id] and paragraphs[p_id]["citation_counts"] is not None:
result_dic[p_id] = result_dic[p_id] + (paragraphs[p_id]["citation_counts"] / max_citations)
p_ids = sorted(result_dic.items(), key=lambda x: x[1], reverse=True)
new_orders = []
id_mapping = {}
for i, p_id in enumerate(p_ids):
new_orders.append(paragraphs[p_id[0]])
id_mapping[i] = int(p_id[0])
return new_orders, result_dic, id_mapping
def create_prompt_with_llama3_format(prompt, system_message="You are a helpful AI assistant for scientific literature review. Please carefully follow user's instruction and help them to understand the most recent papers."):
if system_message is not None:
formatted_text = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{0}<|eot_id|>".format(system_message)
else:
formatted_text = "<|begin_of_text|>"
formatted_text += "<|start_header_id|>user<|end_header_id|>\n\n" + prompt + "<|eot_id|>"
formatted_text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return formatted_text
def load_hf_tokenizer(
model_name_or_path,
tokenizer_name_or_path=None,
use_fast_tokenizer=True,
padding_side="left",
token=os.getenv("HF_TOKEN", None),
):
from transformers import AutoTokenizer
# Need to explicitly import the olmo tokenizer.
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_name_or_path
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, token=token)
except:
# some tokenizers (e.g., GPTNeoXTokenizer) don't have the slow or fast version, so we just roll back to the default one
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, token=token)
# set padding side to left for batch generation
tokenizer.padding_side = padding_side
# set pad token to eos token if pad token is not set (as is the case for llama models)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
class OpenScholar(object):
def __init__(self, model, tokenizer, client=None, api_model_name=None, use_contexts=True, top_n=8, reranker=None, min_citation=None, norm_cite=False, ss_retriever=False):
self.model = model
self.tokenizer = tokenizer
self.client = client
self.model_name = api_model_name
self.top_n = top_n
self.no_retrieval = not use_contexts
self.reranker = reranker
self.min_citation = min_citation
self.norm_cite = norm_cite
self.ss_retriever = ss_retriever
self.use_contexts = use_contexts
# Reranking: We rerank passages based on the LMs' predictions on how useful passages are.
def process_ranking_results(self, result):
ratings = {int(match.group(1)): int(match.group(2)) for match in re.finditer(r'\[(\d+)\] Rating: (\d)', result)}
return ratings
def reranking_passages_cross_encoder(self, item, batch_size=5, llama3_chat=False, task_name="default", use_abstract=False):
if self.min_citation is not None:
ctx_above_threshold = [p for p in item["ctxs"] if "citation_counts" in p and p["citation_counts"] >= self.min_citation]
if len(ctx_above_threshold) > self.top_n:
item["ctxs"] = ctx_above_threshold
print("after filtering -- number of ctxs: {0}".format(len(item["ctxs"])))
reranked_contexts, sorted_results, id_mapping = rerank_paragraphs_bge(item["input"], item["ctxs"], self.reranker, norm_cite=self.norm_cite, use_abstract=use_abstract)
return reranked_contexts, sorted_results, id_mapping
def reranking_passages_cross_encoder_supplemental(self, item, passages, batch_size=5, llama3_chat=False, task_name="default"):
if self.min_citation is not None:
ctx_above_threshold = [p for p in passages if "citation_counts" in p and p["citation_counts"] >= self.min_citation]
if len(ctx_above_threshold) > self.top_n:
passages = ctx_above_threshold
print("after filtering -- number of ctxs: {0}".format(len(passages)))
reranked_contexts, sorted_results, id_mapping = rerank_paragraphs_bge(item["input"], passages, self.reranker, norm_cite=False, start_index=len(item["ctxs"]))
return reranked_contexts, sorted_results, id_mapping
def retrieve_keywords(self, question):
prompt = [instructions.keyword_extraction_prompt.format_map({"question": question})]
if self.client is not None:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": prompt[0]},
],
temperature=0.9,
max_tokens=1000,
)
raw_output = result.choices[0].message.content
outputs = raw_output
else:
sampling_params = vllm.SamplingParams(
temperature=0.9, # greedy decoding
max_tokens=1000,
stop_token_ids=[128009]
)
outputs = self.model.generate(prompt, sampling_params)
outputs = [it.outputs[0].text for it in outputs][0]
raw_output = [t.split("[Response_End]")[0] for t in outputs.split("[Response_Start]") if "[Response_End]" in t][0] if "[Response_End]" in outputs else outputs
queries = raw_output.split(", ")[:3]
queries = [query.replace("Search queries: " , "") for query in queries if len(query) > 0]
return queries
# Generation: Generate output based on query, passages
def generate_response(self, item, max_tokens=3000, llama3_chat=False, task_name="default", zero_shot=False):
ranked_results = {}
print("zero-shot?: {}".format(zero_shot))
print(item["input"])
if self.use_contexts is False:
ctxs = []
# support more task
if task_name in instructions.task_instructions:
if zero_shot is True:
input_query = instructions.task_instructions[task_name][0] + instructions.task_instructions[task_name][1] + item["input"]
else:
demonstration = instructions.demonstrations[task_name]
input_query = instructions.task_instructions[task_name][0] + demonstration + instructions.task_instructions[task_name][1] + item["input"]
if task_name == "single_qa":
input_query = instructions.generation_instance_prompts_w_references_single_paper_no_context.format_map({"input": item["input"]})
else:
ctxs = ""
for doc_idx, doc in enumerate(item["ctxs"][:self.top_n]):
if "title" in doc and len(doc["title"]) > 0:
ctxs += "[{0}] Title: {1} Text: {2}\n".format(doc_idx, doc["title"], doc["text"])
else:
ctxs += "[{0}] {1}\n".format(doc_idx, doc["text"])
item["final_passages"] = ctxs
if task_name =="summarization":
if zero_shot is True:
input_query = instructions.prompts_w_references_summarization_zero_shot.format_map({"context": ctxs, "input": item["input"]})
else:
input_query = instructions.generation_instance_prompts_summarization.format_map({"context": ctxs, "input": item["input"]})
elif task_name == "single_qa":
if zero_shot is True:
input_query = instructions.generation_instance_prompts_w_references_single_paper_zero_shot.format_map({"context": ctxs, "input": item["input"]})
else:
input_query = instructions.generation_instance_prompts_w_references_single_paper.format_map({"context": ctxs, "input": item["input"]})
elif task_name in instructions.task_instructions:
task_instruction = instructions.task_instructions[task_name][0]
instance_header = instructions.task_instructions[task_name][1]
if zero_shot is True:
input_query = "{0}\nReferences:\n{1}\n{2}{3}".format(task_instruction, ctxs, instance_header, item["input"])
else:
demonstration = instructions.demonstrations[task_name]
input_query = "{0}{1}\nReferences:\n{2}\n{3}{4}".format(task_instruction, demonstration, ctxs, instance_header, item["input"])
else:
if zero_shot is True:
input_query = instructions.generation_instance_prompts_w_references_zero_shot.format_map({"context": ctxs, "input": item["input"]})
else:
input_query = instructions.generation_instance_prompts_w_references.format_map({"context": ctxs, "input": item["input"]})
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
if self.client is not None:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=max_tokens,
)
raw_output = result.choices[0].message.content
outputs = raw_output
cost = calculate_openai_api_cost(len(input_query.split(" ")),len(raw_output.split(" ")), self.model_name)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=max_tokens,
stop_token_ids=[128009]
)
outputs = self.model.generate([input_query], sampling_params)
outputs = [it.outputs[0].text for it in outputs][0]
cost = 0
raw_output = [t.split("[Response_End]")[0] for t in outputs.split("[Response_Start]") if "[Response_End]" in t][0] if "[Response_End]" in outputs else outputs
if "References:" in raw_output:
raw_output = raw_output.split("References:")[0]
item["output"] = raw_output
return raw_output, ctxs, cost
# Feedback: send feedback on model' predictions.
def process_feedback(self, response):
feedbacks_and_questions = re.findall(r'Feedback: (.*?)(?:Question: (.*?))?\n', response)
ratings = [(feedback.strip(), question.strip() if question else "") for feedback, question in feedbacks_and_questions]
return ratings
def get_feedback(self, item, llama3_chat):
input_query = instructions.feedback_example_instance_prompt.format_map({"question": item["input"], "passages": item["final_passages"], "answer": item["output"]})
# TODO: check if the llama3 chat format is helpful or not.
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
if self.client is not None:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=2000,
)
outputs = result.choices[0].message.content
cost = calculate_openai_api_cost(len(input_query.split(" ")),len(outputs.split(" ")), self.model_name)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=2000,
stop_token_ids=[128009]
)
outputs = self.model.generate([input_query], sampling_params)
outputs = [it.outputs[0].text for it in outputs][0]
cost = 0
raw_output = [t.split("[Response_End]")[0] for t in outputs.split("[Response_Start]") if "[Response_End]" in t][0] if "[Response_End]" in outputs else outputs
feedbacks = self.process_feedback(raw_output)
return feedbacks, cost
def edit_with_feedback(self, item, feedback, max_tokens=3000, llama3_chat=False):
input_query = instructions.editing_instance_prompt.format_map({"question": item["input"], "passages": item["final_passages"], "answer": item["output"], "feedback": feedback})
# TODO: check if the llama3 chat format is helpful or not.
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
if self.client is not None:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=max_tokens,
)
raw_output = result.choices[0].message.content
outputs = raw_output
cost = calculate_openai_api_cost(len(input_query.split(" ")),len(outputs.split(" ")), self.model_name)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=max_tokens,
stop_token_ids=[128009]
)
outputs = self.model.generate([input_query], sampling_params)
outputs = [it.outputs[0].text for it in outputs][0]
cost = 0
raw_output = [t.split("[Response_End]")[0] for t in outputs.split("[Response_Start]") if "[Response_End]" in t][0] if "[Response_End]" in outputs else outputs
print("orig answer: {}".format( item["output"]))
print("feedback: {}".format(feedback))
print("updated answer: {}".format(raw_output))
return raw_output, cost
def edit_with_feedback_retrieval(self, item, feedback, passages, passage_start_index, max_tokens=2000, llama3_chat=False):
processed_passages = ""
for doc_idx, doc in enumerate(passages[:self.top_n]):
if "title" in doc and len(doc["title"]) > 0:
processed_passages += "[{0}] Title: {1} Text: {2}\n".format(passage_start_index+doc_idx, doc["title"], doc["text"])
else:
processed_passages += "[{0}] {1}\n".format(passage_start_index+doc_idx + len(item["ctxs"]), doc["text"])
input_query = instructions.editing_with_retrieval_instance_prompt.format_map({"question": item["input"], "retrieved_passages": processed_passages, "answer": item["output"], "feedback": feedback})
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
if self.client is not None:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=3000,
)
raw_output = result.choices[0].message.content
outputs = raw_output
cost = calculate_openai_api_cost(len(input_query.split(" ")),len(outputs.split(" ")), self.model_name)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=3000,
stop_token_ids=[128009]
)
outputs = self.model.generate([input_query], sampling_params)
outputs = [it.outputs[0].text for it in outputs][0]
cost = 0
raw_output = [t.split("[Response_End]")[0] for t in outputs.split("[Response_Start]") if "[Response_End]" in t][0] if "[Response_End]" in outputs else outputs
return raw_output, cost
def insert_attributions_posthoc_paragraph(self, item, llama3_chat=False):
text = item["output"]
if "final_passages" in item:
passages = item["final_passages"]
else:
ctxs = item["ctxs"]
passages = ""
for idx, p in enumerate(ctxs):
passages += "[{0}] {1}\n".format(idx, p)
print(text)
sentences = text.split("\n")
print(sentences)
# post process sentences
updated_sentences = []
post_hoc_sentence = {}
for s_index, statement in enumerate(sentences):
if len(statement) < 10:
if len(updated_sentences) > 0 and len(statement) > 0 and statement[0] == "[":
updated_sentences[-1] = updated_sentences[-1] + " " + statement
else:
updated_sentences.append(statement)
else:
# cases where citations are included
if "[" in statement or (s_index < len(sentences) - 1 and len(sentences[s_index+1]) > 0 and sentences[s_index+1][0] == "["):
updated_sentences.append(statement)
else:
updated_sentences.append("[replace_{}]".format(s_index))
post_hoc_sentence["[replace_{}]".format(s_index)] = statement
if len(post_hoc_sentence) > 0:
print("{0} sentences require attributions, e..g, {1}".format(len(post_hoc_sentence), list(post_hoc_sentence.values())[0] ))
prompts = []
for s in list(post_hoc_sentence.values()):
input_query = instructions.posthoc_attributions_paragraph.format_map({"statement": s, "passages": passages})
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
prompts.append(input_query)
if self.client is not None:
outputs = []
for input_query in prompts:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=2000,
)
raw_output = result.choices[0].message.content
outputs.append(raw_output)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=2000,
stop_token_ids=[128009]
)
outputs = self.model.generate(prompts, sampling_params)
outputs = [it.outputs[0].text for it in outputs]
# Postprocess Output
for output, sentence_key in zip(outputs, list(post_hoc_sentence.keys())):
if len([t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t]) == 0:
post_hoc_sentence[sentence_key] = post_hoc_sentence[sentence_key]
else:
processed_output = [t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t][0]
post_hoc_sentence[sentence_key] = processed_output
final_processed_outputs = []
for item in updated_sentences:
if item in post_hoc_sentence:
final_processed_outputs.append(post_hoc_sentence[item])
else:
final_processed_outputs.append(item)
updated_sentences = final_processed_outputs
return "\n".join(updated_sentences)
def insert_attributions_posthoc(self, item, llama3_chat=False):
text = item["output"]
passages = item["final_passages"]
sentences = sent_tokenize(text)
# post process sentences
updated_sentences = []
post_hoc_sentence = {}
for s_index, statement in enumerate(sentences):
if len(statement) < 10:
if statement[0] == "[":
updated_sentences[-1] = updated_sentences[-1] + " " + statement
else:
updated_sentences.append(statement)
else:
# cases where citations are included
if "[" in statement or (s_index < len(sentences) - 1 and sentences[s_index+1][0] =="["):
updated_sentences.append(statement)
else:
updated_sentences.append("[replace_{}]".format(s_index))
post_hoc_sentence["[replace_{}]".format(s_index)] = statement
if len(post_hoc_sentence) > 0:
print("{0} sentences require attributions, e..g, {1}".format(len(post_hoc_sentence), list(post_hoc_sentence.values())[0] ))
prompts = []
for s in list(post_hoc_sentence.values()):
input_query = instructions.posthoc_attributions.format_map({"statement": s, "passages": passages})
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
prompts.append(input_query)
if self.client is not None:
outputs = []
for input_query in prompts:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=2000,
)
raw_output = result.choices[0].message.content
outputs.append(raw_output)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7, # greedy decoding
max_tokens=2000,
stop_token_ids=[128009]
)
outputs = self.model.generate(prompts, sampling_params)
outputs = [it.outputs[0].text for it in outputs]
# process_output
for output, sentence_key in zip(outputs, list(post_hoc_sentence.keys())):
if len([t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t]) == 0:
post_hoc_sentence[sentence_key] = post_hoc_sentence[sentence_key]
else:
processed_output = [t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t][0]
post_hoc_sentence[sentence_key] = processed_output
final_processed_outputs = []
for item in updated_sentences:
if item in post_hoc_sentence:
final_processed_outputs.append(post_hoc_sentence[item])
else:
final_processed_outputs.append(item)
updated_sentences = final_processed_outputs
return " ".join(updated_sentences)
def insert_attributions_posthoc_paragraph_all(self, item, llama3_chat=False):
text = item["output"]
if "final_passages" in item:
passages = item["final_passages"]
else:
ctxs = item["ctxs"]
passages = ""
for idx, p in enumerate(ctxs):
passages += "[{0}] {1}\n".format(idx, p)
sentences = text.split("\n")
print(sentences)
updated_sentences = []
post_hoc_sentence = {}
prompts = []
for s_index, statement in enumerate(sentences):
if len(statement) < 10:
if len(updated_sentences) > 0 and len(statement) > 0 and statement[0] == "[":
updated_sentences[-1] = updated_sentences[-1] + " " + statement
else:
updated_sentences.append(statement)
else:
updated_sentences.append("[replace_{}]".format(s_index))
post_hoc_sentence["[replace_{}]".format(s_index)] = statement
for s in list(post_hoc_sentence.values()):
input_query = instructions.posthoc_attributions_paragraph_all.format_map({"statement": s, "passages": passages})
if llama3_chat is True:
input_query = create_prompt_with_llama3_format(input_query)
prompts.append(input_query)
if self.client is not None:
outputs = []
cost = 0
for input_query in prompts:
result = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user",
"content": input_query},
],
temperature=0.7,
max_tokens=1000,
)
raw_output = result.choices[0].message.content
outputs.append(raw_output)
cost += calculate_openai_api_cost(len(input_query.split(" ")),len(raw_output.split(" ")), self.model_name)
else:
sampling_params = vllm.SamplingParams(
temperature=0.7,
max_tokens=1000,
stop_token_ids=[128009]
)
outputs = self.model.generate(prompts, sampling_params)
outputs = [it.outputs[0].text for it in outputs]
cost = 0
# process_output
for output, sentence_key in zip(outputs, list(post_hoc_sentence.keys())):
if len([t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t]) == 0:
post_hoc_sentence[sentence_key] = post_hoc_sentence[sentence_key]
else:
processed_output = [t.split("[Response_End]")[0] for t in output.split("[Response_Start]") if "[Response_End]" in t][0]
post_hoc_sentence[sentence_key] = processed_output
final_processed_outputs = []
for item in updated_sentences:
if item in post_hoc_sentence:
final_processed_outputs.append(post_hoc_sentence[item])
else:
final_processed_outputs.append(item)
updated_sentences = final_processed_outputs
return "\n".join(updated_sentences), cost
def run(self, item, ranking_ce=False, use_feedback=False, skip_generation=False, posthoc_at=False, llama3_chat=False, task_name="default", zero_shot=False, max_per_paper=None, use_abstract=False, max_tokens=3000):
print("llama3 chat format? {0}".format(llama3_chat))
print("use feedback: {}".format(use_feedback))
total_cost = 0
if ranking_ce is True:
item["ctxs"], ranked_results, id_mapping = self.reranking_passages_cross_encoder(item, batch_size=1, llama3_chat=llama3_chat, task_name=task_name, use_abstract=False)
item["ranked_results"] = ranked_results
item["id_mapping"] = id_mapping
if max_per_paper is not None:
filtered_ctxs = []
title_to_count = {}
for ctx in item["ctxs"]:
if "title" not in ctx or ctx["title"] is None:
ctx["title"] = ""
title_to_count.setdefault(ctx["title"], 0)
if title_to_count[ctx["title"]] > max_per_paper:
# print("We have already aded the paper {0} {1} times".format(ctx["title"], max_per_paper))
continue
else:
filtered_ctxs.append(ctx)
title_to_count[ctx["title"]] += 1
item["ctxs"] = filtered_ctxs
if skip_generation is False:
generated_result, passages, gen_cost = self.generate_response(item, max_tokens=max_tokens, llama3_chat=llama3_chat, task_name=task_name, zero_shot=zero_shot)
if "\n\n References":
generated_result = generated_result.split("\n\n References")[0]
item["initial_result"] = generated_result
total_cost += gen_cost
if use_feedback is True:
print("generating feedback")
feedbacks, feedback_cost = self.get_feedback(item, llama3_chat=llama3_chat)[:3]
total_cost += feedback_cost
item["feedbacks"] = feedbacks
for feedback_idx, feedback in tqdm(enumerate(feedbacks[:3])):
# currently only supports non retrieval feedback
if len(feedback[1]) == 0:
edited_answer, edited_cost = self.edit_with_feedback(item, feedback[0], llama3_chat=llama3_chat)
if "Here is the revised answer:\n\n" in edited_answer:
edited_answer = edited_answer.split("Here is the revised answer:\n\n")[1]
total_cost += edited_cost
if len(item["output"]) > 0 and len(edited_answer) / len(item["output"]) > 0.9:
item["output"] = edited_answer
item["edited_answer_{}".format(feedback_idx)] = edited_answer
else:
print("skipping as edited answers got too short")
else:
new_papers = []
# new_papers = retrieve_pes2o_passages(feedback[1], 20, "pes2o")
print("web searched papers: {}".format(len(new_papers)))
if self.ss_retriever is True:
new_keywords = self.retrieve_keywords(feedback[1])
paper_list = {}
if len(new_keywords) > 0:
for keyword in new_keywords:
top_papers = search_paper_via_query(keyword)
print(top_papers)
if top_papers is None:
print(keyword)
else:
for paper in top_papers:
if paper["paperId"] not in paper_list:
paper["text"] = paper["abstract"]
paper["citation_counts"] = paper["citationCount"]
paper_list[paper["paperId"]] = paper
new_papers += list(paper_list.values())
# remove duplicarted data
if len(new_papers) > 0:
print("before deduplication: {}".format(len(new_papers)))
new_papers_dicts = {paper["text"][:100] + paper["title"]: paper for paper in new_papers if paper is not None and type(paper["text"]) is str}
new_papers = list(new_papers_dicts.values())
print("after deduplication: {}".format(len(new_papers)))
# add new papers when and only when we have the new papers.
if len(new_papers) > 0:
new_passages_reranked, _ , _ = self.reranking_passages_cross_encoder_supplemental(item, new_papers, batch_size=10, llama3_chat=llama3_chat, task_name=task_name)
passages_start_index = len(item["ctxs"])
edited_answer, edited_cost = self.edit_with_feedback_retrieval(item, feedback[0], new_passages_reranked, passages_start_index)
total_cost += edited_cost
if len(item["output"]) > 0 and len(edited_answer) / len(item["output"]) > 0.9:
item["ctxs"] += new_passages_reranked[:self.top_n]
item["edited_answer_{}".format(feedback_idx)] = edited_answer
item["output"] = edited_answer
item["edited_answer_{}".format(feedback_idx)] = edited_answer
elif len(item["output"]) == 0 and len(edited_answer) > 0:
item["ctxs"] += new_passages_reranked[:self.top_n]
item["edited_answer_{}".format(feedback_idx)] = edited_answer
item["output"] = edited_answer
item["edited_answer_{}".format(feedback_idx)] = edited_answer
else:
print("skipping as edited answers got too short")
if posthoc_at is True:
# attributed_results = self.insert_attributions_posthoc(item, llama3_chat=llama3_chat)
# attributed_results = self.insert_attributions_posthoc_paragraph(item, llama3_chat=llama3_chat)
attributed_results, attributed_cost = self.insert_attributions_posthoc_paragraph_all(item, llama3_chat=llama3_chat)
total_cost += attributed_cost
item["output"] = attributed_results
item["output"] = item["output"].replace("[Response_Start]", "").replace("[Response_End]", "")
print(item["output"])
if "\n### References" in item["output"]:
item["output"] = item["output"].split("\n### References")[0]
return item, total_cost
def process_paragraph(text):
text = text.replace("<cit.>", "")
text = remove_citations(text)
return text
def process_input_data(data, use_contexts=True):
processed_data = []
for item in data:
if "answer" not in item:
item["answer"] = ""
if "input" not in item:
if "question" in item:
item["input"] = item["question"]
if "query" in item:
item["input"] = item["query"]
new_ctxs = []
if use_contexts is True:
# normalize ctx format for different retrieval APIs
for ctx in item["ctxs"]:
if type(ctx) is list:
for c in ctx:
if type(c) is dict:
new_ctxs.append(c)
if type(ctx) is dict:
new_ctxs.append(ctx)
item["ctxs"] = new_ctxs
# remove duplicated contexts
processed_paras = []
for ctx in tqdm(item["ctxs"]):
if "retrieval text" in ctx:
ctx["text"] = ctx["retrieval text"]
if ctx["text"] is None or len(ctx["text"]) ==0:
continue
if type(ctx["text"]) != str:
ctx["text"] = " ".join(ctx["text"]["contexts"])
ctx["text"] = process_paragraph(ctx["text"])
if "title" not in ctx:
ctx["title"] = ""
processed_paras.append(ctx)
processed_paras_dicts = {paper["text"][:100] + paper["title"]: paper for paper in processed_paras}
processed_paras = list(processed_paras_dicts.values())
item["ctxs"] = processed_paras
item["original_ctxs"] = processed_paras
processed_data.append(item)
return processed_data