Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does the Binarize() function use STE? #9

Open
leejiajun opened this issue Aug 30, 2018 · 29 comments
Open

Does the Binarize() function use STE? #9

leejiajun opened this issue Aug 30, 2018 · 29 comments

Comments

@leejiajun
Copy link

Does the Binarize() function use STE?
I haven't seen the STE algorithm in this whole project.

@mohdumar644
Copy link

I want to know the same.

Probably the Binarize operation on the layer weights is not recorded (the weights are leaf nodes?). During backward pass, the gradient is calculated w.r.t. the binary weights. This same gradient is used to update the real weights, hence implicitly STE.

@haichaoyu
Copy link

Want to know the same +1

@Ashokvardhan
Copy link

@mohdumar644, @haichaoyu, @leejiajun : I went through the code of binarized_modules.py and in the class BinarizeLinear, we can see that the line self.weight.data=Binarize(self.weight.org) modifies the values of the weights to be binary. This means that both during the forward and backward passes, the binary weights are used. However for the gradient update we need real weights which are only available in self.weight.org. However I am not sure how these real weights are used by the optimizer.

@mohdumar644
Copy link

weights are binarized in forward pass using self.weight.data=Binarize(self.weight.org).

prior to optimizer step, real weights are copied back.

@mohdumar644
Copy link

mohdumar644 commented Jul 8, 2019

btw, a more understandable approach to constructing binarized/quantized neural networks in PyTorch can be found here with a custom Quantizer/Binarizer module having explicit STE in the backward pass.

@Ashokvardhan
Copy link

@mohdumar644 thanks for the links! In the attached implementation too, the class BinarizeLinear uses the function QuantizeWeights which is very similar to this PyTorch implementation. They also use the copying trick. So I don't see a clear difference in the implementation except for the Quantizer module which has own forward and backward functions.

@Ashokvardhan
Copy link

weights are binarized in forward pass using self.weight.data=Binarize(self.weight.org).

prior to optimizer step, real weights are copied back.

I see. I was only looking at the file main_mnist.py which didn't have this copying trick and hence got confused. So it seems the current implementation for MNIST is wrong because they are updating the binary weights instead of the real weights. what do you think?

@haichaoyu
Copy link

haichaoyu commented Jul 8, 2019

Hi @Ashokvardhan , thanks for your insights. I just focused on the the resnet code for cifar10 (resnet_binary.py). In fact, I have two questions about the code.

  • One is the hardtanh here. This function is called between bn and next conv layer. hardtanh makes output between -1 and +1. This makes backward of binarization inside each conv and linear does not take effects (gradients are always 1). My solution is to just remove the tanh layer and use an explicit STE layer as you mentioned above (The difference is that I used gate function in backward). After these modifications, performance remains the same.

  • The second possible problem is that the Binarization operation (inside each conv and linear) are directly applied to Tensor.data instead of Tensor. As a result, binarization is not recorded in gradient graph. An explicit STE on Tensor instead of Tensor.data may solve this problem.

Please correct me if there is any problem.

@mohdumar644
Copy link

mohdumar644 commented Jul 9, 2019

@mohdumar644 thanks for the links! In the attached implementation too, the class BinarizeLinear uses the function QuantizeWeights which is very similar to this PyTorch implementation. They also use the copying trick. So I don't see a clear difference in the implementation except for the Quantizer module which has own forward and backward functions.

My bad actually since I am a bit rusty - the links I shared only change the way the intermediate activations are binarized/quantized, not the weights - for the weights the way is the same i.e. an extended Linear module. The code I shared just avoided having to input.data = Binarize(input.data) inside this extended class.

weights are binarized in forward pass using self.weight.data=Binarize(self.weight.org).
prior to optimizer step, real weights are copied back.

I see. I was only looking at the file main_mnist.py which didn't have this copying trick and hence got confused. So it seems the current implementation for MNIST is wrong because they are updating the binary weights instead of the real weights. what do you think?

Did you look at this line?

@mohdumar644
Copy link

mohdumar644 commented Jul 9, 2019

  • One is the hardtanh here. This function is called between bn and next conv layer. hardtanh makes output between -1 and +1. This makes backward of binarization inside each conv and linear does not take effects (gradients are always 1). My solution is to just remove the tanh layer and use an explicit STE layer as you mentioned above (The difference is that I used gate function in backward). After these modifications, performance remains the same.

You basically integrated Hardtanh into Quantizer/Binarizer function by using a clamper/gate in the backward. Good!

@Ashokvardhan
Copy link

@mohdumar644 thanks for the links! In the attached implementation too, the class BinarizeLinear uses the function QuantizeWeights which is very similar to this PyTorch implementation. They also use the copying trick. So I don't see a clear difference in the implementation except for the Quantizer module which has own forward and backward functions.

