Skip to content

Commit

Permalink
add Negative Data training and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon committed Jan 18, 2024
1 parent bdecd3c commit 5f34554
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
33 changes: 27 additions & 6 deletions yoeo/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@
from terminaltables import AsciiTable

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, ConcatDataset
from torch.autograd import Variable

from yoeo.models import load_model
from yoeo.utils.utils import load_classes, ap_per_class, get_batch_statistics, non_max_suppression, to_cpu, xywh2xyxy, \
print_environment_info, seg_iou
from yoeo.utils.datasets import ListDataset
from yoeo.utils.datasets import ListDataset, NegativeDataset
from yoeo.utils.transforms import DEFAULT_TRANSFORMS
from yoeo.utils.parse_config import parse_data_config


def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_size=8, img_size=416,
def evaluate_model_file(model_path, weights_path, img_path, class_names, negative_img_dir="", negative_data_fraction=0.0, batch_size=8, img_size=416,
n_cpu=8, iou_thres=0.5, conf_thres=0.5, nms_thres=0.5, verbose=True,
robot_class_ids: Optional[List[int]] = None):
"""Evaluate model on validation dataset.
Expand All @@ -34,6 +34,10 @@ def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_s
:type img_path: str
:param class_names: Dict containing detection and segmentation class names
:type class_names: Dict
:param negative_img_dir: Path to negative image folder, defaults to ""
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data, defaults to 0.0
:type negative_data_fraction: float
:param batch_size: Size of each image batch, defaults to 8
:type batch_size: int, optional
:param img_size: Size of each image dimension for yolo, defaults to 416
Expand All @@ -53,7 +57,7 @@ def evaluate_model_file(model_path, weights_path, img_path, class_names, batch_s
:return: Returns precision, recall, AP, f1, ap_class
"""
dataloader = _create_validation_data_loader(
img_path, batch_size, img_size, n_cpu)
img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu)
model = load_model(model_path, weights_path)
metrics_output, seg_class_ious = _evaluate(
model,
Expand Down Expand Up @@ -180,12 +184,16 @@ def seg_iou_mean_without_nan(seg_iou: List[float]) -> np.ndarray:
return yolo_metrics_output, seg_class_ious


def _create_validation_data_loader(img_path, batch_size, img_size, n_cpu):
def _create_validation_data_loader(img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu):
"""
Creates a DataLoader for validation.
:param img_path: Path to file containing all paths to validation images.
:type img_path: str
:param negative_img_dir: Path to negative image folder
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data
:type negative_data_fraction: float
:param batch_size: Size of each image batch
:type batch_size: int
:param img_size: Size of each image dimension for yolo
Expand All @@ -196,8 +204,19 @@ def _create_validation_data_loader(img_path, batch_size, img_size, n_cpu):
:rtype: DataLoader
"""
dataset = ListDataset(img_path, img_size=img_size, multiscale=False, transform=DEFAULT_TRANSFORMS)

Check warning on line 207 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
dataset_len = len(dataset)
negative_dataset_len = int(negative_data_fraction*dataset_len)

Check warning on line 210 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
negative_dataset = NegativeDataset(
negative_img_dir,
img_size=img_size,
negative_dataset_max_len=negative_dataset_len)

Check warning on line 215 in yoeo/test.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
concat_dataset = ConcatDataset([dataset, negative_dataset])

dataloader = DataLoader(
dataset,
concat_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_cpu,
Expand All @@ -214,6 +233,8 @@ def run():
parser.add_argument("-w", "--weights", type=str, default="weights/yoeo.pth",
help="Path to weights or checkpoint file (.weights or .pth)")
parser.add_argument("-d", "--data", type=str, default="config/torso.data", help="Path to data config file (.data)")
parser.add_argument("-n", "--negative_data_dir", default='', type=str, help="Path to negative data directory")
parser.add_argument("--negative_data_fraction", default=0, type=float, help="Fraction of negative data relative to positive data (default=0.0)")
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Size of each image batch")
parser.add_argument("-v", "--verbose", action='store_true', help="Makes the validation more verbose")
parser.add_argument("--img_size", type=int, default=416, help="Size of each image dimension for yolo")
Expand Down
31 changes: 26 additions & 5 deletions yoeo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, ConcatDataset
import torch.optim as optim
from torch.autograd import Variable

Expand All @@ -18,7 +18,7 @@
from yoeo.models import load_model
from yoeo.utils.logger import Logger
from yoeo.utils.utils import to_cpu, load_classes, print_environment_info, provide_determinism, worker_seed_set
from yoeo.utils.datasets import ListDataset
from yoeo.utils.datasets import ListDataset, NegativeDataset
from yoeo.utils.augmentations import AUGMENTATION_TRANSFORMS
from yoeo.utils.transforms import DEFAULT_TRANSFORMS

Check failure on line 23 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

'yoeo.utils.transforms.DEFAULT_TRANSFORMS' imported but unused
from yoeo.utils.parse_config import parse_data_config
Expand All @@ -30,11 +30,15 @@
from torchsummary import summary


def _create_data_loader(img_path, batch_size, img_size, n_cpu, multiscale_training=False):
def _create_data_loader(img_path, negative_img_dir, negative_data_fraction, batch_size, img_size, n_cpu, multiscale_training=False):
"""Creates a DataLoader for training.
:param img_path: Path to file containing all paths to training images.
:type img_path: str
:param negative_img_dir: Path to negative image folder
:type negative_img_dir: str
:param negative_data_fraction: Fraction of negative data relative to positive data
:type negative_data_fraction: float
:param batch_size: Size of each image batch
:type batch_size: int
:param img_size: Size of each image dimension for yolo
Expand All @@ -51,8 +55,20 @@ def _create_data_loader(img_path, batch_size, img_size, n_cpu, multiscale_traini
img_size=img_size,
multiscale=multiscale_training,
transform=AUGMENTATION_TRANSFORMS)

Check warning on line 58 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
dataset_len = len(dataset)
negative_dataset_len = int(negative_data_fraction*dataset_len)

Check warning on line 61 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
negative_dataset = NegativeDataset(
negative_img_dir,
img_size=img_size,
transform=AUGMENTATION_TRANSFORMS,
negative_dataset_max_len=negative_dataset_len)

Check warning on line 67 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

blank line contains whitespace
concat_dataset = ConcatDataset([dataset, negative_dataset])

dataloader = DataLoader(
dataset,
concat_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_cpu,
Expand All @@ -61,12 +77,13 @@ def _create_data_loader(img_path, batch_size, img_size, n_cpu, multiscale_traini
worker_init_fn=worker_seed_set)
return dataloader


def run():

Check failure on line 80 in yoeo/train.py

View workflow job for this annotation

GitHub Actions / linter

expected 2 blank lines, found 1
print_environment_info()
parser = argparse.ArgumentParser(description="Trains the YOLO model.")
parser.add_argument("-m", "--model", type=str, default="config/yoeo.cfg", help="Path to model definition file (.cfg)")
parser.add_argument("-d", "--data", type=str, default="config/torso.data", help="Path to data config file (.data)")
parser.add_argument("-n", "--negative_data_dir", default='', type=str, help="Path to negative data directory")
parser.add_argument("--negative_data_fraction", default=0, type=float, help="Fraction of negative data relative to positive data (default=0.0)")
parser.add_argument("-e", "--epochs", type=int, default=300, help="Number of epochs")
parser.add_argument("-v", "--verbose", action='store_true', help="Makes the training more verbose")
parser.add_argument("--n_cpu", type=int, default=8, help="Number of cpu threads to use during batch generation")
Expand Down Expand Up @@ -128,6 +145,8 @@ def run():
# Load training dataloader
dataloader = _create_data_loader(
train_path,
args.negative_data_dir,
args.negative_data_fraction,
mini_batch_size,
model.hyperparams['height'],
args.n_cpu,
Expand All @@ -136,6 +155,8 @@ def run():
# Load validation dataloader
validation_dataloader = _create_validation_data_loader(
valid_path,
args.negative_data_dir,
args.negative_data_fraction,
mini_batch_size,
model.hyperparams['height'],
args.n_cpu)
Expand Down
39 changes: 38 additions & 1 deletion yoeo/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,44 @@ def __getitem__(self, index):

def __len__(self):
return len(self.files)


class NegativeDataset(Dataset):
def __init__(self, folder_path, img_size=416, transform=None,negative_dataset_max_len=0):
self.img_size = img_size
self.transform = transform
self.negative_dataset_max_len = negative_dataset_max_len
if folder_path:
self.files = sorted(glob.glob("%s/*.*" % folder_path))[:self.negative_dataset_max_len]
else:
self.files = []

def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
img = np.array(
Image.open(img_path).convert('RGB'),
dtype=np.uint8)

# Label Placeholder
bb_targets = np.zeros((1, 5))
mask_targets = np.zeros_like(img)

# -----------
# Transform
# -----------
if self.transform:
try:
img, bb_targets, mask_targets = self.transform(
(img, bb_targets, mask_targets)
)
except Exception as e:
print(f"Could not apply transform.")
raise e

return img_path, img, bb_targets, mask_targets

def __len__(self):
return len(self.files)


class ListDataset(Dataset):
Expand Down Expand Up @@ -138,7 +176,6 @@ def __getitem__(self, index):
except Exception as e:
print(f"Could not apply transform.")
raise e
return

return img_path, img, bb_targets, mask_targets

Expand Down

0 comments on commit 5f34554

Please sign in to comment.