From d0562d7a2fc2af435f0b1ae08a8d45b1e12d4920 Mon Sep 17 00:00:00 2001 From: Burhan <62214284+Burhan-Q@users.noreply.github.com> Date: Mon, 8 Jan 2024 12:57:53 -0500 Subject: [PATCH] Add type hinting to explorer.py (#7388) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/data/explorer/explorer.py | 62 ++++++++++++++++++--------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/ultralytics/data/explorer/explorer.py b/ultralytics/data/explorer/explorer.py index 49fd37dad02..002b90028a1 100644 --- a/ultralytics/data/explorer/explorer.py +++ b/ultralytics/data/explorer/explorer.py @@ -1,11 +1,12 @@ from io import BytesIO from pathlib import Path -from typing import List +from typing import Any, List, Tuple, Union import cv2 import numpy as np import torch from matplotlib import pyplot as plt +from pandas import DataFrame from PIL import Image from tqdm import tqdm @@ -13,18 +14,18 @@ from ultralytics.data.dataset import YOLODataset from ultralytics.data.utils import check_det_dataset from ultralytics.models.yolo.model import YOLO -from ultralytics.utils import LOGGER, checks +from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch class ExplorerDataset(YOLODataset): - def __init__(self, *args, data=None, **kwargs): + def __init__(self, *args, data: dict = None, **kwargs) -> None: super().__init__(*args, data=data, **kwargs) # NOTE: Load the image directly without any resize operations. - def load_image(self, i): + def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]: """Loads 1 image from dataset index 'i', returns (im, resized hw).""" im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] if im is None: # not cached in RAM @@ -39,7 +40,7 @@ def load_image(self, i): return self.ims[i], self.im_hw0[i], self.im_hw[i] - def build_transforms(self, hyp=None): + def build_transforms(self, hyp: IterableSimpleNamespace = None): return Format( bbox_format='xyxy', normalize=False, @@ -53,7 +54,10 @@ def build_transforms(self, hyp=None): class Explorer: - def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None: + def __init__(self, + data: Union[str, Path] = 'coco128.yaml', + model: str = 'yolov8n.pt', + uri: str = '~/ultralytics/explorer') -> None: checks.check_requirements(['lancedb', 'duckdb']) import lancedb @@ -68,7 +72,7 @@ def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/e self.table = None self.progress = 0 - def create_embeddings_table(self, force=False, split='train'): + def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None: """ Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it already exists. Pass force=True to overwrite the existing table. @@ -118,7 +122,7 @@ def create_embeddings_table(self, force=False, split='train'): self.table = table - def _yield_batches(self, dataset, data_info, model, exclude_keys: List): + def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]): # Implement Batching for i in tqdm(range(len(dataset))): self.progress = float(i + 1) / len(dataset) @@ -129,7 +133,9 @@ def _yield_batches(self, dataset, data_info, model, exclude_keys: List): batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() yield [batch] - def query(self, imgs=None, limit=25): + def query(self, + imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + limit: int = 25) -> Any: # pyarrow.Table """ Query the table for similar images. Accepts a single image or a list of images. @@ -162,7 +168,9 @@ def query(self, imgs=None, limit=25): embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() return self.table.search(embeds).limit(limit).to_arrow() - def sql_query(self, query, return_type='pandas'): + def sql_query(self, + query: str, + return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table """ Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. @@ -177,7 +185,7 @@ def sql_query(self, query, return_type='pandas'): ```python exp = Explorer() exp.create_embeddings_table() - query = 'SELECT * FROM table WHERE labels LIKE "%person%"' + query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" result = exp.sql_query(query) ``` """ @@ -201,7 +209,7 @@ def sql_query(self, query, return_type='pandas'): elif return_type == 'arrow': return rs.arrow() - def plot_sql_query(self, query, labels=True): + def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: """ Plot the results of a SQL-Like query on the table. Args: @@ -215,7 +223,7 @@ def plot_sql_query(self, query, labels=True): ```python exp = Explorer() exp.create_embeddings_table() - query = 'SELECT * FROM table WHERE labels LIKE "%person%"' + query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'" result = exp.plot_sql_query(query) ``` """ @@ -223,7 +231,11 @@ def plot_sql_query(self, query, labels=True): img = plot_similar_images(result, plot_labels=labels) return Image.fromarray(img) - def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'): + def get_similar(self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table """ Query the table for similar images. Accepts a single image or a list of images. @@ -251,7 +263,11 @@ def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'): elif return_type == 'arrow': return similar - def plot_similar(self, img=None, idx=None, limit=25, labels=True): + def plot_similar(self, + img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, + idx: Union[int, List[int]] = None, + limit: int = 25, + labels: bool = True) -> Image.Image: """ Plot the similar images. Accepts images or indexes. @@ -275,7 +291,7 @@ def plot_similar(self, img=None, idx=None, limit=25, labels=True): img = plot_similar_images(similar, plot_labels=labels) return Image.fromarray(img) - def similarity_index(self, max_dist=0.2, top_k=None, force=False): + def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: """ Calculate the similarity index of all the images in the table. Here, the index will contain the data points that are max_dist or closer to the image in the embedding space at a given index. @@ -329,7 +345,7 @@ def _yield_sim_idx(): self.sim_index = sim_table return sim_table.to_pandas() - def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): + def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image: """ Plot the similarity index of all the images in the table. Here, the index will contain the data points that are max_dist or closer to the image in the embedding space at a given index. @@ -341,13 +357,16 @@ def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. Returns: - PIL Image containing the plot. + PIL.PngImagePlugin.PngImageFile containing the plot. Example: ```python exp = Explorer() exp.create_embeddings_table() - exp.plot_similarity_index() + + similarity_idx_plot = exp.plot_similarity_index() + similarity_idx_plot.show() # view image preview + similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file ``` """ sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) @@ -368,9 +387,10 @@ def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): buffer.seek(0) # Use Pillow to open the image from the buffer - return Image.open(buffer) + return Image.fromarray(np.array(Image.open(buffer))) - def _check_imgs_or_idxs(self, img, idx): + def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], + idx: Union[None, int, List[int]]) -> List[np.ndarray]: if img is None and idx is None: raise ValueError('Either img or idx must be provided.') if img is not None and idx is not None: