Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Dec 12, 2024
1 parent b78510c commit 1288e3a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 39 deletions.
12 changes: 12 additions & 0 deletions train/src/ui/hyperparameters.html
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@
<span class="ml5">iterations</span>
</div>
</sly-field>
<sly-field title="Run Model Benchmark evaluation">
<el-switch v-model="state.runBenchmark"
on-color="#13ce66" off-color="#B8B8B8"
:disabled="data.done6">
</el-switch>
</sly-field>
<sly-field title="Run Speed test">
<el-switch v-model="state.runSpeedTest"
on-color="#13ce66" off-color="#B8B8B8"
:disabled="data.done6">
</el-switch>
</sly-field>
</el-tab-pane>
<el-tab-pane label="Checkpoints" name="checkpoints">
<sly-field title="Checkpoints interval"
Expand Down
4 changes: 3 additions & 1 deletion train/src/ui/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def init_general(state):

state["valInterval"] = 1
state["logConfigInterval"] = 5
state["runBenchmark"] = True
state["runSpeedTest"] = True

def init_checkpoints(state):
state["checkpointInterval"] = 1
Expand Down Expand Up @@ -99,4 +101,4 @@ def use_hyp(api: sly.Api, task_id, context, state, app_logger):
{"field": "state.disabled7", "payload": False},
{"field": "state.activeStep", "payload": 7},
]
g.api.app.set_fields(g.task_id, fields)
g.api.app.set_fields(g.task_id, fields)
10 changes: 5 additions & 5 deletions train/src/ui/monitoring.html
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@
>
</div>
</div>
<!-- <div v-if="data.benchmarkUrl">
<sly-field title="" :description="Open the Model Benchmark evaluation report.">
<div v-if="data.benchmarkUrl">
<sly-field title="" description="Open the Model Benchmark evaluation report.">
<a slot="title" target="_blank" :href="data.benchmarkUrl">Evaluation Report</a>
<sly-icon
slot="icon"
:options="{ color: '#dcb0ff', bgColor: '#faebff', className: 'zmdi zmdi-assignment' }"
/>
</sly-field>
</div> -->
</div>
<sly-field v-if="state.preparingData" class="mt10">
<b style="color: #20a0ff"
>Preparing segmentation data (it may take a few minutes)...</b
Expand All @@ -112,13 +112,13 @@
</div>
<el-progress :percentage="data.progressPercentUploadDir"></el-progress>
</div>
<!-- <div v-if="data.progressBenchmark && !data.benchmarkUrl" class="mt10"></div>
<div v-if="state.benchmarkInProgress && data.progressBenchmark" class="mt10"></div>
<div style="color: #20a0ff">
{{data.progressBenchmark}}: {{data.progressCurrentBenchmark}} /
{{data.progressTotalBenchmark}}
</div>
<el-progress :percentage="data.progressPercentBenchmark"></el-progress>
</div> -->
</div>
<div v-if="data.progressEpoch" class="mt10">
<div style="color: #20a0ff">
{{data.progressEpoch}}: {{data.progressCurrentEpoch}} /
Expand Down
68 changes: 35 additions & 33 deletions train/src/ui/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def external_callback(progress: sly.tqdm_sly):
percent = math.floor(progress.n / progress.total * 100)
fields = [
{"field": f"data.progressBenchmark", "payload": progress.message},
{"field": f"data.progressBenchmark", "payload": progress.desc},
{"field": f"data.progressCurrentBenchmark", "payload": progress.n},
{"field": f"data.progressTotalBenchmark", "payload": progress.total},
{"field": f"data.progressPercentBenchmark", "payload": percent},
Expand All @@ -45,6 +45,7 @@ def update(self, n=1):
_open_lnk_name = "open_app.lnk"
m = None


def init(data, state):
init_progress("Epoch", data)
init_progress("Iter", data)
Expand Down Expand Up @@ -74,13 +75,14 @@ def init(data, state):
data["outputName"] = None
data["outputUrl"] = None
data["benchmarkUrl"] = None
data["benchmarkInProgress"] = False


def init_devices():
try:
from torch import cuda
except ImportError as ie:
sly.logger.warn(
sly.logger.warning(
"Unable to import Torch. Please, run 'pip install torch' to resolve the issue.",
extra={"error message": str(ie)},
)
Expand All @@ -89,8 +91,8 @@ def init_devices():
devices = []
cuda.init()
if not cuda.is_available():
sly.logger.warn("CUDA is not available")
return
sly.logger.warning("CUDA is not available")
return devices

for idx in range(cuda.device_count()):
current_device = f"cuda:{idx}"
Expand Down Expand Up @@ -328,7 +330,6 @@ def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):
global m

benchmark_report_template = None
# if run_model_benchmark_checkbox.is_checked():
try:
from sly_mmsegm import MMSegmentationModelBench
import torch
Expand Down Expand Up @@ -402,7 +403,7 @@ def run_benchmark(api: sly.Api, task_id, classes, cfg, state, remote_dir):
arch_type=arch_type,
)
m._load_model(deploy_params)
asyncio.set_event_loop(asyncio.new_event_loop()) # fix for the issue with the event loop
asyncio.set_event_loop(asyncio.new_event_loop())
m.serve()

