-
Notifications
You must be signed in to change notification settings - Fork 1
/
sensitivity_v_accuracy.py
186 lines (144 loc) · 6.65 KB
/
sensitivity_v_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import random
import string
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
from tqdm import tqdm
import json
import random
import torch.nn.functional as F
import os
from semantic_uncertainty.calc_entropy import get_entropy_from_probabilities
from question_loader import *
import pickle
from utils import *
##### SETTINGS #####
cache_dir = '/tmp'
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
possible_outputs = ["A", "B", "C", "D", "E", "F", "G", "H"]
batch_size = 8
redownload = False
data_outpath = './data/edit_distance_v_accuracy_K_10'
######################
if redownload:
model_cache_path = os.path.join(cache_dir, model_name)
if os.path.exists(model_cache_path):
os.rmdir(model_cache_path)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir, torch_dtype=torch.float16, device_map="auto")
tokenizer.pad_token = tokenizer.eos_token
def torch_to_numpy(torch_tensor):
return torch_tensor.detach().cpu().numpy()
def get_next_token(prompt_batch, top_k=len(possible_outputs)):
inputs = tokenizer(prompt_batch, padding = True, return_tensors="pt").to(model.device)
allowed_tokens = tokenizer.convert_tokens_to_ids(possible_outputs)
logits_bias = torch.full((len(prompt_batch), model.config.vocab_size), -float('inf')).to(model.device)
logits_bias[:, allowed_tokens] = 0
# print("Shape of input_ids:", inputs.input_ids.shape)
# print("Shape of attention_mask:", inputs.attention_mask.shape)
with torch.no_grad():
outputs = model(**inputs)
# Print shape of model output logits
# print("Shape of model output logits:", outputs.logits.shape)
next_token_logits = outputs.logits[:, -1, :] + logits_bias
# Print shape of next_token_logits
# print("Shape of next_token_logits:", next_token_logits.shape)
probs = F.softmax(next_token_logits, dim=-1)
# Print shape of probs
# print("Shape of probs:", probs.shape)
top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=-1)
# Print shapes of top_k results
# print("Shape of top_k_indices:", top_k_indices.shape)
# print("Shape of top_k_probs:", top_k_probs.shape)
top_k_responses = [tokenizer.convert_ids_to_tokens(top_k_indices[i]) for i in range(len(prompt_batch))]
return top_k_responses, torch_to_numpy(top_k_probs)
def perturb_input(input_text, edit_distance):
"""
Randomly perturb the input text by the specified edit distance.
"""
chars = list(input_text)
for _ in range(edit_distance):
operation = random.choice(['insert', 'delete', 'substitute'])
if operation == 'insert':
pos = random.randint(0, len(chars))
chars.insert(pos, random.choice(string.ascii_letters))
elif operation == 'delete' and chars:
pos = random.randint(0, len(chars) - 1)
chars.pop(pos)
elif operation == 'substitute' and chars:
pos = random.randint(0, len(chars) - 1)
chars[pos] = random.choice(string.ascii_letters)
return ''.join(chars)
def calculate_sensitivity(original_response, perturbed_responses):
"""
Calculate sensitivity based on the original response and perturbed responses.
"""
different_responses = sum(1 for resp in perturbed_responses if resp != original_response)
return different_responses / len(perturbed_responses)
# Main execution
tot_questions = get_data_len()
# K = 10 # Edit distance for perturbation
K = [5, 10, 15, 20, 25]
n_samples = 5000
num_perturbations = 500 # Number of perturbations per input
res = load_data(data_outpath)
print("Loaded data from", data_outpath, "current rows:", len(res))
correct_count = 0
with tqdm(total=n_samples) as pbar:
for iter in range(n_samples):
row = random.randrange(tot_questions)
cur_prompt = get_row_query(row)
# Get original response
original_response, original_probs = get_next_token([cur_prompt])
original_ans = original_response[0][0]
# Check accuracy
cor_ans = get_correct_answer(row)
is_correct = original_ans == cor_ans
if is_correct:
correct_count += 1
# Calculate entropy
entropy = get_entropy_from_probabilities(original_probs[0])
# Update progress bar
pbar.set_postfix({'Correct %': f'{(correct_count / (iter + 1)) * 100:.2f}%'})
pbar.update(1)
mp = {
"row": row,
"original_entropy": entropy,
"is_correct": is_correct,
"model_prob": original_probs[0].tolist(),
"model_response": original_response[0]
}
for k in K:
perturbed_prompts = [get_peturbed_row_query(row, k) for _ in range(num_perturbations)]
# # Get responses for perturbed inputs --> memory error
# perturbed_responses, peturbed_probs = get_next_token(perturbed_prompts)
res_peturbs = [get_next_token([prompt]) for prompt in perturbed_prompts]
perturbed_responses = [res_peturb[0] for res_peturb in res_peturbs]
peturbed_probs = [res_peturb[1] for res_peturb in res_peturbs]
tot_correct = 0
tot_same = 0
peturbed_entropies = []
for i in range(len(perturbed_responses)):
cur_ans = perturbed_responses[i][0][0]
# print(peturbed_probs, original_ans)
if cur_ans == original_ans:
tot_same += 1
if cur_ans == cor_ans:
tot_correct += 1
peturbed_entropies.append(get_entropy_from_probabilities(peturbed_probs[i]))
sensitivity_correct = tot_correct / num_perturbations
sensitivity_same = tot_same / num_perturbations
avg_entropy = sum(peturbed_entropies) / num_perturbations
assert sensitivity_correct >= 0
assert sensitivity_same >= 0
mp[f"peturbed_entropy_{k}"] = avg_entropy
mp[f"same_sensitivity_{k}"] = sensitivity_same
mp[f"correct_sensitivity_{k}"] = sensitivity_correct
print(mp)
res.append(mp)
if iter % 50 == 0:
print("Iteration:", iter, "saved to", data_outpath)
# Save results
dump_data(res, data_outpath)
print(res)
dump_data(res, data_outpath)