From 3b815c56daec0cd8a356f2e8edf344434111e90c Mon Sep 17 00:00:00 2001 From: max-unfinity Date: Mon, 25 Nov 2024 19:59:52 +0000 Subject: [PATCH] added custom models training --- .gitignore | 1 + .vscode/launch.json | 2 +- supervisely_integration/train/main.py | 38 ++++++++++++++++++--------- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index 5b984bc1..ad6b9511 100644 --- a/.gitignore +++ b/.gitignore @@ -178,6 +178,7 @@ rtdetr_pytorch/dataset/ rtdetrv2_pytorch/output/ rtdetrv2_pytorch/dataset/ rtdetrv2_pytorch/configs/rtdetrv2/custom_config.yml +rtdetrv2_pytorch/configs/rtdetrv2/model_config.yml supervisely_integration/train/output supervisely_integration/train/data supervisely diff --git a/.vscode/launch.json b/.vscode/launch.json index f57c80ab..f72573cd 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,7 +15,7 @@ "--ws", "websockets", "--app-dir", - "supervisely_integration/train" + "supervisely_integration/train", ], "justMyCode": false, "env": { diff --git a/supervisely_integration/train/main.py b/supervisely_integration/train/main.py index 71814c1f..b8022c16 100644 --- a/supervisely_integration/train/main.py +++ b/supervisely_integration/train/main.py @@ -1,7 +1,10 @@ +import os +import shutil import sys sys.path.insert(0, "rtdetrv2_pytorch") import yaml import supervisely as sly +from supervisely.nn.utils import ModelSource from supervisely.nn.training.train_app import TrainApp # from supervisely_integration.train.serve import RTDETRModelMB from supervisely_integration.train.sly2coco import get_coco_annotations @@ -16,7 +19,7 @@ "RT-DETRv2", f"{base_path}/models_v2.json", f"{base_path}/hyperparameters.yaml", - f"{base_path}/app_options", + f"{base_path}/app_options.yaml", ) # train.register_inference_class(RTDETRModelMB) @@ -24,29 +27,28 @@ @train.start def start_training(): - checkpoint = train.model_files["checkpoint"] train_ann_path, val_ann_path = convert_data() - custom_config = prepare_config() - custom_config_path = "rtdetrv2_pytorch/configs/rtdetrv2/custom_config.yml" - with open(custom_config_path, 'w') as f: - yaml.dump(custom_config, f) + checkpoint = train.model_files["checkpoint"] + custom_config_path = prepare_config() cfg = YAMLConfig( custom_config_path, tuning=checkpoint, ) output_dir = cfg.output_dir - tensorboard_logs = f"{output_dir}/summary" + os.makedirs(output_dir, exist_ok=True) + # dump resolved config model_config_path = f"{output_dir}/model_config.yml" with open(model_config_path, 'w') as f: yaml.dump(cfg.yaml_cfg, f) # train + tensorboard_logs = f"{output_dir}/summary" train.start_tensorboard(tensorboard_logs) solver = DetSolver(cfg) solver.fit() - model_name = train.model_name - # Gather experiment info + # gather experiment info experiment_info = { - "model_name": model_name, + "task_type": sly.nn.TaskType.OBJECT_DETECTION, + "model_name": train.model_name, "model_files": {"config": model_config_path}, "checkpoints": output_dir, "best_checkpoint": "best.pth", @@ -71,8 +73,16 @@ def convert_data(): def prepare_config(): + rtdetrv2_config_dir = "rtdetrv2_pytorch/configs/rtdetrv2" + if train.model_source == ModelSource.CUSTOM: + config_path = train.model_files["config"] + config = os.path.basename(config_path) + shutil.move(config_path, f"{rtdetrv2_config_dir}/{config}") + else: + config = train.model_files["config"] + custom_config = train.hyperparameters - custom_config["__include__"] = [train.model_files["config"]] + custom_config["__include__"] = [config] custom_config["remap_mscoco_category"] = False custom_config["num_classes"] = train.num_classes custom_config["print_freq"] = 50 @@ -91,4 +101,8 @@ def prepare_config(): custom_config["train_dataloader"]["num_workers"] = get_num_workers(custom_config["batch_size"]) custom_config["val_dataloader"]["num_workers"] = get_num_workers(custom_config["batch_size"]) - return custom_config + custom_config_path = f"{rtdetrv2_config_dir}/custom_config.yml" + with open(custom_config_path, 'w') as f: + yaml.dump(custom_config, f) + + return custom_config_path