Skip to content

Commit

Permalink
Merge pull request #100 from jain18ayush/documentation/main-demo
Browse files Browse the repository at this point in the history
Added documentation for main demo as listed in Issue facebookresearch#88
  • Loading branch information
soniajoseph authored Jul 2, 2024
2 parents 9d31d62 + 4a7ce2b commit b0ad5f9
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 25 deletions.
14 changes: 10 additions & 4 deletions src/vit_prisma/models/base_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,16 @@ def run_with_cache(
else:
return out, cache_dict

def tokens_to_residual_directions(self, labels):
'''
Logit-lens related funtions not implemented; see how we can implement a vision equivalent.
'''
def tokens_to_residual_directions(self, labels: torch.Tensor) -> torch.Tensor:
"""
Computes the residual directions for given labels.
Args:
labels (torch.Tensor): A 1D tensor of label indices with shape (batch_size,).
Returns:
torch.Tensor: The residual directions with shape (batch_size, d_model).
"""

answer_residual_directions = self.head.W_H[:,labels]
answer_residual_directions = einops.rearrange(
Expand Down
23 changes: 21 additions & 2 deletions src/vit_prisma/prisma_tools/hooked_root_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,28 @@ def run_with_hooks(
*model_args,
fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
reset_hooks_end = True,
clear_contexts = False,
reset_hooks_end: bool = True,
clear_contexts: bool = False,
):
"""
Executes the model with specified forward and backward hooks.
Args:
model_args: The arguments to be passed to the model's forward method.
fwd_hooks (List[Tuple[Union[str, Callable], Callable]], optional): A list of tuples specifying forward hooks.
Each tuple contains a string (layer name) or a callable to match layers, and a callable hook function.
bwd_hooks (List[Tuple[Union[str, Callable], Callable]], optional): A list of tuples specifying backward hooks.
Each tuple contains a string (layer name) or a callable to match layers, and a callable hook function.
reset_hooks_end (bool, optional): Whether to reset the hooks at the end of the run. Default is True.
clear_contexts (bool, optional): Whether to clear contexts at the end of the run. Default is False.
Returns:
The output of the model's forward method.
Raises:
Warning: If backward hooks are provided and reset_hooks_end is True, a warning is logged that hooks will be reset before a backward pass can occur.
"""

if len(bwd_hooks) > 0 and reset_hooks_end:
logging.warning(
"WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
Expand Down
40 changes: 38 additions & 2 deletions src/vit_prisma/prisma_tools/logit_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,28 @@
import numpy as np
import torch
from collections import defaultdict
from typing import Union, Optional, Dict, List, Tuple

from vit_prisma.utils.data_utils.imagenet_dict import IMAGENET_DICT
from vit_prisma.utils.data_utils.imagenet_utils import imagenet_index_from_word


def get_patch_logit_directions(cache, all_answers, incl_mid=False, return_labels=True):
def get_patch_logit_directions(cache, all_answers: torch.Tensor, incl_mid: bool = False, return_labels: bool = True) -> tuple:
"""
Computes the patch logit directions based on accumulated residuals from the cache.
Args:
cache: An object that provides methods to access and process model residuals.
all_answers (torch.Tensor): A tensor containing all possible answers with shape (num_answers, d_model).
incl_mid (bool, optional): Whether to include intermediate layers. Default is False.
return_labels (bool, optional): Whether to return labels along with the result. Default is True.
Returns:
tuple: A tuple containing:
- result (torch.Tensor): The computed logit directions with shape (batch_size, num_patches, num_labels, num_answers).
- labels: Labels associated with the accumulated residuals, if `return_labels` is True.
"""

accumulated_residual, labels = cache.accumulated_resid(
layer=-1, incl_mid=incl_mid, return_labels=True
)
Expand All @@ -28,7 +44,27 @@ def get_patch_logit_directions(cache, all_answers, incl_mid=False, return_labels
result = result.permute(1, 2, 0, 3)
return result, labels

def get_patch_logit_dictionary(patch_logit_directions, batch_idx=0, rank_label=None):
def get_patch_logit_dictionary(
patch_logit_directions: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
batch_idx: int = 0,
rank_label: Optional[str] = None
) -> Dict[int, List[Tuple[float, str, int, Optional[int]]]]:
"""
Constructs a dictionary of patch logit predictions for a given batch index.
Args:
patch_logit_directions (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): A tensor or a tuple of tensors
containing the logit directions with shape
(batch_size, num_patches, num_labels, num_answers).
batch_idx (int, optional): The index of the batch to process. Default is 0.
rank_label (Optional[str], optional): A label to rank against the predictions. Default is None.
Returns:
Dict[int, List[Tuple[float, str, int, Optional[int]]]]: A dictionary where each key is a patch index and each value is a list of tuples.
Each tuple contains the logit, predicted class name, predicted index,
and optionally the rank of the rank_label.
"""

patch_dictionary = defaultdict(list)
# if tuple, get first entry
if isinstance(patch_logit_directions, tuple):
Expand Down
17 changes: 15 additions & 2 deletions src/vit_prisma/utils/data_utils/imagenet_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@

from vit_prisma.utils.data_utils.imagenet_dict import IMAGENET_DICT

def imagenet_index_from_word(search_term):
def imagenet_index_from_word(search_term: str) -> int:
"""
Finds the ImageNet index corresponding to a search term.
Args:
search_term (str): The search term to look up in the ImageNet dictionary.
Returns:
int: The index corresponding to the search term in the ImageNet dictionary.
Raises:
ValueError: If the search term is not found in the ImageNet dictionary.
"""

# Convert the search term to lowercase to ensure case-insensitive matching
search_term = search_term.lower()

Expand All @@ -11,4 +24,4 @@ def imagenet_index_from_word(search_term):
return key # Return the key directly once found

# If the loop completes without returning, the term was not found; raise an exception
raise ValueError(f"'{search_term}' not found in IMAGENET_DICT.")
raise ValueError(f"'{search_term}' not found in IMAGENET_DICT.")
14 changes: 13 additions & 1 deletion src/vit_prisma/utils/prisma_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,19 @@
from vit_prisma.utils.data_utils.imagenet_utils import imagenet_index_from_word


def test_prompt(example_data_point, model, example_answer=None, top_k=10):
def test_prompt(example_data_point: torch.Tensor, model: Any, example_answer: Optional[str] = None, top_k: int = 10) -> None:
"""
Evaluates a model's predictions on a given data point and prints the top-k predicted labels along with their logits and probabilities.
Args:
example_data_point (torch.Tensor): The input data point to be evaluated by the model.
model (Any): The model used for generating predictions.
example_answer (Optional[str], optional): The correct label for the data point, if available. Default is None.
top_k (int, optional): The number of top predictions to display. Default is 10.
Returns:
None
"""

logits = model(example_data_point.unsqueeze(0))
probs = logits.softmax(dim=-1)
Expand Down
68 changes: 60 additions & 8 deletions src/vit_prisma/visualization/patch_level_logit_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,44 @@
import plotly.graph_objects as go
import plotly.express as px

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import torch

from vit_prisma.utils.data_utils.imagenet_emoji import IMAGENET_EMOJI

def display_grid_on_image_with_heatmap(image, patch_dictionary, patch_size=32,
layer_idx=-1,
imagenet_class_to_emoji=IMAGENET_EMOJI,
emoji_font_size=30,
heatmap_mode='logit_values',
alpha_color=.6,
return_graph=False):
def display_grid_on_image_with_heatmap(
image: Union[np.ndarray, torch.Tensor],
patch_dictionary: Dict[int, List[Tuple[float, str, int, Optional[int]]]],
patch_size: int = 32,
layer_idx: int = -1,
imagenet_class_to_emoji: Optional[Dict[int, str]] = None,
emoji_font_size: int = 30,
heatmap_mode: str = 'logit_values',
alpha_color: float = 0.6,
return_graph: bool = False
) -> Optional[go.Figure]:
"""
Displays a grid overlay on the image with a heatmap and optional emoji annotations.
Args:
image (Union[np.ndarray, torch.Tensor]): The input image, either as a numpy array or a PyTorch tensor.
patch_dictionary (Dict[int, List[Tuple[float, str, int, Optional[int]]]]): A dictionary where each key is a patch index and each value is a list of tuples.
Each tuple contains the logit, predicted class name, predicted index, and optionally the rank of the rank_label.
patch_size (int, optional): The size of each patch in the grid. Default is 32.
layer_idx (int, optional): The layer index to use from the patch dictionary. Default is -1.
imagenet_class_to_emoji (Optional[Dict[int, str]], optional): A dictionary mapping ImageNet class indices to emojis. Default is None.
emoji_font_size (int, optional): The size of the emojis in the annotations. Default is 30.
heatmap_mode (str, optional): The mode for the heatmap. Options are 'logit_values' or 'emoji_colors'. Default is 'logit_values'.
alpha_color (float, optional): The opacity of the heatmap overlay. Default is 0.6.
return_graph (bool, optional): If True, the function returns the plotly figure object. If False, it displays the heatmap. Default is False.
Returns:
Optional[go.Figure]: The plotly figure object if return_graph is True, otherwise None.
Raises:
ValueError: If `heatmap_mode` is not one of the valid options ('logit_values', 'emoji_colors').
"""

valid_heatmap_modes = ['logit_values', 'emoji_colors']
if heatmap_mode not in valid_heatmap_modes:
Expand Down Expand Up @@ -101,7 +128,32 @@ def display_grid_on_image_with_heatmap(image, patch_dictionary, patch_size=32,

# Animal logit lens

def display_patch_logit_lens(patch_dictionary, width=1000, height=1200, emoji_size=26, return_graph=False, show_colorbar=True, labels=None):
def display_patch_logit_lens(
patch_dictionary: Dict[int, List[Tuple[float, str, int, Optional[int]]]],
width: int = 1000,
height: int = 1200,
emoji_size: int = 26,
return_graph: bool = False,
show_colorbar: bool = True,
labels: Optional[List[str]] = None
) -> Optional[go.Figure]:
"""
Displays an interactive heatmap of patch logit values with optional emoji annotations.
Args:
patch_dictionary (Dict[int, List[Tuple[float, str, int, Optional[int]]]]): A dictionary where each key is a patch index and each value is a list of tuples.
Each tuple contains the logit, predicted class name, predicted index, and optionally the rank of the rank_label.
width (int, optional): The width of the heatmap. Default is 1000.
height (int, optional): The height of the heatmap. Default is 1200.
emoji_size (int, optional): The size of the emojis in the annotations. Default is 26.
return_graph (bool, optional): If True, the function returns the plotly figure object. If False, it displays the heatmap. Default is False.
show_colorbar (bool, optional): If True, a colorbar is displayed. Default is True.
labels (Optional[List[str]], optional): A list of labels for the hover text. Default is None.
Returns:
Optional[go.Figure]: The plotly figure object if return_graph is True, otherwise None.
"""

num_patches = len(patch_dictionary)

# Assuming data_array_formatted is correctly shaped according to your data structure
Expand Down
26 changes: 24 additions & 2 deletions src/vit_prisma/visualization/visualize_attention_js.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import os
from typing import List
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from jinja2 import Template

def convert_to_3_channels(image):
Expand Down Expand Up @@ -64,8 +66,28 @@ def __init__(self, attn_head, image, name="No Name", cls_token=True):



def plot_javascript(list_of_attn_heads, list_of_images, list_of_names=None, ATTN_SCALING=8, cls_token=True):

def plot_javascript(
list_of_attn_heads: Union[torch.Tensor, List[np.ndarray]],
list_of_images: Union[List[np.ndarray], np.ndarray],
list_of_names: Optional[Union[torch.Tensor, List[str]]] = None,
ATTN_SCALING: int = 8,
cls_token: bool = True
) -> str:
"""
Generates HTML and JavaScript code to visualize attention heads with corresponding images.
Args:
list_of_attn_heads (Union[torch.Tensor, List[np.ndarray]]): A tensor of shape (num_heads, num_patches, num_patches)
or a list of numpy arrays with the same shape.
list_of_images (Union[List[np.ndarray], np.ndarray]): A list of images or a single image array, each image with shape
(height, width, channels).
list_of_names (Optional[Union[torch.Tensor, List[str]]], optional): A tensor or a list of names for the attention heads. Default is None.
ATTN_SCALING (int, optional): Scaling factor for attention visualization. Default is 8.
cls_token (bool, optional): Whether to include the CLS token. Default is True.
Returns:
str: Generated HTML and JavaScript code for visualizing the attention heads with corresponding images.
"""
# if list of attn heads is tensor
if type(list_of_attn_heads) == torch.Tensor:
list_of_attn_heads = [np.array(list_of_attn_heads[i]) for i in range(list_of_attn_heads.shape[0])]
Expand Down
22 changes: 18 additions & 4 deletions src/vit_prisma/visualization/visualize_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.figure


def display_grid_on_image(image: Union[np.ndarray, torch.Tensor], patch_size: int = 32, return_plot: bool = False) -> Optional[matplotlib.figure.Figure]:
"""
Separates an image into a grid of patches and overlays the grid on the image.
Args:
image (torch.Tensor): The input image, either as a numpy array or a PyTorch tensor.
Dimensions of (H, W, C) if numpy or (C, H, W) if tensor
patch_size (int, optional): The size of each patch in the grid. Default is 32.
return_plot (bool, optional): If True, the function will return the plot figure. If False, it will display the plot. Default is False.
Returns:
matplotlib.figure.Figure or None: If return_plot is True, returns the matplotlib figure object. Otherwise, displays the image with the grid overlay.
"""

def display_grid_on_image(image, patch_size=32, return_plot=False):
if isinstance(image, torch.Tensor):
image = image.detach().numpy().transpose(1, 2, 0)
if image.shape[0] != 224:
Expand Down

0 comments on commit b0ad5f9

Please sign in to comment.