-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
118 lines (95 loc) · 4.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super(MultiHeadAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert (
self.head_dim * heads == embed_size
), "Embedding size needs to be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# Split the embedding into self.heads different pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# Einsum does matrix multiplication for query * keys for each training example
# with every other training example, don't be confused by einsum
# it's just a way to do batch matrix multiplication
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# Mask padded indices so their weights become 0
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
out = self.fc_out(out)
return out
class FeedForward(nn.Module):
def __init__(self, embed_size, hidden_size):
super(FeedForward, self).__init__()
self.fc1 = nn.Linear(embed_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, embed_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, dropout, forward_expansion):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(embed_size, heads)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = FeedForward(embed_size, forward_expansion * embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, value, key, query, mask):
attention = self.attention(value, key, query, mask)
# Add skip connection, run through normalization and finally dropout
x = self.dropout(self.norm1(attention + query))
forward = self.feed_forward(x)
out = self.dropout(self.norm2(forward + x))
return out
class MusicTransformer(nn.Module):
def __init__(self, chord_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
super(MusicTransformer, self).__init__()
self.embed_size = embed_size
self.device = device
self.word_embedding = nn.Embedding(chord_vocab_size, embed_size)
self.position_embedding = nn.Embedding(max_length, embed_size)
self.layers = nn.ModuleList(
[
TransformerBlock(
embed_size,
heads,
dropout=dropout,
forward_expansion=forward_expansion,
)
for _ in range(num_layers)
]
)
self.dropout = nn.Dropout(dropout)
# Output layer to predict two notes (assuming each note is represented by a single number)
self.fc_out = nn.Linear(embed_size, 128)
def forward(self, x, mask):
N, sequence_length, chord_size = x.size()
x = x.view(N, sequence_length * chord_size)
positions = torch.arange(0, sequence_length * chord_size).expand(N, sequence_length * chord_size).to(self.device)
out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))
for layer in self.layers:
out = layer(out, out, out, mask)
# Use the output of the first token for classification or apply pooling
out = out.mean(dim=1) # Average pooling over the sequence
out = self.fc_out(out)
return out