import requests
Expand All @@ -418,7 +419,7 @@ def run_app():

while True:
try:
response = requests.get("http://localhost:8000")
requests.get("http://localhost:8000")
print("✅ Local server is ready")
break
except requests.exceptions.ConnectionError:
Expand Down Expand Up @@ -448,7 +449,10 @@ def get_image_infos_by_split(split: list):
ds_infos_dict = {ds_info.name: ds_info for ds_info in dataset_infos}
image_names_per_dataset = {}
for item in split:
image_names_per_dataset.setdefault(item.dataset_name, []).append(item.name)
name = item.name
if name[1] == "_":
name = name[2:]
image_names_per_dataset.setdefault(item.dataset_name, []).append(name)
image_infos = []
for dataset_name, image_names in image_names_per_dataset.items():
ds_info = ds_infos_dict[dataset_name]
Expand All @@ -473,14 +477,15 @@ def get_image_infos_by_split(split: list):
benchmark_images_ids = [img_info.id for img_info in val_image_infos]
train_images_ids = [img_info.id for img_info in train_image_infos]

model_benchmark_pbar = TqdmBenchmark
state["benchmarkInProgress"] = True
bm = sly.nn.benchmark.SemanticSegmentationBenchmark(
api,
g.project_info.id,
output_dir=g.data_dir + "/benchmark",
gt_dataset_ids=benchmark_dataset_ids,
gt_images_ids=benchmark_images_ids,
progress=model_benchmark_pbar,
progress=TqdmBenchmark,
progress_secondary=TqdmBenchmark,
classes_whitelist=classes,
)

Expand All @@ -507,36 +512,33 @@ def get_image_infos_by_split(split: list):
bm.upload_eval_results(eval_res_dir + "/evaluation/")

# # 6. Speed test
try:
session_info = session.get_session_info()
support_batch_inference = session_info.get("batch_inference_support", False)
max_batch_size = session_info.get("max_batch_size")
batch_sizes = (1, 8, 16)
if not support_batch_inference:
batch_sizes = (1,)
elif max_batch_size is not None:
batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size])
bm.run_speedtest(session, g.project_info.id, batch_sizes=batch_sizes)
bm.upload_speedtest_results(eval_res_dir + "/speedtest/")
except Exception as e:
sly.logger.warning(f"Speedtest failed. Skipping. {e}")
if state["runSpeedTest"]:
try:
session_info = session.get_session_info()
support_batch_inference = session_info.get("batch_inference_support", False)
max_batch_size = session_info.get("max_batch_size")
batch_sizes = (1, 8, 16)
if not support_batch_inference:
batch_sizes = (1,)
elif max_batch_size is not None:
batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size])
bm.run_speedtest(session, g.project_info.id, batch_sizes=batch_sizes)
bm.upload_speedtest_results(eval_res_dir + "/speedtest/")
except Exception as e:
sly.logger.warning(f"Speedtest failed. Skipping. {e}")

# 7. Prepare visualizations, report and
bm.visualize()
remote_dir = bm.upload_visualizations(eval_res_dir + "/visualizations/")
report = bm.upload_report_link(remote_dir)

# 8. UI updates
benchmark_report_template = api.file.get_info_by_path(
sly.env.team_id(), remote_dir + "template.vue"
)
lnk = f"/model-benchmark?id={benchmark_report_template.id}"
lnk = abs_url(lnk) if is_development() or is_debug_with_sly_net() else lnk
benchmark_report_template = bm.report

fields = [
{"field": f"data.progressBenchmark", "payload": False},
{"field": f"data.benchmarkUrl", "payload": lnk},
{"field": f"data.benchmarkUrl", "payload": bm.get_report_link()},
]
state["benchmarkInProgress"] = False
api.app.set_fields(g.task_id, fields)
sly.logger.info(
f"Predictions project name: {bm.dt_project_info.name}. Workspace_id: {bm.dt_project_info.workspace_id}"
Expand All @@ -552,6 +554,7 @@ def get_image_infos_by_split(split: list):

return benchmark_report_template


@g.my_app.callback("train")
@sly.timeit
@g.my_app.ignore_errors_and_show_dialog_window()
Expand Down Expand Up @@ -621,10 +624,9 @@ def train(api: sly.Api, task_id, context, state, app_logger):
g.api.app.set_fields(g.task_id, fields)

benchmark_report_template = None

# run benchmark
# if state["runModelBenchmark"]:
benchmark_report_template = run_benchmark(api, task_id, classes, cfg, state, remote_dir)
if state["runBenchmark"]:
benchmark_report_template = run_benchmark(api, task_id, classes, cfg, state, remote_dir)

w.workflow_input(api, g.project_info, state)
w.workflow_output(api, g.sly_mmseg_generated_metadata, state, benchmark_report_template)
Expand Down

0 comments on commit 1288e3a

Please sign in to comment.