diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 12e174c..c2cc655 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -252,7 +252,7 @@ def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int): sinusoid_input = torch.outer(pos_seq, inv_freq) - pos_emb = torch.zeros(pos_seq.shape[0], embed_dim) + pos_emb = torch.zeros(pos_seq.shape[0], embed_dim, device=pos_seq.device) pos_emb[:, 0::2] = sinusoid_input.sin() pos_emb[:, 1::2] = sinusoid_input.cos()