From 783033fa6b12318bfa2f2e87f696f0ba94fb5b5e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 8 Jan 2024 23:36:29 +0100 Subject: [PATCH] `ultralytics 8.0.238` Explorer Ask AI feature and fixes (#7408) Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: uwer Co-authored-by: Uwe Rosebrock Co-authored-by: Ayush Chaurasia Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com> Co-authored-by: Muhammad Rizwan Munawar Co-authored-by: AdamP --- docs/en/datasets/explorer/api.md | 30 +++++++- docs/en/datasets/explorer/dash.md | 0 docs/en/datasets/explorer/dashboard.md | 58 +++++++++++++++ docs/en/datasets/explorer/explorer.ipynb | 70 +++++++++++++++++-- docs/en/datasets/explorer/index.md | 17 +++-- docs/en/guides/heatmaps.md | 2 +- .../instance-segmentation-and-tracking.md | 20 +++--- docs/mkdocs.yml | 2 +- ultralytics/__init__.py | 2 +- ultralytics/data/explorer/__init__.py | 3 + ultralytics/data/explorer/explorer.py | 59 ++++++++++++---- ultralytics/data/explorer/gui/dash.py | 68 +++++++++++++----- ultralytics/data/explorer/utils.py | 70 ++++++++++++++++++- ultralytics/engine/model.py | 2 +- ultralytics/models/yolo/obb/predict.py | 3 +- ultralytics/nn/modules/head.py | 14 ++-- ultralytics/solutions/distance_calculation.py | 6 +- ultralytics/solutions/heatmap.py | 36 +++++++--- ultralytics/utils/__init__.py | 1 + 19 files changed, 387 insertions(+), 76 deletions(-) delete mode 100644 docs/en/datasets/explorer/dash.md create mode 100644 docs/en/datasets/explorer/dashboard.md diff --git a/docs/en/datasets/explorer/api.md b/docs/en/datasets/explorer/api.md index 8e11efc4c6c..ac5dfc456d2 100644 --- a/docs/en/datasets/explorer/api.md +++ b/docs/en/datasets/explorer/api.md @@ -119,7 +119,31 @@ You can also plot the similar images using the `plot_similar` method. This metho plt.show() ``` -## 2. SQL Querying +## 2. Ask AI (Natural Language Querying) + +This allows you to write how you want to filter your dataset using natural language. You don't have to be proficient in writing SQL queries. Our AI powered query generator will automatically do that under the hood. For example - you can say - "show me 100 images with exactly one person and 2 dogs. There can be other objects too" and it'll internally generate the query and show you those results. +Note: This works using LLMs under the hood so the results are probabilistic and might get things wrong sometimes + +!!! Example "Ask AI" + + ```python + from ultralytics import Explorer + from ultralytics.data.explorer import plot_query_result + + + # create an Explorer object + exp = Explorer(data='coco128.yaml', model='yolov8n.pt') + exp.create_embeddings_table() + + df = exp.ask_ai("show me 100 images with exactly one person and 2 dogs. There can be other objects too") + print(df.head()) + + # plot the results + plt = plot_query_result(df) + plt.show() + ``` + +## 3. SQL Querying You can run SQL queries on your dataset using the `sql_query` method. This method takes a SQL query as input and returns a pandas dataframe with the results. @@ -153,7 +177,7 @@ You can also plot the results of a SQL query using the `plot_sql_query` method. print(df.head()) ``` -## 3. Working with embeddings Table (Advanced) +## 4. Working with embeddings Table (Advanced) You can also work with the embeddings table directly. Once the embeddings table is created, you can access it using the `Explorer.table` @@ -210,7 +234,7 @@ When using large datasets, you can also create a dedicated vector index for fast Find more details on the type vector indices available and parameters [here](https://lancedb.github.io/lancedb/ann_indexes/#types-of-index) In the future, we will add support for creating vector indices directly from Explorer API. -## 4. Embeddings Applications +## 5. Embeddings Applications You can use the embeddings table to perform a variety of exploratory analysis. Here are some examples: diff --git a/docs/en/datasets/explorer/dash.md b/docs/en/datasets/explorer/dash.md deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/docs/en/datasets/explorer/dashboard.md b/docs/en/datasets/explorer/dashboard.md new file mode 100644 index 00000000000..acb99cc2a0e --- /dev/null +++ b/docs/en/datasets/explorer/dashboard.md @@ -0,0 +1,58 @@ +--- +comments: 5rue +description: Learn about Ultralytics Explorer GUI for semantic search, SQL queries, and AI-powered natural language search in CV datasets. +keywords: Ultralytics, Explorer GUI, semantic search, vector similarity search, AI queries, SQL queries, computer vision, dataset exploration, image search, OpenAI integration +--- + +# Explorer GUI + +Explorer GUI is like a playground build using (Ultralytics Explorer API)[api.md]. It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs. + +### Installation + +```bash +pip install ultralytics[explorer] +``` + +!!! note "Note" + + Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI. + You can set it like this - `yolo settings openai_api_key="..."` + +## Semantic Search / Vector Similarity Search + +Semantic search is a technique for finding similar images to a given image. It is based on the idea that similar images will have similar embeddings. In the UI, you can select one of more images and search for the images similar to them. This can be useful when you want to find images similar to a given image or a set of images that don't perform as expected. + +For example: +In this VOC Exploration dashboard, user selects a couple aeroplane images like this: +

+Screenshot 2024-01-08 at 8 46 33 PM +

+ +On performing similarity search, you should see a similar result: +

+Screenshot 2024-01-08 at 8 46 46 PM +

+ +## Ask AI + +This allows you to write how you want to filter your dataset using natural language. You don't have to be proficient in writing SQL queries. Our AI powered query generator will automatically do that under the hood. For example - you can say - "show me 100 images with exactly one person and 2 dogs. There can be other objects too" and it'll internally generate the query and show you those results. Here's an example output when asked to "Show 10 images with exactly 5 persons" and you'll see a result like this: +

+Screenshot 2024-01-08 at 7 19 48 PM (1) +

+ +Note: This works using LLMs under the hood so the results are probabilistic and might get things wrong sometimes + +## Run SQL queries on your CV datasets + +You can run SQL queries on your dataset to filter it. It also works if you only provide the WHERE clause. Example SQL query would show only the images that have at least one 1 person and 1 dog in them: + +```sql +WHERE labels LIKE '%person%' AND labels LIKE '%dog%' +``` + +

+Screenshot 2024-01-08 at 8 57 49 PM +

+ +This is a Demo build using the Explorer API. You can use the API to build your own exploratory notebooks or scripts to get insights into your datasets. Learn more about the Explorer API [here](api.md). diff --git a/docs/en/datasets/explorer/explorer.ipynb b/docs/en/datasets/explorer/explorer.ipynb index 5e5b7b02443..002cb2db225 100644 --- a/docs/en/datasets/explorer/explorer.ipynb +++ b/docs/en/datasets/explorer/explorer.ipynb @@ -109,7 +109,10 @@ "metadata": {}, "source": [ "You can use the also plot the similar samples directly using the `plot_similar` util\n", - "\"Screenshot\n" + "

\n", + "\n", + " \n", + "

\n" ] }, { @@ -139,17 +142,74 @@ "metadata": {}, "source": [ "

\n", - "\"Screenshot\n", + "\n", "\n", "

" ] }, + { + "cell_type": "markdown", + "id": "0cea63f1-71f1-46da-af2b-b1b7d8f73553", + "metadata": {}, + "source": [ + "## 2. Ask AI: Search or filter with Natural Language\n", + "You can prompt the Explorer object with the kind of data points you want to see and it'll try to return a dataframe with those. Because it is powered by LLMs, it doesn't always get it right. In that case, it'll return None.\n", + "

\n", + "\"Screenshot\n", + "\n", + "

\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92fb92ac-7f76-465a-a9ba-ea7492498d9c", + "metadata": {}, + "outputs": [], + "source": [ + "df = exp.ask_ai(\"show me images containing more than 10 objects with at least 2 persons\")\n", + "df.head(5)" + ] + }, + { + "cell_type": "markdown", + "id": "f2a7d26e-0ce5-4578-ad1a-b1253805280f", + "metadata": {}, + "source": [ + "for plotting these results you can use `plot_query_result` util\n", + "Example:\n", + "```\n", + "plt = plot_query_result(exp.ask_ai(\"show me 10 images containing exactly 2 persons\"))\n", + "Image.fromarray(plt)\n", + "```\n", + "

\n", + " \n", + "\n", + "

" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1cfab84-9835-4da0-8e9a-42b30cf84511", + "metadata": {}, + "outputs": [], + "source": [ + "# plot\n", + "from ultralytics.data.explorer import plot_query_result\n", + "from PIL import Image\n", + "\n", + "plt = plot_query_result(exp.ask_ai(\"show me 10 images containing exactly 2 persons\"))\n", + "Image.fromarray(plt)" + ] + }, { "cell_type": "markdown", "id": "35315ae6-d827-40e4-8813-279f97a83b34", "metadata": {}, "source": [ - "## 2. Run SQL queries on your Dataset!\n", + "## 3. Run SQL queries on your Dataset!\n", "Sometimes you might want to investigate a certain type of entries in your dataset. For this Explorer allows you to execute SQL queries.\n", "It accepts either of the formats:\n", "- Queries beginning with \"WHERE\" will automatically select all columns. This can be thought of as a short-hand query\n", @@ -179,7 +239,7 @@ "metadata": {}, "source": [ "Just like similarity search, you also get a util to directly plot the sql queries using `exp.plot_sql_query`\n", - "\"Screenshot\n" + "\n" ] }, { @@ -419,7 +479,7 @@ "metadata": {}, "source": [ "You should see something like this\n", - "\"Screenshot\n" + "\n" ] }, { diff --git a/docs/en/datasets/explorer/index.md b/docs/en/datasets/explorer/index.md index ebfe189a652..0f94982a11b 100644 --- a/docs/en/datasets/explorer/index.md +++ b/docs/en/datasets/explorer/index.md @@ -16,6 +16,12 @@ Explorer depends on external libraries for some of its functionality. These are pip install ultralytics[explorer] ``` +### Explorer API + +This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets. + +Learn more about the Explorer API [here](api.md). + ## GUI Explorer Usage The GUI demo runs in your browser allowing you to create embeddings for your dataset and search for similar images, run SQL queries and perform semantic search. It can be run using the following command: @@ -24,8 +30,11 @@ The GUI demo runs in your browser allowing you to create embeddings for your dat yolo explorer ``` -### Explorer API +!!! note "Note" + Ask AI feature works using OpenAI, so you'll be prompted to set the api key for OpenAI when you first run the GUI. + You can set it like this - `yolo settings openai_api_key="..."` -This is a Python API for Exploring your datasets. It also powers the GUI Explorer. You can use this to create your own exploratory notebooks or scripts to get insights into your datasets. - -Learn more about the Explorer API [here](api.md). +Example +

+Screenshot 2024-01-08 at 7 19 48 PM (1) +

diff --git a/docs/en/guides/heatmaps.md b/docs/en/guides/heatmaps.md index 09644f85838..e52f2383354 100644 --- a/docs/en/guides/heatmaps.md +++ b/docs/en/guides/heatmaps.md @@ -99,7 +99,7 @@ A heatmap generated with [Ultralytics YOLOv8](https://github.com/ultralytics/ult fps, (w, h)) - line_points = [(256, 409), (694, 532)] # line for object counting + line_points = [(20, 400), (1080, 404)] # line for object counting # Init heatmap heatmap_obj = heatmap.Heatmap() diff --git a/docs/en/guides/instance-segmentation-and-tracking.md b/docs/en/guides/instance-segmentation-and-tracking.md index 30db7a7b732..5c6fb589714 100644 --- a/docs/en/guides/instance-segmentation-and-tracking.md +++ b/docs/en/guides/instance-segmentation-and-tracking.md @@ -31,7 +31,7 @@ There are two types of instance segmentation tracking available in the Ultralyti from ultralytics import YOLO from ultralytics.utils.plotting import Annotator, colors - model = YOLO("yolov8n-seg.pt") + model = YOLO("yolov8n-seg.pt") # segmentation model names = model.model.names cap = cv2.VideoCapture("path/to/video/file.mp4") w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -45,15 +45,15 @@ There are two types of instance segmentation tracking available in the Ultralyti break results = model.predict(im0) - clss = results[0].boxes.cls.cpu().tolist() - masks = results[0].masks.xy - annotator = Annotator(im0, line_width=2) - for mask, cls in zip(masks, clss): - annotator.seg_bbox(mask=mask, - mask_color=colors(int(cls), True), - det_label=names[int(cls)]) + if results[0].masks is not None: + clss = results[0].boxes.cls.cpu().tolist() + masks = results[0].masks.xy + for mask, cls in zip(masks, clss): + annotator.seg_bbox(mask=mask, + mask_color=colors(int(cls), True), + det_label=names[int(cls)]) out.write(im0) cv2.imshow("instance-segmentation", im0) @@ -77,7 +77,7 @@ There are two types of instance segmentation tracking available in the Ultralyti track_history = defaultdict(lambda: []) - model = YOLO("yolov8n-seg.pt") + model = YOLO("yolov8n-seg.pt") # segmentation model cap = cv2.VideoCapture("path/to/video/file.mp4") w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -93,7 +93,7 @@ There are two types of instance segmentation tracking available in the Ultralyti results = model.track(im0, persist=True) - if results[0].boxes.id is not None: + if results[0].boxes.id is not None and results[0].masks is not None: masks = results[0].masks.xy track_ids = results[0].boxes.id.int().cpu().tolist() diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index e484e7094fe..bd072a2e621 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -222,7 +222,7 @@ nav: - Explorer: - datasets/explorer/index.md - Explorer API: datasets/explorer/api.md - - GUI Dashboard Demo: datasets/explorer/dash.md + - Explorer Dashboard: datasets/explorer/dashboard.md - VOC Exploration Example: datasets/explorer/explorer.ipynb - Detection: - datasets/detect/index.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a607ed382ee..4656c5a9e35 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.237' +__version__ = '8.0.238' from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO diff --git a/ultralytics/data/explorer/__init__.py b/ultralytics/data/explorer/__init__.py index e69de29bb2d..e4af304f945 100644 --- a/ultralytics/data/explorer/__init__.py +++ b/ultralytics/data/explorer/__init__.py @@ -0,0 +1,3 @@ +from .utils import plot_query_result + +__all__ = ['plot_query_result'] diff --git a/ultralytics/data/explorer/explorer.py b/ultralytics/data/explorer/explorer.py index 002b90028a1..064697e0214 100644 --- a/ultralytics/data/explorer/explorer.py +++ b/ultralytics/data/explorer/explorer.py @@ -16,7 +16,7 @@ from ultralytics.models.yolo.model import YOLO from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks -from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch +from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch class ExplorerDataset(YOLODataset): @@ -58,7 +58,7 @@ def __init__(self, data: Union[str, Path] = 'coco128.yaml', model: str = 'yolov8n.pt', uri: str = '~/ultralytics/explorer') -> None: - checks.check_requirements(['lancedb', 'duckdb']) + checks.check_requirements(['lancedb>=0.4.3', 'duckdb']) import lancedb self.connection = lancedb.connect(uri) @@ -112,8 +112,7 @@ def create_embeddings_table(self, force: bool = False, split: str = 'train') -> # Create the table schema batch = dataset[0] vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0] - Schema = get_table_schema(vector_size) - table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite') + table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite') table.add( self._yield_batches(dataset, data_info, @@ -159,10 +158,7 @@ def query(self, raise ValueError('Table is not created. Please create the table first.') if isinstance(imgs, str): imgs = [imgs] - elif isinstance(imgs, list): - pass - else: - raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}') + assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}' embeds = self.model.embed(imgs) # Get avg if multiple images are passed (len > 1) embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() @@ -189,16 +185,19 @@ def sql_query(self, result = exp.sql_query(query) ``` """ + assert return_type in ['pandas', + 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' import duckdb if self.table is None: raise ValueError('Table is not created. Please create the table first.') # Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this. - table = self.table.to_arrow() # noqa + table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB if not query.startswith('SELECT') and not query.startswith('WHERE'): raise ValueError( - 'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.') + f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}' + ) if query.startswith('WHERE'): query = f"SELECT * FROM 'table' {query}" LOGGER.info(f'Running query: {query}') @@ -228,7 +227,10 @@ def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image: ``` """ result = self.sql_query(query, return_type='arrow') - img = plot_similar_images(result, plot_labels=labels) + if len(result) == 0: + LOGGER.info('No results found.') + return None + img = plot_query_result(result, plot_labels=labels) return Image.fromarray(img) def get_similar(self, @@ -255,6 +257,8 @@ def get_similar(self, similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg') ``` """ + assert return_type in ['pandas', + 'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}' img = self._check_imgs_or_idxs(img, idx) similar = self.query(img, limit=limit) @@ -288,7 +292,10 @@ def plot_similar(self, ``` """ similar = self.get_similar(img, idx, limit, return_type='arrow') - img = plot_similar_images(similar, plot_labels=labels) + if len(similar) == 0: + LOGGER.info('No results found.') + return None + img = plot_query_result(similar, plot_labels=labels) return Image.fromarray(img) def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame: @@ -299,7 +306,7 @@ def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bo Args: max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2. top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running - vector search. Defaults to 0.01. + vector search. Defaults: None. force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. Returns: @@ -401,6 +408,32 @@ def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.nda return img if isinstance(img, list) else [img] + def ask_ai(self, query): + """ + Ask AI a question. + + Args: + query (str): Question to ask. + + Returns: + Answer from AI. + + Example: + ```python + exp = Explorer() + exp.create_embeddings_table() + answer = exp.ask_ai('Show images with 1 person and 2 dogs') + ``` + """ + result = prompt_sql_query(query) + try: + df = self.sql_query(result) + except Exception as e: + LOGGER.error('AI generated query is not valid. Please try again with a different prompt') + LOGGER.error(e) + return None + return df + def visualize(self, result): """ Visualize the results of a query. diff --git a/ultralytics/data/explorer/gui/dash.py b/ultralytics/data/explorer/gui/dash.py index 1de184406c0..e9e7ac2ade0 100644 --- a/ultralytics/data/explorer/gui/dash.py +++ b/ultralytics/data/explorer/gui/dash.py @@ -1,11 +1,13 @@ import time from threading import Thread +import pandas as pd + from ultralytics import Explorer -from ultralytics.utils import ROOT +from ultralytics.utils import ROOT, SETTINGS from ultralytics.utils.checks import check_requirements -check_requirements('streamlit') +check_requirements('streamlit>=1.29.0') check_requirements('streamlit-select>=0.2') import streamlit as st from streamlit_select import image_select @@ -35,9 +37,9 @@ def init_explorer_form(): with st.form(key='explorer_init_form'): col1, col2 = st.columns(2) with col1: - dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) + st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) with col2: - model = st.selectbox('Select model', models, key='model') + st.selectbox('Select model', models, key='model') st.checkbox('Force recreate embeddings', key='force_recreate_embeddings') st.form_submit_button('Explore', on_click=_get_explorer) @@ -47,11 +49,23 @@ def query_form(): with st.form('query_form'): col1, col2 = st.columns([0.8, 0.2]) with col1: - query = st.text_input('Query', '', label_visibility='collapsed', key='query') + st.text_input('Query', + "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", + label_visibility='collapsed', + key='query') with col2: st.form_submit_button('Query', on_click=run_sql_query) +def ai_query_form(): + with st.form('ai_query_form'): + col1, col2 = st.columns([0.8, 0.2]) + with col1: + st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query') + with col2: + st.form_submit_button('Ask AI', on_click=run_ai_query) + + def find_similar_imgs(imgs): exp = st.session_state['explorer'] similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') @@ -64,12 +78,12 @@ def similarity_form(selected_imgs): with st.form('similarity_form'): subcol1, subcol2 = st.columns([1, 1]) with subcol1: - limit = st.number_input('limit', - min_value=None, - max_value=None, - value=25, - label_visibility='collapsed', - key='limit') + st.number_input('limit', + min_value=None, + max_value=None, + value=25, + label_visibility='collapsed', + key='limit') with subcol2: disabled = not len(selected_imgs) @@ -95,6 +109,7 @@ def similarity_form(selected_imgs): def run_sql_query(): + st.session_state['error'] = None query = st.session_state.get('query') if query.rstrip().lstrip(): exp = st.session_state['explorer'] @@ -102,9 +117,26 @@ def run_sql_query(): st.session_state['imgs'] = res.to_pydict()['im_file'] +def run_ai_query(): + if not SETTINGS['openai_api_key']: + st.session_state[ + 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' + return + st.session_state['error'] = None + query = st.session_state.get('ai_query') + if query.rstrip().lstrip(): + exp = st.session_state['explorer'] + res = exp.ask_ai(query) + if not isinstance(res, pd.DataFrame) or res.empty: + st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.' + return + st.session_state['imgs'] = res['im_file'].to_list() + + def reset_explorer(): st.session_state['explorer'] = None st.session_state['imgs'] = None + st.session_state['error'] = None def utralytics_explorer_docs_callback(): @@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback(): st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', width=100) st.markdown( - "

This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more

", + "

This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more

", unsafe_allow_html=True, help=None) - st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/') + st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/') def layout(): @@ -129,9 +161,12 @@ def layout(): st.button(':arrow_backward: Select Dataset', on_click=reset_explorer) exp = st.session_state.get('explorer') col1, col2 = st.columns([0.75, 0.25], gap='small') - - imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] - total_imgs = len(imgs) + imgs = [] + if st.session_state.get('error'): + st.error(st.session_state['error']) + else: + imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] + total_imgs, selected_imgs = len(imgs), [] with col1: subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) with subcol1: @@ -159,6 +194,7 @@ def layout(): st.experimental_rerun() query_form() + ai_query_form() if total_imgs: imgs_displayed = imgs[start_idx:start_idx + num] selected_imgs = image_select( diff --git a/ultralytics/data/explorer/utils.py b/ultralytics/data/explorer/utils.py index 359c5868a69..7d2127411fe 100644 --- a/ultralytics/data/explorer/utils.py +++ b/ultralytics/data/explorer/utils.py @@ -1,9 +1,14 @@ +import getpass from typing import List import cv2 import numpy as np +import pandas as pd from ultralytics.data.augment import LetterBox +from ultralytics.utils import LOGGER as logger +from ultralytics.utils import SETTINGS +from ultralytics.utils.checks import check_requirements from ultralytics.utils.ops import xyxy2xywh from ultralytics.utils.plotting import plot_images @@ -47,15 +52,16 @@ def sanitize_batch(batch, dataset_info): return batch -def plot_similar_images(similar_set, plot_labels=True): +def plot_query_result(similar_set, plot_labels=True): """ Plot images from the similar set. Args: - similar_set (list): Pyarrow table containing the similar data points + similar_set (list): Pyarrow or pandas object containing the similar data points plot_labels (bool): Whether to plot labels or not """ - similar_set = similar_set.to_pydict() + similar_set = similar_set.to_dict( + orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict() empty_masks = [[[]]] empty_boxes = [[]] images = similar_set.get('im_file', []) @@ -102,3 +108,61 @@ def plot_similar_images(similar_set, plot_labels=True): max_subplots=len(images), save=False, threaded=False) + + +def prompt_sql_query(query): + check_requirements('openai>=1.6.1') + from openai import OpenAI + + if not SETTINGS['openai_api_key']: + logger.warning('OpenAI API key not found in settings. Please enter your API key below.') + openai_api_key = getpass.getpass('OpenAI API key: ') + SETTINGS.update({'openai_api_key': openai_api_key}) + openai = OpenAI(api_key=SETTINGS['openai_api_key']) + + messages = [ + { + 'role': + 'system', + 'content': + ''' + You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on + the following schema and a user request. You only need to output the format with fixed selection + statement that selects everything from "'table'", like `SELECT * from 'table'` + + Schema: + im_file: string not null + labels: list not null + child 0, item: string + cls: list not null + child 0, item: int64 + bboxes: list> not null + child 0, item: list + child 0, item: double + masks: list>> not null + child 0, item: list> + child 0, item: list + child 0, item: int64 + keypoints: list>> not null + child 0, item: list> + child 0, item: list + child 0, item: double + vector: fixed_size_list[256] not null + child 0, item: float + + Some details about the schema: + - the "labels" column contains the string values like 'person' and 'dog' for the respective objects + in each image + - the "cls" column contains the integer values on these classes that map them the labels + + Example of a correct query: + request - Get all data points that contain 2 or more people and at least one dog + correct query- + SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1; + '''}, + { + 'role': 'user', + 'content': f'{query}'}, ] + + response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages) + return response.choices[0].message.content diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 374c872d3cf..ced3f18bad8 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -246,7 +246,7 @@ def predict(self, source=None, stream=False, predictor=None, **kwargs): prompts = args.pop('prompts', None) # for SAM-type models if not self.predictor: - self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks) + self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, args) diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py index 227a753001d..76b967ac320 100644 --- a/ultralytics/models/yolo/obb/predict.py +++ b/ultralytics/models/yolo/obb/predict.py @@ -41,8 +41,7 @@ def postprocess(self, preds, img, orig_imgs): orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) results = [] - for i, pred in enumerate(preds): - orig_img = orig_imgs[i] + for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)): pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True) img_path = self.batch[0][i] # xywh, r, conf, cls diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 35363fe2454..1bc07e624c3 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -61,13 +61,13 @@ def forward(self, x): dbox = self.decode_bboxes(box) if self.export and self.format in ('tflite', 'edgetpu'): - # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5: - # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309 - # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695 - img_h = shape[2] * self.stride[0] - img_w = shape[3] * self.stride[0] - img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1) - dbox /= img_size + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + img_h = shape[2] + img_w = shape[3] + img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * img_size) + dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1) y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x) diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py index 684b504f5d8..306a5bd2ea5 100644 --- a/ultralytics/solutions/distance_calculation.py +++ b/ultralytics/solutions/distance_calculation.py @@ -4,6 +4,7 @@ import cv2 +from ultralytics.utils.checks import check_imshow from ultralytics.utils.plotting import Annotator, colors @@ -37,6 +38,9 @@ def __init__(self): self.left_mouse_count = 0 self.selected_boxes = {} + # Check if environment support imshow + self.env_check = check_imshow(warn=True) + def set_args(self, names, pixels_per_meter=10, @@ -168,7 +172,7 @@ def start_process(self, im0, tracks): self.centroids = [] - if self.view_img: + if self.view_img and self.env_check: self.display_frames() return im0 diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py index 8f50961d171..5d5b6d1fd3b 100644 --- a/ultralytics/solutions/heatmap.py +++ b/ultralytics/solutions/heatmap.py @@ -28,6 +28,8 @@ def __init__(self): self.imw = None self.imh = None self.im0 = None + self.view_in_counts = True + self.view_out_counts = True # Heatmap colormap and heatmap np array self.colormap = None @@ -67,6 +69,8 @@ def set_args(self, colormap=cv2.COLORMAP_JET, heatmap_alpha=0.5, view_img=False, + view_in_counts=True, + view_out_counts=True, count_reg_pts=None, count_txt_thickness=2, count_txt_color=(0, 0, 0), @@ -85,6 +89,8 @@ def set_args(self, imh (int): The height of the frame. heatmap_alpha (float): alpha value for heatmap display view_img (bool): Flag indicating frame display + view_in_counts (bool): Flag to control whether to display the incounts on video stream. + view_out_counts (bool): Flag to control whether to display the outcounts on video stream. count_reg_pts (list): Object counting region points count_txt_thickness (int): Text thickness for object counting display count_txt_color (RGB color): count text color value @@ -99,6 +105,8 @@ def set_args(self, self.imh = imh self.heatmap_alpha = heatmap_alpha self.view_img = view_img + self.view_in_counts = view_in_counts + self.view_out_counts = view_out_counts self.colormap = colormap # Region and line selection @@ -171,9 +179,10 @@ def generate_heatmap(self, im0, tracks): if self.count_reg_pts is not None: # Draw counting region - self.annotator.draw_region(reg_pts=self.count_reg_pts, - color=self.region_color, - thickness=self.region_thickness) + if self.view_in_counts or self.view_out_counts: + self.annotator.draw_region(reg_pts=self.count_reg_pts, + color=self.region_color, + thickness=self.region_thickness) for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids): @@ -235,11 +244,22 @@ def generate_heatmap(self, im0, tracks): heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX) heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap) - if self.count_reg_pts is not None: - incount_label = 'InCount : ' + f'{self.in_counts}' - outcount_label = 'OutCount : ' + f'{self.out_counts}' - self.annotator.count_labels(in_count=incount_label, - out_count=outcount_label, + incount_label = 'In Count : ' + f'{self.in_counts}' + outcount_label = 'OutCount : ' + f'{self.out_counts}' + + # Display counts based on user choice + counts_label = None + if not self.view_in_counts and not self.view_out_counts: + counts_label = None + elif not self.view_in_counts: + counts_label = outcount_label + elif not self.view_out_counts: + counts_label = incount_label + else: + counts_label = incount_label + ' ' + outcount_label + + if self.count_reg_pts is not None and counts_label is not None: + self.annotator.count_labels(counts=counts_label, count_txt_size=self.count_txt_thickness, txt_color=self.count_txt_color, color=self.count_color) diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index c275f78f5dc..bf73bed30b9 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -856,6 +856,7 @@ def __init__(self, file=SETTINGS_YAML, version='0.0.4'): 'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), 'sync': True, 'api_key': '', + 'openai_api_key': '', 'clearml': True, # integrations 'comet': True, 'dvc': True,