Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing number of classes #282

Open
vikashg opened this issue Jul 31, 2024 · 0 comments
Open

Changing number of classes #282

vikashg opened this issue Jul 31, 2024 · 0 comments

Comments

@vikashg
Copy link

vikashg commented Jul 31, 2024

Hi,
I am trying to retrain the linear classifier by changing the number of classes. This is what I did.
First I trained the features using the following command

out_dir=./main_dino_output
python -m torch.distributed.launch --arch vit_small --data_path </path/to/my/datadir> --output_dir $out_dir --epochs 1000

Now that I have trained my features. I will run a linear classifier as follows

python   eval_linear.py --pretrained_weights $out_dir/checkpoint.pth --num_labels 5 --data_path $data_dir --epochs 500 --arch vit_small

This part executes properly. Now I would like to evaluate the trained algorithm. The trained model from this step is saved as ./checkpoint.pth.tar. So, I execute the above command with the --evaluate flag turned on

python   eval_linear.py  --evaluate --pretrained_weights ./checkpoint.pth.tar --num_labels 5 --data_path $data_dir 

However, in this case I get the following error:
image

The model throws an error as it is expecting 1000 classes and not 5.

size mismatch for module.linear.weight: copying a param with shape torch.Size([1000, 1536]) from checkpoint, the shape in current model is torch.Size([5, 1536]).
size mismatch for module.linear.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([5]).

When I look at the code

dino/eval_linear.py

Lines 79 to 83 in 7c446df

if args.evaluate:
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

It is downloading new weights and trying to run the evaluation on the new weights.
I think this is a bug and if I am providing the weights, it should not download the weights as it is doing in the utils.load_pretrained_linear_weights.
When I comment out Line 80 in the eval_linear.py file, the code works fine.

Is this the right thing to do. Please let me know.

P.S.: I know that positing screenshots is generally not the norm. But I wanted to show that it is downloading new weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant