SIDU: SImilarity Difference and Uniqueness method for explainable AI from the original paper
- Pytorch implementation of the SIDU method.
- Simple interface for loading pretrained models by specifying one of the following string name
- Clear interface for generating saliency maps
Some examples made with VGG19 on Caltech-101 dataset:
pip install pytorch-sidu
Load models from the pretrainde ones available in pytorch
from pytorch_sidu import sidu
from pytorch_sidu.utils.utils import load_torch_model_by_string
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)
After instantianting your model, generate saliency maps from Dataloader
data_loader = <your dataloader>
target_layer = 'layer4.2.conv2'
image, _ = next(iter(data_loader))
saliency_maps = sidu(model, target_layer, image)
import torch
import torchvision
from matplotlib import pyplot as plt
from pytorch_sidu import sidu
from pytorch_sidu.utils.utils import load_torch_model_by_string
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)), torchvision.transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
torchvision.datasets.CIFAR10(root='./data', download=True, transform=transform), batch_size=2)
target_layer = 'layer4.2.conv2'
model_name = 'ResNet34_Weights.IMAGENET1K_V1'
model = load_torch_model_by_string(model_name)
for image, _ in data_loader:
saliency_maps = sidu(model, target_layer, image)
image, saliency_maps = image.cpu(), saliency_maps.cpu()
for j in range(len(image)):
plt.figure(figsize=(5, 2.5))
plt.subplot(1, 2, 1)
plt.imshow(image[j].permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(image[j].permute(1, 2, 0))
plt.imshow(saliency_maps[j].squeeze().detach().numpy(), cmap='jet', alpha=0.4)
plt.axis('off')
plt.show()
- integration of xai metrics
- make methods work on both single images and dataloaders
- adding
device
flag to sidu function to allow device selection