Skip to content

Commit

Permalink
Merge pull request #11 from NINAnor/ablation_study
Browse files Browse the repository at this point in the history
Ablation study
  • Loading branch information
BenCretois authored Jan 25, 2024
2 parents 98ada32 + 0774b49 commit 553a5f5
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 40 deletions.
5 changes: 3 additions & 2 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ trainer:
model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_type: pann # beats, pann or baseline
model_path: /data/model/PANN/Cnn14_mAP=0.431.pth # /data/model/BEATs/BEATs_iter3_plus_AS2M.pt # # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
model_type: beats # beats, pann or baseline
state: None # train or validate or None if using beats or baseline // since we remove from layers from the original PANN we can't load the ckpts normally (see _build_model in prototraining.py)
model_path: /data/lightning_logs/version_96/checkpoints/epoch=85-step=8600.ckpt #/data/lightning_logs/pann/lightning_logs/version_1/checkpoints/epoch=99-step=10000.ckpt #/data/lightning_logs/pann/lightning_logs/version_1/checkpoints/epoch=99-step=10000.ckpt #/data/model/PANN/Cnn14_mAP=0.431.pth # /data/model/BEATs/BEATs_iter3_plus_AS2M.pt # # Or FOR INFERENCE: /data/lightning_logs/version_X/checkpoints/epoch=X-step=1500.ckpt
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
Expand Down
53 changes: 37 additions & 16 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@ def to_dataframe(features, labels):


