Skip to content

Commit

Permalink
Merge branch 'master' of github.com:alibaba/EasyCV
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Oct 31, 2023
2 parents c4c797d + 8c3ba59 commit c00f5c1
Show file tree
Hide file tree
Showing 16 changed files with 159 additions and 96 deletions.
78 changes: 38 additions & 40 deletions configs/config_templates/yolox_itag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='MMMosaic', img_scale='${img_scale}', pad_val=114.0),
dict(type='MMMosaic', img_scale=tuple(img_scale), pad_val=114.0),
dict(
type='MMRandomAffine',
scaling_ratio_range='${scale_ratio}',
border=['-${img_scale}[0] // 2', '-${img_scale}[1] // 2']),
scaling_ratio_range=scale_ratio,
border=[img_scale[0] // 2, img_scale[1] // 2]),
dict(
type='MMMixUp', # s m x l; tiny nano will detele
img_scale='${img_scale}',
img_scale=tuple(img_scale),
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(
Expand All @@ -70,45 +70,43 @@
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(
type='MMNormalize',
mean='${img_norm_cfg.mean}',
std='${img_norm_cfg.std}',
to_rgb='${img_norm_cfg.to_rgb}'),
mean=img_norm_cfg['mean'],
std=img_norm_cfg['std'],
to_rgb=img_norm_cfg['to_rgb']),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='MMResize', img_scale='${img_scale}', keep_ratio=True),
dict(type='MMResize', img_scale=img_scale, keep_ratio=True),
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)),
dict(
type='MMNormalize',
mean='${img_norm_cfg.mean}',
std='${img_norm_cfg.std}',
to_rgb='${img_norm_cfg.to_rgb}'),
mean=img_norm_cfg['mean'],
std=img_norm_cfg['std'],
to_rgb=img_norm_cfg['to_rgb']),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
]

train_path = 'data/coco/train2017.manifest'
val_path = 'data/coco/val2017.manifest'

train_dataset = dict(
type='DetImagesMixDataset',
data_source=dict(type='DetSourcePAI', path=train_path, classes=CLASSES),
pipeline=train_pipeline,
dynamic_scale=tuple(img_scale))

val_dataset = dict(
type='DetImagesMixDataset',
imgs_per_gpu=2,
data_source=dict(type='DetSourcePAI', path=val_path, classes=CLASSES),
pipeline=test_pipeline,
dynamic_scale=None,
label_padding=False)

data = dict(
imgs_per_gpu=16,
workers_per_gpu=4,
train=dict(
type='DetImagesMixDataset',
data_source=dict(
type='DetSourcePAI',
path='data/coco/train2017.manifest',
classes='${CLASSES}'),
pipeline='${train_pipeline}',
dynamic_scale='${img_scale}'),
val=dict(
type='DetImagesMixDataset',
imgs_per_gpu=2,
data_source=dict(
type='DetSourcePAI',
path='data/coco/val2017.manifest',
classes='${CLASSES}'),
pipeline='${test_pipeline}',
dynamic_scale=None,
label_padding=False))
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset)

# additional hooks
interval = 10
Expand All @@ -120,38 +118,38 @@
priority=48),
dict(
type='SyncRandomSizeHook',
ratio_range='${random_size}',
img_scale='${img_scale}',
interval='${interval}',
ratio_range=random_size,
img_scale=img_scale,
interval=interval,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=15,
interval='${interval}',
interval=interval,
priority=48)
]

# evaluation
vis_num = 20
score_thr = 0.5
eval_config = dict(
interval='${interval}',
interval=interval,
gpu_collect=False,
visualization_config=dict(
vis_num='${vis_num}',
score_thr='${score_thr}',
vis_num=vis_num,
score_thr=score_thr,
) # show by TensorboardLoggerHookV2
)

eval_pipelines = [
dict(
mode='test',
data='${data.val}',
data=val_dataset,
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)],
)
]

checkpoint_config = dict(interval='${interval}')
checkpoint_config = dict(interval=interval)
# optimizer
# basic_lr_per_img = 0.01 / 64.0
optimizer = dict(
Expand Down
31 changes: 28 additions & 3 deletions easycv/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,10 @@ def _export_yolox(model, cfg, filename):

if hasattr(cfg, 'export'):
export_type = getattr(cfg.export, 'export_type', 'raw')
default_export_type_list = ['raw', 'jit', 'blade']
default_export_type_list = ['raw', 'jit', 'blade', 'onnx']
if export_type not in default_export_type_list:
logging.warning(
'YOLOX-PAI only supports the export type as [raw,jit,blade], otherwise we use raw as default'
'YOLOX-PAI only supports the export type as [raw,jit,blade,onnx], otherwise we use raw as default'
)
export_type = 'raw'

Expand All @@ -276,7 +276,7 @@ def _export_yolox(model, cfg, filename):
len(img_scale) == 2
), 'Export YoloX predictor config contains img_scale must be (int, int) tuple!'

input = 255 * torch.rand((batch_size, 3) + img_scale)
input = 255 * torch.rand((batch_size, 3) + tuple(img_scale))

# assert use_trt_efficientnms only happens when static_opt=True
if static_opt is not True:
Expand Down Expand Up @@ -355,6 +355,31 @@ def _export_yolox(model, cfg, filename):

json.dump(config, ofile)

if export_type == 'onnx':

with io.open(
filename + '.config.json' if filename.endswith('onnx')
else filename + '.onnx.config.json', 'w') as ofile:
config = dict(
model=cfg.model,
export=cfg.export,
test_pipeline=cfg.test_pipeline,
classes=cfg.CLASSES)

json.dump(config, ofile)

