-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
156 lines (129 loc) · 4.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
from pathlib import Path
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM
import torch
import torch.nn as nn
class GPTPromptTuningMixin:
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
soft_prompt_path: str = None,
n_tokens: int = None,
initialize_from_vocab: bool = True,
random_range: float = 0.5,
**kwargs,
):
model = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
# Make sure to freeze Tranformers model
for param in model.parameters():
param.requires_grad = False
if soft_prompt_path is not None:
model.set_soft_prompt_embeds(soft_prompt_path)
elif n_tokens is not None:
print("Initializing soft prompt...")
model.initialize_soft_prompt(
n_tokens=n_tokens,
initialize_from_vocab=initialize_from_vocab,
random_range=random_range,
)
return model
def set_soft_prompt_embeds(
self,
soft_prompt_path: str,
) -> None:
"""
Args:
soft_prompt_path: torch soft prompt file path
"""
self.soft_prompt = torch.load(
soft_prompt_path, map_location=torch.device("cpu")
)
self.n_tokens = self.soft_prompt.num_embeddings
print(f"Set soft prompt! (n_tokens: {self.n_tokens})")
def initialize_soft_prompt(
self,
n_tokens: int = 20,
initialize_from_vocab: bool = True,
random_range: float = 0.5,
) -> None:
self.n_tokens = n_tokens
if initialize_from_vocab:
init_prompt_value = self.transformer.wte.weight[:n_tokens].clone().detach()
else:
init_prompt_value = torch.FloatTensor(2, 10).uniform_(
-random_range, random_range
)
self.soft_prompt = nn.Embedding(n_tokens, 1536)
# Initialize weight
self.soft_prompt.weight = nn.parameter.Parameter(init_prompt_value)
def _cat_learned_embedding_to_input(self, input_ids) -> torch.Tensor:
inputs_embeds = self.transformer.wte(input_ids)
if len(list(inputs_embeds.shape)) == 2:
inputs_embeds = inputs_embeds.unsqueeze(0)
# [batch_size, n_tokens, n_embd]
learned_embeds = self.soft_prompt.weight.repeat(inputs_embeds.size(0), 1, 1)
inputs_embeds = torch.cat([learned_embeds, inputs_embeds], dim=1)
return inputs_embeds
def _extend_labels(self, labels, ignore_index=-100) -> torch.Tensor:
if len(list(labels.shape)) == 1:
labels = labels.unsqueeze(0)
n_batches = labels.shape[0]
return torch.cat(
[
torch.full((n_batches, self.n_tokens), ignore_index).to(self.device),
labels,
],
dim=1,
)
def _extend_attention_mask(self, attention_mask):
if len(list(attention_mask.shape)) == 1:
attention_mask = attention_mask.unsqueeze(0)
n_batches = attention_mask.shape[0]
return torch.cat(
[torch.full((n_batches, self.n_tokens), 1).to(self.device), attention_mask],
dim=1,
)
def save_soft_prompt(self, path: str, filename: str = "soft_prompt.model"):
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(self.soft_prompt, os.path.join(path, filename))
# print(f"Saved soft prompt: {os.path.join(path, filename)}")
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
if input_ids is not None:
inputs_embeds = self._cat_learned_embedding_to_input(input_ids).to(
self.device
)
if labels is not None:
labels = self._extend_labels(labels).to(self.device)
if attention_mask is not None:
attention_mask = self._extend_attention_mask(attention_mask).to(self.device)
# Drop most of the args for now
return super().forward(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
return_dict=return_dict,
)
class GPT2PromptTuningLM(GPTPromptTuningMixin, GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
class GPTNeoPromptTuningLM(GPTPromptTuningMixin, GPTNeoForCausalLM):
def __init__(self, config):
super().__init__(config)