Skip to content

Commit

Permalink
added checkpoints dir
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoLisin committed Nov 10, 2023
1 parent 41e61ab commit 6d52984
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 48 deletions.
3 changes: 2 additions & 1 deletion train/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
src/__pycache__/
local.env
tempfiles/
tempfiles/
runs/
5 changes: 4 additions & 1 deletion train/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

if sly.is_production():
app_session_id = sly.io.env.task_id()
root_model_checkpoint_dir = sly.app.get_synced_data_dir()
else:
app_session_id = 777 # for debug
root_model_checkpoint_dir = os.path.join(app_root_directory, "runs")


det_models_data_path = os.path.join(root_source_path, "models", "det_models_data.json")
seg_models_data_path = os.path.join(root_source_path, "models", "seg_models_data.json")
Expand All @@ -29,4 +32,4 @@
train_params_filepath = "training_params.yml" # for debug
train_counter, val_counter = 0, 0
center_matches = {}
keypoints_template = None
keypoints_template = None
159 changes: 113 additions & 46 deletions train/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ def update_globals(new_dataset_ids):
reselect_data_button,
]
)
card_project_settings = Card(
title="Dataset selection", content=project_settings_content
)
card_project_settings = Card(title="Dataset selection", content=project_settings_content)


### 2. Project classes
Expand Down Expand Up @@ -658,10 +656,14 @@ def select_task(task_type):
def select_classes():
selected_classes = classes_table.get_selected_classes()
selected_shapes = [
cls.geometry_type.geometry_name() for cls in project_meta.obj_classes if cls.name in selected_classes
cls.geometry_type.geometry_name()
for cls in project_meta.obj_classes
if cls.name in selected_classes
]
task_type = task_type_select.get_value()
if task_type == "pose estimation" and ("graph" not in selected_shapes or "rectangle" not in selected_shapes):
if task_type == "pose estimation" and (
"graph" not in selected_shapes or "rectangle" not in selected_shapes
):
sly.app.show_dialog(
title="Pose estimation task requires input project to have at least one class of shape graph and one class of shape rectangle",
description="Please, select both classes of shape rectangle and graph or change task type",
Expand Down Expand Up @@ -762,7 +764,9 @@ def change_file_preview(value):
@additional_config_radio.value_changed
def change_radio(value):
if value == "import template from Team Files":
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
templates = api.file.list(team_id, remote_templates_dir)
if len(templates) == 0:
no_templates_notification.show()
Expand All @@ -777,7 +781,9 @@ def change_radio(value):

@additional_config_template_select.value_changed
def change_template(template):
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
remote_template_path = os.path.join(remote_templates_dir, template)
local_template_path = os.path.join(g.app_data_dir, template)
api.file.download(team_id, remote_template_path, local_template_path)
Expand All @@ -789,7 +795,9 @@ def change_template(template):
@save_template_button.click
def upload_template():
save_template_button.loading = True
remote_templates_dir = os.path.join("/yolov8_train", task_type_select.get_value(), "param_templates")
remote_templates_dir = os.path.join(
"/yolov8_train", task_type_select.get_value(), "param_templates"
)
additional_params = train_settings_editor.get_text()
ryaml = ruamel.yaml.YAML()
additional_params = ryaml.load(additional_params)
Expand Down Expand Up @@ -861,19 +869,20 @@ def change_logs_visibility():
@start_training_button.click
def start_training():
task_type = task_type_select.get_value()
if sly.is_production():
local_dir = g.root_source_path
else:
local_dir = g.app_root_directory

local_dir = g.root_model_checkpoint_dir
if task_type == "object detection":
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
checkpoint_dir = os.path.join(local_dir, "detect")
local_artifacts_dir = os.path.join(local_dir, "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
checkpoint_dir = os.path.join(local_dir, "pose")
local_artifacts_dir = os.path.join(local_dir, "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
local_artifacts_dir = os.path.join(local_dir, "runs", "segment", "train")
checkpoint_dir = os.path.join(local_dir, "segment")
local_artifacts_dir = os.path.join(local_dir, "segment", "train")

sly.logger.info(f"Local artifacts dir: {local_artifacts_dir}")

Expand Down Expand Up @@ -904,10 +913,15 @@ def start_training():
if task_type != "object detection":
unnecessary_classes = []
for cls in project_meta.obj_classes:
if cls.name in selected_classes and cls.geometry_type.geometry_name() not in necessary_geometries:
if (
cls.name in selected_classes
and cls.geometry_type.geometry_name() not in necessary_geometries
):
unnecessary_classes.append(cls.name)
if len(unnecessary_classes) > 0:
sly.Project.remove_classes(g.project_dir, classes_to_remove=unnecessary_classes, inplace=True)
sly.Project.remove_classes(
g.project_dir, classes_to_remove=unnecessary_classes, inplace=True
)
# transfer project to detection task if necessary
if task_type == "object detection":
sly.Project.to_detection_task(g.project_dir, inplace=True)
Expand All @@ -928,7 +942,9 @@ def start_training():
description="Val split length is 0 after ignoring images. Please check your data",
status="error",
)
raise ValueError("Val split length is 0 after ignoring images. Please check your data")
raise ValueError(
"Val split length is 0 after ignoring images. Please check your data"
)
# split the data
train_val_split._project_fs = sly.Project(g.project_dir, sly.OpenMode.READ)
train_set, val_set = train_val_split.get_splits()
Expand Down Expand Up @@ -963,7 +979,9 @@ def download_monitor(monitor, api: sly.Api, progress: sly.Progress):
model_filename = selected_model.lower() + ".pt"
pretrained = True
weights_dst_path = os.path.join(g.app_data_dir, model_filename)
weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
weights_url = (
f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
)
with urlopen(weights_url) as file:
weights_size = file.length

Expand Down Expand Up @@ -1110,10 +1128,16 @@ def on_results_file_changed(filepath, pbar):
# visualize train batch
batch = f"train_batch{x}.jpg"
local_train_batches_path = os.path.join(local_artifacts_dir, batch)
if os.path.exists(local_train_batches_path) and batch not in plotted_train_batches and x < 10:
if (
os.path.exists(local_train_batches_path)
and batch not in plotted_train_batches
and x < 10
):
plotted_train_batches.append(batch)
remote_train_batches_path = os.path.join(remote_images_path, batch)
tf_train_batches_info = api.file.upload(team_id, local_train_batches_path, remote_train_batches_path)
tf_train_batches_info = api.file.upload(
team_id, local_train_batches_path, remote_train_batches_path
)
train_batches_gallery.append(tf_train_batches_info.full_storage_url)
if x == 0:
train_batches_gallery_f.show()
Expand Down Expand Up @@ -1184,6 +1208,7 @@ def stop_on_batch_end_if_needed(*args, **kwargs):
workers=n_workers_input.get_value(),
optimizer=select_optimizer.get_value(),
pretrained=pretrained,
project=checkpoint_dir,
**additional_params,
)

Expand Down Expand Up @@ -1218,8 +1243,12 @@ def stop_on_batch_end_if_needed(*args, **kwargs):
# visualize additional training results
confusion_matrix_path = os.path.join(local_artifacts_dir, "confusion_matrix_normalized.png")
if os.path.exists(confusion_matrix_path):
remote_confusion_matrix_path = os.path.join(remote_images_path, "confusion_matrix_normalized.png")
tf_confusion_matrix_info = api.file.upload(team_id, confusion_matrix_path, remote_confusion_matrix_path)
remote_confusion_matrix_path = os.path.join(
remote_images_path, "confusion_matrix_normalized.png"
)
tf_confusion_matrix_info = api.file.upload(
team_id, confusion_matrix_path, remote_confusion_matrix_path
)
additional_gallery.append(tf_confusion_matrix_info.full_storage_url)
additional_gallery_f.show()
pr_curve_path = os.path.join(local_artifacts_dir, "PR_curve.png")
Expand All @@ -1240,18 +1269,24 @@ def stop_on_batch_end_if_needed(*args, **kwargs):
pose_f1_curve_path = os.path.join(local_artifacts_dir, "PoseF1_curve.png")
if os.path.exists(pose_f1_curve_path):
remote_pose_f1_curve_path = os.path.join(remote_images_path, "PoseF1_curve.png")
tf_pose_f1_curve_info = api.file.upload(team_id, pose_f1_curve_path, remote_pose_f1_curve_path)
tf_pose_f1_curve_info = api.file.upload(
team_id, pose_f1_curve_path, remote_pose_f1_curve_path
)
additional_gallery.append(tf_pose_f1_curve_info.full_storage_url)
mask_f1_curve_path = os.path.join(local_artifacts_dir, "MaskF1_curve.png")
if os.path.exists(mask_f1_curve_path):
remote_mask_f1_curve_path = os.path.join(remote_images_path, "MaskF1_curve.png")
tf_mask_f1_curve_info = api.file.upload(team_id, mask_f1_curve_path, remote_mask_f1_curve_path)
tf_mask_f1_curve_info = api.file.upload(
team_id, mask_f1_curve_path, remote_mask_f1_curve_path
)
additional_gallery.append(tf_mask_f1_curve_info.full_storage_url)

# rename best checkpoint file
results = pd.read_csv(watch_file)
results.columns = [col.replace(" ", "") for col in results.columns]
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (0.9 * results["metrics/mAP50-95(B)"])
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (
0.9 * results["metrics/mAP50-95(B)"]
)
print("Final results:")
print(results)
best_epoch = results["fitness"].idxmax()
Expand All @@ -1263,7 +1298,10 @@ def stop_on_batch_end_if_needed(*args, **kwargs):
# add geometry config to saved weights for pose estimation task
if task_type == "pose estimation":
for obj_class in project_meta.obj_classes:
if obj_class.geometry_type.geometry_name() == "graph" and obj_class.name in selected_classes:
if (
obj_class.geometry_type.geometry_name() == "graph"
and obj_class.name in selected_classes
):
geometry_config = obj_class.geometry_config
break
weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
Expand Down Expand Up @@ -1422,19 +1460,19 @@ def auto_train(request: Request):
card_train_progress.unlock()
card_train_progress.uncollapse()

if sly.is_production():
local_dir = g.root_source_path
else:
local_dir = g.app_root_directory
local_dir = g.root_model_checkpoint_dir
if task_type == "object detection":
necessary_geometries = ["rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "detect", "train")
checkpoint_dir = os.path.join(local_dir, "detect")
local_artifacts_dir = os.path.join(local_dir, "detect", "train")
elif task_type == "pose estimation":
necessary_geometries = ["graph", "rectangle"]
local_artifacts_dir = os.path.join(local_dir, "runs", "pose", "train")
checkpoint_dir = os.path.join(local_dir, "pose")
local_artifacts_dir = os.path.join(local_dir, "pose", "train")
elif task_type == "instance segmentation":
necessary_geometries = ["bitmap", "polygon"]
local_artifacts_dir = os.path.join(local_dir, "runs", "segment", "train")
checkpoint_dir = os.path.join(local_dir, "segment")
local_artifacts_dir = os.path.join(local_dir, "segment", "train")

sly.logger.info(f"Local artifacts dir: {local_artifacts_dir}")

Expand Down Expand Up @@ -1473,10 +1511,15 @@ def auto_train(request: Request):
if task_type != "object detection":
unnecessary_classes = []
for cls in project_meta.obj_classes:
if cls.name in selected_classes and cls.geometry_type.geometry_name() not in necessary_geometries:
if (
cls.name in selected_classes
and cls.geometry_type.geometry_name() not in necessary_geometries
):
unnecessary_classes.append(cls.name)
if len(unnecessary_classes) > 0:
sly.Project.remove_classes(g.project_dir, classes_to_remove=unnecessary_classes, inplace=True)
sly.Project.remove_classes(
g.project_dir, classes_to_remove=unnecessary_classes, inplace=True
)
# transfer project to detection task if necessary
if task_type == "object detection":
sly.Project.to_detection_task(g.project_dir, inplace=True)
Expand Down Expand Up @@ -1519,7 +1562,9 @@ def download_monitor(monitor, api: sly.Api, progress: sly.Progress):
model_filename = selected_model.lower() + ".pt"
pretrained = True
weights_dst_path = os.path.join(g.app_data_dir, model_filename)
weights_url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
weights_url = (
f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{model_filename}"
)
with urlopen(weights_url) as file:
weights_size = file.length

Expand Down Expand Up @@ -1638,10 +1683,16 @@ def on_results_file_changed(filepath, pbar):
# visualize train batch
batch = f"train_batch{x}.jpg"
local_train_batches_path = os.path.join(local_artifacts_dir, batch)
if os.path.exists(local_train_batches_path) and batch not in plotted_train_batches and x < 10:
if (
os.path.exists(local_train_batches_path)
and batch not in plotted_train_batches
and x < 10
):
plotted_train_batches.append(batch)
remote_train_batches_path = os.path.join(remote_images_path, batch)
tf_train_batches_info = api.file.upload(team_id, local_train_batches_path, remote_train_batches_path)
tf_train_batches_info = api.file.upload(
team_id, local_train_batches_path, remote_train_batches_path
)
train_batches_gallery.append(tf_train_batches_info.full_storage_url)
if x == 0:
train_batches_gallery_f.show()
Expand Down Expand Up @@ -1685,6 +1736,7 @@ def train_batch_watcher_func():

model.train(
data=data_path,
project=checkpoint_dir,
epochs=state.get("n_epochs", n_epochs_input.get_value()),
patience=state.get("patience", patience_input.get_value()),
batch=state.get("batch_size", batch_size_input.get_value()),
Expand Down Expand Up @@ -1747,8 +1799,12 @@ def train_batch_watcher_func():
# visualize additional training results
confusion_matrix_path = os.path.join(local_artifacts_dir, "confusion_matrix_normalized.png")
if os.path.exists(confusion_matrix_path):
remote_confusion_matrix_path = os.path.join(remote_images_path, "confusion_matrix_normalized.png")
tf_confusion_matrix_info = api.file.upload(team_id, confusion_matrix_path, remote_confusion_matrix_path)
remote_confusion_matrix_path = os.path.join(
remote_images_path, "confusion_matrix_normalized.png"
)
tf_confusion_matrix_info = api.file.upload(
team_id, confusion_matrix_path, remote_confusion_matrix_path
)
additional_gallery.append(tf_confusion_matrix_info.full_storage_url)
additional_gallery_f.show()
pr_curve_path = os.path.join(local_artifacts_dir, "PR_curve.png")
Expand All @@ -1769,18 +1825,24 @@ def train_batch_watcher_func():
pose_f1_curve_path = os.path.join(local_artifacts_dir, "PoseF1_curve.png")
if os.path.exists(pose_f1_curve_path):
remote_pose_f1_curve_path = os.path.join(remote_images_path, "PoseF1_curve.png")
tf_pose_f1_curve_info = api.file.upload(team_id, pose_f1_curve_path, remote_pose_f1_curve_path)
tf_pose_f1_curve_info = api.file.upload(
team_id, pose_f1_curve_path, remote_pose_f1_curve_path
)
additional_gallery.append(tf_pose_f1_curve_info.full_storage_url)
mask_f1_curve_path = os.path.join(local_artifacts_dir, "MaskF1_curve.png")
if os.path.exists(mask_f1_curve_path):
remote_mask_f1_curve_path = os.path.join(remote_images_path, "MaskF1_curve.png")
tf_mask_f1_curve_info = api.file.upload(team_id, mask_f1_curve_path, remote_mask_f1_curve_path)
tf_mask_f1_curve_info = api.file.upload(
team_id, mask_f1_curve_path, remote_mask_f1_curve_path
)
additional_gallery.append(tf_mask_f1_curve_info.full_storage_url)

# rename best checkpoint file
results = pd.read_csv(watch_file)
results.columns = [col.replace(" ", "") for col in results.columns]
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (0.9 * results["metrics/mAP50-95(B)"])
results["fitness"] = (0.1 * results["metrics/mAP50(B)"]) + (
0.9 * results["metrics/mAP50-95(B)"]
)
print("Final results:")
print(results)
best_epoch = results["fitness"].idxmax()
Expand All @@ -1792,7 +1854,10 @@ def train_batch_watcher_func():
# add geometry config to saved weights for pose estimation task
if task_type == "pose estimation":
for obj_class in project_meta.obj_classes:
if obj_class.geometry_type.geometry_name() == "graph" and obj_class.name in selected_classes:
if (
obj_class.geometry_type.geometry_name() == "graph"
and obj_class.name in selected_classes
):
geometry_config = obj_class.geometry_config
break
weights_filepath = os.path.join(local_artifacts_dir, "weights", best_filename)
Expand All @@ -1807,7 +1872,9 @@ def train_batch_watcher_func():
print(app_url, file=text_file)

# upload training artifacts to team files
remote_artifacts_dir = os.path.join("/yolov8_train", task_type, project_info.name, str(g.app_session_id))
remote_artifacts_dir = os.path.join(
"/yolov8_train", task_type, project_info.name, str(g.app_session_id)
)

def upload_monitor(monitor, api: sly.Api, progress: sly.Progress):
value = monitor.bytes_read
Expand Down

0 comments on commit 6d52984

Please sign in to comment.