Skip to content

Commit

Permalink
updates and fixes in chart
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 30, 2024
1 parent 85a0308 commit d8b102d
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 29 deletions.
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN pip3 install transformers==4.33.2 timm==0.9.5 scikit-learn==1.3.1 umap-learn
# Download metaclip base model
RUN python -c 'import transformers; transformers.AutoModel.from_pretrained("facebook/metaclip-b16-fullcc2.5b")'

RUN pip3 install supervisely==6.73.258
RUN pip3 install supervisely==6.73.266
RUN pip3 install fastapi==0.109.0 bokeh==3.1.1

LABEL python_sdk_version=6.73.258
LABEL python_sdk_version=6.73.266
196 changes: 171 additions & 25 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from supervisely.app.content import StateJson, DataJson
from dotenv import load_dotenv
import torch
import re
from typing import List, Union, Tuple
from supervisely.app.widgets import (
ScatterChart,
Container,
Card,
LabeledImage,
Expand All @@ -23,19 +22,20 @@
Button,
Field,
Progress,
SelectDataset,
IFrame,
Bokeh,
SelectDatasetTree,
NotificationBox,
Empty,
Flexbox,
)

from . import run_utils
from . import calculate_embeddings


def update_globals(new_dataset_ids):
global dataset_ids, project_id, workspace_id, team_id, project_info, project_meta, is_marked, tag_meta
global dataset_ids, project_id, workspace_id, team_id, project_info, project_meta, is_marked, tag_meta, issue_tag_meta
dataset_ids = new_dataset_ids
if dataset_ids:
project_id = api.dataset.get_info_by_id(dataset_ids[0]).project_id
Expand All @@ -57,15 +57,20 @@ def update_globals(new_dataset_ids):
is_marked = False
tag_meta = project_meta.get_tag_meta(tag_name)
print("tag_meta is exists:", bool(tag_meta))
issue_tag_meta = project_meta.get_tag_meta(issue_tag_name)
print("issue_tag_meta is exists:", bool(issue_tag_meta))


### Globals init
available_projection_methods = ["UMAP", "PCA", "t-SNE", "PCA-UMAP", "PCA-t-SNE"]
tag_name = "MARKED"
issue_tag_name = "ISSUE"
instance_mode = None

load_dotenv("local.env")
load_dotenv(os.path.expanduser("~/supervisely.env"))
api = sly.Api()
if sly.is_development():
load_dotenv("local.env")
load_dotenv(os.path.expanduser("~/supervisely.env"))
api = sly.Api.from_env()

# if app had started from context menu, one of this has to be set:
project_id = sly.env.project_id(raise_not_found=False)
Expand Down Expand Up @@ -169,18 +174,50 @@ def update_globals(new_dataset_ids):
content = Container([btn_run, check_force_recalculate, progress, info_run])
card_run = Card(title="Run", content=content)

### Embeddings Chart Settings
dot_size_btn = Button("Change dots size", button_size="small")
dot_size_num = InputNumber(min=0.01, value=0.05, step=0.01)
dot_size = Flexbox([Container([dot_size_num, dot_size_btn], direction="horizontal")])

### Embeddings Chart
bokeh = Bokeh(plots=[], x_axis_visible=True, y_axis_visible=True, grid_visible=True)
bokeh = Bokeh(
plots=[],
x_axis_visible=True,
y_axis_visible=True,
grid_visible=True,
show_legend=True,
legend_location="right",
legend_click_policy="hide",
)
bokeh_iframe = IFrame()
card_chart = Card(content=bokeh_iframe)
card_chart = Card(content=Container([dot_size, bokeh_iframe]), title="Embeddings chart", collapsable=True)
labeled_image = LabeledImage()
text = Text("no object selected")
show_all_anns = False
cur_info = None
btn_toggle = Button(f"Show all annotations: {show_all_anns}", "default", button_size="small")
btn_mark = Button(f"Assign tag 'MARKED'", button_size="small")
card_preview = Card(title="Object preview", content=Container(widgets=[labeled_image, text, btn_toggle, btn_mark]))
preview_widgets = Container([labeled_image, text, btn_toggle, btn_mark])
preview_widgets.hide()


