-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_utils.py
40 lines (31 loc) · 1.32 KB
/
llm_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
def get_score_from_generation(generation, get_dist=False):
scores = generation["scores"][0]
if get_dist:
return torch.softmax(scores, dim=-1)
else:
return scores
def generate_next_token(
input: list, model: AutoModelForCausalLM, tokenizer: AutoTokenizer, answer_prepend=[], past_key_values=None
):
device = model.device
tokenizer.pad_token = tokenizer.eos_token
model_inputs = tokenizer.apply_chat_template(input, add_generation_prompt=True, tokenize=False)
model_inputs = tokenizer(model_inputs, padding=True, add_special_tokens=False)
for token in answer_prepend:
model_inputs["input_ids"].append(token)
model_inputs["attention_mask"].append(1)
model_inputs = model_inputs.convert_to_tensors(tensor_type="pt", prepend_batch_axis=True)
model_inputs = model_inputs.to(device)
generation_config = GenerationConfig(
do_sample=False,
max_new_tokens=1,
pad_token_id=tokenizer.eos_token_id,
output_scores=True,
output_logits=True,
return_dict_in_generate=True,
use_cache=True,
)
generation = model.generate(**model_inputs, generation_config=generation_config, past_key_values=past_key_values)
return generation