Skip to content

Commit

Permalink
Merge pull request #22 from supervisely-ecosystem/inf-video-id-method
Browse files Browse the repository at this point in the history
Inference Video ID method added
  • Loading branch information
qanelph authored Jun 9, 2022
2 parents 26abcc5 + db57df8 commit 1284449
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 8 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ffmpeg-python==0.2.0
133 changes: 133 additions & 0 deletions serve/src/sly_apply_nn_to_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import shutil
import time

import cv2
import ffmpeg

import supervisely as sly
from tqdm import tqdm


class InferenceVideoInterface:
def __init__(self, api, start_frame_index, frames_count, frames_direction, video_info, imgs_dir):
self.api = api

self.video_info = video_info
self.images_paths = []

self.start_frame_index = start_frame_index
self.frames_count = frames_count
self.frames_direction = frames_direction

self._video_fps = round(1 / self.video_info.frames_to_timecodes[1])

self._geometries = []
self._frames_indexes = []

self._add_frames_indexes()

self._frames_path = os.path.join(imgs_dir, f'video_inference_{video_info.id}_{time.time_ns()}', 'frames')
self._imgs_dir = imgs_dir

self._local_video_path = None

os.makedirs(self._frames_path, exist_ok=True)

# sly.logger.info(f'{self.__class__.__name__} initialized')

def _add_frames_indexes(self):
total_frames = self.video_info.frames_count
cur_index = self.start_frame_index

if self.frames_direction == 'forward':
end_point = cur_index + self.frames_count if cur_index + self.frames_count < total_frames else total_frames
self._frames_indexes = [curr_frame_index for curr_frame_index in range(cur_index, end_point, 1)]
else:
end_point = cur_index - self.frames_count if cur_index - self.frames_count > -1 else -1
self._frames_indexes = [curr_frame_index for curr_frame_index in range(cur_index, end_point, -1)]
self._frames_indexes = []

def _download_video_by_frames(self):
for index, frame_index in tqdm(enumerate(self._frames_indexes), desc='Downloading frames',
total=len(self._frames_indexes)):
frame_path = os.path.join(f"{self._frames_path}", f"frame{index:06d}.png")
self.images_paths.append(frame_path)

if os.path.isfile(frame_path):
continue

img_rgb = self.api.video.frame.download_np(self.video_info.id, frame_index)
cv2.imwrite(os.path.join(f"{self._frames_path}", f"frame{index:06d}.png"),
img_rgb) # save frame as PNG file

def _download_entire_video(self):
def videos_to_frames(video_path, frames_range=None):
def check_rotation(path_video_file):
# this returns meta-data of the video file in form of a dictionary
meta_dict = ffmpeg.probe(path_video_file)

# from the dictionary, meta_dict['streams'][0]['tags']['rotate'] is the key
# we are looking for
rotate_code = None
try:
if int(meta_dict['streams'][0]['tags']['rotate']) == 90:
rotate_code = cv2.ROTATE_90_CLOCKWISE
elif int(meta_dict['streams'][0]['tags']['rotate']) == 180:
rotate_code = cv2.ROTATE_180
elif int(meta_dict['streams'][0]['tags']['rotate']) == 270:
rotate_code = cv2.ROTATE_90_COUNTERCLOCKWISE
except Exception as ex:
pass

return rotate_code

def correct_rotation(frame, rotate_code):
return cv2.rotate(frame, rotate_code)

video_name = (video_path.split('/')[-1]).split('.mp4')[0]
# output_path = os.path.join(, f'converted_{time.time_ns()}_{video_name}')

vidcap = cv2.VideoCapture(video_path)
success, image = vidcap.read()
count = 0
rotateCode = check_rotation(video_path)

while success:
output_image_path = os.path.join(f"{self._frames_path}", f"frame{count:06d}.png")
if frames_range:
if frames_range[0] <= count <= frames_range[1]:
if rotateCode is not None:
image = correct_rotation(image, rotateCode)
cv2.imwrite(output_image_path, image) # save frame as PNG file
self.images_paths.append(output_image_path)
else:
if rotateCode is not None:
image = correct_rotation(image, rotateCode)
cv2.imwrite(output_image_path, image) # save frame as PNG file
self.images_paths.append(output_image_path)
success, image = vidcap.read()
count += 1

fps = vidcap.get(cv2.CAP_PROP_FPS)

return {'frames_path': self._frames_path, 'fps': fps, 'video_path': video_path}

self._local_video_path = os.path.join(self._imgs_dir, f'{time.time_ns()}_{self.video_info.name}')
self.api.video.download_path(self.video_info.id, self._local_video_path)
return videos_to_frames(self._local_video_path, [self.start_frame_index, self.start_frame_index + self.frames_count - 1])

def download_frames(self):
if self.frames_count > (self.video_info.frames_count * 0.3):
sly.logger.debug('Downloading entire video')
self._download_entire_video()
else:
sly.logger.debug('Downloading video frame by frame')
self._download_video_by_frames()

def __del__(self):
if os.path.isdir(self._frames_path):
shutil.rmtree(os.path.dirname(self._frames_path), ignore_errors=True)

if self._local_video_path is not None and os.path.isfile(self._local_video_path):
os.remove(self._local_video_path)
62 changes: 54 additions & 8 deletions serve/src/sly_serve.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import traceback

import supervisely as sly
import functools
import sly_globals as g
import utils
import os
import sly_apply_nn_to_video as nn_to_video


def send_error_data(func):
Expand All @@ -12,13 +15,32 @@ def wrapper(*args, **kwargs):
try:
value = func(*args, **kwargs)
except Exception as e:
sly.logger.error(f"Error while processing data: {e}")
request_id = kwargs["context"]["request_id"]
g.my_app.send_response(request_id, data={"error": repr(e)})
# raise e
try:
g.my_app.send_response(request_id, data={"error": repr(e)})
print(traceback.format_exc())
except Exception as ex:
sly.logger.exception(f"Cannot send error response: {ex}")
return value

return wrapper


def inference_images_dir(img_paths, context, state, app_logger):
annotations = []
for image_path in img_paths:
ann_json = utils.inference_image_path(image_path=image_path,
project_meta=g.meta,
context=context,
state=state,
app_logger=app_logger)
annotations.append(ann_json)
sly.fs.silent_remove(image_path)
return annotations


@g.my_app.callback("get_output_classes_and_tags")
@sly.timeit
def get_output_classes_and_tags(api: sly.Api, task_id, context, state, app_logger):
Expand All @@ -39,9 +61,11 @@ def get_custom_inference_settings(api: sly.Api, task_id, context, state, app_log
def get_session_info(api: sly.Api, task_id, context, state, app_logger):
info = {
"app": "MM Segmentation Serve",
"type": "Semantic Segmentation",
"device": g.device,
"session_id": task_id,
"classes_count": len(g.meta.obj_classes),
"videos_support": True
}
request_id = context["request_id"]
g.my_app.send_response(request_id, data=info)
Expand Down Expand Up @@ -92,17 +116,39 @@ def inference_batch_ids(api: sly.Api, task_id, context, state, app_logger):
paths.append(os.path.join(g.my_app.data_dir, sly.rand_str(10) + info.name))
api.image.download_paths(infos[0].dataset_id, ids, paths)

results = []
for image_path in paths:
ann_json = utils.inference_image_path(image_path=image_path, project_meta=g.meta,
context=context, state=state, app_logger=app_logger)
results.append(ann_json)
sly.fs.silent_remove(image_path)
results = inference_images_dir(paths, context, state, app_logger)

request_id = context["request_id"]
g.my_app.send_response(request_id, data=results)


@g.my_app.callback("inference_video_id")
@sly.timeit
@send_error_data
def inference_video_id(api: sly.Api, task_id, context, state, app_logger):
video_info = g.api.video.get_info_by_id(state['videoId'])

sly.logger.info(f'inference {video_info.id=} started')
inf_video_interface = nn_to_video.InferenceVideoInterface(api=g.api,
start_frame_index=state.get('startFrameIndex', 0),
frames_count=state.get('framesCount',
video_info.frames_count - 1),
frames_direction=state.get('framesDirection', 'forward'),
video_info=video_info,
imgs_dir=os.path.join(g.my_app.data_dir,
'videoInference'))

inf_video_interface.download_frames()

annotations = inference_images_dir(img_paths=inf_video_interface.images_paths,
context=context,
state=state,
app_logger=app_logger)

g.my_app.send_response(context["request_id"], data={'ann': annotations})
sly.logger.info(f'inference {video_info.id=} done, {len(annotations)} annotations created')


@g.my_app.callback("run")
@g.my_app.ignore_errors_and_show_dialog_window()
def init_model(api: sly.Api, task_id, context, state, app_logger):
Expand All @@ -115,7 +161,7 @@ def init_model(api: sly.Api, task_id, context, state, app_logger):
{"field": "state.deployed", "payload": True},
]
g.api.app.set_fields(g.TASK_ID, fields)
sly.logger.info("Model has been successfully deployed")
sly.logger.info("🟩 Model has been successfully deployed")


def init_state_and_data(data, state):
Expand Down

0 comments on commit 1284449

Please sign in to comment.