My bad actually since I am a bit rusty - the links I shared only change the way the intermediate activations are binarized/quantized, not the weights - for the weights the way is the same i.e. an extended Linear module. The code I shared just avoided having to input.data = Binarize(input.data) inside this extended class.

weights are binarized in forward pass using self.weight.data=Binarize(self.weight.org).
prior to optimizer step, real weights are copied back.

I see. I was only looking at the file main_mnist.py which didn't have this copying trick and hence got confused. So it seems the current implementation for MNIST is wrong because they are updating the binary weights instead of the real weights. what do you think?

Did you look at this line?

Aah, didn't notice this. Thanks! Btw, this file main_binary.py assumes that one has access to the datasets for cifar-10, cifar-100, imagenet. Do you know in what format they are needed to run this code? In case you already happened to run these codes, can you let me know about this? Thanks!

@mohdumar644
Copy link

As per this,

  • the cifar-x datasets (and even the mnist dataset when needed) will be automatically downloaded and extracted in a certain folder the first time the script is run, and later reused.
  • the imagenet is much larger (>100 gb) and will be needed to be downloaded and setup in the relevant directory yourself

@Ashokvardhan
Copy link

Hi @Ashokvardhan , thanks for your insights. I just focused on the the resnet code for cifar10 (resnet_binary.py). In fact, I have two questions about the code.

  • One is the hardtanh here. This function is called between bn and next conv layer. hardtanh makes output between -1 and +1. This makes backward of binarization inside each conv and linear does not take effects (gradients are always 1). My solution is to just remove the tanh layer and use an explicit STE layer as you mentioned above (The difference is that I used gate function in backward). After these modifications, performance remains the same.

Sorry, I didn't understand this. can you please explain in detail?

  • The second possible problem is that the Binarization operation (inside each conv and linear) are directly applied to Tensor.data instead of Tensor. As a result, binarization is not recorded in gradient graph. An explicit STE on Tensor instead of Tensor.data may solve this problem.

Please correct me if there is any problem.

I realized that there is no problem with their implementation. The reason they didn't have to use STE() is because they are modifying only the Tensor.data which won't be recorded for gradient computation. This is exactly what we need because we are using the gradients on the binary weights to update the real weights directly.

@haichaoyu
Copy link

Regarding the first point:

  • According to the paper (Equation 4), the gradient of STE is 0 when abs(weight) > 0 and 1 when abs(weight) < 1. If activations are passed through torch.HardTanh before STE, abs of all activations are less than 1. Thus gradients of STE are always 1?

Actually, both of my concerns are summarized in this issue title by leejiajun. Does the network use STE with gate gradient? (Equation 4 in the paper)

@itayhubara
Copy link
Owner

By using the HradTanhH you ensure that the gradients of STE is 0 when abs(activation) > 1 and 1 when abs(activation) < 1. Note that the STE is used over the activations (not the weights) to avoid running back-propagation through the sign function.

@haichaoyu
Copy link

@itayhubara, Sorry, its my mistake. What about weight STE?

@Ashokvardhan
Copy link

Ashokvardhan commented Jul 10, 2019

@mohdumar644 Thanks for the links. Btw, though the paper is about binary weights and activations, in the code I only see weights being binarized. Are activations binarized anywhere? The code uses nn.Hardtanh() function which only restricts it to be in [-1,1] but not binary.

In fact, the other implementation you referred to actually uses quantization on activation.

@mohdumar644
Copy link

mohdumar644 commented Jul 10, 2019

The line

input.data=Binarize(input.data)

in binarized_modules.py is related to the activations.
It binarizes the output of the HardTanh

@Ashokvardhan
Copy link

@mohdumar644 : I am confused now because in the backward pass, the gradients for activations are computed at the binary activations and not the real ones. Suppose h is the real activation from the previous layer. Ignoring the batch norm stuff, if the real weights are w, then the BinarizeLinear computes Hardtanh(Q(w). Q(h)), where Q is the sign quantization function. Now when we do backpass, the gradients of activations are computed at Q(h). Since there is no STE() for activations, I don't see how the current gradients are obtained.

@mohdumar644
Copy link

mohdumar644 commented Jul 10, 2019

@mohdumar644 : I am confused now because in the backward pass, the gradients for activations are computed at the binary activations and not the real ones. Suppose h is the real activation from the previous layer. Ignoring the batch norm stuff, if the real weights are w, then the BinarizeLinear computes Hardtanh(Q(w). Q(h)), where Q is the sign quantization function. Now when we do backpass, the gradients of activations are computed at Q(h). Since there is no STE() for activations, I don't see how the current gradients are obtained.

