Skip to content

Commit

Permalink
Merge pull request #10 from aniemore/dev-master
Browse files Browse the repository at this point in the history
Hotfix Fusion module, returned forward func
  • Loading branch information
Ar4ikov authored May 15, 2023
2 parents 6a6e70c + 386128f commit 15c481d
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions aniemore/custom/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def __init__(self, audio_dim, text_dim, num_heads):
self.audio_norm = torch.nn.LayerNorm(self.dimension)
self.text_norm = torch.nn.LayerNorm(self.dimension)

def forward(self, audio_output, text_output):
# Multihead cross attention (dims ARE switched)
audio_attn, _ = self.a_self_attention(audio_output, text_output, text_output)
text_attn, _ = self.t_self_attention(text_output, audio_output, audio_output)

# Add & Norm with dropout
audio_add = self.audio_norm(audio_output + audio_attn)
text_add = self.text_norm(text_output + text_attn)

return audio_add, text_add


class WavLMForVoiceClassification(BaseModelForVoiceBaseClassification):
"""WavLMForVoiceClassification is a model for voice classification task
Expand Down

0 comments on commit 15c481d

Please sign in to comment.