Skip to content

Commit

Permalink
changed t-SNE output default path to model_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
alanngnet committed Mar 26, 2024
1 parent 20dae20 commit c282f62
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ The training script's output consists of checkpoint files and embedding vectors,
This script evaluates your trained model by providing mAP and MR1 metrics and an optional t-SNE clustering plot (compare Fig. 3 in the CoverHunter paper).

1. Have a pre-trained CoverHunter model's output checkpoint files available. You only need your best set (typically your highest-numbered one). If you use original CoverHunter's pre-trained model from https://drive.google.com/file/d/1rDZ9CDInpxQUvXRLv87mr-hfDfnV7Y-j/view), unzip it, and move it to a folder that you rename to, in this example, 'pretrained_model'.
2. Run your query data through extract_csi_features. In the hparams.yaml file for the feature extraction, turn off all augmentation. See data/covers80_testset/hparams.yaml for an example configuration to treat covers80 as the query data:<br> `python3 -m tools.extract_csi_features data/covers80_testset`<br>
The important output from that is full.txt and the cqt_feat subfolder's contents.
3. Run the evaluation script:<br>
2. Run your query data through `extract_csi_features.py`. In the `hparams.yaml` file for the feature extraction, turn off all augmentation. See `data/covers80_testset/hparams.yaml` for an example configuration to treat covers80 as the query data:<br> `python3 -m tools.extract_csi_features data/covers80_testset`<br>
The important output from that is `full.txt` and the `cqt_feat` subfolder's contents.
3. Run the evaluation script. This example assumes you are using the trained model you created in `egs/covers80` and you want to use all the optional features I added in this fork:<br>
`python3 -m tools.eval_testset egs/covers80 data/covers80_testset/full.txt data/covers80_testset/full.txt -plot_name="egs/covers80/tSNE.png" -dist_name='distmatrix' -test_only_labels='data/covers80/dev-only-song-ids.txt'`

CoverHunter only shared an evaluation example for the case when query and reference data are identical, presumably to do a self-similarity evaluation of the model. But there is an optional 4th parameter for `query_in_ref_path` that would be relevant if query and reference are not identical. See the "query_in_ref" heading below under "Input and Output Files."

The optional `plot_name` argument is a path where you want to save the t-SNE plot output. See example plot below. Note that your query and reference files must be identical to generate a t-SNE plot (to do a self-similarity evaluation).
The optional `plot_name` argument is a path or just a filename where you want to save the t-SNE plot output. If you provide just a filename, `model_dir` will be used as the path. See example plot below. Note that your query and reference files must be identical to generate a t-SNE plot (to do a self-similarity evaluation).

The optional `test_only_labels` argument is a path to the text file generated by `extract_csi_features.py` if its hyperparameters asked for some song_ids to be reserved exclusively for the test aka "dev" dataset. The t-SNE plot will then mark those for you to see how well your model can cluster classes (song_ids) it has never seen before.

Expand Down
Binary file modified tSNE-example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed tSNE.png
Binary file not shown.
19 changes: 15 additions & 4 deletions tools/eval_testset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _main():
parser.add_argument('query_path')
parser.add_argument('ref_path')
parser.add_argument('-query_in_ref_path', default='', type=str)
parser.add_argument('-plot_name', default='', type=str, help='Save a t-SNE plot of the distance matrix to this path')
parser.add_argument('-plot_name', default='', type=str, help='Save a t-SNE plot of the distance matrix to this path. Default path is model_dir if plot_name is just a filename.')
parser.add_argument('-test_only_labels', default='', type=str, help='Path to list of song_ids reserved for test dataset for use in t-SNE plot.')
parser.add_argument('-dist_name', default='', type=str, help='Save the distance matrix to this path')

Expand Down Expand Up @@ -56,16 +56,27 @@ def _main():
epoch = model.load_model_parameters(checkpoint_dir, device=device)

embed_dir = os.path.join(model_dir, "embed_{}_{}".format(epoch, "tmp"))

if args.plot_name:
plot_name = args.plot_name
path = os.path.dirname(plot_name)
if path!= '':
assert os.path.isdir(path), f"Invalid plot path: {plot_name}"
else:
# put the plot in model_dir as default location
plot_name = os.path.join(model_dir, plot_name)
else:
plot_name = ''

if args.test_only_labels:
# convert list of song IDs from strings to integers as _cluster_plot() expects
test_only_labels = [int(n) for n in read_lines(args.test_only_labels)]
# convert list of song IDs from strings to integers as _cluster_plot() expects
test_only_labels = [int(n) for n in read_lines(args.test_only_labels)]

mean_ap, hit_rate, rank1 = eval_for_map_with_feat(
hp, model, embed_dir, query_path=query_path,
ref_path=ref_path, query_in_ref_path=query_in_ref_path,
batch_size=64, logger=logger, test_only_labels=test_only_labels,
plot_name=args.plot_name, dist_name=args.dist_name)
plot_name=plot_name, dist_name=args.dist_name)

return

Expand Down

0 comments on commit c282f62

Please sign in to comment.