-
Notifications
You must be signed in to change notification settings - Fork 2
/
create_record.py
137 lines (122 loc) · 5.22 KB
/
create_record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import io
import os
import subprocess
# import ray
import tensorflow.compat.v1 as tf
from PIL import Image
from psutil import cpu_count
from utils import *
from object_detection.utils import dataset_util, label_map_util
import os
label_map = label_map_util.load_labelmap('/tf/code/label_map.pbtxt')
label_map_dict = label_map_util.get_label_map_dict(label_map)
t2idict = {y:x for x,y in label_map_dict.items()}
def class_text_to_int(text):
return t2idict[text]
def create_tf_example(filename, encoded_jpeg, annotations):
"""
This function create a tf.train.Example from the Waymo frame.
args:
- filename [str]: name of the image
- encoded_jpeg [bytes]: jpeg encoded image
- annotations [protobuf object]: bboxes and classes
returns:
- tf_example [tf.Train.Example]: tf example in the objection detection api format.
"""
# TODO: Implement function to convert the data
encoded_jpg_io = io.BytesIO(encoded_jpeg)
image = Image.open(encoded_jpg_io)
width, height = image.size
image_format = b'jpeg'
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in enumerate(annotations):
xmin = row.box.center_x - row.box.length/2.0
xmax = row.box.center_x + row.box.length/2.0
ymin = row.box.center_y - row.box.width/2.0
ymax = row.box.center_y + row.box.width/2.0
xmins.append(xmin / width)
xmaxs.append(xmax / width)
ymins.append(ymin / height)
ymaxs.append(ymax / height)
classes_text.append(class_text_to_int(row.type).encode('utf8'))
classes.append(row.type)
print(class_text_to_int(row.type).encode('utf8'))
print(row.type)
filename = filename.encode('utf8')
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(filename),
'image/source_id': bytes_feature(filename),
'image/encoded': bytes_feature(encoded_jpeg),
'image/format': bytes_feature(image_format),
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
}))
return tf_example
def process_tfr(filepath, data_dir):
"""
process a Waymo tf record into a tf api tf record
args:
- filepath [str]: path to the Waymo tf record file
- data_dir [str]: path to the destination directory
"""
# create processed data dir
dest = os.path.join(data_dir, 'processed')
os.makedirs(dest, exist_ok=True)
file_name = os.path.basename(filepath)
logger = get_module_logger(__name__)
if os.path.exists(f'{dest}/{file_name}'):
return
logger.info(f'Processing {filepath}')
writer = tf.python_io.TFRecordWriter(f'{dest}/{file_name}')
dataset = tf.data.TFRecordDataset(filepath, compression_type='')
for idx, data in enumerate(dataset):
frame = open_dataset.Frame()
frame.ParseFromString(bytearray(data.numpy()))
# FRONT
encoded_jpeg, annotations = parse_frame(frame, 'FRONT')
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
# FORNT_LEFT
encoded_jpeg, annotations = parse_frame(frame, 'FRONT_LEFT')
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
# FRONT_RIGHT
encoded_jpeg, annotations = parse_frame(frame, 'FRONT_RIGHT')
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
# SIDE_LEFT
encoded_jpeg, annotations = parse_frame(frame, 'SIDE_LEFT')
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
# SIDE_RIGHT
encoded_jpeg, annotations = parse_frame(frame, 'SIDE_RIGHT')
filename = file_name.replace('.tfrecord', f'_{idx}.tfrecord')
tf_example = create_tf_example(filename, encoded_jpeg, annotations)
writer.write(tf_example.SerializeToString())
writer.close()
return
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download and process tf files')
parser.add_argument('--data_dir', required=True,
help='processed data directory')
parser.add_argument('--filepath', required=True,
help='raw data path')
args = parser.parse_args()
process_tfr(args.filepath, args.data_dir)
logger = get_module_logger(__name__)