Skip to content

Commit

Permalink
Add type hinting to explorer.py (ultralytics#7388)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Burhan-Q and pre-commit-ci[bot] authored Jan 8, 2024
1 parent e19398a commit d0562d7
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions ultralytics/data/explorer/explorer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
from io import BytesIO
from pathlib import Path
from typing import List
from typing import Any, List, Tuple, Union

import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from pandas import DataFrame
from PIL import Image
from tqdm import tqdm

from ultralytics.data.augment import Format
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, checks
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks

from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch


class ExplorerDataset(YOLODataset):

def __init__(self, *args, data=None, **kwargs):
def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs)

# NOTE: Load the image directly without any resize operations.
def load_image(self, i):
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
Expand All @@ -39,7 +40,7 @@ def load_image(self, i):

return self.ims[i], self.im_hw0[i], self.im_hw[i]

def build_transforms(self, hyp=None):
def build_transforms(self, hyp: IterableSimpleNamespace = None):
return Format(
bbox_format='xyxy',
normalize=False,
Expand All @@ -53,7 +54,10 @@ def build_transforms(self, hyp=None):

class Explorer:

def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
def __init__(self,
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
import lancedb

Expand All @@ -68,7 +72,7 @@ def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/e
self.table = None
self.progress = 0

def create_embeddings_table(self, force=False, split='train'):
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
Expand Down Expand Up @@ -118,7 +122,7 @@ def create_embeddings_table(self, force=False, split='train'):

self.table = table

def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
# Implement Batching
for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset)
Expand All @@ -129,7 +133,9 @@ def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
yield [batch]

def query(self, imgs=None, limit=25):
def query(self,
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
limit: int = 25) -> Any: # pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
Expand Down Expand Up @@ -162,7 +168,9 @@ def query(self, imgs=None, limit=25):
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow()

def sql_query(self, query, return_type='pandas'):
def sql_query(self,
query: str,
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
Expand All @@ -177,7 +185,7 @@ def sql_query(self, query, return_type='pandas'):
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.sql_query(query)
```
"""
Expand All @@ -201,7 +209,7 @@ def sql_query(self, query, return_type='pandas'):
elif return_type == 'arrow':
return rs.arrow()

def plot_sql_query(self, query, labels=True):
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
"""
Plot the results of a SQL-Like query on the table.
Args:
Expand All @@ -215,15 +223,19 @@ def plot_sql_query(self, query, labels=True):
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.plot_sql_query(query)
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
return Image.fromarray(img)

def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
def get_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
Expand Down Expand Up @@ -251,7 +263,11 @@ def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
elif return_type == 'arrow':
return similar

def plot_similar(self, img=None, idx=None, limit=25, labels=True):
def plot_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True) -> Image.Image:
"""
Plot the similar images. Accepts images or indexes.
Expand All @@ -275,7 +291,7 @@ def plot_similar(self, img=None, idx=None, limit=25, labels=True):
img = plot_similar_images(similar, plot_labels=labels)
return Image.fromarray(img)

def similarity_index(self, max_dist=0.2, top_k=None, force=False):
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
"""
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
are max_dist or closer to the image in the embedding space at a given index.
Expand Down Expand Up @@ -329,7 +345,7 @@ def _yield_sim_idx():
self.sim_index = sim_table
return sim_table.to_pandas()

def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
"""
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
max_dist or closer to the image in the embedding space at a given index.
Expand All @@ -341,13 +357,16 @@ def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
PIL Image containing the plot.
PIL.PngImagePlugin.PngImageFile containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
exp.plot_similarity_index()
similarity_idx_plot = exp.plot_similarity_index()
similarity_idx_plot.show() # view image preview
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
Expand All @@ -368,9 +387,10 @@ def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
buffer.seek(0)

# Use Pillow to open the image from the buffer
return Image.open(buffer)
return Image.fromarray(np.array(Image.open(buffer)))

def _check_imgs_or_idxs(self, img, idx):
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
if img is not None and idx is not None:
Expand Down

0 comments on commit d0562d7

Please sign in to comment.