Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aimspot committed Feb 14, 2024
1 parent 9c5b811 commit 9fc4713
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
21 changes: 10 additions & 11 deletions ODRS/train_utils/train_model/scripts/yolov5_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# opt.name = 'exp'
# train.main(opt)

# ----------------------------------------------OLD version

def train_V5(IMG_SIZE, BATCH_SIZE, EPOCHS, CONFIG_PATH, MODEL_PATH, GPU_COUNT, SELECT_GPU):
"""
Runs yolov5 training using the parameters specified in the config.
:param IMG_SIZE: Size of input images as integer or w,h.
:param BATCH_SIZE: Batch size for training.
:param EPOCHS: Number of epochs to train for.
Expand All @@ -33,15 +33,14 @@ def train_V5(IMG_SIZE, BATCH_SIZE, EPOCHS, CONFIG_PATH, MODEL_PATH, GPU_COUNT, S
train_script_path = str(Path(file.parents[1]) / 'models' / 'yolov5' / 'train.py')

full_command = (
command +
f" {train_script_path}" +
f" --img {IMG_SIZE}" +
f" --batch {BATCH_SIZE}" +
f" --epochs {EPOCHS}" +
f" --data {CONFIG_PATH}" +
f" --cfg {MODEL_PATH}" +
f" --device {SELECT_GPU}" +
f" --project {CONFIG_PATH.parent}" +
f"{command} {train_script_path}"
f" --img {IMG_SIZE}"
f" --batch {BATCH_SIZE}"
f" --epochs {EPOCHS}"
f" --data {CONFIG_PATH}"
f" --cfg {MODEL_PATH}"
f" --device {SELECT_GPU}"
f" --project {CONFIG_PATH.parent}"
f" --name exp"
)
os.system(full_command)
7 changes: 5 additions & 2 deletions ODRS/utils/dataset_info.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import cv2
import os
import numpy as np
import sys
from loguru import logger
from pathlib import Path
import cv2
import numpy as np

project_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(os.path.dirname(project_dir)))

from ODRS.utils.ml_plot import plot_class_balance
from ODRS.utils.ml_utils import dumpCSV



def load_class_names(classes_file):
""" Загрузка названий классов из файла. """
with open(classes_file, 'r') as file:
Expand Down
3 changes: 2 additions & 1 deletion ODRS/utils/ml_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from pathlib import Path


def plot_class_balance(labels, output_path):
""" Построение и сохранение графика баланса классов с наклоненными метками и вывод среднего значения. """
class_counts = Counter(labels)
Expand All @@ -15,4 +16,4 @@ def plot_class_balance(labels, output_path):
plt.title('Class balance')
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(output_file)
plt.savefig(output_file)

0 comments on commit 9fc4713

Please sign in to comment.