Skip to content

Commit

Permalink
fix skip classes with team color
Browse files Browse the repository at this point in the history
  • Loading branch information
phinik committed Jan 15, 2024
1 parent b580a77 commit 0dbd5af
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions yoeo/scripts/createYOEOLabelsFromTORSO-21.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@
from tqdm import tqdm


# Available classes for YOEO
CLASSES = {
'bb_classes': ['ball', 'goalpost', 'robot'],
'bb_classes_with_robot_colors': ['ball', 'goalpost', 'robot_blue', 'robot_red', 'robot_unknown'],
'segmentation_classes': ['background', 'lines', 'field'],
'skip_classes': ['obstacle', 'L-Intersection', 'X-Intersection', 'T-Intersection'],
}


def range_limited_float_type_0_to_1(arg):
"""Type function for argparse - a float within some predefined bounds
Derived from 'https://stackoverflow.com/questions/55324449/how-to-specify-a-minimum-or-maximum-float-value-with-argparse/55410582#55410582'.
Expand All @@ -41,6 +32,13 @@ def range_limited_float_type_0_to_1(arg):
parser.add_argument("--robots-with-team-colors", action="store_true", help="The robot class will be subdivided into subclasses, one for each team color (currently either 'blue', 'red' or 'unknown').")
args = parser.parse_args()

# Available classes for YOEO
CLASSES = {
'bb_classes': ['ball', 'goalpost', 'robot'] if args.robots_with_team_colors else ['ball', 'goalpost', 'robot_blue', 'robot_red', 'robot_unknown'],
'segmentation_classes': ['background', 'lines', 'field'],
'skip_classes': ['obstacle', 'L-Intersection', 'X-Intersection', 'T-Intersection'],
}

# Remove skipped classes from CLASSES list
for skip_class in args.skip_classes:
if skip_class in CLASSES['bb_classes']:
Expand Down Expand Up @@ -124,13 +122,18 @@ def range_limited_float_type_0_to_1(arg):
annotations = []

for annotation in image_data['annotations']:
class_name = annotation['type']

if args.robots_with_team_colors and class_name == 'robot':
class_name += f"_{annotation['color']}"

# Skip annotations, if is not a bounding box or should be skipped or is blurred or concealed and user chooses to skip them
if (annotation['type'] in CLASSES['segmentation_classes'] or # Handled by segmentations
annotation['type'] in CLASSES['skip_classes'] or # Skip this annotation class
if (class_name in CLASSES['segmentation_classes'] or # Handled by segmentations
class_name in CLASSES['skip_classes'] or # Skip this annotation class
(args.skip_blurred and annotation.get('blurred', False)) or
(args.skip_concealed and annotation.get('concealed', False))):
continue
elif annotation['type'] in CLASSES['bb_classes']: # Handle bounding boxes
elif class_name in CLASSES['bb_classes']: # Handle bounding boxes
if annotation['in_image']: # If annotation is not in image, do nothing
min_x = min(map(lambda x: x[0], annotation['vector']))
max_x = max(map(lambda x: x[0], annotation['vector']))
Expand All @@ -148,21 +151,10 @@ def range_limited_float_type_0_to_1(arg):
relative_center_y = center_y / img_height

# Derive classID from index in predefined classes
if not args.robots_with_team_colors:
classID = CLASSES['bb_classes'].index(annotation['type'])
else:
class_name = annotation['type']

# If the annotation contains a robot, the team color has to be appended to the annotation type
# to get the full class name.
if class_name == 'robot':
class_name += f"_{annotation['color']}"

classID = CLASSES['bb_classes_with_robot_colors'].index(class_name)

classID = CLASSES['bb_classes'].index(class_name)
annotations.append(f"{classID} {relative_center_x} {relative_center_y} {relative_annotation_width} {relative_annotation_height}")
else:
print(f"The annotation type '{annotation['type']}' is not supported. Image: '{img_name_with_extension}'")
print(f"The annotation type '{class_name}' is not supported. Image: '{img_name_with_extension}'")

# Store bounding box annotations in .txt file
with open(os.path.join(labels_dir, img_name_without_extension + ".txt"), "w") as output:
Expand All @@ -184,7 +176,7 @@ def range_limited_float_type_0_to_1(arg):
# The names file contains the class names of bb detections and segmentations
names_path = os.path.join(destination_dir, "yoeo_names.yaml")
names = {
'detection': CLASSES['bb_classes'] if not args.robots_with_team_colors else CLASSES['bb_classes_with_robot_colors'],
'detection': CLASSES['bb_classes'],
'segmentation': CLASSES["segmentation_classes"],
}
with open(names_path, "w") as names_file:
Expand Down

0 comments on commit 0dbd5af

Please sign in to comment.