Skip to content

Commit

Permalink
correct device for pos_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Sep 15, 2024
1 parent 9c0fe3f commit 99f3b7d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int):

sinusoid_input = torch.outer(pos_seq, inv_freq)

pos_emb = torch.zeros(pos_seq.shape[0], embed_dim)
pos_emb = torch.zeros(pos_seq.shape[0], embed_dim, device=pos_seq.device)

pos_emb[:, 0::2] = sinusoid_input.sin()
pos_emb[:, 1::2] = sinusoid_input.cos()
Expand Down

0 comments on commit 99f3b7d

Please sign in to comment.