cur_infos = None
batch_text = Text()
issue_tag_text = Text()
add_issue_tag = Button(f"Asign 'ISSUE' tags", button_size="small", plain=True)
job_issue = Button(f"Create Labeling Job", button_size="small", plain=True)
batch_tagging_field = Field(
Container([Flexbox([add_issue_tag, Empty(), job_issue]), issue_tag_text]),
"Issues",
"Assign 'ISSUE' tag to IMAGES or Create Labeling Job",
)
batch_tagging_cont = Container([batch_text, batch_tagging_field])
batch_tagging_cont.hide()

card_preview = Card(
title="Preview card",
content=Container(widgets=[preview_widgets, batch_tagging_cont]),
)
card_embeddings_chart = Container(widgets=[card_chart, card_preview], direction="horizontal", fractions=[3, 1])
card_embeddings_chart.hide()

Expand All @@ -199,6 +236,12 @@ def update_globals(new_dataset_ids):
)


@dot_size_btn.click
def change_dot_size():
bokeh.update_radii(dot_size_num.value)
bokeh_iframe.set(bokeh.html_route_with_timestamp, height="650px", width="100%")


@btn_toggle.click
def toggle_ann():
global show_all_anns
Expand All @@ -209,17 +252,65 @@ def toggle_ann():


@bokeh.value_changed
def on_click(selected_idxs):
global global_idxs_mapping, all_info_list, project_meta, is_marked, tag_meta
if len(selected_idxs) >= 1:
info = all_info_list[selected_idxs[0]]
def on_click(selected_idxs: List[Tuple[Union[int, str], List[int]]]):
global global_idxs_mapping, all_info_list, project_meta, is_marked, tag_meta, cur_infos

issue_tag_text.text = ""
batch_tagging_cont.show()

selected_ids = [global_idxs_mapping[d.plot_id][i] for d in selected_idxs for i in d.selected_ids]
selected_cnt = len(selected_ids)
if selected_cnt == 1:
batch_text.text = ""
preview_widgets.show()
info = all_info_list[selected_ids[0]]
cur_infos = [info]
if tag_meta is not None:
tag = read_tag(info["image_id"], info["object_id"])
is_marked = bool(tag)
update_marked()
show_image(info, project_meta)
if btn_mark.is_hidden():
btn_mark.show()
elif selected_cnt > 1:
preview_widgets.hide()
cur_infos = [all_info_list[i] for i in selected_ids]
obj_clss = list(set([info["object_cls"] for info in cur_infos]))
is_objects = any([info["object_id"] is not None for info in cur_infos])
is_images = any([info["object_id"] is None for info in cur_infos])
both = is_objects and is_images

t = f"{len(cur_infos)} "
t += "items. " if both else "images. " if is_images else "objects. "
t += f"Object classes: {str(obj_clss)}. "
batch_text.set(t, "info")


@job_issue.click
def create_labeling_job():
global cur_infos
issue_tag_text.text = ""
if cur_infos is not None:
ds_id_to_img_ids = defaultdict(set)
for info in cur_infos:
ds_id_to_img_ids[info["dataset_id"]].add(info["image_id"])
jobs = []
for ds_id, img_ids in ds_id_to_img_ids.items():
if len(img_ids) > 0:
jobs.extend(
api.labeling_job.create(
f"Labeling job for {project_info.name} project embeddings",
ds_id,
[api.user.get_my_info().id],
images_ids=list(img_ids),
# include_images_with_tags=[issue_tag_name],
)
)
if len(jobs) > 0:
ids = [job.id for job in jobs]
issue_tag_text.set(f"Labeling jobs created IDs: {ids}", "success")
else:
issue_tag_text.set("No objects to create labeling job", "warning")


def update_marked():
Expand All @@ -230,6 +321,34 @@ def update_marked():
btn_mark.text = "Assign tag 'MARKED'"


@add_issue_tag.click
def issue_tagging():
global project_meta, cur_infos, issue_tag_meta
if issue_tag_meta is None:
print("first marking, creating tag_meta")
issue_tag_meta = sly.TagMeta(issue_tag_name, sly.TagValueType.NONE)
project_meta, issue_tag_meta = get_or_create_tag_meta(project_id, issue_tag_meta)

ds_ids_to_img_ids = defaultdict(set)
for info in cur_infos:
ds_ids_to_img_ids[info["dataset_id"]].add(info["image_id"])

added = 0
for ds_id, img_ids in ds_ids_to_img_ids.items():
img_ids_to_mark = []
img_ids = list(img_ids)
for img_id, tag in zip(img_ids, read_img_tags(ds_id, img_ids, issue_tag_meta)):
if tag is None:
img_ids_to_mark.append(img_id)

