-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathIntegratedGradientsTransformer.py
135 lines (102 loc) · 5.97 KB
/
IntegratedGradientsTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from SentenceTransformerWrapper import SentenceTransformerWrapper
from torch import nn
import torch
from sentence_transformers.util import dot_score, cos_sim
import sys
sys.path.append("../master_thesis_ai")
from gpl_improved.utils import load_sbert
from captum.attr import IntegratedGradients, LayerIntegratedGradients
from beir.datasets.data_loader import GenericDataLoader
from utils import beir_path, concat_title_and_body
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class IRWrapperQuery(nn.Module):
def __init__(self, query_model, doc_model, pooler, doc_input_ids, doc_attention_mask):
super(IRWrapperQuery, self).__init__()
self.pooler = pooler
# We need both query model and the doc model for Integrated gradients to work.
self.query_model = query_model
self.doc_model = doc_model
# Give the document already. Now we need to see how query interacts with the document
self.d = doc_input_ids
self.d_att = doc_attention_mask
def forward_with_features(self, features, model):
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]
output_states = model(**trans_features, return_dict=False)
output_tokens = output_states[0]
trans_features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
if model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
trans_features.update({"all_layer_embeddings": hidden_states})
return self.pooler(trans_features)['sentence_embedding']
def forward(self, query_input_ids, query_attention_mask):
features_doc = {'input_ids': self.d, 'attention_mask' : self.d_att}
features_query = {'input_ids': query_input_ids, 'attention_mask': query_attention_mask}
# Get the query features, and pass it to the model
q_emb = self.forward_with_features(features_query, self.query_model)
# Get the document features, and pass it to the model.
doc_emb = self.forward_with_features(features_doc, self.doc_model)
score = dot_score(doc_emb, q_emb)
return score.diagonal()
class IRWrapperDoc(nn.Module):
def __init__(self, query_model, doc_model, pooler, query_input_ids, query_attention_mask):
super(IRWrapperDoc, self).__init__()
self.pooler = pooler
self.query_model = query_model
self.doc_model = doc_model
self.q = query_input_ids
self.q_att = query_attention_mask
def forward_with_features(self, features, model):
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]
output_states = model(**trans_features, return_dict=False)
output_tokens = output_states[0]
trans_features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
if model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
trans_features.update({"all_layer_embeddings": hidden_states})
return self.pooler(trans_features)['sentence_embedding']
def forward(self, doc_input_ids , doc_attention_mask):
# sourcery skip: inline-immediately-returned-variable
features_query = {'input_ids': self.q, 'attention_mask': self.q_att}
features_doc = {'input_ids': doc_input_ids, 'attention_mask' : doc_attention_mask}
q_emb = self.forward_with_features(features_query, self.query_model)
doc_emb = self.forward_with_features(features_doc, self.doc_model)
score = dot_score(q_emb, doc_emb)
return score.diagonal()
class IRWrapper(nn.Module):
def __init__(self, query_model, doc_model, pooler):
super(IRWrapper, self).__init__()
self.pooler = pooler
self.query_model = query_model
self.doc_model = doc_model
def forward_with_features(self, features, model):
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
if "token_type_ids" in features:
trans_features["token_type_ids"] = features["token_type_ids"]
output_states = model(**trans_features, return_dict=False)
output_tokens = output_states[0]
trans_features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
if model.config.output_hidden_states:
all_layer_idx = 2
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1
hidden_states = output_states[all_layer_idx]
trans_features.update({"all_layer_embeddings": hidden_states})
return self.pooler(trans_features)['sentence_embedding']
def forward(self, query_input_ids, doc_input_ids ,query_attention_mask, doc_attention_mask):
# sourcery skip: inline-immediately-returned-variable
features_query = {'input_ids': query_input_ids, 'attention_mask': query_attention_mask}
features_doc = {'input_ids': doc_input_ids, 'attention_mask' : doc_attention_mask}
q_emb = self.forward_with_features(features_query, self.query_model)
doc_emb = self.forward_with_features(features_doc, self.doc_model)
score = dot_score(q_emb, doc_emb)
return score