Skip to content

Commit

Permalink
add new pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed May 11, 2023
1 parent c02c4b4 commit 3f8fb49
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion hear21passt/models/passt.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def _cfg(url='', **kwargs):
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.1-audioset/passt-s-f128-p16-s10-ap.476-swa.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_kd_p16_128_ap486': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v.0.0.9/passt-s-kd-ap.486.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
classifier=('head.1', 'head_dist'), num_classes=527),
'passt_s_p16_s16_128_ap468': _cfg(
url='https://github.com/kkoutini/PaSST/releases/download/v0.0.2-audioset/passt-s-f128-p16-s16-ap.468.pt',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(1, 128, 998), crop_pct=1.0,
Expand Down Expand Up @@ -710,6 +714,18 @@ def passt_s_swa_p16_128_ap476(pretrained=False, **kwargs):
'passt_s_swa_p16_128_ap476', pretrained=pretrained, distilled=True, **model_kwargs)
return model

def passt_s_kd_p16_128_ap486(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
print("\n\n Loading PaSST pre-trained on AudioSet (with KD) Patch 16 stride 10 structured patchout mAP=486 \n\n")
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
if model_kwargs.get("stride") != (10, 10):
warnings.warn(
f"This model was pre-trained with strides {(10, 10)}, but now you set (fstride,tstride) to {model_kwargs.get('stride')}.")
model = _create_vision_transformer(
'passt_s_kd_p16_128_ap486', pretrained=pretrained, distilled=True, **model_kwargs)
return model

def passt_s_p16_s16_128_ap468(pretrained=False, **kwargs):
""" PaSST pre-trained on AudioSet
"""
Expand Down Expand Up @@ -821,7 +837,7 @@ def fix_embedding_layer(model, embed="default"):
return model


def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527, in_channels=1, fstride=10,
def get_model(arch="passt_s_kd_p16_128_ap486", pretrained=True, n_classes=527, in_channels=1, fstride=10,
tstride=10,
input_fdim=128, input_tdim=998, u_patchout=0, s_patchout_t=0, s_patchout_f=0,
):
Expand All @@ -848,6 +864,8 @@ def get_model(arch="passt_s_swa_p16_128_ap476", pretrained=True, n_classes=527,
model_func = deit_base_distilled_patch16_384
elif arch == "passt_s_swa_p16_128_ap476": # pretrained
model_func = passt_s_swa_p16_128_ap476
elif arch == "passt_s_kd_p16_128_ap486": # pretrained
model_func = passt_s_kd_p16_128_ap486
elif arch == "passt_s_p16_s16_128_ap468":
if fstride!=16 or tstride!=16:
raise ValueError("fstride and tstride must be 16 for arch=passt_s_p16_s16_128_ap468. "
Expand Down

0 comments on commit 3f8fb49

Please sign in to comment.