Skip to content

Commit

Permalink
Update .gitignore and refactor training code to use work_dir; enable …
Browse files Browse the repository at this point in the history
…TensorBoard and update load function
  • Loading branch information
cxnt committed Nov 25, 2024
1 parent 9951b30 commit 03e4fb4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,6 @@ temp/
imgaug.json
supervisely_integration/train/data
supervisely_integration/train/output
supervisely_integration/train/work_dir
supervisely
app_data
7 changes: 4 additions & 3 deletions supervisely_integration/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@

app_options = {
"enable_device_selector": False,
"enable_tensorboard": True,
"debug": {
"download_model": False,
"download_project": False,
},
}

current_file_dir = os.path.dirname(os.path.abspath(__file__))
work_dir = os.path.join(current_file_dir, "output")
work_dir = os.path.join(current_file_dir, "work_dir")
train = TrainApp("rt-detr", models, hyperparameters_path, app_options, work_dir)

inference_settings = {"confidence_threshold": 0.4}
Expand All @@ -59,11 +60,11 @@
# local_file_path = os.path.join(work_dir, "app_config.json")
# api.file.download(file, local_file_path)
# app_config = load_json_file(local_file_path)
# train.gui.load_from_config(app_config)
# train.gui.load_from_state(app_config)


# Debug
# utils.load_from_config(train, hyperparameters_path)
utils.load_from_state(train, hyperparameters_path)


@train.start
Expand Down
6 changes: 4 additions & 2 deletions supervisely_integration/train/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from multiprocessing import cpu_count

from supervisely.nn.training.train_app import TrainApp


def get_num_workers(batch_size: int):
num_workers = min(batch_size, 8, cpu_count())
return num_workers


# Debug
def load_from_config(train, hyperparameters_path: str):
def load_from_state(train: TrainApp, hyperparameters_path: str):
with open(hyperparameters_path, "r") as f:
hyper_params = f.read()

Expand Down Expand Up @@ -46,4 +48,4 @@ def load_from_config(train, hyperparameters_path: str):
},
}

train.gui.load_from_config(app_config)
train.gui.load_from_state(app_config)

0 comments on commit 03e4fb4

Please sign in to comment.