A typical BNN goes like a_bin_k = Sign(HardTanh(BatchNorm(Linear(w_b * a_bin_k_prev). Lets ignore BatchNorm.

The sign binarization function on input.data inside forward function of BinarizeLinear is not recorded in the gradient computation graph.
PyTorch does not see the operation Sign(x), and does not use its gradient. But it has the binarized activations available in the BinarizeLinear saved tensors for backward pass since we did binarize them using sign on input.data. So the gradient on weights g_w_b is calculated using binarized activations, and gradient on binarized activations g_a_b is calculated using binarized weights (according to backprop formulas of linear layer).

Now in our PyTorch model, we performed HardTanh before the next Linear layer, so we backpropagate further from BinarizeLinear to the preceding HardTanh. We did not write Sign function between HardTanh and Linear anywhere in our main model description so its gradient is nowhere considered.
If we say y=HardTanh(x), then we already have g_y as g_a_b, since the sign function was not recorded in computation graph and PyTorch thinks the output of this HardTanh 'y' is the same as the one on which the gradient was calculated in BinarizeLinear. Thus we implicitly performed STE. Now we can find g_x using gradient definition of HardTanh and backprop further back. Of course, HardTanh is only used to clip gradients for proper training in BNNs.

@Ashokvardhan
Copy link

@mohdumar644 Thanks for the clarification. I was under the impression that the gradients with respect to activations are always calculcated w.r.t to real activation values. If they are computed only w.r.t binary activations, indeed the current implementation is current and your explanation makes a lot of sense. With regards to activations being binarized, the other implementation does this explicitly with no source for confusion.

@mohdumar644
Copy link

@mohdumar644 Thanks for the clarification. I was under the impression that the gradients with respect to activations are always calculcated w.r.t to real activation values. If they are computed only w.r.t binary activations, indeed the current implementation is current and your explanation makes a lot of sense. With regards to activations being binarized, the other implementation does this explicitly with no source for confusion.

Always good to refer to the original BNN paper.

image

@Ashokvardhan
Copy link

@mohdumar644 I was confused because I didn't see nn.Hardtanh() in the paper but in the implementation they used that function. Because of this ambiguity, it was not clear which one to follow. Now looking back, the forward pass in the paper is effectively like a linear neural network just with binary activations and weights and clamping the gradients in the backward pass when activations are beyond 1, right? Of course, for the implementation sake one can view it as using nn.Hardtanh() in the forward pass to make the computations of backward pass easier.

@mohdumar644
Copy link

Yes the paper's algorithm can be functionally achieved using a variety of ways in different frameworks.

@Ashokvardhan
Copy link

As per this,

  • the cifar-x datasets (and even the mnist dataset when needed) will be automatically downloaded and extracted in a certain folder the first time the script is run, and later reused.
  • the imagenet is much larger (>100 gb) and will be needed to be downloaded and setup in the relevant directory yourself

I see. Do you need to do any processing on the imagenet dataset or just the raw images are used for the code? Btw, where did you downloaded it from?

@Ashokvardhan
Copy link

@itayhubara Can you mention a bit what the instructions for ImageNet dataset that are needed for the experiments in the paper? Such as what format you need the dataset and where to obtain it from, etc.

@codeprb
Copy link

codeprb commented Nov 19, 2019

@itayhubara I too have the same question as @Ashokvardhan. Which imagenet dataset has been used here. I have downloaded a partial dataset from kaggle with 45000 images for training and 5000 images for validation. Please let me know if there is any specific requirement for imagenet dataset.

@itayhubara
Copy link
Owner

itayhubara commented Nov 19, 2019

You need to download the ImageNet dataset from here http://www.image-net.org/challenges/LSVRC/2012/downloads (note that you must login for that) and then you can use torchvision to create a dataloader.

@codeprb
Copy link

codeprb commented Mar 24, 2020

I am trying to run Resnet model with the Imagenet dataset. Unfortunately, I get 100% precision soon after the 1 epoch of the training. I am unable to trace the error. Could you please let me know what could be the error.
2020-03-24 16:36:48 - INFO - TRAINING - Epoch: [0][0/176] Time 59.987 (59.987) Data 2.543 (2.543) Loss 6.8668 (6.8668) Prec@1 0.000 (0.000) Prec@5 0.000 (0.000)
2020-03-24 17:30:24 - INFO - TRAINING - Epoch: [0][10/176] Time 49.430 (297.805) Data 0.001 (0.232) Loss 0.0000 (0.6243) Prec@1 100.000 (90.909) Prec@5 100.000 (90.909)
2020-03-24 17:38:45 - INFO - TRAINING - Epoch: [0][20/176] Time 49.411 (179.859) Data 0.000 (0.122) Loss 0.0000 (0.3270) Prec@1 100.000 (95.238) Prec@5 100.000 (95.238)

I am storing data under datasets/ImageNet_train/train/1/ and test data as datasets/ImageNet_test/test/1/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants