Skip to content

SIDU: SImilarity Difference and Uniqueness method for explainable AI

License

Notifications You must be signed in to change notification settings

MarcoParola/pytorch-sidu

Repository files navigation

pytorch-sidu

size

SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper

  • Pytorch implementation of the SIDU method.
  • Simple interface for loading pretrained models by specifying one of the following string name
  • Clear interface for generating saliency maps

Some examples made with VGG19 on Caltech-101 dataset:

img1 img7 img9

Installation

pip install pytorch-sidu

Usage

Load models from the pretrainde ones available in pytorch

from pytorch_sidu import sidu
from pytorch_sidu.utils.utils import load_torch_model_by_string

model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)

After instantianting your model, generate saliency maps from Dataloader

data_loader = <your dataloader>
target_layer = 'layer4.2.conv2'
image, _ = next(iter(data_loader))
saliency_maps = sidu(model, target_layer, image)

A complete example on CIFAR-10

import torch
import torchvision
from matplotlib import pyplot as plt
from pytorch_sidu import sidu
from pytorch_sidu.utils.utils import load_torch_model_by_string


transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)

target_layer = 'layer4.2.conv2'
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)

for image, _ in data_loader:
    saliency_maps = sidu(model, target_layer, image)
    image, saliency_maps = image.cpu(), saliency_maps.cpu()

    for j in range(len(image)):
        plt.figure(figsize=(5, 2.5))
        plt.subplot(1, 2, 1)
        plt.imshow(image[j].permute(1, 2, 0))
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(image[j].permute(1, 2, 0))
        plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
        plt.axis('off')
        plt.show()

upcoming features:

  • integration of xai metrics
  • make methods work on both single images and dataloaders
  • adding device flag to sidu function to allow device selection