def train_model(
model_class=ProtoBEATsModel,
model_type="pann",
datamodule_class=DCASEDataModule,
milestones=[10, 20, 30],
max_epochs=15,
enable_model_summary=False,
num_sanity_val_steps=0,
seed=42,
pretrained_model=None,
state=None,
beats_path="/data/model/BEATs/BEATs_iter3_plus_AS2M.pt"
):
# create the lightning trainer object
trainer = pl.Trainer(
Expand All @@ -67,6 +68,7 @@ def train_model(
# logger=pl.loggers.TensorBoardLogger("logs/", name="my_model"),
)


# create the model object
model = model_class(milestones=milestones)

Expand All @@ -86,23 +88,29 @@ def train_model(
return model


def training(pretrained_model, custom_datamodule, max_epoch, milestones=[10, 20, 30]):
def training(model_type, pretrained_model, state, custom_datamodule, max_epoch, beats_path):

model = train_model(
ProtoBEATsModel,
model_type,
custom_datamodule,
milestones,
max_epochs=max_epoch,
enable_model_summary=False,
num_sanity_val_steps=0,
seed=42,
pretrained_model=pretrained_model,
state=state,
beats_path=beats_path
)

return model


def get_proto_coordinates(model, support_data, support_labels, n_way):
z_supports, _ = model.get_embeddings(support_data, padding_mask=None)
def get_proto_coordinates(model, model_type, support_data, support_labels, n_way):

if model_type == "beats":
z_supports, _ = model.get_embeddings(support_data, padding_mask=None)
else:
z_supports = model.get_embeddings(support_data, padding_mask=None)

# Get the coordinates of the NEG and POS prototypes
prototypes = model.get_prototypes(
Expand All @@ -115,6 +123,7 @@ def get_proto_coordinates(model, support_data, support_labels, n_way):

def predict_labels_query(
model,
model_type,
z_supports,
queryloader,
prototypes,
Expand All @@ -134,7 +143,7 @@ def predict_labels_query(
# Get POS prototype
POS_prototype = prototypes[pos_index].to("cuda")
d_supports_to_POS_prototypes, _ = calculate_distance(
z_supports.to("cuda"), POS_prototype
model_type, z_supports.to("cuda"), POS_prototype
)
mean_dist_supports = d_supports_to_POS_prototypes.mean(0)
std_dist_supports = d_supports_to_POS_prototypes.std(0)
Expand All @@ -151,8 +160,11 @@ def predict_labels_query(
# Get the embeddings for the query
feature, label = data
feature = feature.to("cuda")
q_embedding, _ = model.get_embeddings(feature, padding_mask=None)

if model_type == "beats":
q_embedding, _ = model.get_embeddings(feature, padding_mask=None)
else:
q_embedding = model.get_embeddings(feature, padding_mask=None)
# Calculate beginTime and endTime for each segment
# We multiply by 100 to get the time in seconds
if i == 0:
Expand All @@ -163,11 +175,17 @@ def predict_labels_query(
end = begin + tensor_length * frame_shift / 1000

# Get the scores:
classification_scores, dists = calculate_distance(q_embedding, prototypes)
classification_scores, dists = calculate_distance(model_type, q_embedding, prototypes)

if model_type != "beats":
dists = dists.squeeze()
classification_scores = classification_scores.squeeze()

# Get the z_score:
z_score = compute_z_scores(
dists[pos_index], mean_dist_supports, std_dist_supports
dists[pos_index],
mean_dist_supports,
std_dist_supports
)

# Get the labels (either POS or NEG):
Expand Down Expand Up @@ -215,7 +233,7 @@ def euclidean_distance(x1, x2):
return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1))


def calculate_distance(z_query, z_proto):
def calculate_distance(model_type, z_query, z_proto):
# Compute the euclidean distance from queries to prototypes
dists = []
for q in z_query:
Expand All @@ -225,8 +243,9 @@ def calculate_distance(z_query, z_proto):
) # Contrary to prototraining I need to add a dimension to store the
dists = torch.cat(dists, dim=0)

# We drop the last dimension without changing the gradients
dists = dists.mean(dim=2).squeeze()
if model_type == "beats":
# We drop the last dimension without changing the gradients
dists = dists.mean(dim=2).squeeze()

scores = -dists

Expand Down Expand Up @@ -296,17 +315,17 @@ def main(

# Train the model with the support data
print("[INFO] TRAINING THE MODEL FOR {}".format(filename))

model = training(
cfg["model"]["model_path"],
custom_dcasedatamodule,
max_epoch=cfg["trainer"]["max_epochs"],
)


# Get the prototypes coordinates
a = custom_dcasedatamodule.test_dataloader()
s, sl, _, _, ways = a
prototypes, z_supports = get_proto_coordinates(model, s, sl, n_way=len(ways))
prototypes, z_supports = get_proto_coordinates(model, model_type, s, sl, n_way=len(ways))

### Get the query dataset ###
df_query = to_dataframe(query_spectrograms, query_labels)
Expand All @@ -327,6 +346,7 @@ def main(
z_score_pos,
) = predict_labels_query(
model,
model_type,
z_supports,
queryLoader,
prototypes,
Expand Down Expand Up @@ -388,6 +408,7 @@ def main(
z_score_pos,
) = predict_labels_query(
model,
model_type,
z_supports,
queryLoader,
prototypes,
Expand Down
58 changes: 36 additions & 22 deletions prototypicalbeats/prototraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(
model_path: str = None,
distance: str = "euclidean",
specaugment_params = None,
state: str = None,
beats_path: str = "/data/model/BEATs/BEATs_iter3_plus_AS2M.pt",
**kwargs,
) -> None:
"""TransferLearningModel.
Expand All @@ -40,22 +42,13 @@ def __init__(
self.milestones = milestones
self.distance = distance
self.model_type = model_type

# If BEATS --> initialise BEATs model
if self.model_type == "beats":
self.checkpoint = torch.load(model_path)
self.cfg = BEATsConfig(
{
**self.checkpoint["cfg"],
"finetuned_model": False,
"specaugment_params": specaugment_params,
}
)

# If we are using the PANN model:
if self.model_type == "pann":
self.checkpoint = torch.load(model_path)
self.state = state
self.specaugment_params = specaugment_params
self.beats_path = beats_path

if model_path:
self.checkpoint = torch.load(model_path)

self._build_model()
self.save_hyperparameters()

Expand All @@ -68,20 +61,41 @@ def _build_model(self):
print("[MODEL] Loading the baseline model")
self.model = ProtoNet()

if self.state == "evaluate":
self.model.load_state_dict(self.checkpoint["state_dict"])

if self.model_type == "beats":
print("[MODEL] Loading the BEATs model")
self.beats = torch.load("/data/model/BEATs/BEATs_iter3_plus_AS2M.pt")
self.cfg = BEATsConfig(
{
**self.beats["cfg"],
"finetuned_model": False,
"specaugment_params": self.specaugment_params,
}
)
self.model = BEATs(self.cfg)
self.model.load_state_dict(self.checkpoint["model"])

if self.state == "train":
self.model.load_state_dict(self.checkpoint["model"])

if self.state == "validate":
self.model.load_state_dict(self.checkpoint["state_dict"], strict=False)

if self.model_type == "pann":
print("[MODEL] Loading the PANN model")
layers_to_remove = ["spectrogram_extractor.stft.conv_real.weight", "spectrogram_extractor.stft.conv_imag.weight", "logmel_extractor.melW",
"fc_audioset.weight", "fc_audioset.bias"]

for key in layers_to_remove:
del self.checkpoint["model"][key]
self.model = Cnn14()
self.model.load_state_dict(self.checkpoint["model"])

if self.state == "train":
layers_to_remove = ["spectrogram_extractor.stft.conv_real.weight", "spectrogram_extractor.stft.conv_imag.weight", "logmel_extractor.melW",
"fc_audioset.weight", "fc_audioset.bias"]
for key in layers_to_remove:
del self.checkpoint["model"][key]
self.model.load_state_dict(self.checkpoint["model"])

if self.state == "validate":
self.model.load_state_dict(self.checkpoint["state_dict"], strict=False) # we set strict = False because the names of the modules slightly vary

else:
print("[ERROR] the model specified is not included in the pipeline. Please use 'baseline', 'pann' or 'beats'")

Expand Down

0 comments on commit 553a5f5

Please sign in to comment.