diff --git a/transformer/Models.py b/transformer/Models.py index 9af9fd6..884dff7 100644 --- a/transformer/Models.py +++ b/transformer/Models.py @@ -29,20 +29,17 @@ def __init__(self, d_hid, n_position=200): self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) def _get_sinusoid_encoding_table(self, n_position, d_hid): - ''' Sinusoid position encoding table ''' - # TODO: make it with torch instead of numpy - + ''' Sinusoid position encoding table ''' + def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + return [position / torch.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table = torch.tensor([get_position_angle_vec(pos_i) for pos_i in range(n_position)], dtype=torch.float32) + sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - return torch.FloatTensor(sinusoid_table).unsqueeze(0) + return sinusoid_table.unsqueeze(0) - def forward(self, x): - return x + self.pos_table[:, :x.size(1)].clone().detach() class Encoder(nn.Module):