Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix an error on Windows about 32, 64-bit integer #48

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

kkppll-ss
Copy link

@kkppll-ss kkppll-ss commented Apr 20, 2018

I run translate.py on my Windows machine (with Windows 10 64-bit, Anaconda and PyTorch from this Anaconda package) and run into the following error:

Traceback (most recent call last):
  File "C:/Users/yao_z/Documents/attention-is-all-you-need-pytorch/translate.py", line 66, in <module>
    main()
  File "C:/Users/yao_z/Documents/attention-is-all-you-need-pytorch/translate.py", line 58, in main
    all_hyp, all_scores = translator.translate_batch(batch)
  File "C:\Users\yao_z\Documents\attention-is-all-you-need-pytorch\transformer\Translator.py", line 107, in translate_batch
    dec_partial_seq, dec_partial_pos, src_seq, enc_output)
  File "C:\Users\yao_z\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\yao_z\Documents\attention-is-all-you-need-pytorch\transformer\Models.py", line 111, in forward
    dec_input = self.tgt_word_emb(tgt_seq)
  File "C:\Users\yao_z\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\yao_z\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\modules\sparse.py", line 103, in forward
    self.scale_grad_by_freq, self.sparse
  File "C:\Users\yao_z\Miniconda3\envs\pytorch\lib\site-packages\torch\nn\_functions\thnn\sparse.py", line 59, in forward
    output = torch.index_select(weight, 0, indices.view(-1))
TypeError: torch.index_select received an invalid combination of arguments - got (torch.cuda.FloatTensor, int, !torch.cuda.IntTensor!), but expected (torch.cuda.FloatTensor source, int dim, torch.cuda.LongTensor index)

This is because in the following code snippet from Beam.py, dec_seq is returned as a numpy ndarray with dtype=np.int32, not np.int64.

def get_tentative_hypothesis(self):
        "Get the decoded sequence for the current timestep."

        if len(self.next_ys) == 1:
            dec_seq = self.next_ys[0].unsqueeze(1)
        else:
            _, keys = self.sort_scores()
            hyps = [self.get_hypothesis(k) for k in keys]
            hyps = [[Constants.BOS] + h for h in hyps]
            dec_seq = torch.from_numpy(np.array(hyps))

        return dec_seq

On Windows, even 64 bit Windows and 64 bit Anaconda, numpy uses 32-bit integer as the default integral type, as is illustrated in the following picture and you can see more from here:

capture

As a result, in the following code snippet from Translator.py, dec_partial_seq is an IntTensor instead of a LongTensor, which leads to the error I mentioned previously.

 # -- Decoding -- #
            dec_output, *_ = self.model.decoder(
dec_partial_seq, dec_partial_pos, src_seq, enc_output)

Therefore, I add dtype=np.int64 argument, the error disappears and it now runs smoothly on Windows.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant