-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VPL - Variational Prototype Learning for deep face recognition #88
Comments
Is it the current SOTA for the face recognition task? |
I think an imlementation is possible, will try and check if will be better. |
VPL mode is added. It can be enabled by import losses, train, models
import tensorflow_addons as tfa
keras.mixed_precision.set_global_policy("mixed_float16")
data_basic_path = '/datasets/ms1m-retinaface-t1'
data_path = data_basic_path + '_112x112_folders'
eval_paths = [os.path.join(data_basic_path, ii) for ii in ['lfw.bin', 'cfp_fp.bin', 'agedb_30.bin']]
from keras_cv_attention_models import efficientnet
basic_model = efficientnet.EfficientNetV2B0(input_shape=(112, 112, 3), num_classes=0)
basic_model = models.buildin_models(basic_model, dropout=0, emb_shape=512, output_layer='GDC', bn_epsilon=1e-4, bn_momentum=0.9, scale=True, use_bias=False)
tt = train.Train(data_path, eval_paths=eval_paths,
save_path='TT_efv2_b0_swish_GDC_arc_emb512_dr0_adamw_5e4_bs512_ms1mv3_randaug_cos16_batch_float16_vpl.h5',
basic_model=basic_model, model=None, lr_base=0.01, lr_decay=0.5, lr_decay_steps=16, lr_min=1e-6, lr_warmup_steps=3,
batch_size=512, random_status=100, eval_freq=4000, output_weight_decay=1, use_vpl=True)
import tensorflow_addons as tfa
optimizer = tfa.optimizers.AdamW(learning_rate=1e-2, weight_decay=5e-4, exclude_from_weight_decay=["/gamma", "/beta"])
sch = [
{"loss": losses.ArcfaceLoss(scale=16), "epoch": 4, "optimizer": optimizer},
{"loss": losses.ArcfaceLoss(scale=32), "epoch": 3},
{"loss": losses.ArcfaceLoss(scale=64), "epoch": 46},
]
tt.train(sch, 0)
exit()
|
In the vpl paper the results were absolutely different. It turns out the reality quite another. What do you think about it? Maybe your implementation a little different than their, though I'm not sure about it? |
Ya, I have compared them several times. It seems the main parts are:
|
It's 2 parameters now, |
Thank you for your work. It's indeed worth it try. And additional question about IJB validation dataset: Did you try use their 1:N test? |
I'm using my
|
Do you have a plan to implement their(insightface) last work? It seems does not work as they claimed in their papers. deepinsight/insightface#1801
The text was updated successfully, but these errors were encountered: