-
Notifications
You must be signed in to change notification settings - Fork 126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does the Binarize() function use STE? #9
Comments
I want to know the same. Probably the Binarize operation on the layer weights is not recorded (the weights are leaf nodes?). During backward pass, the gradient is calculated w.r.t. the binary weights. This same gradient is used to update the real weights, hence implicitly STE. |
Want to know the same +1 |
@mohdumar644, @haichaoyu, @leejiajun : I went through the code of |
weights are binarized in forward pass using prior to optimizer step, real weights are copied back. |
|
@mohdumar644 thanks for the links! In the attached implementation too, the class BinarizeLinear uses the function QuantizeWeights which is very similar to this PyTorch implementation. They also use the copying trick. So I don't see a clear difference in the implementation except for the Quantizer module which has own forward and backward functions. |
I see. I was only looking at the file main_mnist.py which didn't have this copying trick and hence got confused. So it seems the current implementation for MNIST is wrong because they are updating the binary weights instead of the real weights. what do you think? |
Hi @Ashokvardhan , thanks for your insights. I just focused on the the resnet code for cifar10 (resnet_binary.py). In fact, I have two questions about the code.
Please correct me if there is any problem. |
My bad actually since I am a bit rusty - the links I shared only change the way the intermediate activations are binarized/quantized, not the weights - for the weights the way is the same i.e. an extended Linear module. The code I shared just avoided having to
Did you look at this line? |
You basically integrated Hardtanh into Quantizer/Binarizer function by using a clamper/gate in the backward. Good! |
Aah, didn't notice this. Thanks! Btw, this file main_binary.py assumes that one has access to the datasets for cifar-10, cifar-100, imagenet. Do you know in what format they are needed to run this code? In case you already happened to run these codes, can you let me know about this? Thanks! |
As per this,
|
Sorry, I didn't understand this. can you please explain in detail?
I realized that there is no problem with their implementation. The reason they didn't have to use STE() is because they are modifying only the Tensor.data which won't be recorded for gradient computation. This is exactly what we need because we are using the gradients on the binary weights to update the real weights directly. |
Regarding the first point:
Actually, both of my concerns are summarized in this issue title by leejiajun. Does the network use STE with gate gradient? (Equation 4 in the paper) |
By using the HradTanhH you ensure that the gradients of STE is 0 when abs(activation) > 1 and 1 when abs(activation) < 1. Note that the STE is used over the activations (not the weights) to avoid running back-propagation through the sign function. |
@itayhubara, Sorry, its my mistake. What about weight STE? |
@mohdumar644 Thanks for the links. Btw, though the paper is about binary weights and activations, in the code I only see weights being binarized. Are activations binarized anywhere? The code uses nn.Hardtanh() function which only restricts it to be in [-1,1] but not binary. In fact, the other implementation you referred to actually uses quantization on activation. |
The line
in binarized_modules.py is related to the activations. |
@mohdumar644 : I am confused now because in the backward pass, the gradients for activations are computed at the binary activations and not the real ones. Suppose h is the real activation from the previous layer. Ignoring the batch norm stuff, if the real weights are w, then the BinarizeLinear computes Hardtanh(Q(w). Q(h)), where |
A typical BNN goes like a_bin_k = Sign(HardTanh(BatchNorm(Linear(w_b * a_bin_k_prev). Lets ignore BatchNorm. The sign binarization function on input.data inside forward function of BinarizeLinear is not recorded in the gradient computation graph. Now in our PyTorch model, we performed HardTanh before the next Linear layer, so we backpropagate further from BinarizeLinear to the preceding HardTanh. We did not write Sign function between HardTanh and Linear anywhere in our main model description so its gradient is nowhere considered. |
@mohdumar644 Thanks for the clarification. I was under the impression that the gradients with respect to activations are always calculcated w.r.t to real activation values. If they are computed only w.r.t binary activations, indeed the current implementation is current and your explanation makes a lot of sense. With regards to activations being binarized, the other implementation does this explicitly with no source for confusion. |
Always good to refer to the original BNN paper. |
@mohdumar644 I was confused because I didn't see nn.Hardtanh() in the paper but in the implementation they used that function. Because of this ambiguity, it was not clear which one to follow. Now looking back, the forward pass in the paper is effectively like a linear neural network just with binary activations and weights and clamping the gradients in the backward pass when activations are beyond 1, right? Of course, for the implementation sake one can view it as using nn.Hardtanh() in the forward pass to make the computations of backward pass easier. |
Yes the paper's algorithm can be functionally achieved using a variety of ways in different frameworks. |
I see. Do you need to do any processing on the imagenet dataset or just the raw images are used for the code? Btw, where did you downloaded it from? |
@itayhubara Can you mention a bit what the instructions for ImageNet dataset that are needed for the experiments in the paper? Such as what format you need the dataset and where to obtain it from, etc. |
@itayhubara I too have the same question as @Ashokvardhan. Which imagenet dataset has been used here. I have downloaded a partial dataset from kaggle with 45000 images for training and 5000 images for validation. Please let me know if there is any specific requirement for imagenet dataset. |
You need to download the ImageNet dataset from here http://www.image-net.org/challenges/LSVRC/2012/downloads (note that you must login for that) and then you can use torchvision to create a dataloader. |
I am trying to run Resnet model with the Imagenet dataset. Unfortunately, I get 100% precision soon after the 1 epoch of the training. I am unable to trace the error. Could you please let me know what could be the error. I am storing data under datasets/ImageNet_train/train/1/ and test data as datasets/ImageNet_test/test/1/ |
Does the Binarize() function use STE?
I haven't seen the STE algorithm in this whole project.
The text was updated successfully, but these errors were encountered: