Skip to content

Commit

Permalink
Merge pull request #3 from WildMeOrg/pipeline-v3
Browse files Browse the repository at this point in the history
Pipeline v3
  • Loading branch information
TanyaStere42 authored Feb 6, 2024
2 parents da6abe7 + cebf90e commit 5d71b00
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 75 deletions.
43 changes: 0 additions & 43 deletions dasd

This file was deleted.

81 changes: 66 additions & 15 deletions scoutbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def fetch(pull=False, config=None):
Raises:
AssertionError: If any model cannot be fetched.
"""
if config == 'v3':
if config in ['v3', 'v3-cls']:
loc.fetch(pull=pull, config=config)
else:
wic.fetch(pull=pull, config=None)
Expand Down Expand Up @@ -190,7 +190,17 @@ def pipeline(


def pipeline_v3(
filepath
filepath,
config,
batched_detection_model=None,
loc_thresh=0.45,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True

):
"""
Run the ML pipeline on a given image filepath and return the detections
Expand All @@ -211,31 +221,33 @@ def pipeline_v3(
"""

# Run Localizer
yolov8_model_path = loc.fetch(config='v3')

batched_detection_model = tile_batched.Yolov8DetectionModel(
model_path=yolov8_model_path,
confidence_threshold=0.45,
device='cuda:0'
)
if batched_detection_model is None:
yolov8_model_path = loc.fetch(config=config)

batched_detection_model = tile_batched.Yolov8DetectionModel(
model_path=yolov8_model_path,
confidence_threshold=loc_thresh,
device='cuda:0'
)

det_result = tile_batched.get_sliced_prediction_batched(
cv2.imread(filepath),
batched_detection_model,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
perform_standard_pred=perform_standard_pred,
postprocess_class_agnostic=postprocess_class_agnostic
)

# Postprocess detections for WIC
coco_prediction_list = []
for object_prediction in det_result.object_prediction_list:
coco_prediction_list.append(object_prediction.to_coco_prediction(image_id=None).json)

wic_score = max([item['score'] for item in coco_prediction_list])
wic_score = max([item['score'] for item in coco_prediction_list], default=0)

# Convert to output formats

Expand Down Expand Up @@ -401,6 +413,45 @@ def batch(
return wic_list, detects_list


def batch_v3(
filepaths,
config,
loc_thresh=0.45,
slice_height=512,
slice_width=512,
overlap_height_ratio=0.25,
overlap_width_ratio=0.25,
perform_standard_pred=False,
postprocess_class_agnostic=True
):
yolov8_model_path = loc.fetch(config=config)

batched_detection_model = tile_batched.Yolov8DetectionModel(
model_path=yolov8_model_path,
confidence_threshold=loc_thresh,
device='cuda:0'
)

wic_list = []
detects_list = []
for filepath in filepaths:
wic_, detects = pipeline_v3(filepath,
config,
batched_detection_model=batched_detection_model,
loc_thresh=loc_thresh,
slice_height=slice_height,
slice_width=slice_width,
overlap_height_ratio=overlap_height_ratio,
overlap_width_ratio=overlap_width_ratio,
perform_standard_pred=perform_standard_pred,
postprocess_class_agnostic=postprocess_class_agnostic
)
wic_list.append(wic_)
detects_list.append(detects)

return wic_list, detects_list


def example():
"""
Run the pipeline on an example image with the default configuration
Expand Down
21 changes: 20 additions & 1 deletion scoutbot/loc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,28 @@
],
},
'v3': {
'hash': None, # 9e001aa3c10d05ba8a269103d3d358ceeb7d6f3bcc5758c1be4405ff743e0e90 #'46cbbccf922552703a1fe8a756544e43'
'hash': None,
'name': 'yolov8.kaza.pt',
'path': join(PWD, 'models', 'yolo', 'yolov8.kaza.pt'),
'thresh': 0.45,
'slice_height': 512,
'slice_width': 512,
'overlap_height_ratio': 0.25,
'overlap_width_ratio': 0.25,
'perform_standard_pred': False,
'postprocess_class_agnostic': True
},
'v3-cls': {
'hash': None,
'name': 'yolov8-cls.kaza.pt',
'path': join(PWD, 'models', 'yolo', 'yolov8-cls.kaza.pt'),
'thresh': 0.45,
'slice_height': 512,
'slice_width': 512,
'overlap_height_ratio': 0.25,
'overlap_width_ratio': 0.25,
'perform_standard_pred': False,
'postprocess_class_agnostic': True
}
}
CONFIGS[None] = CONFIGS[DEFAULT_CONFIG]
Expand Down
51 changes: 37 additions & 14 deletions scoutbot/scoutbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pipeline_filepath_validator(ctx, param, value):
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
def fetch(config):
"""
Expand All @@ -45,7 +45,7 @@ def fetch(config):
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
@click.option(
'--output',
Expand Down Expand Up @@ -124,9 +124,18 @@ def pipeline(
agg_thresh /= 100.0
agg_nms_thresh /= 100.0

if config == 'v3':
if config in ['v3', 'v3-cls']:
wic_, detects = scoutbot.pipeline_v3(
filepath
filepath,
config,
loc_thresh=loc.CONFIGS[config]['thresh'],
slice_height=loc.CONFIGS[config]['slice_height'],
slice_width=loc.CONFIGS[config]['slice_width'],
overlap_height_ratio=loc.CONFIGS[config]['overlap_height_ratio'],
overlap_width_ratio=loc.CONFIGS[config]['overlap_width_ratio'],
perform_standard_pred=loc.CONFIGS[config]['perform_standard_pred'],
postprocess_class_agnostic=loc.CONFIGS[config]['postprocess_class_agnostic']

)
else:
wic_, detects = scoutbot.pipeline(
Expand Down Expand Up @@ -164,7 +173,7 @@ def pipeline(
'--config',
help='Which ML models to use for inference',
default=None,
type=click.Choice(['phase1', 'mvp', 'old', 'new']),
type=click.Choice(['phase1', 'mvp', 'old', 'new', 'v3', 'v3-cls']),
)
@click.option(
'--output',
Expand Down Expand Up @@ -260,15 +269,29 @@ def batch(

log.debug(f'Running batch on {len(filepaths)} files...')

wic_list, detects_list = scoutbot.batch(
filepaths,
config=config,
wic_thresh=wic_thresh,
loc_thresh=loc_thresh,
loc_nms_thresh=loc_nms_thresh,
agg_thresh=agg_thresh,
agg_nms_thresh=agg_nms_thresh,
)
if config in ['v3', 'v3-cls']:
wic_list, detects_list = scoutbot.batch_v3(
filepaths,
config,
loc_thresh=loc.CONFIGS[config]['thresh'],
slice_height=loc.CONFIGS[config]['slice_height'],
slice_width=loc.CONFIGS[config]['slice_width'],
overlap_height_ratio=loc.CONFIGS[config]['overlap_height_ratio'],
overlap_width_ratio=loc.CONFIGS[config]['overlap_width_ratio'],
perform_standard_pred=loc.CONFIGS[config]['perform_standard_pred'],
postprocess_class_agnostic=loc.CONFIGS[config]['postprocess_class_agnostic']

)
else:
wic_list, detects_list = scoutbot.batch(
filepaths,
config=config,
wic_thresh=wic_thresh,
loc_thresh=loc_thresh,
loc_nms_thresh=loc_nms_thresh,
agg_thresh=agg_thresh,
agg_nms_thresh=agg_nms_thresh,
)
results = zip(filepaths, wic_list, detects_list)

data = {}
Expand Down
4 changes: 2 additions & 2 deletions scoutbot/tile_batched/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def perform_inference(self, images: List[np.ndarray], batch_size=128):
batch_images = images[i:i + batch_size]
preds = self.model.predict(source=batch_images, verbose=False, device=self.device)

all_preds.extend(preds)
all_preds.extend(preds)

prediction_result = [
result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in all_preds
Expand Down Expand Up @@ -247,7 +247,7 @@ def __len__(self):


def slice_image(
image: Union[str],
image: Union[str, np.ndarray],
slice_height: Optional[int] = None,
slice_width: Optional[int] = None,
overlap_height_ratio: float = 0.2,
Expand Down

0 comments on commit 5d71b00

Please sign in to comment.