diff --git a/tutorials/mct_model_garden/evaluation_metrics/anomaly_eval.py b/tutorials/mct_model_garden/evaluation_metrics/anomaly_eval.py deleted file mode 100644 index 6dd678a59..000000000 --- a/tutorials/mct_model_garden/evaluation_metrics/anomaly_eval.py +++ /dev/null @@ -1,54 +0,0 @@ - -import os -from tqdm import tqdm -import numpy as np -import torch - -from torchvision import transforms -from sklearn.metrics import roc_auc_score -import tifffile -from tutorials.resources.utils.efficient_ad_utils import ImageFolderWithPath, predict_combined - -# Global constants -IMAGE_SIZE = 256 -OUT_CHANNELS = 384 -SEED = 42 - -# Transform definitions -DEFAULT_TRANSFORM = transforms.Compose([ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) - - -def benchmark(unified_model, name, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None): - """Benchmark the model by testing it on a dataset and printing the AUC score.""" - dataset_path = './mvtec_anomaly_detection' - test_output_dir = os.path.join('output', 'anomaly_maps', name, 'bottle', 'test') - test_set = ImageFolderWithPath(os.path.join(dataset_path, 'bottle', 'test')) - unified_model.eval() - auc = test(test_set=test_set, unified_model=unified_model, test_output_dir=test_output_dir, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None, desc='Final inference') - print('Final image auc: {:.4f}'.format(auc)) - -def test(test_set, unified_model, test_output_dir=None, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None, desc='Running inference'): - """Test the model and calculate the AUC score.""" - y_true, y_score = [], [] - for image, target, path in tqdm(test_set, desc=desc): - orig_width, orig_height = image.size - image = DEFAULT_TRANSFORM(image)[None] # Add batch dimension - if torch.cuda.is_available(): - image = image.cuda() - map_combined , _, _ = predict_combined(image, unified_model, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None) - map_combined = torch.nn.functional.interpolate(map_combined, (orig_height, orig_width), mode='bilinear') - map_combined = map_combined[0, 0].detach().cpu().numpy() - defect_class = os.path.basename(os.path.dirname(path)) - if test_output_dir: - img_nm = os.path.split(path)[1].split('.')[0] - defect_dir = os.path.join(test_output_dir, defect_class) - os.makedirs(defect_dir, exist_ok=True) - tifffile.imwrite(os.path.join(defect_dir, img_nm + '.tiff'), map_combined) - y_true.append(0 if defect_class == 'good' else 1) - y_score.append(np.max(map_combined)) - auc = roc_auc_score(y_true=y_true, y_score=y_score) - return auc * 100 \ No newline at end of file diff --git a/tutorials/mct_model_garden/models_pytorch/Efficient_Anomaly_Det/efficient_ad.py b/tutorials/mct_model_garden/models_pytorch/Efficient_Anomaly_Det/efficient_ad.py deleted file mode 100644 index 0cd6e4e4f..000000000 --- a/tutorials/mct_model_garden/models_pytorch/Efficient_Anomaly_Det/efficient_ad.py +++ /dev/null @@ -1,227 +0,0 @@ -# The following code was mostly duplicated from https://github.com/nelson1425/EfficientAD -# and changed to generate an equivalent PyTorch model suitable for quantization. -# Main changes: -# * Modify layers to make them more suitable for quantization. -# * Inheritance class from HuggingFace -# * Uninfied version of model combining the three subversions -# ============================================================================== -""" -Efficient Anomaly Detection Model - PyTorch implementation - -This code contains a PyTorch implementation of efficient ad model, following -https://github.com/nelson1425/EfficientAD. This implementation includes a unified version of the model that combines the three submodels -into one to ease the process of quantization and deployment. - -The code is organized as follows: -- -- primary model definition - UnifiedAnomalyDetectionModel -- sub models -- auto encoder - get_autoencoder -- student and teacher models - get_pdn_small - -For more details on the model, refer to the original repository: -https://github.com/nelson1425/EfficientAD - -""" -from torch import nn -from torchvision.datasets import ImageFolder -import torch -import json - -def get_autoencoder(out_channels=384): - """ - Constructs an autoencoder model with specified output channels. - - Parameters: - - out_channels (int): The number of output channels in the final convolutional layer. - - Returns: - - nn.Sequential: A PyTorch sequential model representing the autoencoder. - """ - return nn.Sequential( - # encoder - nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=8), - # decoder - nn.Upsample(size=3, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=8, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=15, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=32, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=63, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=127, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, - padding=2), - nn.ReLU(inplace=True), - nn.Dropout(0.2), - nn.Upsample(size=56, mode='bilinear'), - nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, - padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=3, - stride=1, padding=1) - ) - -def get_pdn_small(out_channels=384, padding=False): - """ - Constructs a small PDN (Pyramidal Decomposition Network) model. - - Parameters: - - out_channels (int): The number of output channels in the final convolutional layer. - - padding (bool): If True, applies padding to convolutional layers. - - Returns: - - nn.Sequential: A PyTorch sequential model representing the small PDN. - """ - pad_mult = 1 if padding else 0 - return nn.Sequential( - nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4, - padding=3 * pad_mult), - nn.ReLU(inplace=True), - nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult), - nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, - padding=3 * pad_mult), - nn.ReLU(inplace=True), - nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult), - nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, - padding=1 * pad_mult), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4) - ) - -class UnifiedAnomalyDetectionModel(nn.Module): - """ - A unified model for anomaly detection combining teacher, student, and autoencoder models. - - Parameters: - - teacher (nn.Module): The teacher model. - - student (nn.Module): The student model. - - autoencoder (nn.Module): The autoencoder model. - - out_channels (int): Number of output channels in the student's output used for comparison. - - teacher_mean (float): Mean used for normalizing the teacher's output. - - teacher_std (float): Standard deviation used for normalizing the teacher's output. - - q_st_start (float, optional): Start quantile for student-teacher comparison normalization. - - q_st_end (float, optional): End quantile for student-teacher comparison normalization. - - q_ae_start (float, optional): Start quantile for autoencoder-student comparison normalization. - - q_ae_end (float, optional): End quantile for autoencoder-student comparison normalization. - - Methods: - - forward(input_image): Processes the input image through the model. - - save_model(filepath): Saves the model state to a file. - - load_model(filepath, teacher_model, student_model, autoencoder_model): Loads the model state from a file. - """ - def __init__(self, teacher, student, autoencoder, out_channels, teacher_mean, teacher_std, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None): - super(UnifiedAnomalyDetectionModel, self).__init__() - self.teacher = teacher - self.student = student - self.autoencoder = autoencoder - self.out_channels = out_channels - self.teacher_mean = teacher_mean - self.teacher_std = teacher_std - self.q_st_start = q_st_start - self.q_st_end = q_st_end - self.q_ae_start = q_ae_start - self.q_ae_end = q_ae_end - - def forward(self, input_image): - teacher_output = self.teacher(input_image) - student_output = self.student(input_image) - autoencoder_output = self.autoencoder(input_image) - teacher_output = (teacher_output - self.teacher_mean) / self.teacher_std - student_output_st = student_output[:, :self.out_channels] - student_output_ae = student_output[:, self.out_channels:] - - # Calculate MSE between teacher-student and autoencoder-student - mse_st = (teacher_output - student_output_st) * (teacher_output - student_output_st) - mse_ae = (autoencoder_output - student_output_ae) * (autoencoder_output - student_output_ae) - - return mse_st , mse_ae - - def save_model(self, filepath): - """ Save the entire model including sub-models and parameters """ - model_info = { - 'model_state_dict': self.state_dict(), - 'out_channels': self.out_channels, - 'teacher_mean': self.teacher_mean.tolist(), - 'teacher_std': self.teacher_std.tolist(), - 'q_st_start': self.q_st_start.tolist() if self.q_st_start is not None else None, - 'q_st_end': self.q_st_end.tolist() if self.q_st_end is not None else None, - 'q_ae_start': self.q_ae_start.tolist() if self.q_ae_start is not None else None, - 'q_ae_end': self.q_ae_end.tolist() if self.q_ae_end is not None else None - } - torch.save(model_info, filepath) - - @staticmethod - def load_model(filepath, teacher_model, student_model, autoencoder_model): - """ Load the entire model including sub-models and parameters """ - model_info = torch.load(filepath) - model = UnifiedAnomalyDetectionModel( - teacher=teacher_model, - student=student_model, - autoencoder=autoencoder_model, - out_channels=model_info['out_channels'], - teacher_mean=torch.tensor(model_info['teacher_mean']), - teacher_std=torch.tensor(model_info['teacher_std']), - q_st_start=torch.tensor(model_info['q_st_start']) if model_info['q_st_start'] is not None else None, - q_st_end=torch.tensor(model_info['q_st_end']) if model_info['q_st_end'] is not None else None, - q_ae_start=torch.tensor(model_info['q_ae_start']) if model_info['q_ae_start'] is not None else None, - q_ae_end=torch.tensor(model_info['q_ae_end']) if model_info['q_ae_end'] is not None else None - ) - model.load_state_dict(model_info['model_state_dict']) - return model - - - -"""Model taining example usage - google colab colab - -from tutorials.mct_model_garden.models_pytorch.Efficient_Anomaly_Det import get_pdn_small, get_autoencoder -from tutorials.resources.utils.efficient_ad_utils import train_ad - -dataset_path = './mvtec_anomaly_detection' -sub_dataset = 'bottle' -train_steps = 70000 -out_channels = 384 -image_size = 256 -teacher_weights = 'drive/MyDrive/anom/teacher_final.pth' -teacher = get_pdn_small(out_channels) -student = get_pdn_small(2 * out_channels) -loaded_model = torch.load(teacher_weights, map_location='cpu') -# Extract the state_dict from the loaded model -state_dict = loaded_model.state_dict() -teacher.load_state_dict(state_dict) -autoencoder = get_autoencoder(out_channels) - -train_ad(train_steps, dataset_path, sub_dataset, autoencoder, teacher, student)""" \ No newline at end of file diff --git a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_efficient_anomaly_detection.ipynb b/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_efficient_anomaly_detection.ipynb deleted file mode 100644 index 6d5917bd0..000000000 --- a/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_efficient_anomaly_detection.ipynb +++ /dev/null @@ -1,842 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "xGqjJDA1AaJo" - }, - "source": [ - "# Anomaly Detection Training Benchmark and Quantization for IMX500" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jHiBhX_DkkGH" - }, - "source": [ - "[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/imx500_notebooks/pytorch/pytorch_efficient_anomaly_detection.ipynb)\n", - "\n", - "### Overview\n", - "\n", - "In this tutorial we demonstrate training, quantization and benchmarking of an anomaly detection model. The resulting model will be imx500 compatible" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LF92AFYDJGd0" - }, - "source": [ - "Classification models are powerful and reliable, but what if you have little or no examples of one of your classes, what if one of your classes contains too much unpredictable variation?\n", - "\n", - "Here we go through the process of building, training and quantizing an anomaly detection model designed to solve exactly these problems.\n", - "\n", - "Anomaly detection models are useful as they only require your typical images to train and can in theory determin anything that is not typical.\n", - "\n", - "We use Efficient ad, one of the top performing anomaly detection models on the mvtec benchmark. benchmark leader board can be found [here](https://paperswithcode.com/sota/anomaly-detection-on-mvtec-ad)\n", - "\n", - "This particular model uses the teacher student method. Where the student model is trained to both mimic the feature map output of a simple pre-trained CNN aswell as mimic the output of an auto encoder that is its self also trained on the normal images.\n", - "\n", - "We use the [mvtec](https://www.mvtec.com/company/research/datasets/mvtec-ad) dataset to benchmark and train this model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kvXjuRWEnA6f" - }, - "source": [ - "## Summary\n", - "\n", - "In this tutorial we will cover for an anomaly detection model:\n", - "\n", - "1. Mvtec Benchmark\n", - "2. Post training quantization.\n", - "3. Visulization\n", - "4. Training this model." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0y6W3GB6M3eZ" - }, - "source": [ - "## Setup\n", - "\n", - "### install relevant packages" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tCXBzuGxAFxU", - "outputId": "d45e5967-9471-47ba-e92d-da32ed116852" - }, - "outputs": [], - "source": [ - "!pip install torch\n", - "!pip install torchvision\n", - "!pip install tifffile\n", - "!pip install tqdm\n", - "!pip install scikit-learn\n", - "!pip install Pillow\n", - "!pip install scipy\n", - "!pip install tabulate" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Install MCT (if it’s not already installed). Additionally, in order to use all the necessary utility functions for this tutorial, we also copy [MCT tutorials folder](https://github.com/sony/model_optimization/tree/main/tutorials) and add it to the system path." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import importlib\n", - "\n", - "if not importlib.util.find_spec('model_compression_toolkit'):\n", - " !pip install model_compression_toolkit\n", - "!git clone https://github.com/sony/model_optimization.git temp_mct && mv temp_mct/tutorials . && \\rm -rf temp_mct\n", - "sys.path.insert(0,\"tutorials\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tGqXJH15oLvm" - }, - "source": [ - "Download and extract the mvtec benchmark dataset. This is used for both training and evaluation. Link below is a direct link from the mcvtec website. \n", - "\n", - "For more information on the Mvtec Benchmark dataset, visit: https://www.mvtec.com/company/research/datasets/mvtec-ad" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "qOqdM0peKsZ9", - "outputId": "2c8c5730-71a8-4fe9-8108-046a24780903" - }, - "outputs": [], - "source": [ - "!mkdir mvtec_anomaly_detection\n", - "!wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz\n", - "!tar -xvf mvtec_anomaly_detection.tar.xz -C mvtec_anomaly_detection" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TX_WMo1zomnk" - }, - "source": [ - "Finally download the official mvtec benchmark. Link below is a direct link from the mcvtec website." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uTRG_ebOZQZJ", - "outputId": "b8218f5a-856c-4969-84b5-a6f09f67c8f4" - }, - "outputs": [], - "source": [ - "!wget https://www.mydrive.ch/shares/60736/698155e0e6d0467c4ff6203b16a31dc9/download/439517473-1665667812/mvtec_ad_evaluation.tar.xz\n", - "!tar -xvf mvtec_ad_evaluation.tar.xz\n", - "!rm mvtec_ad_evaluation.tar.xz" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_nvWU5--PJy0" - }, - "source": [ - "## Model Quantization\n", - "\n", - "### Download and Build Model\n", - "\n", - "We have pretrained a model on the bottle dataset from mvtec. Here we will load the combined model from huggingface, a combination of (teacher, student, and autoenoder)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 201, - "referenced_widgets": [ - "0a490fcaf8e44211a7138ff84d617788", - "a95f884036e648739e1c3b7d87269219", - "dddd26a508d64fd4bec206b4a05eb382", - "756e67b40ca44fec9f2941aaa547a130", - "34f0d6010c3b4eba83efd1a113e6312d", - "ac6f54e6ee604cd99208674531fcca17", - "3cbc9af0fb2148edb893e369649c2edc", - "7618819d67624f4ead9cd06fdedb5de5", - "5dd006f38b314a88b573a79e4d4955fc", - "5d8a2e15ec8c47158ede018efe0d8353", - "c7723d0bb97b4a1a9c5bd4434e6a17ff" - ] - }, - "id": "PuTz0cryyRD6", - "outputId": "dd922691-e229-410d-bea9-35977f4b6279" - }, - "outputs": [], - "source": [ - "from huggingface_hub import hf_hub_download\n", - "import tutorials.mct_model_garden.models_pytorch.Efficient_Anomaly_Det.efficient_ad import get_pdn_small, get_autoencoder, UnifiedAnomalyDetectionModel\n", - "\n", - "out_channels = 384\n", - "\n", - "model_path = hf_hub_download(repo_id=\"SSI-DNN/Efficient_Anomaly_Detection\", filename=\"efficientAD_bottle.pth\")\n", - "\n", - "teacher = get_pdn_small(out_channels)\n", - "student = get_pdn_small(2 * out_channels)\n", - "autoencoder = get_autoencoder(out_channels)\n", - "\n", - "model = efficient_ad.UnifiedAnomalyDetectionModel.load_model(model_path, teacher, student, autoencoder)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BCOLpRKLqi6Z" - }, - "source": [ - "### Post training quantization using Model Compression Toolkit\n", - "\n", - "To perform model quantization we require a representative dataset. Because of the specific requirements of anomaly detection models (the assumption that anomalous images are not seen) we restrict the representative dataset to as such." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "481cEIZSa0bl", - "outputId": "a77f69b4-341d-498d-f559-8f7978fab0d9" - }, - "outputs": [], - "source": [ - "import model_compression_toolkit as mct\n", - "from torch.utils.data import DataLoader\n", - "import torchvision.transforms as transforms\n", - "from typing import Iterator, List\n", - "from tutorials.resources.utils.efficient_ad_utils import ImageFolderWithoutTarget, train_transform\n", - "\n", - "\n", - "def train_dataset_generator(train_loader: DataLoader) -> Iterator[List]:\n", - " while True:\n", - " for data, _ in train_loader:\n", - " yield [data.numpy()]\n", - "\n", - "def get_representative_dataset(n_iter: int, dataset_loader: Iterator[List]) -> Iterator[List]:\n", - " def representative_dataset() -> Iterator[List]:\n", - " ds_iter = iter(dataset_loader)\n", - " for _ in range(n_iter):\n", - " yield next(ds_iter)\n", - " return representative_dataset\n", - "\n", - "train_set = ImageFolderWithoutTarget(\n", - " os.path.join('./mvtec_anomaly_detection', 'bottle', 'train'),\n", - " transform=transforms.Lambda(train_transform))\n", - "\n", - "train_loader = DataLoader(train_set, batch_size=4, shuffle=True) # Ensure this matches your batch size and other DataLoader settings\n", - "train_dataset = train_dataset_generator(train_loader)\n", - "representative_dataset_gen = get_representative_dataset(n_iter=20, dataset_loader=train_dataset)\n", - "\n", - "# Set target platform capabilities\n", - "tpc = mct.get_target_platform_capabilities(fw_name=\"pytorch\", target_platform_name='imx500', target_platform_version='v1')\n", - "\n", - "# Perform post training quantization\n", - "quant_model, _ = mct.ptq.pytorch_post_training_quantization(in_module=model,\n", - " representative_data_gen=representative_dataset_gen,\n", - " target_platform_capabilities=tpc)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fZXaCUenskKG" - }, - "source": [ - "## Model Export\n", - "\n", - "This model can be converted to run on imx500.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jrKrsmSasm34" - }, - "outputs": [], - "source": [ - "import model_compression_toolkit as mct\n", - "\n", - "mct.exporter.pytorch_export_model(model=quant_model,\n", - " save_model_path='./quant_model.onnx',\n", - " repr_dataset=representative_dataset_gen)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tZXDd4bXA_4E" - }, - "source": [ - "## Float Benchmark\n", - "\n", - "Mvtec benchmark provides its own code. we first need to run the model on the test images and save the output heat maps and anomaly confidence. then run mvtecs benchmark." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#We first need to calculate the normalisation values\n", - "tutorials.resources.utils.efficient_ad_utils import map_normalization, teacher_normalization\n", - "\n", - "teacher_mean, teacher_std = teacher_normalization(teacher, train_loader)\n", - "q_st_start, q_st_end, q_ae_start, q_ae_end = map_normalization(\n", - " validation_loader=train_loader, teacher=teacher, student=student,\n", - " autoencoder=autoencoder, teacher_mean=teacher_mean,\n", - " teacher_std=teacher_std, desc='Final map normalization')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yP2FGLY8BMlV", - "outputId": "0092a064-6893-41b1-f084-1afd8c807333" - }, - "outputs": [], - "source": [ - "from tutorials.mct_model_garden.evaluation_metrics.anomaly_eval import benchmark\n", - "benchmark(model, 'mvtec_ad', q_st_start, q_st_end, q_ae_start, q_ae_end)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ITfY-C6ylrUi" - }, - "source": [ - "### Mvtec benchmark\n", - "\n", - "This results in a classification accuracy AU-ROC and a segmentation accuracy AU-PRO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KeelewrtZdUA", - "outputId": "7b2cf763-29a5-4b2a-fadd-b18347a05d6f" - }, - "outputs": [], - "source": [ - "!python ./mvtec_ad_evaluation/evaluate_experiment.py --dataset_base_dir './mvtec_anomaly_detection/' --anomaly_maps_dir './output/anomaly_maps/mvtec_ad/' --output_dir './output/metrics/mvtec_ad/' --evaluated_objects bottle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_fPTNtJLgPc5" - }, - "source": [ - "## Quantized model benchmark\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "siM716G2506g", - "outputId": "3e5253b7-c3ac-408d-fd15-3e586744523d" - }, - "outputs": [], - "source": [ - "from tutorials.mct_model_garden.evaluation_metrics.anomaly_eval import benchmark\n", - "benchmark(quant_model, 'mvtec_ad_quant', q_st_start, q_st_end, q_ae_start, q_ae_end)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uVbfzRVv6YJC", - "outputId": "9d5cb7da-d239-47cd-d019-3de2921f5d60" - }, - "outputs": [], - "source": [ - "!python ./mvtec_ad_evaluation/evaluate_experiment.py --dataset_base_dir './mvtec_anomaly_detection/' --anomaly_maps_dir './output/anomaly_maps/mvtec_ad_quant/' --output_dir './output/metrics/mvtec_ad/' --evaluated_objects bottle" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xD2E4k0XibZ_" - }, - "source": [ - "## Anomaly Map Visulization\n", - "\n", - "We can visulize the heatmap of the predicted anomalies with the code below. Here red spots indicate locations with a high likely hood of defect, based on its training images. We also print the models prediction.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 1000 - }, - "id": "SfK27g5J95Kd", - "outputId": "3266c83d-2ea6-4b1b-ea15-290053553a67" - }, - "outputs": [], - "source": [ - "from tutorials.resources.utils.efficient_ad_utils import visualize_anomalies\n", - "import os\n", - "name = 'visulize'\n", - "dataset_path = './mvtec_anomaly_detection'\n", - "test_output_dir = os.path.join('output', 'anomaly_maps',\n", - " name, 'bottle', 'test')\n", - "model.eval()\n", - "visualize_anomalies(model, dataset_path, test_output_dir, q_st_start, q_st_end, q_ae_start, q_ae_end)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n1zwMJyri8by" - }, - "source": [ - "## Conclusion\n", - "\n", - "In this notebook we provide examples on how to quantize and benchmark the latest anomaly detection model as well as providing code to visulize the models output." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\\\n", - "Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.\n", - "\n", - "Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "you may not use this file except in compliance with the License.\n", - "You may obtain a copy of the License at\n", - "\n", - " http://www.apache.org/licenses/LICENSE-2.0\n", - "\n", - "Unless required by applicable law or agreed to in writing, software\n", - "distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "See the License for the specific language governing permissions and\n", - "limitations under the License." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "0a490fcaf8e44211a7138ff84d617788": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a95f884036e648739e1c3b7d87269219", - "IPY_MODEL_dddd26a508d64fd4bec206b4a05eb382", - "IPY_MODEL_756e67b40ca44fec9f2941aaa547a130" - ], - "layout": "IPY_MODEL_34f0d6010c3b4eba83efd1a113e6312d" - } - }, - "34f0d6010c3b4eba83efd1a113e6312d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3cbc9af0fb2148edb893e369649c2edc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "5d8a2e15ec8c47158ede018efe0d8353": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "5dd006f38b314a88b573a79e4d4955fc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "756e67b40ca44fec9f2941aaa547a130": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_5d8a2e15ec8c47158ede018efe0d8353", - "placeholder": "​", - "style": "IPY_MODEL_c7723d0bb97b4a1a9c5bd4434e6a17ff", - "value": " 32.3M/32.3M [00:02<00:00, 13.9MB/s]" - } - }, - "7618819d67624f4ead9cd06fdedb5de5": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a95f884036e648739e1c3b7d87269219": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_ac6f54e6ee604cd99208674531fcca17", - "placeholder": "​", - "style": "IPY_MODEL_3cbc9af0fb2148edb893e369649c2edc", - "value": "efficientAD_bottle.pth: 100%" - } - }, - "ac6f54e6ee604cd99208674531fcca17": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "c7723d0bb97b4a1a9c5bd4434e6a17ff": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "dddd26a508d64fd4bec206b4a05eb382": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_7618819d67624f4ead9cd06fdedb5de5", - "max": 32265074, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_5dd006f38b314a88b573a79e4d4955fc", - "value": 32265074 - } - } - } - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tutorials/resources/utils/efficient_ad_utils.py b/tutorials/resources/utils/efficient_ad_utils.py deleted file mode 100644 index cb0c69167..000000000 --- a/tutorials/resources/utils/efficient_ad_utils.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import os -import random -import shutil -import itertools -from tqdm import tqdm -import numpy as np -import torch -from torch.utils.data import DataLoader -from torchvision import transforms, datasets -from sklearn.metrics import roc_auc_score -import tifffile -from PIL import Image -import matplotlib.pyplot as plt -import matplotlib.cm as cm - - -# Global constants -IMAGE_SIZE = 256 -OUT_CHANNELS = 384 -SEED = 42 - -# Transform definitions -DEFAULT_TRANSFORM = transforms.Compose([ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) -]) - -TRANSFORM_AE = transforms.RandomChoice([ - transforms.ColorJitter(brightness=0.2), - transforms.ColorJitter(contrast=0.2), - transforms.ColorJitter(saturation=0.2) -]) - -class ImageFolderWithoutTarget(datasets.ImageFolder): - """Custom dataset that includes only images, no labels.""" - def __getitem__(self, index): - sample, _ = super().__getitem__(index) - return sample - -class ImageFolderWithPath(datasets.ImageFolder): - """Custom dataset that includes image paths along with images and labels.""" - def __getitem__(self, index): - path, target = self.samples[index] - sample, target = super().__getitem__(index) - return sample, target, path - -def infinite_dataloader(loader): - """Create an infinite dataloader that cycles through the dataset.""" - for data in itertools.cycle(loader): - yield data - -def train_transform(image): - """Apply transformations to the training images.""" - return DEFAULT_TRANSFORM(image), DEFAULT_TRANSFORM(TRANSFORM_AE(image)) - -def visualize_anomalies(unified_model, dataset_path, test_output_dir=None, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None, desc='Running inference'): - """Visualize anomalies by overlaying heatmaps on the original images.""" - test_set = ImageFolderWithPath(os.path.join(dataset_path, 'bottle', 'test')) - images_to_display = random.sample(list(test_set), 10) # Randomly select 10 images to display - y_true, y_score = [], [] - - for image, target, path in tqdm(images_to_display, desc=desc): - orig_width, orig_height = image.size - image_tensor = DEFAULT_TRANSFORM(image)[None] # Add batch dimension - map_combined , _, _ = predict_combined(image_tensor, unified_model, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None) - map_combined = torch.nn.functional.interpolate(map_combined, (orig_height, orig_width), mode='bilinear') - map_combined = map_combined[0, 0].detach().cpu().numpy() - - heatmap = cm.jet(map_combined) # Apply colormap - heatmap = np.uint8(cm.ScalarMappable(cmap='jet').to_rgba(map_combined) * 255) - heatmap_pil = Image.fromarray(heatmap, 'RGBA').convert('RGB') # Convert RGBA to RGB - image_pil = image.convert('RGB') # Ensure the original image is in RGB - - combined_image = Image.new('RGB', (orig_width * 2, orig_height)) - combined_image.paste(image_pil, (0, 0)) - combined_image.paste(heatmap_pil, (orig_width, 0)) - - defect_class = os.path.basename(os.path.dirname(path)) - y_true_image = 0 if defect_class == 'good' else 1 - y_score_image = np.max(map_combined) - y_true.append(y_true_image) - y_score.append(y_score_image) - - prediction_correct = 'Correct' if (y_score_image > 0.5) == y_true_image else 'Incorrect' - defect_status = 'Defect' if y_true_image == 1 else 'No Defect' - - plt.figure(figsize=(12, 6)) - plt.imshow(combined_image) - plt.title(f"Actual: {defect_status}, Prediction: {prediction_correct}") - plt.axis('off') - plt.show() - - auc = roc_auc_score(y_true=y_true, y_score=y_score) - return auc * 100 - -def train_ad(train_steps, dataset_path, sub_dataset, autoencoder, teacher, student): - """Train the anomaly detection model.""" - torch.manual_seed(SEED) - np.random.seed(SEED) - random.seed(SEED) - - train_output_dir = os.path.join('output', 'trainings', 'mvtec_ad', sub_dataset) - test_output_dir = os.path.join('output', 'anomaly_maps', 'mvtec_ad', sub_dataset, 'test') - shutil.rmtree(train_output_dir, ignore_errors=True) - shutil.rmtree(test_output_dir, ignore_errors=True) - os.makedirs(train_output_dir) - os.makedirs(test_output_dir) - - full_train_set = ImageFolderWithoutTarget(os.path.join(dataset_path, sub_dataset, 'train'), transform=transforms.Lambda(train_transform)) - train_size = int(0.9 * len(full_train_set)) - validation_size = len(full_train_set) - train_size - train_set, validation_set = torch.utils.data.random_split(full_train_set, [train_size, validation_size], generator=torch.Generator().manual_seed(SEED)) - - train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=4, pin_memory=True) - validation_loader = DataLoader(validation_set, batch_size=1) - train_loader_infinite = infinite_dataloader(train_loader) - - teacher.eval() - student.train() - autoencoder.train() - - if torch.cuda.is_available(): - teacher.cuda() - student.cuda() - autoencoder.cuda() - - teacher_mean, teacher_std = teacher_normalization(teacher, train_loader) - optimizer = torch.optim.Adam(itertools.chain(student.parameters(), autoencoder.parameters()), lr=1e-4, weight_decay=1e-5) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.95 * train_steps), gamma=0.1) - - for iteration, (image_st, image_ae) in enumerate(train_loader_infinite, start=1): - if torch.cuda.is_available(): - image_st = image_st.cuda() - image_ae = image_ae.cuda() - - student_output_st = student(image_st)[:, :OUT_CHANNELS] - loss_st = torch.mean((teacher(image_st) - student_output_st) ** 2) - - ae_output = autoencoder(image_ae) - loss_ae = torch.mean((teacher(image_ae) - ae_output) ** 2) - - loss_total = loss_st + loss_ae - optimizer.zero_grad() - loss_total.backward() - optimizer.step() - scheduler.step() - - if iteration % 10 == 0: - tqdm.write(f"Step {iteration}, Loss: {loss_total.item():.4f}") - - if iteration >= train_steps: - break - - torch.save(teacher, os.path.join(train_output_dir, 'teacher_final.pth')) - torch.save(student, os.path.join(train_output_dir, 'student_final.pth')) - torch.save(autoencoder, os.path.join(train_output_dir, 'autoencoder_final.pth')) - -@torch.no_grad() -def predict(image, teacher, student, autoencoder, teacher_mean, teacher_std, - q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None): - """Predict using the trained models and calculate anomaly maps.""" - teacher_output = teacher(image) - teacher_output = (teacher_output - teacher_mean) / teacher_std - student_output = student(image) - autoencoder_output = autoencoder(image) - map_st = torch.mean((teacher_output - student_output[:, :OUT_CHANNELS])**2, dim=1, keepdim=True) - map_ae = torch.mean((autoencoder_output - student_output[:, OUT_CHANNELS:])**2, dim=1, keepdim=True) - if q_st_start is not None: - map_st = 0.1 * (map_st - q_st_start) / (q_st_end - q_st_start) - if q_ae_start is not None: - map_ae = 0.1 * (map_ae - q_ae_start) / (q_ae_end - q_ae_start) - map_combined = 0.5 * map_st + 0.5 * map_ae - return map_combined, map_st, map_ae - -@torch.no_grad() -def predict_combined(image, unified_model, q_st_start=None, q_st_end=None, q_ae_start=None, q_ae_end=None): - """Predict using the trained models and calculate anomaly maps.""" - map_st, map_ae = unified_model(image) - map_st = torch.mean(map_st, dim=1, keepdim=True) - map_ae = torch.mean(map_ae, dim=1, keepdim=True) - if q_st_start is not None: - map_st = 0.1 * (map_st - q_st_start) / (q_st_end - q_st_start) - if q_ae_start is not None: - map_ae = 0.1 * (map_ae - q_ae_start) / (q_ae_end - q_ae_start) - map_combined = 0.5 * map_st + 0.5 * map_ae - return map_combined, map_st, map_ae - -@torch.no_grad() -def map_normalization(validation_loader, teacher, student, autoencoder, teacher_mean, teacher_std, desc='Map normalization'): - """Normalize the anomaly maps generated by the models.""" - maps_st, maps_ae = [], [] - for image, _ in tqdm(validation_loader, desc=desc): - if torch.cuda.is_available(): - image = image.cuda() - map_combined, map_st, map_ae = predict(image, teacher, student, autoencoder, teacher_mean, teacher_std) - maps_st.append(map_st) - maps_ae.append(map_ae) - maps_st = torch.cat(maps_st) - maps_ae = torch.cat(maps_ae) - q_st_start = torch.quantile(maps_st, 0.9) - q_st_end = torch.quantile(maps_st, 0.995) - q_ae_start = torch.quantile(maps_ae, 0.9) - q_ae_end = torch.quantile(maps_ae, 0.995) - return q_st_start, q_st_end, q_ae_start, q_ae_end - -@torch.no_grad() -def teacher_normalization(teacher, train_loader): - """Calculate the normalization parameters for the teacher model outputs.""" - mean_outputs, mean_distances = [], [] - for train_image, _ in tqdm(train_loader, desc='Computing normalization parameters'): - if torch.cuda.is_available(): - train_image = train_image.cuda() - teacher_output = teacher(train_image) - mean_output = torch.mean(teacher_output, dim=[0, 2, 3]) - mean_outputs.append(mean_output) - distance = (teacher_output - mean_output[None, :, None, None]) ** 2 - mean_distance = torch.mean(distance, dim=[0, 2, 3]) - mean_distances.append(mean_distance) - channel_mean = torch.mean(torch.stack(mean_outputs), dim=0)[None, :, None, None] - channel_std = torch.sqrt(torch.mean(torch.stack(mean_distances), dim=0))[None, :, None, None] - return channel_mean, channel_std \ No newline at end of file