From d8b102d4a67eda32188e27d94ad6a04e8d4d155e Mon Sep 17 00:00:00 2001 From: almaz Date: Mon, 30 Dec 2024 16:18:06 +0100 Subject: [PATCH] updates and fixes in chart --- docker/Dockerfile | 4 +- src/main.py | 196 ++++++++++++++++++++++++++++++++++++++++------ src/run_utils.py | 3 +- 3 files changed, 174 insertions(+), 29 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index be2c492..400fc4e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,7 @@ RUN pip3 install transformers==4.33.2 timm==0.9.5 scikit-learn==1.3.1 umap-learn # Download metaclip base model RUN python -c 'import transformers; transformers.AutoModel.from_pretrained("facebook/metaclip-b16-fullcc2.5b")' -RUN pip3 install supervisely==6.73.258 +RUN pip3 install supervisely==6.73.266 RUN pip3 install fastapi==0.109.0 bokeh==3.1.1 -LABEL python_sdk_version=6.73.258 +LABEL python_sdk_version=6.73.266 diff --git a/src/main.py b/src/main.py index cf459fd..2d2c1c0 100644 --- a/src/main.py +++ b/src/main.py @@ -7,9 +7,8 @@ from supervisely.app.content import StateJson, DataJson from dotenv import load_dotenv import torch -import re +from typing import List, Union, Tuple from supervisely.app.widgets import ( - ScatterChart, Container, Card, LabeledImage, @@ -23,11 +22,12 @@ Button, Field, Progress, - SelectDataset, IFrame, Bokeh, SelectDatasetTree, NotificationBox, + Empty, + Flexbox, ) from . import run_utils @@ -35,7 +35,7 @@ def update_globals(new_dataset_ids): - global dataset_ids, project_id, workspace_id, team_id, project_info, project_meta, is_marked, tag_meta + global dataset_ids, project_id, workspace_id, team_id, project_info, project_meta, is_marked, tag_meta, issue_tag_meta dataset_ids = new_dataset_ids if dataset_ids: project_id = api.dataset.get_info_by_id(dataset_ids[0]).project_id @@ -57,15 +57,20 @@ def update_globals(new_dataset_ids): is_marked = False tag_meta = project_meta.get_tag_meta(tag_name) print("tag_meta is exists:", bool(tag_meta)) + issue_tag_meta = project_meta.get_tag_meta(issue_tag_name) + print("issue_tag_meta is exists:", bool(issue_tag_meta)) ### Globals init available_projection_methods = ["UMAP", "PCA", "t-SNE", "PCA-UMAP", "PCA-t-SNE"] tag_name = "MARKED" +issue_tag_name = "ISSUE" +instance_mode = None -load_dotenv("local.env") -load_dotenv(os.path.expanduser("~/supervisely.env")) -api = sly.Api() +if sly.is_development(): + load_dotenv("local.env") + load_dotenv(os.path.expanduser("~/supervisely.env")) +api = sly.Api.from_env() # if app had started from context menu, one of this has to be set: project_id = sly.env.project_id(raise_not_found=False) @@ -169,18 +174,50 @@ def update_globals(new_dataset_ids): content = Container([btn_run, check_force_recalculate, progress, info_run]) card_run = Card(title="Run", content=content) +### Embeddings Chart Settings +dot_size_btn = Button("Change dots size", button_size="small") +dot_size_num = InputNumber(min=0.01, value=0.05, step=0.01) +dot_size = Flexbox([Container([dot_size_num, dot_size_btn], direction="horizontal")]) ### Embeddings Chart -bokeh = Bokeh(plots=[], x_axis_visible=True, y_axis_visible=True, grid_visible=True) +bokeh = Bokeh( + plots=[], + x_axis_visible=True, + y_axis_visible=True, + grid_visible=True, + show_legend=True, + legend_location="right", + legend_click_policy="hide", +) bokeh_iframe = IFrame() -card_chart = Card(content=bokeh_iframe) +card_chart = Card(content=Container([dot_size, bokeh_iframe]), title="Embeddings chart", collapsable=True) labeled_image = LabeledImage() text = Text("no object selected") show_all_anns = False cur_info = None btn_toggle = Button(f"Show all annotations: {show_all_anns}", "default", button_size="small") btn_mark = Button(f"Assign tag 'MARKED'", button_size="small") -card_preview = Card(title="Object preview", content=Container(widgets=[labeled_image, text, btn_toggle, btn_mark])) +preview_widgets = Container([labeled_image, text, btn_toggle, btn_mark]) +preview_widgets.hide() + + +cur_infos = None +batch_text = Text() +issue_tag_text = Text() +add_issue_tag = Button(f"Asign 'ISSUE' tags", button_size="small", plain=True) +job_issue = Button(f"Create Labeling Job", button_size="small", plain=True) +batch_tagging_field = Field( + Container([Flexbox([add_issue_tag, Empty(), job_issue]), issue_tag_text]), + "Issues", + "Assign 'ISSUE' tag to IMAGES or Create Labeling Job", +) +batch_tagging_cont = Container([batch_text, batch_tagging_field]) +batch_tagging_cont.hide() + +card_preview = Card( + title="Preview card", + content=Container(widgets=[preview_widgets, batch_tagging_cont]), +) card_embeddings_chart = Container(widgets=[card_chart, card_preview], direction="horizontal", fractions=[3, 1]) card_embeddings_chart.hide() @@ -199,6 +236,12 @@ def update_globals(new_dataset_ids): ) +@dot_size_btn.click +def change_dot_size(): + bokeh.update_radii(dot_size_num.value) + bokeh_iframe.set(bokeh.html_route_with_timestamp, height="650px", width="100%") + + @btn_toggle.click def toggle_ann(): global show_all_anns @@ -209,10 +252,19 @@ def toggle_ann(): @bokeh.value_changed -def on_click(selected_idxs): - global global_idxs_mapping, all_info_list, project_meta, is_marked, tag_meta - if len(selected_idxs) >= 1: - info = all_info_list[selected_idxs[0]] +def on_click(selected_idxs: List[Tuple[Union[int, str], List[int]]]): + global global_idxs_mapping, all_info_list, project_meta, is_marked, tag_meta, cur_infos + + issue_tag_text.text = "" + batch_tagging_cont.show() + + selected_ids = [global_idxs_mapping[d.plot_id][i] for d in selected_idxs for i in d.selected_ids] + selected_cnt = len(selected_ids) + if selected_cnt == 1: + batch_text.text = "" + preview_widgets.show() + info = all_info_list[selected_ids[0]] + cur_infos = [info] if tag_meta is not None: tag = read_tag(info["image_id"], info["object_id"]) is_marked = bool(tag) @@ -220,6 +272,45 @@ def on_click(selected_idxs): show_image(info, project_meta) if btn_mark.is_hidden(): btn_mark.show() + elif selected_cnt > 1: + preview_widgets.hide() + cur_infos = [all_info_list[i] for i in selected_ids] + obj_clss = list(set([info["object_cls"] for info in cur_infos])) + is_objects = any([info["object_id"] is not None for info in cur_infos]) + is_images = any([info["object_id"] is None for info in cur_infos]) + both = is_objects and is_images + + t = f"{len(cur_infos)} " + t += "items. " if both else "images. " if is_images else "objects. " + t += f"Object classes: {str(obj_clss)}. " + batch_text.set(t, "info") + + +@job_issue.click +def create_labeling_job(): + global cur_infos + issue_tag_text.text = "" + if cur_infos is not None: + ds_id_to_img_ids = defaultdict(set) + for info in cur_infos: + ds_id_to_img_ids[info["dataset_id"]].add(info["image_id"]) + jobs = [] + for ds_id, img_ids in ds_id_to_img_ids.items(): + if len(img_ids) > 0: + jobs.extend( + api.labeling_job.create( + f"Labeling job for {project_info.name} project embeddings", + ds_id, + [api.user.get_my_info().id], + images_ids=list(img_ids), + # include_images_with_tags=[issue_tag_name], + ) + ) + if len(jobs) > 0: + ids = [job.id for job in jobs] + issue_tag_text.set(f"Labeling jobs created IDs: {ids}", "success") + else: + issue_tag_text.set("No objects to create labeling job", "warning") def update_marked(): @@ -230,6 +321,34 @@ def update_marked(): btn_mark.text = "Assign tag 'MARKED'" +@add_issue_tag.click +def issue_tagging(): + global project_meta, cur_infos, issue_tag_meta + if issue_tag_meta is None: + print("first marking, creating tag_meta") + issue_tag_meta = sly.TagMeta(issue_tag_name, sly.TagValueType.NONE) + project_meta, issue_tag_meta = get_or_create_tag_meta(project_id, issue_tag_meta) + + ds_ids_to_img_ids = defaultdict(set) + for info in cur_infos: + ds_ids_to_img_ids[info["dataset_id"]].add(info["image_id"]) + + added = 0 + for ds_id, img_ids in ds_ids_to_img_ids.items(): + img_ids_to_mark = [] + img_ids = list(img_ids) + for img_id, tag in zip(img_ids, read_img_tags(ds_id, img_ids, issue_tag_meta)): + if tag is None: + img_ids_to_mark.append(img_id) + + if len(img_ids_to_mark) > 0: + add_img_tags(list(img_ids_to_mark), issue_tag_meta) + added += len(img_ids_to_mark) + + if added > 0: + issue_tag_text.set(f"Assigned 'ISSUE' tags: {added} images", "success") + + @btn_mark.click def on_mark(): global project_info, project_meta, tag_meta, cur_info, is_marked @@ -280,7 +399,7 @@ def update_table(): @btn_run.click def run(): - global model_name, global_idxs_mapping, all_info_list # , project_meta, dataset_ids, project_id, workspace_id, team_id + global model_name, global_idxs_mapping, all_info_list, instance_mode selected_datasets = set() for dataset_id in dataset_selector.get_selected_ids(): @@ -385,17 +504,22 @@ def run(): print(f"n_classes = {len(obj_classes)}") series, pre_colors, global_idxs_mapping = run_utils.make_series(projections, all_info_list, project_meta) - series_len = len(series) - x_coordinates, y_coordinates, colors = [], [], [] - for s, color in zip(series, pre_colors): - x_coordinates.extend([i["x"] for i in s["data"]]) - y_coordinates.extend([i["y"] for i in s["data"]]) - colors.extend([color] * len(s["data"])) - - r = 0.15 if series_len > 1000 else 0.1 - plot = Bokeh.Circle(x_coordinates, y_coordinates, radii=r, colors=colors) bokeh.clear() - bokeh.add_plots([plot]) + plots = [] + for s, color in zip(series, pre_colors): + x_coordinates = [i["x"] for i in s["data"]] + y_coordinates = [i["y"] for i in s["data"]] + r = 0.05 + plot = Bokeh.Circle( + x_coordinates, + y_coordinates, + radii=r, + colors=[color] * len(s["data"]), + legend_label=s["name"], + plot_id=s["name"], + ) + plots.append(plot) + bokeh.add_plots(plots) bokeh_iframe.set(bokeh.html_route_with_timestamp, height="650px", width="100%") card_embeddings_chart.show() update_table() @@ -421,6 +545,24 @@ def get_tag_meta(project_id, name) -> sly.TagMeta: return project_meta.get_tag_meta(name) +def read_img_tags(ds_id, image_ids, tag_meta): + tags = [] + if len(image_ids) > 0: + filters = [{"field": "id", "operator": "in", "value": image_ids}] + image_infos = api.image.get_list(ds_id, filters=filters, force_metadata_for_links=False) + id_to_info = {img_info.id: img_info for img_info in image_infos} + for img_id in image_ids: + curr_tags = [tag for tag in id_to_info[img_id].tags if tag["tagId"] == tag_meta.sly_id] + tags.append(curr_tags[0] if len(curr_tags) == 1 else None) + return tags + + +def read_labels_tags(object_ids, tag_meta): + if len(object_ids) == 0: + return [] + return [read_label_tag(obj_id, tag_meta) for obj_id in object_ids] + + def read_img_tag(image_id, tag_meta): image_info = api.image.get_info_by_id(image_id) tags = [tag for tag in image_info.tags if tag["tagId"] == tag_meta.sly_id] @@ -444,6 +586,10 @@ def read_tag(image_id, object_id): return read_label_tag(object_id, tag_meta) +def add_img_tags(image_ids, tag_meta, value=None): + return api.image.add_tag_batch(image_ids=image_ids, tag_id=tag_meta.sly_id, value=value) + + def add_img_tag(image_id, tag_meta, value=None): return api.image.add_tag(image_id=image_id, tag_id=tag_meta.sly_id, value=value) diff --git a/src/run_utils.py b/src/run_utils.py index 5da34d9..4a2722c 100644 --- a/src/run_utils.py +++ b/src/run_utils.py @@ -9,7 +9,6 @@ import sklearn.cluster import sklearn.decomposition import umap -from matplotlib.colors import rgb2hex import re @@ -79,7 +78,7 @@ def make_series(projections, all_info_list, project_meta): global_idxs_mapping[obj_cls].append(i) series = [{"name": k, "data": v} for k, v in series.items()] - obj2color = {x.name: rgb2hex(np.array(x.color) / 255) for x in project_meta.obj_classes.items()} + obj2color = {x.name: sly.color.rgb2hex(x.color) for x in project_meta.obj_classes} obj2color["Image"] = "#222222" colors = [obj2color[s["name"]] for s in series]