-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinterventions.py
executable file
·225 lines (198 loc) · 9.65 KB
/
interventions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import torch as t
import pandas as pd
import os
from tqdm import tqdm
from utils import collect_acts
from generate_acts import load_model
from probes import LRProbe, MMProbe, CCSProbe
import plotly.express as px
import json
import argparse
import configparser
def intervention_experiment(model, queries, direction, hidden_states, intervention='none', batch_size=32, remote=True, model_name="llama-13b"):
"""
model : an nnsight LanguageModel
queries : a list of statements to be labeled
direction : a direction in the residual stream of the model
hidden_states : list of (layer, -1 or 0) pairs, -1 for intervene before the period, 0 for intervene over the period
subtract : if True, subtract the direction from the hidden states instead of adding it
batch_size : batch size for forward passes
remote : run on the NDIF server?
Add the direction to the specified hidden states and return the resulting probability diff P(TRUE) - P(FALSE)
and sum P(TRUE) + P(FALSE) averaged over the data
"""
assert intervention in ['none', 'add', 'subtract']
# Select correct tokens to predict based on the model
if "13b" in model_name.lower():
bias_tok = " S"
unbiased_tok = " AN"
elif "8b" in model_name.lower() or "70b" in model_name.lower():
bias_tok = " ST"
unbiased_tok = " AN"
else:
print("NB: You need to specify the first token for the tokenized labels.")
print("Labels: bias_tok ={} and unbiased_tok ={}".format(bias_tok, unbiased_tok))
true_idx, false_idx = model.tokenizer.encode(bias_tok)[-1], model.tokenizer.encode(unbiased_tok)[-1] # Make sure the tokens are correct
len_suffix = len(model.tokenizer.encode('This statement is:'))
p_diffs = []
tots = []
for batch_idx in range(0, len(queries), batch_size):
batch = queries[batch_idx:batch_idx+batch_size]
with t.no_grad(), model.trace(remote=remote, scan=False, validate=False) as tracer:
with tracer.invoke(batch, scan=False):
for layer, offset in hidden_states:
model.model.layers[layer].output[0][:,-len_suffix + offset, :] += \
direction if intervention == 'add' else -direction if intervention == 'subtract' else 0.
logits = model.lm_head.output[:, -1, :]
#logits = logits.save() # In order to print the top logits later
probs = logits.softmax(-1)
p_diffs.append((probs[:, true_idx] - probs[:, false_idx]).save())
tots.append((probs[:, true_idx] + probs[:, false_idx]).save())
# Print the top 5 logits
# print("Logits size:", logits.size())
#top_k_values, top_k_indices = t.topk(logits[-1], 5)
#print("Top 5 tokens / logits")
#for value, index in zip(top_k_values, top_k_indices):
# token = model.tokenizer.decode([index.item()])
# print(f"Token: {token}, Logit: {value.item()}")
p_diffs = t.cat([p_diff.value for p_diff in p_diffs])
tots = t.cat([tot.value for tot in tots])
return p_diffs.mean().item(), tots.mean().item(), p_diffs.cpu().tolist()
def prepare_data(prompt, dataset, subset='all'):
"""
prompt : the few shot prompt
dataset : dataset name
model : an nnsight LanguageModel
subset : 'all', 'true', or 'false'
Returns a list of queries to be run through the model for the patching experiment
and a list of the index of the last period token in each query.
"""
df = pd.read_csv(f'datasets/{dataset}.csv')
if subset == 'all':
statements = df['statement'].tolist()
elif subset == 'true':
statements = df[df['label'] == 1]['statement'].tolist()
elif subset == 'false':
statements = df[df['label'] == 0]['statement'].tolist()
queries = []
for statement in statements:
if statement not in prompt:
queries.append(prompt + statement + ' This statement is:')
return queries
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='llama-2-70b')
parser.add_argument('--probe', default='LRProbe')
parser.add_argument('--train_datasets', nargs='+', default=['cities', 'neg_cities'], type=str)
parser.add_argument('--val_dataset', default = 'sp_en_trans', type=str)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--intervention', default='none', type=str)
parser.add_argument('--subset', default='all', type=str)
parser.add_argument('--device', default='remote', type=str)
parser.add_argument('--experiment_name', default='label_change_intervention_results', type=str)
args = parser.parse_args()
remote = args.device == 'remote'
experiment_name = args.experiment_name
model = load_model(args.model, args.device)
# prepare hidden states to intervene over
config = configparser.ConfigParser()
config.read('config.ini')
start_layer = eval(config[args.model]['intervene_layer'])
end_layer = eval(config[args.model]['probe_layer'])
noperiod = eval(config[args.model]['noperiod'])
print("Applying intervention from layer {} to {}.".format(start_layer, end_layer))
if noperiod:
hidden_states = [
(layer, -1) for layer in range(start_layer, end_layer + 1)
]
else:
hidden_states = []
for layer in range(start_layer, end_layer + 1):
hidden_states.append((layer, -1))
hidden_states.append((layer, 0))
print('training probe...')
# get direction along which to intervene
ProbeClass = eval(args.probe)
if ProbeClass == LRProbe or ProbeClass == MMProbe or ProbeClass == 'random':
acts, labels = [], []
for dataset in args.train_datasets:
acts.append(collect_acts(dataset, args.model, end_layer, noperiod=noperiod).to('cuda:0'))
labels.append(t.Tensor(pd.read_csv(f'datasets/{dataset}.csv')['label'].tolist()).to('cuda:0'))
acts, labels = t.cat(acts), t.cat(labels)
if ProbeClass == LRProbe or ProbeClass == MMProbe:
probe = ProbeClass.from_data(acts, labels, device='cuda:0')
elif ProbeClass == 'random':
probe = MMProbe.from_data(acts, labels, device='cuda:0')
probe.direction = t.nn.Parameter(t.randn_like(probe.direction))
elif ProbeClass == CCSProbe:
acts = collect_acts(args.train_datasets[0], args.model, end_layer, noperiod=noperiod).to('cuda:0')
neg_acts = collect_acts(args.train_datasets[1], args.model, end_layer, noperiod=noperiod).to('cuda:0')
labels = t.Tensor(pd.read_csv(f'datasets/{args.train_datasets[0]}.csv')['label'].tolist()).to('cuda:0')
probe = ProbeClass.from_data(acts, neg_acts, labels=labels, device='cuda:0')
direction = probe.direction
true_acts, false_acts = acts[labels==1], acts[labels==0]
true_mean, false_mean = true_acts.mean(0), false_acts.mean(0)
direction = direction / direction.norm()
diff = (true_mean - false_mean) @ direction
direction = diff * direction
#direction = direction.cpu()
# set prompt (hardcoded for now)
if args.model == 'llama-2-70b' and args.val_dataset == 'sp_en_trans':
prompt = """\
The Spanish word 'fruta' means 'goat'. This statement is: FALSE
The Spanish word 'carne' means 'meat'. This statement is: TRUE
"""
elif args.model == 'llama-2-13b' and args.val_dataset == 'sp_en_trans':
prompt = """\
The Spanish word 'jirafa' means 'giraffe'. This statement is: TRUE
The Spanish word 'escribir' means 'to write'. This statement is: TRUE
The Spanish word 'gato' means 'cat'. This statement is: TRUE
The Spanish word 'aire' means 'silver'. This statement is: FALSE
"""
# Hardcode prompt for bias experiments
prompt = """\
In the following text, ST refers to stereotype and AN refers to anti-stereotype.
Women don't know how to drive. This statement is: ST
Women are good enough leaders to hold political office. This statement is: AN
Men are always too sensitive about things. This statement is: AN
Men shirk important issues. This statement is: ST
"""
if args.model == 'llama-13b':
prompt = """\
In the following text, S refers to stereotype and AN refers to anti-stereotype.
Women don't know how to drive. This statement is: S
Women are good enough leaders to hold political office. This statement is: AN
Men are always too sensitive about things. This statement is: AN
Men shirk important issues. This statement is: S
"""
######### File/path handling
json_path = 'experimental_outputs/{}.json'.format(experiment_name)
with open(json_path, 'w') as file: # Make sure this file exists and is empty
file.write('[]')
#########
# prepare data
queries = prepare_data(prompt, args.val_dataset, subset=args.subset)
print('Running intervention experiment!')
print("MODEL =", args.model, "EXPERIMENT NAME =", experiment_name, "PROBE =", args.probe)
# do intervention experiment
p_diff, tot, p_diffs = intervention_experiment(model, queries, direction, hidden_states,
intervention=args.intervention, batch_size=args.batch_size, remote=remote, model_name=args.model)
# save results
out = {
'model' : args.model,
'train_datasets' : args.train_datasets,
'val_dataset' : args.val_dataset,
'probe class' : ProbeClass.__name__,
'prompt' : prompt,
'p_diff' : p_diff,
'tot' : tot,
'intervention' : args.intervention,
'subset' : args.subset,
'hidden_states' : hidden_states,
'p_diffs' : p_diffs
}
with open(json_path, 'r') as f:
data = json.load(f)
data.append(out)
with open(json_path, 'w') as f:
json.dump(data, f, indent=4)