From 520f2b28e63c9e7744fc47367ee3ba360e9635b9 Mon Sep 17 00:00:00 2001 From: LashaO Date: Sat, 6 Jul 2024 10:07:07 +0400 Subject: [PATCH] Dev docs (#11) * initial commit * visualization modifications * Refactors submodules * Modify gitignore * Diasbles strict model loading * Adds examples * Updates model tag * Notebook checkpoint * Updates readme * Updates Readme --------- Co-authored-by: LashaO --- .gitignore | 2 + .ipynb_checkpoints/README-checkpoint.md | 104 +++ .ipynb_checkpoints/Untitled-checkpoint.ipynb | 6 + README.md | 136 +++- Untitled.ipynb | 33 + setup.py | 7 + wbia_miew_id/__init__.py | 3 - wbia_miew_id/_plugin.py | 28 +- wbia_miew_id/datasets/default_dataset.py | 3 +- wbia_miew_id/datasets/transforms.py | 39 +- wbia_miew_id/engine/__init__.py | 1 - wbia_miew_id/engine/eval_fn.py | 3 +- wbia_miew_id/engine/group_eval.py | 34 +- wbia_miew_id/engine/run_fn.py | 61 -- wbia_miew_id/engine/train_fn.py | 2 +- wbia_miew_id/etl/images.py | 4 +- wbia_miew_id/etl/preprocess.py | 17 +- wbia_miew_id/evaluate.py | 205 +++++ wbia_miew_id/examples/download_example.py | 50 ++ .../examples/extract_and_evaluate.ipynb | 480 +++++++++++ wbia_miew_id/examples/run_training.ipynb | 748 ++++++++++++++++++ wbia_miew_id/examples/split_dataset.ipynb | 291 +++++++ wbia_miew_id/helpers/__init__.py | 4 +- wbia_miew_id/helpers/config.py | 50 +- wbia_miew_id/helpers/split/__init__.py | 3 + wbia_miew_id/helpers/split/split.py | 28 +- wbia_miew_id/models/heads.py | 22 +- wbia_miew_id/models/model.py | 54 +- wbia_miew_id/models/model_helpers.py | 2 +- wbia_miew_id/sweep.py | 106 ++- wbia_miew_id/test.py | 136 ---- wbia_miew_id/train.py | 350 ++++---- wbia_miew_id/visualization/gradcam.py | 12 +- wbia_miew_id/visualization/match_vis.py | 16 +- 34 files changed, 2482 insertions(+), 558 deletions(-) create mode 100644 .ipynb_checkpoints/README-checkpoint.md create mode 100644 .ipynb_checkpoints/Untitled-checkpoint.ipynb create mode 100644 Untitled.ipynb create mode 100644 setup.py delete mode 100644 wbia_miew_id/engine/run_fn.py create mode 100644 wbia_miew_id/evaluate.py create mode 100644 wbia_miew_id/examples/download_example.py create mode 100644 wbia_miew_id/examples/extract_and_evaluate.ipynb create mode 100644 wbia_miew_id/examples/run_training.ipynb create mode 100644 wbia_miew_id/examples/split_dataset.ipynb create mode 100644 wbia_miew_id/helpers/split/__init__.py delete mode 100644 wbia_miew_id/test.py diff --git a/.gitignore b/.gitignore index 73e4c5c..9fa2c16 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ wbia_miew_id/configs/* *.png wbia_miew_id/splits/ wbia_miew_id/helpers/split/configs/config_*.yaml +wbia_miew_id.egg* +wbia_miew_id/examples/beluga_example_miewid/* \ No newline at end of file diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000..99cdfb2 --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,104 @@ + +# WILDBOOK IA - MIEW-ID Plugin + +A plugin for matching and interpreting embeddings for wildlife identification. + + +## Setup + +` pip install -r requirements.txt ` + +Optionally, these environment variables must be set to enable Weights and Biases logging +capability: +``` +WANDB_API_KEY={your_wanb_api_key} +WANDB_MODE={'online'/'offline'} +``` + +## Training +You can create a new line in a code block in markdown by using two spaces at the end of the line followed by a line break. Here's an example: + +``` +cd wbia_miew_id +python train.py +``` + +## Data files + +The data is expected to be in the coco JSON format. Paths to data files and the image directory are defined in the config YAML file. + +The beluga data can be downloaded from [here](https://cthulhu.dyn.wildme.io/public/datasets/beluga-model-data.zip). + +## Configuration file + +A config file path can be set by: +`python train.py --config {path_to_config}` + +- `exp_name`: Name of the experiment +- `project_name`: Name of the project +- `checkpoint_dir`: Directory for storing training checkpoints +- `comment`: Comment text for the experiment +- `viewpoint_list`: List of viewpoint values to keep for all subsets. +- `data`: Subfield for data-related settings + - `images_dir`: Directory containing the all of the dataset images + - `use_full_image_path`: Overrides the images_dir for path construction and instead uses an absolute path that should be defined in the `file_path` file path under the `images` entries for each entry in the COCO JSON. In such a case, `images_dir` can be set to `null` + - `crop_bbox`: Whether to use the `bbox` field of JSON annotations to crop the images. The crops will also be adjusted for rotation if the `theta` field is present for the annotations + - `preprocess_images` pre-applies cropping and resizing and caches the images for training + - `train`: Data parameters regarding the train set used in train.py + - `anno_path`: Path to the JSON file containing the annotations + - `n_filter_min`: Minimum number of samples per name (individual) to keep that individual in the set. Names under the threshold will be discarded + - `n_subsample_max`: Maximum number of samples per name to keep for the training set. Annotations for names over the threshold will be randomly subsampled once at the start of training + - `val`: Data parameters regarding the validation set used in train.py + - `anno_path` + - `n_filter_min` + - `n_subsample_max` + - `test`: Data parameters regarding the test set used in test.py + - `anno_path` + - `n_filter_min` + - `n_subsample_max` + - `checkpoint_path`: Path to model checkpoint to test + - `eval_groups`: Attributes for which to group the testing sets. For example, the value of `['viewpoint']` will create subsets of the test set for each unique value of the viewpoint and run one-vs-all evaluation for each subset separately. The value can be a list - `[['species', 'viewpoint']]` will run evaluation separately for each species+viewpoint combination. `['species', 'viewpoint']` will run grouped eval for each species, and then for each viewpoint. The corresponding fields to be grouped should be present under `annotation` entries in the COCO file. Can be left as `null` to do eval for the full test set. + - `name_keys`: List of keys used for defining a unique name (individual). Fields from multiple keys will be combined to form the final representation of a name. A common use-case is `name_keys: ['name', 'viewpoint']` for treating each name + viewpoint combination as a unique individual + - `image_size`: + - Image height to resize to + - Image width to resize to +- `engine`: Subfields for engine-related settings + - `num_workers`: Number of workers for data loading (default: 0) + - `train_batch_size`: Batch size for training + - `valid_batch_size`: Batch size for validation + - `epochs`: Number of training epochs + - `seed`: Random seed for reproducibility + - `device`: Device to be used for training + - `use_wandb`: Whether to use Weights and Biases for logging + - `use_swa`: Whether to use SWA during training +- `scheduler_params`: Subfields for learning rate scheduler parameters + - `lr_start`: Initial learning rate + - `lr_max`: Maximum learning rate + - `lr_min`: Minimum learning rate + - `lr_ramp_ep`: Number of epochs to ramp up the learning rate + - `lr_sus_ep`: Number of epochs to sustain the maximum learning rate + - `lr_decay`: Rate of learning rate decay per epoch +- `model_params`: Dictionary containing model-related settings + - `model_name`: Name of the model backbone architecture + - `use_fc`: Whether to use a fully connected layer after backbone extraction + - `fc_dim`: Dimension of the fully connected layer + - `dropout`: Dropout rate + - `loss_module`: Loss function module + - `s`: Scaling factor for the loss function + - `margin`: Margin for the loss function + - `pretrained`: Whether to use a pretrained model backbone + - `n_classes`: Number of classes in the training dataset, used for loading checkpoint +- `swa_params`: Subfields for SWA training + - `swa_lr`: SWA learning rate + - `swa_start`: Epoch number to begin SWA training +- `test`: Subfields for plugin-related settings + - `fliplr`: Whether to perform horizontal flipping during testing + - `fliplr_view`: List of viewpoints to apply horizontal flipping + - `batch_size`: Batch size for plugin inference + +## Testing +`python test.py --config {path_to_config} --visualize` + +The `--visualize` flag is optional and will produce top 5 match results for each individual in the test set, along with gradcam visualizations. + +The parameters for the test set are defined under `data.test` of the config.yaml file. diff --git a/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/.ipynb_checkpoints/Untitled-checkpoint.ipynb new file mode 100644 index 0000000..363fcab --- /dev/null +++ b/.ipynb_checkpoints/Untitled-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/README.md b/README.md index 6259e7c..7825a08 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,15 @@ -# WILDBOOK IA - MIEW-ID Plugin +# WILDBOOK IA - MiewID Plugin A plugin for matching and interpreting embeddings for wildlife identification. ## Setup -` pip install -r requirements.txt ` +``` +pip install -r requirements.txt +pip install -e . +``` Optionally, these environment variables must be set to enable Weights and Biases logging capability: @@ -15,19 +18,129 @@ WANDB_API_KEY={your_wanb_api_key} WANDB_MODE={'online'/'offline'} ``` -## Training -You can create a new line in a code block in markdown by using two spaces at the end of the line followed by a line break. Here's an example: +## Multispecies-V2 Model + +Model specs and dataset overview can be found at the [model card page for the Multispecies-v2 model](https://huggingface.co/conservationxlabs/miewid-msv2) + +### Pretrained Model Embeddings Extraction + +``` +import numpy as np +from PIL import Image +import torch +import torchvision.transforms as transforms +from transformers import AutoModel + +model_tag = f"conservationxlabs/miewid-msv2" +model = AutoModel.from_pretrained(model_tag, trust_remote_code=True) + +def generate_random_image(height=440, width=440, channels=3): + random_image = np.random.randint(0, 256, (height, width, channels), dtype=np.uint8) + return Image.fromarray(random_image) + +random_image = generate_random_image() + +preprocess = transforms.Compose([ + transforms.Resize((440, 440)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +input_tensor = preprocess(random_image) +input_batch = input_tensor.unsqueeze(0) + +with torch.no_grad(): + output = model(input_batch) + +print(output) +print(output.shape) + +``` + +### Pretrained Model Evaluation +``` +import torch +from wbia_miew_id.evaluate import Evaluator +from transformers import AutoModel + +evaluator = Evaluator( + device=torch.device('cuda'), + seed=0, + anno_path='beluga_example_miewid/benchmark_splits/test.csv', + name_keys=['name'], + viewpoint_list=None, + use_full_image_path=True, + images_dir=None, + image_size=(440, 440), + crop_bbox=True, + valid_batch_size=12, + num_workers=8, + eval_groups=[['species', 'viewpoint']], + fliplr=False, + fliplr_view=[], + n_filter_min=2, + n_subsample_max=10, + model_params=None, + checkpoint_path=None, + model=model, + visualize=False, + visualization_output_dir='beluga_example_visualizations' +) +``` + +## Example Usage + +### Example dataset download ``` cd wbia_miew_id -python train.py +python examples/download_example.py ``` +### Training + +``` +python train.py --config=examples/beluga_example_miewid/benchmark_model/miew_id.msv2_all.yaml +``` + +### Evaluation + +``` +python evaluate.py --config=examples/beluga_example_miewid/benchmark_model/miew_id.msv2_all.yaml +``` + +Optional `--visualize` flag can be used to produce top 5 match results for each individual in the test set, along with gradcam visualizations. + +### Data Splitting, Training, and Evaluation Using Python Bindings + +Demo notebooks are avaliable at [examples directory](https://github.com/WildMeOrg/wbia-plugin-miew-id/tree/main/wbia_miew_id/examples) + ## Data files -The data is expected to be in the coco JSON format. Paths to data files and the image directory are defined in the config YAML file. +### Example dataset + +The data is expected to be in the CSV or COCO JSON Format. + +[Recommended] The CSV beluga data can be downlaoded from [here](https://cthulhu.dyn.wildme.io/public/datasets/beluga_example_miewid.tar.gz). + +The COCO beluga data can be downloaded from [here](https://cthulhu.dyn.wildme.io/public/datasets/beluga-model-data.zip). + +### Expected CSV data format + +- `theta`: Bounding box rotation in radians +- `viewpoint`: Viewpoint of the individual facing the camera. Used for calculating per-viewpoint stats or separating individuals based on viewpoint +- `name`: Individual ID +- `file_name`: File name +- `viewpoint`: Species name. Used for calculating per-species stats +- `file_path`: Full path to images +- `x, y, w, h`: Bounding box coordinates + +|theta |viewpoint |name |file_name|species|file_path|x |y |w |h | +|--------------|--------------------------------|-----|---------|-------|---------|-----|--------------------------------------------------------------------------------------------------------------------|----|---| +|0 |up |1030 |000000006040.jpg|beluga_whale|/datasets/beluga-440/000000006040.jpg|0 |0 |162 |440| +|0 |up |1030 |000000006043.jpg|beluga_whale|/datasets/beluga-440/000000006043.jpg|0 |0 |154 |440| +|0 |up |508 |000000006044.jpg|beluga_whale|/datasets/beluga-440/000000006044.jpg|0 |0 |166 |440| -The beluga data can be downloaded from [here](https://cthulhu.dyn.wildme.io/public/datasets/beluga-model-data.zip). ## Configuration file @@ -42,7 +155,7 @@ A config file path can be set by: - `data`: Subfield for data-related settings - `images_dir`: Directory containing the all of the dataset images - `use_full_image_path`: Overrides the images_dir for path construction and instead uses an absolute path that should be defined in the `file_path` file path under the `images` entries for each entry in the COCO JSON. In such a case, `images_dir` can be set to `null` - - `crop_bbox`: Whether to use the `bbox` field of JSON annotations to crop the images. The crops will also be adjusted for rotation if the `theta` field is present for the annotations + - `crop_bbox`: Whether to use the bounding box metadata to crop the images. The crops will also be adjusted for rotation if the `theta` field is present for the annotations - `preprocess_images` pre-applies cropping and resizing and caches the images for training - `train`: Data parameters regarding the train set used in train.py - `anno_path`: Path to the JSON file containing the annotations @@ -69,7 +182,6 @@ A config file path can be set by: - `epochs`: Number of training epochs - `seed`: Random seed for reproducibility - `device`: Device to be used for training - - `loss_module`: Loss function module - `use_wandb`: Whether to use Weights and Biases for logging - `use_swa`: Whether to use SWA during training - `scheduler_params`: Subfields for learning rate scheduler parameters @@ -97,9 +209,3 @@ A config file path can be set by: - `fliplr_view`: List of viewpoints to apply horizontal flipping - `batch_size`: Batch size for plugin inference -## Testing -`python test.py --config {path_to_config} --visualize` - -The `--visualize` flag is optional and will produce top 5 match results for each individual in the test set, along with gradcam visualizations. - -The parameters for the test set are defined under `data.test` of the config.yaml file. diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 0000000..7d06d6d --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,33 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "92a81cba-8dcd-4396-908d-79e2e21f2905", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2e8f493 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name='wbia_miew_id', + version='0.1.1', + packages=find_packages(), +) diff --git a/wbia_miew_id/__init__.py b/wbia_miew_id/__init__.py index 7406388..c57bfd5 100644 --- a/wbia_miew_id/__init__.py +++ b/wbia_miew_id/__init__.py @@ -1,4 +1 @@ -from wbia_miew_id import _plugin # NOQA - - __version__ = '0.0.0' diff --git a/wbia_miew_id/_plugin.py b/wbia_miew_id/_plugin.py index 66f3f5b..fbba1c8 100644 --- a/wbia_miew_id/_plugin.py +++ b/wbia_miew_id/_plugin.py @@ -22,7 +22,7 @@ from wbia_miew_id.models import get_model from wbia_miew_id.datasets import PluginDataset, get_test_transforms from wbia_miew_id.metrics import pred_light, compute_distance_matrix, eval_onevsall -from wbia_miew_id.visualization import draw_one, draw_batch +from wbia_miew_id.visualization import draw_batch (print, rrr, profile) = ut.inject2(__name__) @@ -282,30 +282,6 @@ def render_single_result(request, cm, aid, **kwargs): return out_image - # def render_single_result(request, cm, aid, **kwargs): - - # depc = request.depc - # ibs = depc.controller - - # # Load config - # species = ibs.get_annot_species_texts(aid) - - # config = None - # if config is None: - # config = CONFIGS[species] - # config = _load_config(config) - - # # Load model - # model = _load_model(config, MODELS[species], use_dataparallel=False) - - # # This list has to be in the format of [query_aid, db_aid] - # aid_list = [cm.qaid, aid] - # test_loader, test_dataset = _load_data(ibs, aid_list, config) - - # out_image = draw_one(config, test_loader, model, images_dir = '', method='gradcam_plus_plus', eigen_smooth=False, show=False) - - # return out_image - def render_batch_result(request, cm, aids): depc = request.depc @@ -476,7 +452,7 @@ def _load_data(ibs, aid_list, config, multithread=False): Load data, preprocess and create data loaders """ - test_transform = get_test_transforms(config) + test_transform = get_test_transforms((config.data.image_size[0], config.data.image_size[1])) image_paths = ibs.get_annot_image_paths(aid_list) bboxes = ibs.get_annot_bboxes(aid_list) names = ibs.get_annot_name_rowids(aid_list) diff --git a/wbia_miew_id/datasets/default_dataset.py b/wbia_miew_id/datasets/default_dataset.py index 9820fa4..20c2202 100644 --- a/wbia_miew_id/datasets/default_dataset.py +++ b/wbia_miew_id/datasets/default_dataset.py @@ -9,13 +9,14 @@ class MiewIdDataset(Dataset): - def __init__(self, csv, transforms=None, fliplr=False, fliplr_view=[], crop_bbox=False): + def __init__(self, csv, transforms=None, fliplr=False, fliplr_view=[], crop_bbox=False, n_train_classes=None): self.csv = csv#.reset_index() self.augmentations = transforms self.fliplr = fliplr self.fliplr_view = fliplr_view self.crop_bbox = crop_bbox + self.n_train_classes = n_train_classes def __len__(self): return self.csv.shape[0] diff --git a/wbia_miew_id/datasets/transforms.py b/wbia_miew_id/datasets/transforms.py index 060d25a..4d71c76 100644 --- a/wbia_miew_id/datasets/transforms.py +++ b/wbia_miew_id/datasets/transforms.py @@ -46,24 +46,18 @@ def __init__(self, p): def apply(self, img , **params): return triangle(img , self.p) -def get_train_transforms(config): +def get_train_transforms(image_size): return albumentations.Compose( - [ Triangle(p = 0.5), - # albumentations.Resize(config.data.image_size[0],config.data.image_size[1],always_apply=True), - PyTorchResize(config.data.image_size[0], config.data.image_size[1], always_apply=True), - # albumentations.HorizontalFlip(p=0.5), - #albumentations.VerticalFlip(p=0.5), - #albumentations.ImageCompression (quality_lower=50, quality_upper=100, p = 0.5), + [ + Triangle(p=0.5), + PyTorchResize(image_size[0], image_size[1], always_apply=True), albumentations.OneOf([ - albumentations.Sharpen(p=0.3), - albumentations.ToGray(p=0.3), - albumentations.CLAHE(p=0.3), + albumentations.Sharpen(p=0.3), + albumentations.ToGray(p=0.3), + albumentations.CLAHE(p=0.3), ], p=0.5), - #albumentations.Rotate(limit=30, p=0.8), - #albumentations.RandomBrightness(limit=(0.09, 0.6), p=0.7), - # albumentations.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, always_apply=False, p=0.3), albumentations.ShiftScaleRotate( - shift_limit=0.25, scale_limit=0.2, rotate_limit=15,p = 0.5 + shift_limit=0.25, scale_limit=0.2, rotate_limit=15, p=0.5 ), albumentations.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), albumentations.Normalize(), @@ -71,23 +65,20 @@ def get_train_transforms(config): ] ) -def get_valid_transforms(config): - +def get_valid_transforms(image_size): return albumentations.Compose( [ - albumentations.Resize(config.data.image_size[0],config.data.image_size[1],always_apply=True), + albumentations.Resize(image_size[0], image_size[1], always_apply=True), albumentations.Normalize(), - ToTensorV2(p=1.0) + ToTensorV2(p=1.0) ] ) - -def get_test_transforms(config): - +def get_test_transforms(image_size): return albumentations.Compose( [ - albumentations.Resize(config.data.image_size[0],config.data.image_size[1],always_apply=True), + albumentations.Resize(image_size[0], image_size[1], always_apply=True), albumentations.Normalize(), - ToTensorV2(p=1.0) + ToTensorV2(p=1.0) ] - ) + ) \ No newline at end of file diff --git a/wbia_miew_id/engine/__init__.py b/wbia_miew_id/engine/__init__.py index 1e10306..2255a07 100644 --- a/wbia_miew_id/engine/__init__.py +++ b/wbia_miew_id/engine/__init__.py @@ -1,4 +1,3 @@ from .train_fn import * from .eval_fn import * -from .run_fn import * from .group_eval import * \ No newline at end of file diff --git a/wbia_miew_id/engine/eval_fn.py b/wbia_miew_id/engine/eval_fn.py index 1871acc..20447d0 100644 --- a/wbia_miew_id/engine/eval_fn.py +++ b/wbia_miew_id/engine/eval_fn.py @@ -4,8 +4,7 @@ import numpy as np import wandb -from metrics import AverageMeter, compute_distance_matrix, compute_calibration, eval_onevsall, topk_average_precision, precision_at_k, get_accuracy -from helpers.swatools import extract_outputs +from wbia_miew_id.metrics import AverageMeter, compute_distance_matrix, compute_calibration, eval_onevsall, topk_average_precision, precision_at_k, get_accuracy from torch.cuda.amp import autocast def extract_embeddings(data_loader, model, device): diff --git a/wbia_miew_id/engine/group_eval.py b/wbia_miew_id/engine/group_eval.py index f298eeb..7229f5e 100644 --- a/wbia_miew_id/engine/group_eval.py +++ b/wbia_miew_id/engine/group_eval.py @@ -1,10 +1,11 @@ import torch import numpy as np -from datasets import MiewIdDataset, get_test_transforms -from .eval_fn import eval_fn, log_results -from etl import filter_min_names_df, subsample_max_df, preprocess_data +from wbia_miew_id.datasets import MiewIdDataset, get_test_transforms +from wbia_miew_id.engine import eval_fn, log_results +from wbia_miew_id.etl import filter_min_names_df, subsample_max_df, preprocess_data -def group_eval(config, df_test, eval_groups, model): + +def group_eval_run(df_test, eval_groups, model, n_filter_min, n_subsample_max, image_size, fliplr, fliplr_view, crop_bbox, valid_batch_size, device): print("** Calculating groupwise evaluation scores **") @@ -13,30 +14,28 @@ def group_eval(config, df_test, eval_groups, model): for group, df_group in df_test.groupby(eval_group): try: print('* Evaluating group:', group) - n_filter_min = config.data.test.n_filter_min if n_filter_min: print(len(df_group)) df_group = filter_min_names_df(df_group, n_filter_min) - n_subsample_max = config.data.test.n_subsample_max if n_subsample_max: df_group = subsample_max_df(df_group, n_subsample_max) test_dataset = MiewIdDataset( csv=df_group, - transforms=get_test_transforms(config), - fliplr=config.test.fliplr, - fliplr_view=config.test.fliplr_view, - crop_bbox=config.data.crop_bbox, + transforms=get_test_transforms((image_size[0], image_size[1])), + fliplr=fliplr, + fliplr_view=fliplr_view, + crop_bbox=crop_bbox, ) test_loader = torch.utils.data.DataLoader( test_dataset, - batch_size=config.engine.valid_batch_size, + batch_size=valid_batch_size, num_workers=0, shuffle=False, pin_memory=True, drop_last=False, ) - device = torch.device(config.engine.device) + device = torch.device(device) test_score, test_cmc, test_outputs = eval_fn(test_loader, model, device, use_wandb=False, return_outputs=True) except Exception as E: print('* Could not evaluate group:', group) @@ -49,6 +48,7 @@ def group_eval(config, df_test, eval_groups, model): return group_results + def group_eval_fn(config, eval_groups, model, use_wandb=True): print('Evaluating on groups') df_test_group = preprocess_data(config.data.test.anno_path, @@ -59,7 +59,15 @@ def group_eval_fn(config, eval_groups, model, use_wandb=True): n_subsample_max=None, use_full_image_path=config.data.use_full_image_path, images_dir = config.data.images_dir) - group_results = group_eval(config, df_test_group, eval_groups, model) + group_results = group_eval_run(df_test_group, eval_groups, model, + n_filter_min = config.data.test.n_filter_min, + n_subsample_max = config.data.test.n_subsample_max, + image_size = (config.data.image_size[0], config.data.image_size[1]), + fliplr = config.test.fliplr, + fliplr_view = config.test.fliplr_view, + crop_bbox = config.data.crop_bbox, + valid_batch_size = config.engine.valid_batch_size, + device = config.engine.device) group_scores = [] group_cmcs = [] diff --git a/wbia_miew_id/engine/run_fn.py b/wbia_miew_id/engine/run_fn.py deleted file mode 100644 index 896890e..0000000 --- a/wbia_miew_id/engine/run_fn.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -from tabulate import tabulate - -from .train_fn import train_fn -from .eval_fn import eval_fn -from .group_eval import group_eval_fn -from helpers.swatools import update_bn - -def run_fn(config, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=True, swa_model=None, swa_start=None, swa_scheduler=None): - - best_score = 0 - - #### To load the checkpoint - # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - # scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # If you're using a scheduler - # best_score = checkpoint['best_score'] - # start_epoch = checkpoint['epoch'] + 1 # Resume from the next epoch - - - for epoch in range(config.engine.epochs): - train_loss = train_fn(train_loader, model,criterion, optimizer, device,scheduler=scheduler,epoch=epoch, use_wandb=use_wandb, swa_model=swa_model, swa_start=swa_start, swa_scheduler=swa_scheduler) - - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), # If you're using a scheduler - 'best_score': best_score, - }, f'{checkpoint_dir}/checkpoint_latest.bin') - - # print("\nGetting metrics on train set...") - # train_score, train_cmc = eval_fn(train_loader, model, device, use_wandb=use_wandb, return_outputs=False) - - print("\nGetting metrics on validation set...") - eval_groups = config.data.test.eval_groups - - if eval_groups: - valid_score, valid_cmc = group_eval_fn(config, eval_groups, model) - print('Group average score: ', valid_score) - - else: - print('Evaluating on full test set') - valid_score, valid_cmc = eval_fn(valid_loader, model, device, use_wandb=use_wandb, return_outputs=False) - print('Valid score: ', valid_score) - - # print("\n") - # print(tabulate([["Train", 0], ["Valid", valid_score]], headers=["Split", "mAP"])) - # print("\n\n") - - if valid_score > best_score: - best_score = valid_score - torch.save(model.state_dict(), f'{checkpoint_dir}/model_best.bin') - print('best model found for epoch {}'.format(epoch)) - - # Update bn statistics for the swa_model at the end - if swa_model: - print("Updating SWA batchnorm statistics...") - update_bn(train_loader, swa_model, device=device) - torch.save(swa_model.state_dict(), f'{checkpoint_dir}/swa_model_{epoch}.bin') - - return best_score diff --git a/wbia_miew_id/engine/train_fn.py b/wbia_miew_id/engine/train_fn.py index 0d54047..4257e0a 100644 --- a/wbia_miew_id/engine/train_fn.py +++ b/wbia_miew_id/engine/train_fn.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from tqdm.auto import tqdm import wandb -from metrics import AverageMeter, compute_calibration +from wbia_miew_id.metrics import AverageMeter, compute_calibration def train_fn(dataloader, model, criterion, optimizer, device, scheduler, epoch, use_wandb=True, swa_model=None, swa_start=None, swa_scheduler=None): diff --git a/wbia_miew_id/etl/images.py b/wbia_miew_id/etl/images.py index 77eed6d..00c034b 100644 --- a/wbia_miew_id/etl/images.py +++ b/wbia_miew_id/etl/images.py @@ -4,8 +4,8 @@ from PIL import Image from concurrent.futures import ProcessPoolExecutor, as_completed from tqdm.auto import tqdm -from datasets import get_chip_from_img -from etl import preprocess_data +from wbia_miew_id.datasets import get_chip_from_img +from wbia_miew_id.etl import preprocess_data from torchvision import transforms def process_image(row, crop_bbox, preprocess_dir, chip_idx, target_size): diff --git a/wbia_miew_id/etl/preprocess.py b/wbia_miew_id/etl/preprocess.py index 87ded1f..3d04027 100644 --- a/wbia_miew_id/etl/preprocess.py +++ b/wbia_miew_id/etl/preprocess.py @@ -10,7 +10,7 @@ def load_json(file_path): data = json.load(f) return data -def load_to_df(anno_path): +def load_json_to_df(anno_path): data = load_json(anno_path) dfa = pd.DataFrame(data['annotations']) @@ -30,6 +30,12 @@ def load_to_df(anno_path): return df +def load_to_df(anno_path): + df = pd.read_csv(anno_path) + df['bbox'] = df[['x', 'y', 'w', 'h']].values.tolist() + + return df + def filter_viewpoint_df(df, viewpoint_list): df = df[df['viewpoint'].isin(viewpoint_list)] print(' ', len(df), 'annotations remain after filtering by viewpoint list', viewpoint_list) @@ -67,8 +73,12 @@ def filter_df(df, viewpoint_list, n_filter_min, n_subsample_max, filter_key='nam def preprocess_data(anno_path, name_keys=['name'], convert_names_to_ids=True, viewpoint_list=None, n_filter_min=None, n_subsample_max=None, use_full_image_path=False, images_dir=None): - df = load_to_df(anno_path) - + if anno_path.lower().endswith('json'): + df = load_json_to_df(anno_path) + elif anno_path.lower().endswith('csv'): + df = load_to_df(anno_path) + else: + raise NotImplementedError("Annotation file extension not supported.") df['name'] = df[name_keys].apply(lambda row: '_'.join(row.values.astype(str)), axis=1) @@ -77,6 +87,7 @@ def preprocess_data(anno_path, name_keys=['name'], convert_names_to_ids=True, vi df['name_species'] = df['name'] + '_' + df['species'] filter_key = 'name_species' else: + df['species'] = 'default_species' filter_key = 'name' df['name_orig'] = df['name'].copy() diff --git a/wbia_miew_id/evaluate.py b/wbia_miew_id/evaluate.py new file mode 100644 index 0000000..4d23c98 --- /dev/null +++ b/wbia_miew_id/evaluate.py @@ -0,0 +1,205 @@ +from wbia_miew_id.datasets import MiewIdDataset, get_train_transforms, get_valid_transforms, get_test_transforms +from wbia_miew_id.logging_utils import WandbContext +from wbia_miew_id.models import MiewIdNet +from wbia_miew_id.etl import preprocess_data, print_basic_stats +from wbia_miew_id.engine import eval_fn, group_eval_run +from wbia_miew_id.helpers import get_config +from wbia_miew_id.visualization import render_query_results +from wbia_miew_id.metrics import precision_at_k + +import os +import torch +import random +import numpy as np + +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description="Load configuration file.") + parser.add_argument( + '--config', + type=str, + default='configs/default_config.yaml', + help='Path to the YAML configuration file. Default: configs/default_config.yaml' + ) + + parser.add_argument('--visualize', '--vis', action='store_true') + + return parser.parse_args() + +class Evaluator: + def __init__(self, device, seed, anno_path, name_keys, viewpoint_list, use_full_image_path, images_dir, image_size, + crop_bbox, valid_batch_size, num_workers, eval_groups, fliplr, fliplr_view, n_filter_min, n_subsample_max, + model_params=None, checkpoint_path=None, model=None, visualize=False, visualization_output_dir='miewid_visualizations'): + self.device = device + self.visualize = visualize + self.seed = seed + self.model_params = model_params + self.checkpoint_path = checkpoint_path + self.anno_path = anno_path + self.name_keys = name_keys + self.viewpoint_list = viewpoint_list + self.use_full_image_path = use_full_image_path + self.images_dir = images_dir + self.image_size = image_size + self.crop_bbox = crop_bbox + self.valid_batch_size = valid_batch_size + self.num_workers = num_workers + self.eval_groups = eval_groups + self.fliplr = fliplr + self.fliplr_view = fliplr_view + self.n_filter_min = n_filter_min + self.n_subsample_max = n_subsample_max + self.visualization_output_dir = visualization_output_dir + + self.set_seed_torch(seed) + + if model is not None: + self.model = model.to(device) + else: + self.model = self.load_model(device, model_params, checkpoint_path) + + @staticmethod + def set_seed_torch(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + @staticmethod + def load_model(device, model_params, checkpoint_path): + model = MiewIdNet(**model_params) + model.to(device) + + if checkpoint_path: + weights = torch.load(checkpoint_path, map_location=device) + n_train_classes = weights[list(weights.keys())[-1]].shape[-1] + if model_params['n_classes'] != n_train_classes: + print(f"WARNING: Overriding n_classes in config ({model_params['n_classes']}) which is different from actual n_train_classes in the checkpoint - ({n_train_classes}).") + model_params['n_classes'] = n_train_classes + model.load_state_dict(weights, strict=False) + print('loaded checkpoint from', checkpoint_path) + + return model + + @staticmethod + def preprocess_test_data(anno_path, name_keys, viewpoint_list, use_full_image_path, + images_dir, image_size, crop_bbox, valid_batch_size, num_workers, + fliplr, fliplr_view, n_filter_min, n_subsample_max): + df_test = preprocess_data( + anno_path, + name_keys=name_keys, + convert_names_to_ids=True, + viewpoint_list=viewpoint_list, + n_filter_min=n_filter_min, + n_subsample_max=n_subsample_max, + use_full_image_path=use_full_image_path, + images_dir=images_dir, + ) + + test_dataset = MiewIdDataset( + csv=df_test, + transforms=get_test_transforms((image_size[0], image_size[1])), + fliplr=fliplr, + fliplr_view=fliplr_view, + crop_bbox=crop_bbox, + ) + + test_loader = torch.utils.data.DataLoader( + test_dataset, + batch_size=valid_batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True, + drop_last=False, + ) + + return test_loader, df_test + + @staticmethod + def evaluate_groups(self, eval_groups, anno_path, name_keys, viewpoint_list, + use_full_image_path, images_dir, model): + df_test_group = preprocess_data( + anno_path, + name_keys=name_keys, + convert_names_to_ids=True, + viewpoint_list=viewpoint_list, + n_filter_min=None, + n_subsample_max=None, + use_full_image_path=use_full_image_path, + images_dir=images_dir + ) + group_results = group_eval_run(df_test_group, eval_groups, model, + n_filter_min = self.n_filter_min, + n_subsample_max = self.n_subsample_max, + image_size = self.image_size, + fliplr = self.fliplr, + fliplr_view = self.fliplr_view, + crop_bbox = self.crop_bbox, + valid_batch_size = self.valid_batch_size, + device = self.device) + + @staticmethod + def visualize_results(test_outputs, df_test, test_dataset, model, device, k=5, valid_batch_size=2, output_dir='miewid_visualizations'): + embeddings, q_pids, distmat = test_outputs + ranks = list(range(1, k+1)) + score, match_mat, topk_idx, topk_names = precision_at_k(q_pids, distmat, ranks=ranks, return_matches=True) + match_results = (q_pids, topk_idx, topk_names, match_mat) + render_query_results(model, test_dataset, df_test, match_results, device, + k=k, valid_batch_size=valid_batch_size, output_dir=output_dir) + + def evaluate(self): + test_loader, df_test = self.preprocess_test_data( + self.anno_path, self.name_keys, self.viewpoint_list, + self.use_full_image_path, self.images_dir, self.image_size, self.crop_bbox, + self.valid_batch_size, self.num_workers, self.fliplr, + self.fliplr_view, self.n_filter_min, self.n_subsample_max + ) + test_score, cmc, test_outputs = eval_fn(test_loader, self.model, self.device, use_wandb=False, return_outputs=True) + + if self.eval_groups: + self.evaluate_groups(self, + self.eval_groups, self.anno_path, self.name_keys, + self.viewpoint_list, self.use_full_image_path, + self.images_dir, self.model + ) + + if self.visualize: + self.visualize_results(test_outputs, df_test, test_loader.dataset, self.model, self.device, + k=5, valid_batch_size=self.valid_batch_size,output_dir=self.visualization_output_dir ) + + return test_score + +if __name__ == '__main__': + args = parse_args() + config = get_config(args.config) + + visualization_output_dir = f"{config.checkpoint_dir}/{config.project_name}/{config.exp_name}/visualizations" + + evaluator = Evaluator( + device=torch.device(config.engine.device), + seed=config.engine.seed, + anno_path=config.data.test.anno_path, + name_keys=config.data.name_keys, + viewpoint_list=config.data.viewpoint_list, + use_full_image_path=config.data.use_full_image_path, + images_dir=config.data.images_dir, + image_size=(config.data.image_size[0], config.data.image_size[1]), + crop_bbox=config.data.crop_bbox, + valid_batch_size=config.engine.valid_batch_size, + num_workers=config.engine.num_workers, + eval_groups=config.data.test.eval_groups, + fliplr=config.test.fliplr, + fliplr_view=config.test.fliplr_view, + n_filter_min=config.data.test.n_filter_min, + n_subsample_max=config.data.test.n_subsample_max, + model_params=dict(config.model_params), + checkpoint_path=config.data.test.checkpoint_path, + model=None, + visualize=args.visualize, + visualization_output_dir=visualization_output_dir +) + + evaluator.evaluate() \ No newline at end of file diff --git a/wbia_miew_id/examples/download_example.py b/wbia_miew_id/examples/download_example.py new file mode 100644 index 0000000..abc849f --- /dev/null +++ b/wbia_miew_id/examples/download_example.py @@ -0,0 +1,50 @@ +import requests +import tarfile +import os + +def download_file(url, output_path): + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(output_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + +def extract_tarfile(tar_path, extract_to): + with tarfile.open(tar_path, "r:gz") as tar: + top_level_dir = os.path.commonprefix(tar.getnames()) + + for member in tar.getmembers(): + member_path = os.path.join(extract_to, os.path.relpath(member.name, top_level_dir)) + if member.isdir(): + if not os.path.isdir(member_path): + os.makedirs(member_path) + else: + if not os.path.isdir(os.path.dirname(member_path)): + os.makedirs(os.path.dirname(member_path)) + with open(member_path, 'wb') as f: + f.write(tar.extractfile(member).read()) + +def main(): + url = "https://cthulhu.dyn.wildme.io/public/datasets/beluga_example_miewid.tar.gz" + + script_dir = os.path.dirname(os.path.realpath(__file__)) + tar_path = os.path.join(script_dir, "beluga_example_miewid.tar.gz") + extract_to = os.path.join(script_dir, "beluga_example_miewid") + + print(f"Downloading {url} to {tar_path}...") + download_file(url, tar_path) + print(f"Downloaded to {tar_path}") + + if not os.path.exists(extract_to): + os.makedirs(extract_to) + + print(f"Extracting {tar_path} to {extract_to}...") + extract_tarfile(tar_path, extract_to) + print("Extraction completed") + + os.remove(tar_path) + print(f"Removed {tar_path}") + +if __name__ == "__main__": + main() diff --git a/wbia_miew_id/examples/extract_and_evaluate.ipynb b/wbia_miew_id/examples/extract_and_evaluate.ipynb new file mode 100644 index 0000000..d3cbd83 --- /dev/null +++ b/wbia_miew_id/examples/extract_and_evaluate.ipynb @@ -0,0 +1,480 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "id": "2fbe73d6-5142-4b70-a2d7-d6b56a4c5482", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from wbia_miew_id.evaluate import Evaluator\n", + "\n", + "from transformers import AutoModel" + ] + }, + { + "cell_type": "markdown", + "id": "bcac49ce-6a13-4a5f-a079-a1e2023927f8", + "metadata": {}, + "source": [ + "### Evaluate using local checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "6c0672ef-8f4c-420d-9703-b872d2b5f0f3", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = {\n", + " 'model_name': 'efficientnetv2_rw_m',\n", + " 'use_fc': False,\n", + " 'fc_dim': 2048,\n", + " 'dropout': 0,\n", + " 'loss_module': 'arcface_subcenter_dynamic',\n", + " 's': 51.960399844266306,\n", + " 'margin': 0.32841442327915477,\n", + " 'pretrained': True,\n", + " 'n_classes': 11968,\n", + " 'k': 3\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "97a0d91e-7bff-41e8-85d4-20839c471992", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building Model Backbone for efficientnetv2_rw_m model\n", + "loaded checkpoint from beluga_example_miewid/benchmark_model/miew_id.msv2_all.bin\n" + ] + } + ], + "source": [ + "evaluator = Evaluator(\n", + " device=torch.device('cuda'),\n", + " seed=0,\n", + " anno_path='beluga_example_miewid/benchmark_splits/test.csv',\n", + " name_keys=['name'],\n", + " viewpoint_list=None,\n", + " use_full_image_path=True,\n", + " images_dir=None,\n", + " image_size=(440, 440),\n", + " crop_bbox=True,\n", + " valid_batch_size=12,\n", + " num_workers=8,\n", + " eval_groups=[['species', 'viewpoint']],\n", + " fliplr=False,\n", + " fliplr_view=[],\n", + " n_filter_min=2,\n", + " n_subsample_max=10,\n", + " model_params=model_params,\n", + " checkpoint_path='beluga_example_miewid/benchmark_model/miew_id.msv2_all.bin',\n", + " model=None,\n", + " visualize=False,\n", + " visualization_output_dir='beluga_example_visualizations'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "7a7af674-7f28-4d1a-916e-262437efc19c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 849 annotations remain after filtering by min 2 per name_species\n", + " 849 annotations remain after subsampling by max 10 per name_species\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4d4ec0443fb64ff68a45a93b784b5957", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/71 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/wbia-plugin-miew-id/wbia_miew_id/examples/wandb/run-20240706_004556-psep5t4t" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run beluga-example-exp-1 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/lashao/miewid-training" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/lashao/miewid-training/runs/psep5t4t/workspace" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a0cce90f2eeb4f1b953b07f2ccd0ba03", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/172 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


Avg - Rank-1 ▁▂▅▇█
Avg - Rank-10 ▁▄▆▇█
Avg - Rank-20 ▁▄▆▇█
Avg - Rank-5 ▁▄▆▇█
Avg - mAP▁▂▄▆█
beluga_whale-up - Rank-1 ▁▂▅▇█
beluga_whale-up - Rank-10 ▁▄▆▇█
beluga_whale-up - Rank-20 ▁▄▆▇█
beluga_whale-up - Rank-5 ▁▄▆▇█
beluga_whale-up - mAP▁▂▄▆█
epoch▁▃▅▆█
lr▁▃▄▆█
train loss██▆▄▁

Run summary:


Avg - Rank-1 0.53946
Avg - Rank-10 0.72556
Avg - Rank-20 0.78916
Avg - Rank-5 0.66902
Avg - mAP0.42436
beluga_whale-up - Rank-1 0.53946
beluga_whale-up - Rank-10 0.72556
beluga_whale-up - Rank-20 0.78916
beluga_whale-up - Rank-5 0.66902
beluga_whale-up - mAP0.42436
epoch4
lr0.00061
train loss10.53451

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run beluga-example-exp-1 at: https://wandb.ai/lashao/miewid-training/runs/psep5t4t/workspace
Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240706_004556-psep5t4t/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "0.42436280846595764" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer = Trainer(config)\n", + "trainer.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/wbia_miew_id/examples/split_dataset.ipynb b/wbia_miew_id/examples/split_dataset.ipynb new file mode 100644 index 0000000..ab1f59e --- /dev/null +++ b/wbia_miew_id/examples/split_dataset.ipynb @@ -0,0 +1,291 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "2bb57c03-d165-422e-afa7-021856dbb0d4", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "\n", + "from wbia_miew_id.helpers import split_df" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4ccf6f5b-a8b5-43cb-94ee-e2946dbfce0f", + "metadata": {}, + "outputs": [], + "source": [ + "df_annot = pd.read_csv('beluga_example_miewid/annotations.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "820b50cd-d103-419d-ab78-8c5a473ca19e", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
thetaviewpointnamefile_namespeciesfile_pathxywh
00right411000000000001.jpgbeluga_whale/datasets/beluga-440/000000000001.jpg0070440
10right698000000000002.jpgbeluga_whale/datasets/beluga-440/000000000002.jpg0091440
20right700000000000003.jpgbeluga_whale/datasets/beluga-440/000000000003.jpg0093440
30right340000000000008.jpgbeluga_whale/datasets/beluga-440/000000000008.jpg00113440
40right340000000000009.jpgbeluga_whale/datasets/beluga-440/000000000009.jpg00102440
\n", + "
" + ], + "text/plain": [ + " theta viewpoint name file_name species \\\n", + "0 0 right 411 000000000001.jpg beluga_whale \n", + "1 0 right 698 000000000002.jpg beluga_whale \n", + "2 0 right 700 000000000003.jpg beluga_whale \n", + "3 0 right 340 000000000008.jpg beluga_whale \n", + "4 0 right 340 000000000009.jpg beluga_whale \n", + "\n", + " file_path x y w h \n", + "0 /datasets/beluga-440/000000000001.jpg 0 0 70 440 \n", + "1 /datasets/beluga-440/000000000002.jpg 0 0 91 440 \n", + "2 /datasets/beluga-440/000000000003.jpg 0 0 93 440 \n", + "3 /datasets/beluga-440/000000000008.jpg 0 0 113 440 \n", + "4 /datasets/beluga-440/000000000009.jpg 0 0 102 440 " + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_annot.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "635a5c6a-a683-4ef0-94b5-214646390d3f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Filtering...\n", + "Before filtering: 6055 annotations\n", + "After filtering: 5906 annotations\n", + "\n", + "--------------------------------------------------\n", + "\n", + "Calculating stats for combined subsets\n", + "** cross-set stats **\n", + "\n", + "- Counts: \n", + "Number of annotations - total: 5906\n", + "number of individuals in train: 370\n", + "number of annotations in train: 3999\n", + "\n", + "number of individuals in test: 263\n", + "number of annotations in test: 933\n", + "\n", + "number of individuals in val: 263\n", + "number of annotations in val: 974\n", + "\n", + "train ratio: 0.6771080257365392\n", + "average number of annotations per individual in train: 10.81\n", + "average number of annotations per individual in test: 3.55\n", + "average number of annotations per individual in val: 3.70\n", + "\n", + "- New individuals: \n", + "number of new (unseen) individuals in test: 133\n", + "ratio of new names to all individuals in test: 0.51\n", + "\n", + "number of new (unseen) individuals in val: 130\n", + "ratio of new names to all individuals in val: 0.49\n", + "- Individuals in sets: \n", + "number of overlapping individuals in train & test: 130\n", + "ratio of overlapping names to total individuals in train: 0.35\n", + "ratio of overlapping names to total individuals in test: 0.49\n", + "Number of annotations in train for overlapping individuals with test: 565\n", + "Number of annotations in test for overlapping individuals with train: 569\n", + "ratio of annotations in test for overlapping individuals with train: 0.5017636684303352\n", + "number of overlapping individuals in train & val: 133\n", + "ratio of overlapping names to total individuals in train: 0.36\n", + "ratio of overlapping names to total individuals in val: 0.51\n", + "Number of annotations in train for overlapping individuals with val: 624\n", + "Number of annotations in val for overlapping individuals with train: 621\n", + "ratio of annotations in val for overlapping individuals with train: 0.4987951807228916\n" + ] + } + ], + "source": [ + "df_tr, df_te, df_val = split_df(df_annot, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='name', print_key='name', verbose=True, random_state=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "923201e1-54e4-42db-8822-351d0f175734", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs('beluga_example_miewid/splits', exist_ok=True)\n", + "\n", + "df_tr.to_csv('beluga_example_miewid/splits/train.csv')\n", + "df_val.to_csv('beluga_example_miewid/splits/val.csv')\n", + "df_te.to_csv('beluga_example_miewid/splits/test.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65ecbd63-31e6-4a7f-8077-f788ab59a195", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea84ddb6-c44c-46ee-b360-c825ebc9d1e7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5899a809-0d41-4c68-9720-b4ec123e9ece", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c91530d-d0be-4f48-a02c-844a034551ad", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/wbia_miew_id/helpers/__init__.py b/wbia_miew_id/helpers/__init__.py index 3bd0734..0212ae6 100644 --- a/wbia_miew_id/helpers/__init__.py +++ b/wbia_miew_id/helpers/__init__.py @@ -1,3 +1,5 @@ from .config import * from .getters import * -from .tools import * \ No newline at end of file +from .tools import * +from .split import * +from .swatools import * \ No newline at end of file diff --git a/wbia_miew_id/helpers/config.py b/wbia_miew_id/helpers/config.py index 946fbc4..337ee79 100644 --- a/wbia_miew_id/helpers/config.py +++ b/wbia_miew_id/helpers/config.py @@ -176,4 +176,52 @@ def get_config(file_path: str) -> Config: def write_config(config: Config, file_path: str): config_dict = dict(config) with open(file_path, 'w') as file: - yaml.dump(config_dict, file) \ No newline at end of file + yaml.dump(config_dict, file) + +def yaml_to_formatted_string(file_path): + """ + Convert a YAML file to a formatted string. + + Args: + file_path (str): The path to the YAML file. + + Returns: + str: A formatted string representation of the YAML data. + """ + try: + # Read the YAML file + with open(file_path, 'r') as file: + data = yaml.safe_load(file) + # Convert the dictionary to a formatted string + formatted_str = yaml.dump(data, sort_keys=False, default_flow_style=False) + return formatted_str + except FileNotFoundError: + return f"Error: File not found - {file_path}" + except yaml.YAMLError as e: + return f"Error parsing YAML: {e}" + +def formatted_string_to_yaml(formatted_str, output_path): + """ + Convert a formatted string to a YAML string and write it to a file. + + Args: + formatted_str (str): A formatted string representation of data. + output_path (str): The path where the YAML file will be saved. + + Returns: + str: A message indicating success or failure. + """ + try: + # Load the formatted string into a Python dictionary + data = yaml.safe_load(formatted_str) + # Convert the dictionary back to a YAML string + yaml_str = yaml.dump(data, sort_keys=False, default_flow_style=False) + # Write the YAML string to the specified output file + with open(output_path, 'w') as file: + file.write(yaml_str) + return f"YAML successfully written to {output_path}" + except yaml.YAMLError as e: + return f"Error parsing formatted string: {e}" + except IOError as e: + return f"Error writing to file: {e}" + diff --git a/wbia_miew_id/helpers/split/__init__.py b/wbia_miew_id/helpers/split/__init__.py new file mode 100644 index 0000000..1bee67a --- /dev/null +++ b/wbia_miew_id/helpers/split/__init__.py @@ -0,0 +1,3 @@ +from .split import * +from .stats import * +from .tools import * \ No newline at end of file diff --git a/wbia_miew_id/helpers/split/split.py b/wbia_miew_id/helpers/split/split.py index d4557b0..cad2cfc 100755 --- a/wbia_miew_id/helpers/split/split.py +++ b/wbia_miew_id/helpers/split/split.py @@ -1,12 +1,10 @@ import pandas as pd import matplotlib.pyplot as plt -from stats import intersect_stats +from .stats import intersect_stats import scipy from scipy import optimize import numpy as np -from tools import print_div, apply_filters - - +from .tools import print_div, apply_filters def split_classes_objective(r0, w, class_counts, train_ratio, unseen_ratio): """ @@ -30,8 +28,7 @@ def split_classes_objective(r0, w, class_counts, train_ratio, unseen_ratio): full = np.sum(class_counts) return np.abs(train_ratio * full - (train_full + train_part)) - -def split_df(df, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='name', print_key='name_viewpoint', verbose=False): +def split_df(df, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='name', print_key='name_viewpoint', verbose=False, random_state=None): """ Splits a DataFrame into training, testing (and optionally validation) sets based on specified ratios and stratification column. @@ -43,11 +40,14 @@ def split_df(df, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='n stratify_col (str, optional): The column on which to stratify the splits. Defaults to 'name'. print_key (str, optional): Key used for printing statistics if verbose is True. Defaults to 'name_viewpoint'. verbose (bool, optional): If True, prints additional information about the splits. Defaults to False. + random_state (int, optional): Seed for random number generator. Defaults to None. Returns: tuple: Depending on 'is_val', returns a tuple of (train_df, test_df) or (train_df, test_df, val_df). - """ + + if random_state is not None: + np.random.seed(random_state) # Assertions to check validity of ratio inputs assert (train_ratio > 0 and train_ratio < 1), "train_ratio must be between 0 and 1." @@ -82,7 +82,7 @@ def split_df(df, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='n ratios[i0:i1] = w # Perform the stratified split - dfa_train, dfa_test = stratified_split(df, sorted_classes, ratios, stratify_col) + dfa_train, dfa_test = stratified_split(df, sorted_classes, ratios, stratify_col, random_state) # If validation set is not required, return train and test sets if not is_val: @@ -107,9 +107,7 @@ def split_df(df, train_ratio=0.7, unseen_ratio=0.5, is_val=True, stratify_col='n # Return the datasets return dfa_train, dfa_test, dfa_val - - -def stratified_split(df, classes, ratios, class_col, shuffle=True): +def stratified_split(df, classes, ratios, class_col, random_state=None, shuffle=True): """ Perform a stratified split of a DataFrame into training and test sets based on specified classes and ratios. @@ -118,11 +116,15 @@ def stratified_split(df, classes, ratios, class_col, shuffle=True): - classes: List of unique classes used for stratification. - ratios: List of ratios for each class in the split. - class_col: Name of the column containing class labels. + - random_state: Seed for random number generator. Defaults to None. - shuffle: Boolean to control whether to shuffle the indices (default: True). Returns: - Two DataFrames: the training set and the test set. """ + if random_state is not None: + np.random.seed(random_state) + train_indices = np.zeros(0, np.int64) for c, ratio in zip(classes, ratios): indices = np.array((df[df[class_col] == c]).index) @@ -134,7 +136,3 @@ def stratified_split(df, classes, ratios, class_col, shuffle=True): train_df = df.loc[train_indices] test_df = df.drop(train_indices) return train_df, test_df - - - - diff --git a/wbia_miew_id/models/heads.py b/wbia_miew_id/models/heads.py index f437ee8..e692548 100644 --- a/wbia_miew_id/models/heads.py +++ b/wbia_miew_id/models/heads.py @@ -57,8 +57,10 @@ def l2_norm(input, axis = 1): output = torch.div(input, norm) return output + class ElasticArcFace(nn.Module): - def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False, k=None): + + def __init__(self, loss_fn, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False, k=None): super(ElasticArcFace, self).__init__() self.in_features = in_features self.out_features = out_features @@ -66,8 +68,10 @@ def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=Fal self.m = m self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) nn.init.normal_(self.kernel, std=0.01) - self.std=std - self.plus=plus + self.std = std + self.plus = plus + self.loss_fn = loss_fn + def forward(self, embbedings, label): embbedings = l2_norm(embbedings, axis=1) kernel_norm = l2_norm(self.kernel, axis=0) @@ -87,7 +91,10 @@ def forward(self, embbedings, label): cos_theta.acos_() cos_theta[index] += m_hot cos_theta.cos_().mul_(self.s) - return cos_theta + + loss = self.loss_fn(cos_theta, label) + + return loss ########## Subcenter Arcface with dynamic margin ########## @@ -112,7 +119,6 @@ def forward(self, features): class ArcFaceLossAdaptiveMargin(nn.modules.Module): def __init__(self, margins, out_dim, s): super().__init__() -# self.crit = nn.CrossEntropyLoss() self.s = s self.register_buffer('margins', torch.tensor(margins)) self.out_dim = out_dim @@ -137,6 +143,7 @@ def forward(self, logits, labels): class ArcFaceSubCenterDynamic(nn.Module): def __init__( self, + loss_fn, embedding_dim, output_classes, margins, @@ -155,7 +162,10 @@ def __init__( out_dim=self.output_classes, s=self.s) + self.loss_fn = loss_fn + def forward(self, features, labels): logits = self.wmetric_classify(features.float()) logits = self.warcface_margin(logits, labels) - return logits \ No newline at end of file + loss = self.loss_fn(logits, labels) + return loss \ No newline at end of file diff --git a/wbia_miew_id/models/model.py b/wbia_miew_id/models/model.py index 3a21272..c825f33 100644 --- a/wbia_miew_id/models/model.py +++ b/wbia_miew_id/models/model.py @@ -55,14 +55,8 @@ def __init__(self, use_fc=False, fc_dim=512, dropout=0.0, - loss_module='softmax', - s=30.0, - margin=0.50, - ls_eps=0.0, - theta_zero=0.785, pretrained=True, - margins=None, - k=None): + **kwargs): """ """ super(MiewIdNet, self).__init__() @@ -75,6 +69,8 @@ def __init__(self, final_in_features = self.backbone.classifier.in_features if model_name.startswith('swinv2'): final_in_features = self.backbone.norm.normalized_shape[0] + + self.final_in_features = final_in_features self.backbone.classifier = nn.Identity() self.backbone.global_pool = nn.Identity() @@ -91,26 +87,6 @@ def __init__(self, self.fc.apply(weights_init_classifier) final_in_features = fc_dim - self.loss_module = loss_module - if loss_module == 'arcface': - self.final = ElasticArcFace(final_in_features, n_classes, - s=s, m=margin) - elif loss_module == 'arcface_subcenter_dynamic': - if margins is None: - margins = [0.3] * n_classes - self.final = ArcFaceSubCenterDynamic( - embedding_dim=final_in_features, - output_classes=n_classes, - margins=margins, - s=s, - k=k ) - # elif loss_module == 'cosface': - # self.final = AddMarginProduct(final_in_features, n_classes, s=s, m=margin) - # elif loss_module == 'adacos': - # self.final = AdaCos(final_in_features, n_classes, m=margin, theta_zero=theta_zero) - else: - self.final = nn.Linear(final_in_features, n_classes) - def _init_params(self): nn.init.xavier_normal_(self.fc.weight) nn.init.constant_(self.fc.bias, 0) @@ -119,16 +95,8 @@ def _init_params(self): def forward(self, x, label=None): feature = self.extract_feat(x) - if not self.training: - return feature - else: - assert label is not None - if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'): - logits = self.final(feature, label) - else: - logits = self.final(feature) - - return logits + + return feature def extract_feat(self, x): batch_size = x.shape[0] @@ -143,14 +111,4 @@ def extract_feat(self, x): x1 = self.bn(x1) x1 = self.fc(x1) - return x - - def extract_logits(self, x, label=None): - feature = self.extract_feat(x) - assert label is not None - if self.loss_module in ('arcface', 'arcface_subcenter_dynamic'): - logits = self.final(feature, label) - else: - logits = self.final(feature) - - return logits \ No newline at end of file + return x \ No newline at end of file diff --git a/wbia_miew_id/models/model_helpers.py b/wbia_miew_id/models/model_helpers.py index dfd02ef..cff64a9 100644 --- a/wbia_miew_id/models/model_helpers.py +++ b/wbia_miew_id/models/model_helpers.py @@ -12,7 +12,7 @@ def get_model(cfg, checkpoint_path=None, use_gpu=True): model.to(device) if checkpoint_path: - model.load_state_dict(torch.load(checkpoint_path)) + model.load_state_dict(torch.load(checkpoint_path), strict=False) print('loaded checkpoint from', checkpoint_path) return model \ No newline at end of file diff --git a/wbia_miew_id/sweep.py b/wbia_miew_id/sweep.py index 66392a5..dd0027a 100644 --- a/wbia_miew_id/sweep.py +++ b/wbia_miew_id/sweep.py @@ -34,31 +34,45 @@ def parse_args(): def objective(trial, config): # Specify the parameters you want to optimize - config.data.train.n_filter_min = trial.suggest_int("train.n_filter_min", 3, 5) - # image_size = 256 #trial.suggest_categorical("image_size", [192, 256, 384, 440, 512]) - image_size = 256 #trial.suggest_categorical("image_size", [192, 256, 384, 440, 512]) - loss_module = trial.suggest_categorical("loss_module", ['arcface_subcenter_dynamic', 'arcface']) - config.model_params.loss_module = loss_module - # config.data.image_size = [image_size, image_size] - n_epochs = 20#trial.suggest_int("epochs", 20, 40) + + image_size = 440 #trial.suggest_categorical("image_size", [192, 256, 384, 440, 512]) + n_epochs = 25#trial.suggest_int("epochs", 20, 40) config.engine.epochs = n_epochs + + # if trial.number > 0: + + # loss_module = trial.suggest_categorical("loss_module", ['arcface_subcenter_dynamic', 'elastic_arcface']) + # config.model_params.loss_module = loss_module + # config.data.image_size = [image_size, image_size] + # Specify the parameters you want to optimize + # config.data.train.n_filter_min = trial.suggest_int("train.n_filter_min", 3, 5) + # image_size = 256 #trial.suggest_categorical("image_size", [192, 256, 384, 440, 512]) + config.model_params.s = trial.suggest_uniform("s", 30, 64) - if config.model_params.loss_module == 'arcface': + if config.model_params.loss_module == 'elastic_arcface': config.model_params.margin = trial.suggest_uniform("margin", 0.3, 0.7) if config.model_params.loss_module == 'arcface_subcenter_dynamic': config.model_params.k = trial.suggest_int("k", 2, 4) # The scheduler params are derived from one base paremeter to minimize the number of parameters to optimzie - lr_base = trial.suggest_loguniform("lr_base", 1e-5, 1e-3) - config.scheduler_params.lr_start = lr_base / 10 - config.scheduler_params.lr_max = lr_base * 10 - config.scheduler_params.lr_min = lr_base / 20 + # lr_base = trial.suggest_loguniform("lr_base", 1e-5, 1e-3) + # config.scheduler_params.lr_start = lr_base / 100 + # config.scheduler_params.lr_max = lr_base * 10 + # config.scheduler_params.lr_min = lr_base / 100 + + lr_start = trial.suggest_loguniform("lr_start", 1e-7, 1e-4) + config.scheduler_params.lr_start = lr_start + lr_max = trial.suggest_loguniform("lr_max", 5e-5, 1e-3) + config.scheduler_params.lr_max = lr_max + lr_min = trial.suggest_loguniform("lr_min", 1e-7, 1e-4) + config.scheduler_params.lr_min = lr_min + + # # SWA parameters to test + # config.engine.use_swa = trial.suggest_categorical("use_swa", [False, True]) + # if config.engine.use_swa: + # config.swa_params.swa_lr = trial.suggest_loguniform("swa_lr", 0.0001, 0.05) + # config.swa_params.swa_start = trial.suggest_int("swa_start", 20, 25) - # SWA parameters to test - config.engine.use_swa = trial.suggest_categorical("use_swa", [False, True]) - if config.engine.use_swa: - config.swa_params.swa_lr = trial.suggest_loguniform("swa_lr", 0.0001, 0.05) - config.swa_params.swa_start = trial.suggest_int("swa_start", 20, 25) print("trial number: ", trial.number) print("config: ", dict(config)) @@ -68,13 +82,21 @@ def objective(trial, config): else: config.exp_name = config.exp_name + f"_t{trial.number}" - - try: - result = run(config) - except Exception as e: - print("Exception occured: ", e) - print(trial) - result = 0 + if trial.number == 0: + return 0.738 + if trial.number == 1: + return 0.59 + if trial.number == 2: + return 0.7422174440141309 + if trial.number == 3: + return 0.7422935392772942 + else: + try: + result = run(config) + except Exception as e: + print("Exception occured: ", e) + print(trial) + result = 0 return result @@ -105,6 +127,42 @@ def signal_handler(signum, frame): sampler=study_sampler, pruner=MedianPruner(), direction="maximize", load_if_exists=True ) + study.enqueue_trial({ + 's': 49.32675426153405, + 'margin': 0.32841442327915477, + 'k': 2, + 'lr_start': 0.00000341898067433194, + 'lr_max': 0.001, + 'lr_min': 0.000002 + }) + + study.enqueue_trial({ + 's': 48.65965913352904, + 'margin': 0.32841442327915477, + 'k': 4, + 'lr_start': 0.00000643117205013199, + 'lr_max': 0.00000643117205013199, + 'lr_min': 0.00000643117205013199 + }) + + study.enqueue_trial({ + 's': 48.65965913352904, + 'margin': 0.32841442327915477, + 'k': 4, + 'lr_start': 0.00000643117205013199, + 'lr_max': 0.0002557875307967728, + 'lr_min': 0.0000018662266976518 + }) + + study.enqueue_trial({ + 's': 51.960399844266306, + 'margin': 0.32841442327915477, + 'k': 3, + 'lr_start': 0.0000473498930449948, + 'lr_max': 0.000896858981000587, + 'lr_min': 0.00000141359355517523 + }) + comb_objective = lambda trial: objective(trial, config) study.optimize(comb_objective, n_trials=n_trials) diff --git a/wbia_miew_id/test.py b/wbia_miew_id/test.py deleted file mode 100644 index 5df2914..0000000 --- a/wbia_miew_id/test.py +++ /dev/null @@ -1,136 +0,0 @@ -from datasets import MiewIdDataset, get_train_transforms, get_valid_transforms, get_test_transforms -from logging_utils import WandbContext -from models import MiewIdNet -from etl import preprocess_data, print_basic_stats -from engine import eval_fn, group_eval -from helpers import get_config -from visualization import render_query_results -from metrics import precision_at_k - -import os -import torch -import random -import numpy as np - -import argparse - -# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" -# os.environ['TORCH_USE_CUDA_DSA'] = "1" - -def parse_args(): - parser = argparse.ArgumentParser(description="Load configuration file.") - parser.add_argument( - '--config', - type=str, - default='configs/default_config.yaml', - help='Path to the YAML configuration file. Default: configs/default_config.yaml' - ) - - parser.add_argument('--visualize', '--vis', action='store_true') - - return parser.parse_args() - -def run_test(config, visualize=False): - - def set_seed_torch(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - - set_seed_torch(config.engine.seed) - - df_test = preprocess_data(config.data.test.anno_path, - name_keys=config.data.name_keys, - convert_names_to_ids=True, - viewpoint_list=config.data.viewpoint_list, - n_filter_min=config.data.test.n_filter_min, - n_subsample_max=config.data.test.n_subsample_max, - use_full_image_path=config.data.use_full_image_path, - images_dir = config.data.images_dir, - ) - - # top_names = df_test['name'].value_counts().index[:10] - # df_test = df_test[df_test['name'].isin(top_names)] - - test_dataset = MiewIdDataset( - csv=df_test, - transforms=get_test_transforms(config), - fliplr=config.test.fliplr, - fliplr_view=config.test.fliplr_view, - crop_bbox=config.data.crop_bbox, - ) - - test_loader = torch.utils.data.DataLoader( - test_dataset, - batch_size=config.engine.valid_batch_size, - num_workers=config.engine.num_workers, - shuffle=False, - pin_memory=True, - drop_last=False, - ) - - device = torch.device(config.engine.device) - - - - checkpoint_path = config.data.test.checkpoint_path - - if checkpoint_path: - print('loading checkpoint from', checkpoint_path) - weights = torch.load(config.data.test.checkpoint_path, map_location=torch.device(config.engine.device)) - n_train_classes = weights[list(weights.keys())[-1]].shape[-1] - if config.model_params.n_classes != n_train_classes: - print(f"WARNING: Overriding n_classes in config ({config.model_params.n_classes}) which is different from actual n_train_classes in the checkpoint - ({n_train_classes}).") - config.model_params.n_classes = n_train_classes - - model = MiewIdNet(**dict(config.model_params)) - model.to(device) - - model.load_state_dict(weights, strict=False) - print('loaded checkpoint from', checkpoint_path) - - else: - model = MiewIdNet(**dict(config.model_params)) - model.to(device) - - - test_score, cmc, test_outputs = eval_fn(test_loader, model, device, use_wandb=False, return_outputs=True) - - - - - eval_groups = config.data.test.eval_groups - - if eval_groups: - df_test_group = preprocess_data(config.data.test.anno_path, - name_keys=config.data.name_keys, - convert_names_to_ids=True, - viewpoint_list=config.data.viewpoint_list, - n_filter_min=None, - n_subsample_max=None, - use_full_image_path=config.data.use_full_image_path, - images_dir = config.data.images_dir) - group_results = group_eval(config, df_test_group, eval_groups, model) - - if visualize: - k=5 - embeddings, q_pids, distmat = test_outputs - ranks=list(range(1, k+1)) - score, match_mat, topk_idx, topk_names = precision_at_k(q_pids, distmat, ranks=ranks, return_matches=True) - match_results = (q_pids, topk_idx, topk_names, match_mat) - render_query_results(config, model, test_dataset, df_test, match_results, k=k) - - return test_score - -if __name__ == '__main__': - args = parse_args() - config_path = args.config - - config = get_config(config_path) - - visualize = args.visualize - - run_test(config, visualize=visualize) \ No newline at end of file diff --git a/wbia_miew_id/train.py b/wbia_miew_id/train.py index b7e8755..750c06e 100644 --- a/wbia_miew_id/train.py +++ b/wbia_miew_id/train.py @@ -1,191 +1,219 @@ -from datasets import MiewIdDataset, get_train_transforms, get_valid_transforms -from logging_utils import WandbContext -from models import MiewIdNet -from etl import preprocess_data, print_intersect_stats, load_preprocessed_mapping, preprocess_dataset -from losses import fetch_loss -from schedulers import MiewIdScheduler -from engine import run_fn -from helpers import get_config, write_config -from torch.optim.swa_utils import AveragedModel, SWALR +from wbia_miew_id.logging_utils import WandbContext +from wbia_miew_id.etl import preprocess_data, print_intersect_stats, load_preprocessed_mapping, preprocess_dataset +from wbia_miew_id.losses import fetch_loss +from wbia_miew_id.schedulers import MiewIdScheduler +from wbia_miew_id.helpers import get_config, write_config, update_bn +from wbia_miew_id.metrics import AverageMeter, compute_calibration +from wbia_miew_id.datasets import MiewIdDataset, get_train_transforms, get_valid_transforms +from wbia_miew_id.models import ArcMarginProduct, ElasticArcFace, ArcFaceSubCenterDynamic, MiewIdNet +from wbia_miew_id.engine import train_fn, eval_fn, group_eval_fn -import os + +from torch.optim.swa_utils import AveragedModel, SWALR import torch import random import numpy as np -from dotenv import load_dotenv - +import os import argparse +import matplotlib.pyplot as plt +from tqdm.auto import tqdm +import wandb -# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" -# os.environ['TORCH_USE_CUDA_DSA'] = "1" -def parse_args(): - parser = argparse.ArgumentParser(description="Load configuration file.") - parser.add_argument( - '--config', - type=str, - default='configs/default_config.yaml', - help='Path to the YAML configuration file. Default: configs/default_config.yaml' - ) - return parser.parse_args() +class Trainer: + def __init__(self, config): + self.config = config -def run(config): - checkpoint_dir = f"{config.checkpoint_dir}/{config.project_name}/{config.exp_name}" - os.makedirs(checkpoint_dir, exist_ok=False) - print('Checkpoints will be saved at: ', checkpoint_dir) - - config_path_out = f'{checkpoint_dir}/{config.exp_name}.yaml' - config.data.test.checkpoint_path = f'{checkpoint_dir}/model_best.bin' - - def set_seed_torch(seed): + def set_seed_torch(self, seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True - - set_seed_torch(config.engine.seed) - - df_train = preprocess_data(config.data.train.anno_path, - name_keys=config.data.name_keys, - convert_names_to_ids=True, - viewpoint_list=config.data.viewpoint_list, - n_filter_min=config.data.train.n_filter_min, - n_subsample_max=config.data.train.n_subsample_max, - use_full_image_path=config.data.use_full_image_path, - images_dir = config.data.images_dir, - ) - - df_val = preprocess_data(config.data.val.anno_path, - name_keys=config.data.name_keys, - convert_names_to_ids=True, - viewpoint_list=config.data.viewpoint_list, - n_filter_min=config.data.val.n_filter_min, - n_subsample_max=config.data.val.n_subsample_max, - use_full_image_path=config.data.use_full_image_path, - images_dir = config.data.images_dir - ) - - print_intersect_stats(df_train, df_val, individual_key='name_orig') - - n_train_classes = df_train['name'].nunique() - - crop_bbox = config.data.crop_bbox - # if config.data.preprocess_images.force_apply: - # preprocess_dir_images = os.path.join(checkpoint_dir, 'images') - # preprocess_dir_train = os.path.join(preprocess_dir_images, 'train') - # preprocess_dir_val = os.path.join(preprocess_dir_images, 'val') - # print("Preprocessing images. Destination: ", preprocess_dir_images) - # os.makedirs(preprocess_dir_train) - # os.makedirs(preprocess_dir_val) - - # target_size = (config.data.image_size[0],config.data.image_size[1]) - - # df_train = preprocess_images(df_train, crop_bbox, preprocess_dir_train, target_size) - # df_val = preprocess_images(df_val, crop_bbox, preprocess_dir_val, target_size) - - # crop_bbox = False - - if config.data.preprocess_images.apply: - - if config.data.preprocess_images.preprocessed_dir is None: - preprocess_dir_images = os.path.join(checkpoint_dir, 'images') - else: - preprocess_dir_images = config.data.preprocess_images.preprocessed_dir - - if os.path.exists(preprocess_dir_images) and not config.data.preprocess_images.force_apply: - print('Preprocessed images directory found at: ', preprocess_dir_images) - else: - preprocess_dataset(config, preprocess_dir_images) - df_train = load_preprocessed_mapping(df_train, preprocess_dir_images) - df_val = load_preprocessed_mapping(df_val, preprocess_dir_images) + def run_fn(self, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=True, swa_model=None, swa_start=None, swa_scheduler=None): + best_score = 0 + for epoch in range(self.config.engine.epochs): + train_loss = train_fn(train_loader, model, criterion, optimizer, device, scheduler=scheduler, epoch=epoch, use_wandb=use_wandb, swa_model=swa_model, swa_start=swa_start, swa_scheduler=swa_scheduler) + + print("\nGetting metrics on validation set...") + + eval_groups = self.config.data.test.eval_groups + + if eval_groups: + valid_score, valid_cmc = group_eval_fn(self.config, eval_groups, model) + print('Group average score: ', valid_score) + else: + print('Evaluating on full test set') + valid_score = eval_fn(valid_loader, model, device, use_wandb=use_wandb) + print('Valid score: ', valid_score) + + if valid_score > best_score: + best_score = valid_score + torch.save(model.state_dict(), f'{checkpoint_dir}/model_best.bin') + print('best model found for epoch {}'.format(epoch)) + + if swa_model: + print("Updating SWA batchnorm statistics...") + update_bn(train_loader, swa_model, device=device) + torch.save(swa_model.state_dict(), f'{checkpoint_dir}/swa_model_{epoch}.bin') + + return best_score + + def run(self): + config = self.config + checkpoint_dir = f"{config.checkpoint_dir}/{config.project_name}/{config.exp_name}" + os.makedirs(checkpoint_dir, exist_ok=True) + print('Checkpoints will be saved at: ', checkpoint_dir) + + config_path_out = f'{checkpoint_dir}/{config.exp_name}.yaml' + config.data.test.checkpoint_path = f'{checkpoint_dir}/model_best.bin' + + self.set_seed_torch(config.engine.seed) + + df_train = preprocess_data(config.data.train.anno_path, + name_keys=config.data.name_keys, + convert_names_to_ids=True, + viewpoint_list=config.data.viewpoint_list, + n_filter_min=config.data.train.n_filter_min, + n_subsample_max=config.data.train.n_subsample_max, + use_full_image_path=config.data.use_full_image_path, + images_dir=config.data.images_dir) + + df_val = preprocess_data(config.data.val.anno_path, + name_keys=config.data.name_keys, + convert_names_to_ids=True, + viewpoint_list=config.data.viewpoint_list, + n_filter_min=config.data.val.n_filter_min, + n_subsample_max=config.data.val.n_subsample_max, + use_full_image_path=config.data.use_full_image_path, + images_dir=config.data.images_dir) - crop_bbox = False - train_dataset = MiewIdDataset( - csv=df_train, - transforms=get_train_transforms(config), - fliplr=config.test.fliplr, - fliplr_view=config.test.fliplr_view, - crop_bbox=crop_bbox, - ) - - valid_dataset = MiewIdDataset( - csv=df_val, - transforms=get_valid_transforms(config), - fliplr=config.test.fliplr, - fliplr_view=config.test.fliplr_view, - crop_bbox=crop_bbox, - ) + + print_intersect_stats(df_train, df_val, individual_key='name_orig') + + n_train_classes = df_train['name'].nunique() + + crop_bbox = config.data.crop_bbox + + if config.data.preprocess_images.apply: + if config.data.preprocess_images.preprocessed_dir is None: + preprocess_dir_images = os.path.join(checkpoint_dir, 'images') + else: + preprocess_dir_images = config.data.preprocess_images.preprocessed_dir + + if os.path.exists(preprocess_dir_images) and not config.data.preprocess_images.force_apply: + print('Preprocessed images directory found at: ', preprocess_dir_images) + else: + preprocess_dataset(config, preprocess_dir_images) + + df_train = load_preprocessed_mapping(df_train, preprocess_dir_images) + df_val = load_preprocessed_mapping(df_val, preprocess_dir_images) + crop_bbox = False + + train_dataset = MiewIdDataset( + csv=df_train, + transforms=get_train_transforms((config.data.image_size[0], config.data.image_size[1])), + fliplr=config.test.fliplr, + fliplr_view=config.test.fliplr_view, + crop_bbox=crop_bbox) - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=config.engine.train_batch_size, - num_workers=config.engine.num_workers, - shuffle=True, - pin_memory=True, - drop_last=True, - ) - - valid_loader = torch.utils.data.DataLoader( - valid_dataset, - batch_size=config.engine.valid_batch_size, - num_workers=config.engine.num_workers, - shuffle=False, - pin_memory=True, - drop_last=False, - ) - - device = torch.device(config.engine.device) - - if config.model_params.n_classes != n_train_classes: - print(f"WARNING: Overriding n_classes in config ({config.model_params.n_classes}) which is different from actual n_train_classes in the dataset - ({n_train_classes}).") - config.model_params.n_classes = n_train_classes - - if config.model_params.loss_module == 'arcface_subcenter_dynamic': - margin_min = 0.2 - margin_max = config.model_params.margin #0.5 - tmp = np.sqrt(1 / np.sqrt(df_train['name'].value_counts().sort_index().values)) - margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * (margin_max - margin_min) + margin_min - else: - margins = None - - model = MiewIdNet(**dict(config.model_params), margins=margins) - model.to(device) - - criterion = fetch_loss() - criterion.to(device) + valid_dataset = MiewIdDataset( + csv=df_val, + transforms=get_valid_transforms((config.data.image_size[0], config.data.image_size[1])), + fliplr=config.test.fliplr, + fliplr_view=config.test.fliplr_view, + crop_bbox=crop_bbox) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=config.engine.train_batch_size, + num_workers=config.engine.num_workers, + shuffle=True, + pin_memory=True, + drop_last=True) + + valid_loader = torch.utils.data.DataLoader( + valid_dataset, + batch_size=config.engine.valid_batch_size, + num_workers=config.engine.num_workers, + shuffle=False, + pin_memory=True, + drop_last=False) + + device = torch.device(config.engine.device) + + if config.model_params.n_classes != n_train_classes: + print(f"WARNING: Overriding n_classes in config ({config.model_params.n_classes}) which is different from actual n_train_classes in the dataset - ({n_train_classes}).") + config.model_params.n_classes = n_train_classes + + if config.model_params.loss_module == 'arcface_subcenter_dynamic': + margin_min = 0.2 + margin_max = config.model_params.margin + tmp = np.sqrt(1 / np.sqrt(df_train['name'].value_counts().sort_index().values)) + margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * (margin_max - margin_min) + margin_min + else: + margins = None + + model = MiewIdNet(**dict(config.model_params), margins=margins) + model.to(device) + + loss_fn = fetch_loss() + + if config.model_params.loss_module == 'elastic_arcface': + criterion = ElasticArcFace(loss_fn=loss_fn, in_features=model.final_in_features, out_features=config.model_params.n_classes) + criterion.to(device) + + elif config.model_params.loss_module == 'arcface_subcenter_dynamic': + if margins is None: + margins = [0.3] * n_classes + criterion = ArcFaceSubCenterDynamic( + loss_fn=loss_fn, + embedding_dim=model.final_in_features, + output_classes=config.model_params.n_classes, + margins=margins, + s=config.model_params.s, + k=config.model_params.k) + criterion.to(device) + else: + raise NotImplementedError("Loss module not recognized") - optimizer = torch.optim.Adam(model.parameters(), lr=config.scheduler_params.lr_start) - - scheduler = MiewIdScheduler(optimizer,**dict(config.scheduler_params)) + optimizer = torch.optim.Adam(list(model.parameters()) + list(criterion.parameters()), lr=config.scheduler_params.lr_start) + scheduler = MiewIdScheduler(optimizer, **dict(config.scheduler_params)) - if config.engine.use_swa: - swa_model = AveragedModel(model) - swa_model.to(device) - swa_scheduler = SWALR(optimizer=optimizer, swa_lr=config.swa_params.swa_lr) - swa_start = config.swa_params.swa_start - else: - swa_model = None - swa_scheduler = None - swa_start = None + if config.engine.use_swa: + swa_model = AveragedModel(model) + swa_model.to(device) + swa_scheduler = SWALR(optimizer=optimizer, swa_lr=config.swa_params.swa_lr) + swa_start = config.swa_params.swa_start + else: + swa_model = None + swa_scheduler = None + swa_start = None - write_config(config, config_path_out) + write_config(config, config_path_out) + with WandbContext(config): + best_score = self.run_fn(model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=config.engine.use_wandb, + swa_model=swa_model, swa_scheduler=swa_scheduler, swa_start=swa_start) - with WandbContext(config): - best_score = run_fn(config, model, train_loader, valid_loader, criterion, optimizer, scheduler, device, checkpoint_dir, use_wandb=config.engine.use_wandb, - swa_model=swa_model, swa_scheduler=swa_scheduler, swa_start=swa_start) + return best_score - return best_score +def parse_args(): + parser = argparse.ArgumentParser(description="Load configuration file.") + parser.add_argument( + '--config', + type=str, + default='configs/default_config.yaml', + help='Path to the YAML configuration file. Default: configs/default_config.yaml' + ) + return parser.parse_args() if __name__ == '__main__': args = parse_args() config_path = args.config - config = get_config(config_path) - - run(config) \ No newline at end of file + trainer = Trainer(config) + trainer.run() diff --git a/wbia_miew_id/visualization/gradcam.py b/wbia_miew_id/visualization/gradcam.py index 1b34fdc..90fd81b 100644 --- a/wbia_miew_id/visualization/gradcam.py +++ b/wbia_miew_id/visualization/gradcam.py @@ -172,7 +172,7 @@ def draw_one(config, test_loader, model, images_dir = '', method='gradcam_plus_p comb_image = cv2.cvtColor(comb_image, cv2.COLOR_BGR2RGB) return comb_image -def generate_embeddings(config, model, test_loader): +def generate_embeddings(device, model, test_loader): print('generating embeddings') tk0 = tqdm(test_loader, total=len(test_loader)) embeddings = [] @@ -199,7 +199,7 @@ def generate_embeddings(config, model, test_loader): images.extend(batch_image) - batch_embeddings = model(batch_image.to(config.engine.device)) + batch_embeddings = model(batch_image.to(device)) batch_embeddings = batch_embeddings.detach().cpu().numpy() @@ -221,14 +221,14 @@ def generate_embeddings(config, model, test_loader): embeddings = pd.concat(embeddings) return embeddings, labels, images, paths, bboxes, thetas -def draw_batch(config, test_loader, model, images_dir = '', method='hires_cam', eigen_smooth=False, render_transformed=False, show=False, use_cuda=True): +def draw_batch(device, test_loader, model, images_dir = '', method='hires_cam', eigen_smooth=False, render_transformed=False, show=False, use_cuda=True): print('** draw_batch started') # Generate embeddings for query and db model.eval() - embeddings, labels, images, paths, bboxes, thetas = generate_embeddings(config, model, test_loader) + embeddings, labels, images, paths, bboxes, thetas = generate_embeddings(device, model, test_loader) print('*** embeddings generated ***') @@ -247,10 +247,10 @@ def draw_batch(config, test_loader, model, images_dir = '', method='hires_cam', db_idx = 1 qry_features = embeddings.iloc[qry_idx].values - qry_features = torch.Tensor(qry_features).to(config.engine.device) + qry_features = torch.Tensor(qry_features).to(device) db_features_batch = embeddings.iloc[db_idx:].values - db_features_batch = torch.Tensor(db_features_batch).to(config.engine.device) + db_features_batch = torch.Tensor(db_features_batch).to(device) tensors = [] stack_target = [] diff --git a/wbia_miew_id/visualization/match_vis.py b/wbia_miew_id/visualization/match_vis.py index 6135a51..9c6ab00 100644 --- a/wbia_miew_id/visualization/match_vis.py +++ b/wbia_miew_id/visualization/match_vis.py @@ -43,13 +43,13 @@ def stack_match_images(images, descriptions, match_mask, text_color=(0, 0, 0)): return result -def render_single_query_result(config, model, vis_loader, df_vis, qry_row, qry_idx, vis_match_mask, k=5): +def render_single_query_result(model, vis_loader, df_vis, qry_row, qry_idx, vis_match_mask, device, output_dir, k=5): - use_cuda = False if config.engine.device in ['mps', 'cpu'] else True + use_cuda = False if device in ['mps', 'cpu'] else True batch_images = draw_batch( - config, vis_loader, model, images_dir = 'dev_test', method='gradcam_plus_plus', eigen_smooth=False, + device, vis_loader, model, images_dir = 'dev_test', method='gradcam_plus_plus', eigen_smooth=False, render_transformed=True, show=False, use_cuda=use_cuda) viewpoints = df_vis['viewpoint'].values @@ -65,7 +65,6 @@ def render_single_query_result(config, model, vis_loader, df_vis, qry_row, qry_i vis_result = stack_match_images(batch_images, descriptions, vis_match_mask) - output_dir = f"{config.checkpoint_dir}/{config.project_name}/{config.exp_name}/visualizations" output_name = f"vis_{qry_name}_{qry_viewpoint}_{qry_loc_idx}_top{k}.jpg" output_path = os.path.join(output_dir, output_name) @@ -74,10 +73,13 @@ def render_single_query_result(config, model, vis_loader, df_vis, qry_row, qry_i print(f"Saved visualization to {output_path}") -def render_query_results(config, model, test_dataset, df_test, match_results, k=5): +def render_query_results(model, test_dataset, df_test, match_results, device, k=5, + valid_batch_size=2, output_dir='miewid_visualizations'): q_pids, topk_idx, topk_names, match_mat = match_results + os.makedirs(output_dir, exist_ok=True) + print("Generating visualizations...") for i in tqdm(range(len(q_pids))): # @@ -95,7 +97,7 @@ def render_query_results(config, model, test_dataset, df_test, match_results, k= vis_loader = torch.utils.data.DataLoader( test_dataset, - batch_size=config.engine.valid_batch_size, + batch_size=valid_batch_size, num_workers=0, shuffle=False, pin_memory=True, @@ -103,4 +105,4 @@ def render_query_results(config, model, test_dataset, df_test, match_results, k= sampler = idxSampler ) - render_single_query_result(config, model, vis_loader, df_vis, qry_row, qry_idx, vis_match_mask, k=k) + render_single_query_result(model, vis_loader, df_vis, qry_row, qry_idx, vis_match_mask, device, output_dir, k=k)