if len(img_ids_to_mark) > 0:
add_img_tags(list(img_ids_to_mark), issue_tag_meta)
added += len(img_ids_to_mark)

if added > 0:
issue_tag_text.set(f"Assigned 'ISSUE' tags: {added} images", "success")


@btn_mark.click
def on_mark():
global project_info, project_meta, tag_meta, cur_info, is_marked
Expand Down Expand Up @@ -280,7 +399,7 @@ def update_table():

@btn_run.click
def run():
global model_name, global_idxs_mapping, all_info_list # , project_meta, dataset_ids, project_id, workspace_id, team_id
global model_name, global_idxs_mapping, all_info_list, instance_mode

selected_datasets = set()
for dataset_id in dataset_selector.get_selected_ids():
Expand Down Expand Up @@ -385,17 +504,22 @@ def run():
print(f"n_classes = {len(obj_classes)}")
series, pre_colors, global_idxs_mapping = run_utils.make_series(projections, all_info_list, project_meta)

series_len = len(series)
x_coordinates, y_coordinates, colors = [], [], []
for s, color in zip(series, pre_colors):
x_coordinates.extend([i["x"] for i in s["data"]])
y_coordinates.extend([i["y"] for i in s["data"]])
colors.extend([color] * len(s["data"]))

r = 0.15 if series_len > 1000 else 0.1
plot = Bokeh.Circle(x_coordinates, y_coordinates, radii=r, colors=colors)
bokeh.clear()
bokeh.add_plots([plot])
plots = []
for s, color in zip(series, pre_colors):
x_coordinates = [i["x"] for i in s["data"]]
y_coordinates = [i["y"] for i in s["data"]]
r = 0.05
plot = Bokeh.Circle(
x_coordinates,
y_coordinates,
radii=r,
colors=[color] * len(s["data"]),
legend_label=s["name"],
plot_id=s["name"],
)
plots.append(plot)
bokeh.add_plots(plots)
bokeh_iframe.set(bokeh.html_route_with_timestamp, height="650px", width="100%")
card_embeddings_chart.show()
update_table()
Expand All @@ -421,6 +545,24 @@ def get_tag_meta(project_id, name) -> sly.TagMeta:
return project_meta.get_tag_meta(name)


def read_img_tags(ds_id, image_ids, tag_meta):
tags = []
if len(image_ids) > 0:
filters = [{"field": "id", "operator": "in", "value": image_ids}]
image_infos = api.image.get_list(ds_id, filters=filters, force_metadata_for_links=False)
id_to_info = {img_info.id: img_info for img_info in image_infos}
for img_id in image_ids:
curr_tags = [tag for tag in id_to_info[img_id].tags if tag["tagId"] == tag_meta.sly_id]
tags.append(curr_tags[0] if len(curr_tags) == 1 else None)
return tags


def read_labels_tags(object_ids, tag_meta):
if len(object_ids) == 0:
return []
return [read_label_tag(obj_id, tag_meta) for obj_id in object_ids]


def read_img_tag(image_id, tag_meta):
image_info = api.image.get_info_by_id(image_id)
tags = [tag for tag in image_info.tags if tag["tagId"] == tag_meta.sly_id]
Expand All @@ -444,6 +586,10 @@ def read_tag(image_id, object_id):
return read_label_tag(object_id, tag_meta)


def add_img_tags(image_ids, tag_meta, value=None):
return api.image.add_tag_batch(image_ids=image_ids, tag_id=tag_meta.sly_id, value=value)


def add_img_tag(image_id, tag_meta, value=None):
return api.image.add_tag(image_id=image_id, tag_id=tag_meta.sly_id, value=value)

Expand Down
3 changes: 1 addition & 2 deletions src/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sklearn.cluster
import sklearn.decomposition
import umap
from matplotlib.colors import rgb2hex
import re


Expand Down Expand Up @@ -79,7 +78,7 @@ def make_series(projections, all_info_list, project_meta):
global_idxs_mapping[obj_cls].append(i)

series = [{"name": k, "data": v} for k, v in series.items()]
obj2color = {x.name: rgb2hex(np.array(x.color) / 255) for x in project_meta.obj_classes.items()}
obj2color = {x.name: sly.color.rgb2hex(x.color) for x in project_meta.obj_classes}
obj2color["Image"] = "#222222"
colors = [obj2color[s["name"]] for s in series]

Expand Down

0 comments on commit d8b102d

Please sign in to comment.