Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SSAST pretrained model to inference #36

Open
gavinwwf opened this issue Jun 26, 2024 · 0 comments
Open

Use SSAST pretrained model to inference #36

gavinwwf opened this issue Jun 26, 2024 · 0 comments

Comments

@gavinwwf
Copy link

@YuanGongND I used SSAST pretrained model to inference, but got the different results every time. And every score in the results is close. What is the reason for the result?
[{"label": "Electric toothbrush", "score": 0.849498987197876}, {"label": "Blender", "score": 0.8397527933120728}, {"label": "Tambourine", "score": 0.8310427665710449}, {"label": "Race car, auto racing", "score": 0.8218237161636353}, {"label": "Pink noise", "score": 0.8042027354240417}, {"label": "Writing", "score": 0.7958802580833435}, {"label": "Singing", "score": 0.7875975966453552}, {"label": "Telephone dialing, DTMF", "score": 0.7849113941192627}, {"label": "Ambulance (siren)", "score": 0.7678646445274353}, {"label": "Country", "score": 0.7541956901550293}]

My code is as follows:

def model_fn(model_dir):
"""
Load the model and set weights
"""

# Load the model
input_tdim = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = f'{model_dir}/SSAST-Tiny-Patch-400.pth'

# fstride, tstride = int(checkpoint_path.split('/')[-1].split('_')[1]), int(
#     checkpoint_path.split('/')[-1].split('_')[2].split('.')[0])

ast_mdl = ASTModel(label_dim=527, fshape=16, tshape=16, fstride=10, tstride=10, input_tdim=input_tdim,
                        model_size='tiny', pretrain_stage=False, load_pretrained_mdl_path=checkpoint_path)

audio_model = torch.nn.DataParallel(ast_mdl)
checkpoint = load_modified_checkpoint(checkpoint_path, audio_model, device)
audio_model.load_state_dict(checkpoint)
audio_model = audio_model.to(device)
audio_model.eval()

labels = load_label(f'{model_dir}/class_labels_indices.csv')

return audio_model, labels

def predict_fn(input_data, model):
"""
The predict_fn is invoked with the return value of input_fn.
"""
audio_model, labels = model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_tdim = 200
feats = make_features(input_data, mel_bins=128, target_length=input_tdim)
feats_data = feats.expand(1, input_tdim, 128).to(device)

with torch.no_grad():
    output = audio_model(feats_data, task='ft_cls')
    output = torch.sigmoid(output)

result_output = output.data.cpu().numpy()[0]
sorted_indexes = np.argsort(result_output)[::-1]

top_k = 10
top_k_labels = [(labels[idx], result_output[idx]) for idx in sorted_indexes[:top_k]]

return top_k_labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant