Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The incorrect implementation of multi-head attention! #17

Open
ZhuYun97 opened this issue Jun 20, 2022 · 0 comments
Open

The incorrect implementation of multi-head attention! #17

ZhuYun97 opened this issue Jun 20, 2022 · 0 comments

Comments

@ZhuYun97
Copy link

ZhuYun97 commented Jun 20, 2022

Assuming the number of attention heads is 4. I find that self-attention is computed between different heads rather than different atoms. The attention scores shape is (num_atoms, 4, 4, 4) which should be (batch_size, max_num_atoms, max_num_atoms). The flattened atom features (num_atoms, node_fdim) should be processed into padded batch data(batch_size, max_num_atoms, node_fdim).

For details, I only extract the main codes there to illustrate why the self-attention is computed between different heads rather than different atoms.

For MTBlock class:

# in the __init__ function
  self.attn = MultiHeadedAttention(h=num_attn_head,
                                   d_model=self.hidden_size,
                                   bias=bias,
                                   dropout=dropout)
for _ in range(num_attn_head):
            self.heads.append(Head(args, hidden_size=hidden_size, atom_messages=atom_messages))

# in the forward function
for head in self.heads:
            q, k, v = head(f_atoms, f_bonds, a2b, a2a, b2a, b2revb)
            queries.append(q.unsqueeze(1))
            keys.append(k.unsqueeze(1))
            values.append(v.unsqueeze(1))

queries = torch.cat(queries, dim=1) # (num_atoms, 4, hidden_size)
keys = torch.cat(keys, dim=1) # (num_atoms, 4, hidden_size)
values = torch.cat(values, dim=1) # (num_atoms, 4, hidden_size)

x_out = self.attn(queries, keys, values, past_key_value)  # multi-headed attention

Now, the queries, keys and values will be fed into multi-head attention to get new results.
For MultiHeadedAttention class:

# in the __init__ function
self.attention = Attention()
self.d_k = d_model // h # equals hidden_size // num_attn_head
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])  # why 3: query, key, value

# in the forward function
    # 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                     for l, x in zip(self.linear_layers, (query, key, value))] # q, k, v 's shape will be (num_bonds, 4, 4, d_k)
x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)

For the Attention class:

class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product SelfAttention
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1)) # scores shape is (num_atoms, 4, 4, 4)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        return torch.matmul(p_attn, value), p_attn # the new output is (num_atoms, 4, 4, d_k) which will be processed into (num_atoms, 4, hidden_size)

As you can see, the scores shape is (num_atoms, 4, 4, 4) which is computed between different heads rather than different atoms. That is, each atom's representation is the combination of different heads' information which is meaningless.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant