-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_v2.py
162 lines (131 loc) · 5.27 KB
/
model_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# Einsum does matrix multiplication for query * keys for each training example
# with every other training example, don't be confused by einsum
# it's just a way to do batch matrix multiplication
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Mask padded indices so their weights become 0
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class FeedForward(nn.Module):
def __init__(self, embed_size, hidden_size):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(embed_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, embed_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = FeedForward(embed_size, forward_expansion * embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
# Add skip connection, run through normalization and finally dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class HarmonizeTransformer(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
num_layers,
num_heads,
dropout,
max_length,
forward_expansion,
):
super(HarmonizeTransformer, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.embed_size = hidden_dim
self.word_embedding = nn.Embedding(input_dim, hidden_dim)
self.position_embedding = nn.Embedding(max_length, hidden_dim)
self.layers = nn.ModuleList(
[
TransformerBlock(
hidden_dim,
num_heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
self.fc_out = nn.Linear(hidden_dim, 128)
def forward(self, x, mask=None):
N, sequence_length, chord_size = x.size()
x = x.view(N, sequence_length * chord_size)
positions = (
torch.arange(0, sequence_length * chord_size)
.expand(N, sequence_length * chord_size)
.to(self.device)
)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out, mask)
# Use the output of the first token for classification or apply pooling
out = out.mean(dim=1) # Average pooling over the sequence
out = self.fc_out(out)
return out
class NotePredictor(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
num_layers,
num_heads,
dropout,
max_length,
forward_expansion,
):
super(NotePredictor, self).__init__()
self.harmonize_transformer = HarmonizeTransformer(
input_dim,
hidden_dim,
num_layers,
num_heads,
dropout,
max_length,
forward_expansion,
)
def forward(self, x, mask=None):
x = self.harmonize_transformer(x, mask)
return torch.sigmoid(x) # Output probabilities for each note