Skip to content

Commit

Permalink
add some loading pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock committed Dec 24, 2019
1 parent 3207a1c commit 1022b7b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 33 deletions.
84 changes: 53 additions & 31 deletions mmaction/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,70 @@
import os.path as osp

import av
import mmcv
import numpy as np

from ..registry import PIPELINES


def sample_clips(num_frames,
clip_len,
frame_interval=1,
num_clips=1,
temporal_jitter=False):
"""Sample clips from the video.
@PIPELINES.register_module
class SampleFrames:
"""Sample frames from the video.
Required keys are "filename", added or modified keys are "total_frames"
and "frame_inds".
Args:
num_frames (int): Total frames of the video.
Attributes:
clip_len (int): Frames of each sampled output clip.
frame_interval (int): Temporal interval of adjacent sampled frames.
num_clips (int): Number of clips to be sampled.
temporal_jitter (bool): Whether to apply temporal jittering.
Returns:
np.ndarray: Shape (num_clips, clip_len)
"""
ori_clip_len = clip_len * frame_interval
avg_interval = (num_frames - ori_clip_len) // num_clips
if avg_interval > 0:
base_offsets = np.arange(num_clips) * avg_interval
clip_offsets = base_offsets + np.random.randint(
avg_interval, size=num_clips)
elif num_frames > max(num_clips, ori_clip_len):
clip_offsets = np.sort(
np.random.randint(num_frames - ori_clip_len + 1, size=num_clips))
else:
clip_offsets = np.zeros((num_clips, ))

frame_inds = clip_offsets + np.arange(clip_len)[None, :] * frame_interval

if temporal_jitter:
perframe_offsets = np.random.randint((num_clips, frame_interval),
size=clip_len)
frame_inds += perframe_offsets

return frame_inds

def __init__(self,
clip_len,
frame_interval=1,
num_clips=1,
temporal_jitter=False):
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
self.temporal_jitter = temporal_jitter

def _sample_clips(self, num_frames):
ori_clip_len = self.clip_len * self.frame_interval
avg_interval = (num_frames - ori_clip_len) // self.num_clips
if avg_interval > 0:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = base_offsets + np.random.randint(
avg_interval, size=self.num_clips)
elif num_frames > max(self.num_clips, ori_clip_len):
clip_offsets = np.sort(
np.random.randint(
num_frames - ori_clip_len + 1, size=self.num_clips))
else:
clip_offsets = np.zeros((self.num_clips, ))

return clip_offsets

def __call__(self, results):
video_reader = mmcv.VideoReader(results['filename'])
total_frames = len(video_reader)
results['total_frames'] = total_frames

clip_offsets = self._sample_clips(total_frames)

frame_inds = clip_offsets + np.arange(
self.clip_len)[None, :] * self.frame_interval

if self.temporal_jitter:
perframe_offsets = np.random.randint(
(self.num_clips, self.frame_interval), size=self.clip_len)
frame_inds += perframe_offsets

results['frame_inds'] = frame_inds

return results


@PIPELINES.register_module
Expand Down
20 changes: 18 additions & 2 deletions mmaction/datasets/video_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import copy
import os.path as osp

from torch.utils.data import Dataset


Expand All @@ -18,7 +21,20 @@ def load_annotations(self, ann_file):
video_infos = []
with open(ann_file, 'r') as fin:
for line in fin:
video_infos.append({''})
filename, label = line.split(' ')
filepath = osp.join(self.data_prefix, filename)
video_infos.append(dict(filename=filepath, label=label))
self.video_infos = video_infos

def __getitem__(self, idx):
def prepare_train_frames(self, idx):
results = copy.deepcopy(self.video_infos[idx])
return self.pipeline(results)

def prepare_test_frames(self, idx):
pass

def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_frames()
else:
return self.prepare_train_frames()

0 comments on commit 1022b7b

Please sign in to comment.