torch.onnx.export(
model,
input.to(device),
filename if filename.endswith('onnx') else filename +
'.onnx',
export_params=True,
opset_version=12,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
)

if export_type == 'jit':
with io.open(filename + '.jit', 'wb') as ofile:
torch.jit.save(yolox_trace, ofile)
Expand Down
4 changes: 2 additions & 2 deletions easycv/core/evaluation/custom_cocotools/cocoeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def accumulate(self, p=None):
fps = np.logical_and(
np.logical_not(dtm), np.logical_not(dtIg))

tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float32)
fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float32)
for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
tp = np.array(tp)
fp = np.array(fp)
Expand Down
2 changes: 1 addition & 1 deletion easycv/datasets/face/pipelines/face_keypoint_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def aug_clr_noise_blur(self, img):
skin_factor_list = [0.6, 0.8, 1.0, 1.2, 1.4]
skin_factor = np.random.choice(skin_factor_list)
img_ycrcb_raw[:, :, 0:1] = np.clip(
img_ycrcb_raw[:, :, 0:1].astype(np.float) * skin_factor, 0,
img_ycrcb_raw[:, :, 0:1].astype(np.float32) * skin_factor, 0,
255).astype(np.uint8)
img = cv2.cvtColor(img_ycrcb_raw, cv2.COLOR_YCR_CB2BGR)

Expand Down
2 changes: 1 addition & 1 deletion easycv/models/utils/pos_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)

Expand Down
37 changes: 31 additions & 6 deletions easycv/predictors/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from .interface import PredictorInterface


# 将张量转化为ndarray格式
def onnx_to_numpy(tensor):
return tensor.detach().cpu().numpy(
) if tensor.requires_grad else tensor.cpu().numpy()


class DetInputProcessor(InputProcessor):

def build_processor(self):
Expand Down Expand Up @@ -349,9 +355,11 @@ def __init__(self,
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
elif model_path.endswith('onnx'):
self.model_type = 'onnx'
else:
self.model_type = 'raw'
assert self.model_type in ['raw', 'jit', 'blade']
assert self.model_type in ['raw', 'jit', 'blade', 'onnx']

if self.model_type == 'blade' or self.use_trt_efficientnms:
import torch_blade
Expand Down Expand Up @@ -381,8 +389,16 @@ def __init__(self,

def _build_model(self):
if self.model_type != 'raw':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
if self.model_type != 'onnx':
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
import onnxruntime
if onnxruntime.get_device() == 'GPU':
model = onnxruntime.InferenceSession(
self.model_path, providers=['CUDAExecutionProvider'])
else:
model = onnxruntime.InferenceSession(self.model_path)
else:
from easycv.utils.misc import reparameterize_models
model = super()._build_model()
Expand All @@ -394,8 +410,9 @@ def prepare_model(self):
If the model is not loaded from a configuration file, e.g. torch jit model, you need to reimplement it.
"""
model = self._build_model()
model.to(self.device)
model.eval()
if self.model_type != 'onnx':
model.to(self.device)
model.eval()
if self.model_type == 'raw':
load_checkpoint(model, self.model_path, map_location='cpu')
return model
Expand All @@ -406,7 +423,15 @@ def model_forward(self, inputs):
"""
if self.model_type != 'raw':
with torch.no_grad():
outputs = self.model(inputs['img'])
if self.model_type != 'onnx':
outputs = self.model(inputs['img'])
else:
outputs = self.model.run(
None, {
self.model.get_inputs()[0].name:
onnx_to_numpy(inputs['img'])
})[0]
outputs = torch.from_numpy(outputs)
outputs = {'results': outputs} # convert to dict format
else:
outputs = super().model_forward(inputs)
Expand Down
2 changes: 1 addition & 1 deletion easycv/thirdparty/mot/bytetrack/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class STrack(BaseTrack):
def __init__(self, tlwh, score):

# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
self.mean, self.covariance = None, None
self.is_activated = False
Expand Down
12 changes: 6 additions & 6 deletions easycv/thirdparty/mot/bytetrack/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ def ious(atlbrs, btlbrs):
:rtype ious np.ndarray
"""
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float)
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
if ious.size == 0:
return ious

from cython_bbox import bbox_overlaps as bbox_ious

ious = bbox_ious(
np.ascontiguousarray(atlbrs, dtype=np.float),
np.ascontiguousarray(btlbrs, dtype=np.float))
np.ascontiguousarray(atlbrs, dtype=np.float32),
np.ascontiguousarray(btlbrs, dtype=np.float32))

return ious

Expand Down Expand Up @@ -151,15 +151,15 @@ def embedding_distance(tracks, detections, metric='cosine'):
:return: cost_matrix np.ndarray
"""

cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float)
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
if cost_matrix.size == 0:
return cost_matrix
det_features = np.asarray([track.curr_feat for track in detections],
dtype=np.float)
dtype=np.float32)
#for i, track in enumerate(tracks):
#cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
track_features = np.asarray([track.smooth_feat for track in tracks],
dtype=np.float)
dtype=np.float32)
cost_matrix = np.maximum(0.0, cdist(track_features, det_features,
metric)) # Nomalized features
return cost_matrix
Expand Down
4 changes: 2 additions & 2 deletions easycv/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# GENERATED VERSION FILE
# TIME: Thu Nov 5 14:17:50 2020

__version__ = '0.11.3'
short_version = '0.11.3'
__version__ = '0.11.4'
short_version = '0.11.4'
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ lmdb
numba
numpy
nuscenes-devkit
onnxruntime
opencv-python
oss2
packaging
Expand Down
Loading

0 comments on commit c00f5c1

Please sign in to comment.