Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace NumPy with PyTorch in PositionalEncoding #207

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ZYM66
Copy link

@ZYM66 ZYM66 commented Apr 28, 2023

Description:

This PR replaces the usage of NumPy with PyTorch in the PositionalEncoding class, specifically in the _get_sinusoid_encoding_table method. The goal is to make the function compatible with PyTorch and take advantage of GPU acceleration when available.

Changes:

Replace np.array with torch.tensor for creating the sinusoid table.
Replace np.power with torch.pow for calculating the position angle vector.
Change the data type from NumPy float array to PyTorch float tensor.

Here's the modified _get_sinusoid_encoding_table method:

def _get_sinusoid_encoding_table(self, n_position, d_hid):
    ''' Sinusoid position encoding table '''

    def get_position_angle_vec(position):
        return [position / torch.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = torch.tensor([get_position_angle_vec(pos_i) for pos_i in range(n_position)], dtype=torch.float32)
    sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return sinusoid_table.unsqueeze(0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant