-
Notifications
You must be signed in to change notification settings - Fork 4
/
fer2013_process.py
55 lines (47 loc) · 1.98 KB
/
fer2013_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import csv
import os
from PIL import Image
import numpy as np
# 读、写数据的地址
data_path = os.getcwd() + "/data/"
csv_file = data_path + 'fer2013.csv' # 读数据集地址
train_csv = data_path + 'train.csv' # 拆数据集保存地址
val_csv = data_path + 'val.csv'
test_csv = data_path + 'test.csv'
# csv文件像素保存为图像的文件夹名称
train_set = os.path.join(data_path, 'train')
val_set = os.path.join(data_path, 'val')
test_set = os.path.join(data_path, 'test')
# 开始整理数据集:读
with open(csv_file) as f:
csv_r = csv.reader(f)
header = next(csv_r)
print(header)
rows = [row for row in csv_r]
trn = [row[:-1] for row in rows if row[-1] == 'Training']
csv.writer(open(train_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + trn)
print(len(trn))
val = [row[:-1] for row in rows if row[-1] == 'PublicTest']
csv.writer(open(val_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + val)
print(len(val))
tst = [row[:-1] for row in rows if row[-1] == 'PrivateTest']
csv.writer(open(test_csv, 'w+'), lineterminator='\n').writerows([header[:-1]] + tst)
print(len(tst))
for save_path, csv_file in [(train_set, train_csv), (val_set, val_csv), (test_set, test_csv)]:
if not os.path.exists(save_path):
os.makedirs(save_path)
num = 1
with open(csv_file) as f:
csv_r = csv.reader(f)
header = next(csv_r)
for i, (label, pixel) in enumerate(csv_r):
# 0 - 6 文件夹分别label为:
# angry ,disgust ,fear ,happy ,sad ,surprise ,neutral
pixel = np.asarray([float(p) for p in pixel.split()]).reshape(48, 48)
sub_folder = os.path.join(save_path, label)
if not os.path.exists(sub_folder):
os.makedirs(sub_folder)
im = Image.fromarray(pixel).convert('L')
image_name = os.path.join(sub_folder, '{:05d}.jpg'.format(i))
print(image_name)
im.save(image_name)