Skip to content

Commit

Permalink
added custom models training
Browse files Browse the repository at this point in the history
  • Loading branch information
max-unfinity committed Nov 25, 2024
1 parent b226774 commit 3b815c5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ rtdetr_pytorch/dataset/
rtdetrv2_pytorch/output/
rtdetrv2_pytorch/dataset/
rtdetrv2_pytorch/configs/rtdetrv2/custom_config.yml
rtdetrv2_pytorch/configs/rtdetrv2/model_config.yml
supervisely_integration/train/output
supervisely_integration/train/data
supervisely
Expand Down
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"--ws",
"websockets",
"--app-dir",
"supervisely_integration/train"
"supervisely_integration/train",
],
"justMyCode": false,
"env": {
Expand Down
38 changes: 26 additions & 12 deletions supervisely_integration/train/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import shutil
import sys
sys.path.insert(0, "rtdetrv2_pytorch")
import yaml
import supervisely as sly
from supervisely.nn.utils import ModelSource
from supervisely.nn.training.train_app import TrainApp
# from supervisely_integration.train.serve import RTDETRModelMB
from supervisely_integration.train.sly2coco import get_coco_annotations
Expand All @@ -16,37 +19,36 @@
"RT-DETRv2",
f"{base_path}/models_v2.json",
f"{base_path}/hyperparameters.yaml",
f"{base_path}/app_options",
f"{base_path}/app_options.yaml",
)

# train.register_inference_class(RTDETRModelMB)


@train.start
def start_training():
checkpoint = train.model_files["checkpoint"]
train_ann_path, val_ann_path = convert_data()
custom_config = prepare_config()
custom_config_path = "rtdetrv2_pytorch/configs/rtdetrv2/custom_config.yml"
with open(custom_config_path, 'w') as f:
yaml.dump(custom_config, f)
checkpoint = train.model_files["checkpoint"]
custom_config_path = prepare_config()
cfg = YAMLConfig(
custom_config_path,
tuning=checkpoint,
)
output_dir = cfg.output_dir
tensorboard_logs = f"{output_dir}/summary"
os.makedirs(output_dir, exist_ok=True)
# dump resolved config
model_config_path = f"{output_dir}/model_config.yml"
with open(model_config_path, 'w') as f:
yaml.dump(cfg.yaml_cfg, f)
# train
tensorboard_logs = f"{output_dir}/summary"
train.start_tensorboard(tensorboard_logs)
solver = DetSolver(cfg)
solver.fit()
model_name = train.model_name
# Gather experiment info
# gather experiment info
experiment_info = {
"model_name": model_name,
"task_type": sly.nn.TaskType.OBJECT_DETECTION,
"model_name": train.model_name,
"model_files": {"config": model_config_path},
"checkpoints": output_dir,
"best_checkpoint": "best.pth",
Expand All @@ -71,8 +73,16 @@ def convert_data():


def prepare_config():
rtdetrv2_config_dir = "rtdetrv2_pytorch/configs/rtdetrv2"
if train.model_source == ModelSource.CUSTOM:
config_path = train.model_files["config"]
config = os.path.basename(config_path)
shutil.move(config_path, f"{rtdetrv2_config_dir}/{config}")
else:
config = train.model_files["config"]

custom_config = train.hyperparameters
custom_config["__include__"] = [train.model_files["config"]]
custom_config["__include__"] = [config]
custom_config["remap_mscoco_category"] = False
custom_config["num_classes"] = train.num_classes
custom_config["print_freq"] = 50
Expand All @@ -91,4 +101,8 @@ def prepare_config():
custom_config["train_dataloader"]["num_workers"] = get_num_workers(custom_config["batch_size"])
custom_config["val_dataloader"]["num_workers"] = get_num_workers(custom_config["batch_size"])

return custom_config
custom_config_path = f"{rtdetrv2_config_dir}/custom_config.yml"
with open(custom_config_path, 'w') as f:
yaml.dump(custom_config, f)

return custom_config_path

0 comments on commit 3b815c5

Please sign in to comment.