Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two differences from the original implementation #17

Open
p3i0t opened this issue Nov 9, 2017 · 5 comments
Open

Two differences from the original implementation #17

p3i0t opened this issue Nov 9, 2017 · 5 comments

Comments

@p3i0t
Copy link

p3i0t commented Nov 9, 2017

I got the same result as you, ~0.846 Pearson score. After checking the original implementation, I found two differences.

  • In your trainer.py file,
def train(self, dataset):
        self.model.train()
        self.optimizer.zero_grad()
        loss, k = 0.0, 0
        indices = torch.randperm(len(dataset))
        for idx in tqdm(range(len(dataset)),desc='Training epoch '+str(self.epoch+1)+''):
            ltree,lsent,rtree,rsent,label = dataset[indices[idx]]
            linput, rinput = Var(lsent), Var(rsent)
            target = Var(map_label_to_target(label,dataset.num_classes))
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree,linput,rtree,rinput)
            err = self.criterion(output, target)
            loss += err.data[0]
            err.backward()           # <------------
            k += 1
            if k%self.args.batchsize==0:
                self.optimizer.step()
                self.optimizer.zero_grad()
        self.epoch += 1
        return loss/len(dataset)

You call .backward() for each sample in the mini-batch, and then perform one step update with self.optimizer.step(). Since the backward() function accumulate the gradients automatically, it seems you need to average both the losses and the gradients over the mini-batch. So I think the arrow line above should be changed to

(err/self.args.batchsize).backward()
  • The original implementation does not really update the embeddings. It does not include the embedding parameters into the model, and all the parameters of the model are optimized with Adagrad. It updates the embedding parameters with the gradients*learning_rate directly, but the learning_rate is set to 0.
    Furthermore, I did some simple calculations. The number of embedding parameters is more than 700000, and 286505 for the other model parameters. Consider the size of the training set is just 4500, it is too small to fine-tune the embeddings.

After I made the two above modifications, I can get 0.854 Pearson score and 0.274 MSE with Adagrad(learning_rate=0.05)

@ryh95
Copy link

ryh95 commented Nov 15, 2017

yes
in section 5.3 the paper said

For the semantic relatedness task, word representations were held fixed as we did not observe any
significant improvement when the representations were tuned

@dasguptar
Copy link
Owner

Hi @wangxin0716 and @ryh95 ,

As you have pointed out, the original paper mentions freezing the word embeddings. I had overlooked this, but have rectified my mistake and incorporated this via commit which adds the option of freezing the word embeddings during training. This results in a slight improvement to the metrics, and we can now reach Pearson's coefficient of 0.8674 and MSE of 0.2536.

We are now within ~0.0005 of the original paper, albeit with a different learning rate, so I do not really know if there is any way left to exactly match the numbers. Different libraries, platforms, OS, etc. might account for numerical precision differences within this ballpark.

@dasguptar
Copy link
Owner

BTW, @wangxin0716 , I also tried the change you suggested, i.e. (err/self.args.batchsize).backward(), however, I ended up getting better final metrics keeping it as is. I believe this should not matter as much, since this is a simple scaling of the gradient and can be effectively achieved using a different learning rate to the same effect.

@thuqinyj16
Copy link

I run with parameter --lr 0.025 --wd 0.0001 --optim adagrad --batchsize 25 --freeze_embed, however, the result is 0.857, 0.01 less than what it is supposed to be. What could possibly caused the situation?

@jeenakk
Copy link

jeenakk commented Jun 23, 2019

Thanks for the code. That was very helpful in understanding the paper. I ran the code with the following configuration :

Namespace(batchsize=25, cuda=False, epochs=50, expname='master', freeze_embed=True, hidden_dim=50, input_dim=300, lr=0.025, mem_dim=150, num_classes=5, optim='adagrad', save='checkpoints/', seed=123, sparse=False, wd=0.0001)

and got the best result at 5th epoch:

Epoch 5, Test Loss: 0.10324564972114664 Pearson: 0.8587949275970459 MSE: 0.2709934413433075

which is less than what is claimed. Could you please suggest, what I could be doing wrong? Is there anyone else facing the same issue? Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants