Skip to content

Commit

Permalink
[ADD] filtering by p-values
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Feb 23, 2024
1 parent e5a8517 commit d52afd0
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 30 deletions.
19 changes: 7 additions & 12 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@ data:
n_task_val: 100
target_fs: 16000 # used in preprocessing
resample: true # used in preprocessing
denoise: true # used in preprocessing
denoise: False # used in preprocessing
normalize: true # used in preprocessing
frame_length: 25.0 # used in preprocessing
tensor_length: 128 # used in preprocessing
n_shot: 5
n_query: 10
n_shot: 3 # number of images PER CLASS in the support set
n_query: 2 # number of images PER CLASS in the query set
overlap: 0.5 # used in preprocessing
n_subsample: 1
num_mel_bins: 128 # used in preprocessing
max_segment_length: 1.0 # used in preprocessing
status: validate # used in preprocessing, train or validate or evaluate
status: train # used in preprocessing, train or validate or evaluate


#################################
Expand All @@ -40,15 +40,10 @@ model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_type: beats # beats, pann or baseline
state: None # train or validate or None if using beats or baseline // since we remove from layers from the original PANN we can't load the ckpts normally (see _build_model in prototraining.py)
model_path: /data/models/BEATs/BEATs_iter3_plus_AS2M.pt
state: train # train or validate - for which model should be loaded
model_path: None
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
# time_mask: 40
# freq_mask: 40

##################################################################
# PARAMETERS FOR RUNNING THE TRAINED MODEL ON THE EVALUATION SET #
##################################################################

# freq_mask: 40
10 changes: 5 additions & 5 deletions CONFIG_PREDICT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ data:
normalize: True # used in preprocessing
frame_length: 25.0 # used in preprocessing
tensor_length: 128 # used in preprocessing
n_shot: 2
n_query: 3
n_shot: 5
n_query: 10
overlap: 0.5 # used in preprocessing
n_subsample: 1
num_mel_bins: 128 # used in preprocessing
Expand All @@ -42,8 +42,8 @@ model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_type: beats # beats, pann or baseline
state: None # train or validate if pann or None if using beats or baseline // since we remove from layers from the original PANN we can't load the ckpts normally (see _build_model in prototraining.py)
model_path: "/data/lightning_logs/version_0/checkpoints/epoch=24-step=2500.ckpt"
state: validate # train or validate - for which model should be loaded
model_path: None
specaugment_params: null
# specaugment_params:
# application_ratio: 1.0
Expand All @@ -54,7 +54,7 @@ model:
# PARAMETERS FOR MODEL PREDICTION #
###################################
predict:
wav_save: True
wav_save: False
overwrite: True
n_self_detected_supports: 0
tolerance: 0
30 changes: 29 additions & 1 deletion evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,31 @@ def merge_preds(df, tolerence, tensor_length):
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
return result

def calculate_p_values(X_filtered):
# Calculate p-values for the filtered subset of X
sorted_X = np.sort(X_filtered)
p_values = np.searchsorted(sorted_X, X_filtered, side='right') / len(X_filtered)
return p_values

def update_labels_for_outliers(X, Y, target_class=1, upper_threshold=0.95):
# Filter X and Y for the target class
X_filtered = X[Y == target_class]
indices_filtered = np.arange(len(X))[Y == target_class] # Indices of Y == target_class in the original array

# Calculate p-values for the filtered subset of X
p_values_filtered = calculate_p_values(X_filtered)

# Identify outliers within the filtered subset based on p-values
outlier_flags = (p_values_filtered > upper_threshold)

# Map back the indices of identified outliers to the original array
outlier_indices = indices_filtered[outlier_flags]

# Update labels in the original Y array for identified outliers
Y[outlier_indices] = 0

return Y

def compute(
cfg,
meta_df,
Expand Down Expand Up @@ -413,10 +438,13 @@ def compute(
overlap=cfg["data"]["overlap"],
pos_index=pos_index,
)

# Identify outliers
updated_labels = update_labels_for_outliers(distances_to_pos, predicted_labels)

# Compute the scores for the analysed file -- just as information
acc, recall, precision, f1score = compute_scores(
predicted_labels=predicted_labels,
predicted_labels=updated_labels,
gt_labels=labels,
)
with open(
Expand Down
10 changes: 5 additions & 5 deletions prototypicalbeats/prototraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.specaugment_params = specaugment_params
self.beats_path = beats_path

if model_path:
if model_path != "None":
self.checkpoint = torch.load(model_path)
if self.state == "validate":
self.adjusted_state_dict= OrderedDict()
Expand Down Expand Up @@ -76,6 +76,10 @@ def _build_model(self):
print("LOADING THE FINE-TUNED MODEL")
self.model.load_state_dict(self.adjusted_state_dict, strict=True)

#else:
# print("NOT LOADING ANY FINE-TUNED MODEL")
# self.model = self.model

if self.model_type == "beats":
print("[MODEL] Loading the BEATs model")
self.beats = torch.load(self.beats_path)
Expand Down Expand Up @@ -115,10 +119,6 @@ def _build_model(self):
print("LOADING THE FINE-TUNED MODEL")
self.model.load_state_dict(self.adjusted_state_dict, strict=True)

#else:
# print("[ERROR] the model specified is not included in the pipeline. Please use 'baseline', 'pann' or 'beats'")


def euclidean_distance(self, x1, x2):
return torch.sqrt(torch.sum((x1 - x2) ** 2, dim=1))

Expand Down
4 changes: 2 additions & 2 deletions shell_scripts/train_baseline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ fi

docker run -v $BASE_FOLDER:/data -v $PWD:/app --gpus all beats poetry run prototypicalbeats/trainer.py fit \
--config $CONFIG_PATH \
--model_type baseline \
--state train \
--model.model_type baseline \
--model.state train \
--trainer.default_root_dir /data/lightning_logs/BASELINE/
1 change: 1 addition & 0 deletions shell_scripts/train_pann.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fi

docker run -v $BASE_FOLDER:/data -v $PWD:/app --gpus all beats poetry run prototypicalbeats/trainer.py fit \
--config $CONFIG_PATH \
--model.model_type=pann \
--trainer.default_root_dir /data/lightning_logs/PANN/ \
--model.model_path /data/models/PANN/Cnn14_mAP=0.431.pth \
--trainer.default_root_dir /data/lightning_logs/PANN/
13 changes: 13 additions & 0 deletions shell_scripts/validate_baseline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

BASE_DIR=$1

cd ..

docker run -v $BASE_DIR:/data -v $PWD:/app \
--gpus all \
beats \
poetry run python /app/evaluate/evaluateDCASE.py \
'model.model_type="baseline"' \
'model.state="validate"' \
'model.model_path="/data/lightning_logs/BASELINE/lightning_logs/version_1/checkpoints/epoch=50-step=5100.ckpt"'
8 changes: 3 additions & 5 deletions shell_scripts/validate_beats.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ docker run -v $BASE_DIR:/data -v $PWD:/app \
--gpus all \
beats \
poetry run python /app/evaluate/evaluateDCASE.py \
--config "/app/CONFIG_PREDICT.yaml" \
--overwrite
#data.status="validate" \
#model.model_type="beats" \
#model.model_path="/data/lightning_logs/BEATs/lightning_logs/version_2/checkpoints/epoch=99-step=10000.ckpt"
'model.model_type="beats"' \
'model.state="train"' \
'model.model_path="/data/models/BEATs/BEATs_iter3_plus_AS2M.pt"'
14 changes: 14 additions & 0 deletions shell_scripts/validate_pann.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

BASE_DIR=$1

cd ..

docker run -v $BASE_DIR:/data \
-v $PWD:/app \
--gpus all \
beats \
poetry run python /app/evaluate/evaluateDCASE.py \
'model.model_type="pann"' \
'model.model_path="/data/models/PANN/Cnn14_mAP=0.431.pth"' \
'model.state="None"'

0 comments on commit d52afd0

Please sign in to comment.