-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_split_script.py
97 lines (76 loc) · 4.35 KB
/
data_split_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import pandas as pd
import argparse
import random
parser = argparse.ArgumentParser(description='Split data into train, validation and test files.')
parser.add_argument("train_language", type=str, help="The options are English, Thai, or both for what the training data will be.")
parser.add_argument("train_font", type=str, help="The options are normal, bold, bold_italic, italic, or all for what the training data will be.")
parser.add_argument("train_dpi", type=str, help="The options are 200, 300, 400 or all for what the training data will be.")
parser.add_argument("directory", default="none",help = "The path of the dataset /scratch/lt2326-2926-h24/ThaiOCR/ThaiOCR-TrainigSet .")
parser.add_argument("--test_language", type=str, default="none", help="The options are English, Thai, or both for what the testing data will be.")
parser.add_argument("--test_font", type=str, default="none",help="The options are normal, bold, bold_italic, italic, or all for what the testing data will be.")
parser.add_argument("--test_dpi", type=str, default="none",help="The options are 200, 300, 400 or all for what the testing data will be.")
args = parser.parse_args()
def collect_from_folder(language, font, dpi, directory, images):
lang_dir = os.path.join(directory,language)
walking = os.walk(lang_dir)
dpi_options = ["200", "300", "400"] if dpi == "all" else [str(dpi)]
for root,_,files in walking:
if font != "all":
path_parts = root.split(os.sep)
language_idx = path_parts.index(language)
if any(d in root for d in dpi_options) and path_parts[-1] == font:
ocr_number = path_parts[language_idx + 1 ]
for file in files:
if file.endswith('.bmp'):
images.append((os.path.join(root,file), ocr_number))
if font == "all" :
if any(d in root for d in dpi_options):
path_parts = root.split(os.sep)
language_idx = path_parts.index(language)
ocr_number = path_parts[language_idx + 1 ]
for file in files:
if file.endswith('.bmp'):
images.append((os.path.join(root,file), ocr_number))
def split_data(train_language, train_font, train_dpi, directory, test_language, test_font, test_dpi):
train_images = []
test_images = []
valid_size = 0.1
test_size = 0.1
if train_language == "both":
collect_from_folder('English',train_font, train_dpi, directory, train_images)
collect_from_folder('Thai',train_font, train_dpi, directory, train_images)
else:
collect_from_folder(train_language,train_font, train_dpi, directory, train_images)
random.shuffle(train_images)
if test_language == "none" or test_font == "none" or test_dpi == "none":
total_train_size = len(train_images)
test_split = int(total_train_size * test_size)
test_set = train_images[:test_split]
valid_split = int(total_train_size*(test_size + valid_size))
valid_set = train_images[test_split:valid_split]
train_set = train_images[valid_split:]
else:
if test_language == "both":
collect_from_folder('English', test_font, test_dpi, directory, test_images)
collect_from_folder('Thai', test_font, test_dpi, directory, test_images)
else:
collect_from_folder(test_language, test_font, test_dpi, directory, test_images)
random.shuffle(test_images)
total_train_size = len(train_images)
test_limit = int(total_train_size * test_size)
test_set = test_images[:test_limit]
total_train_size = len(train_images)
valid_split = int(total_train_size * valid_size)
valid_set = train_images[:valid_split]
train_set = train_images[valid_split:]
with open("train_file.txt",'w') as train:
for path,ocr_number in train_set:
train.write(f"{path},{ocr_number}\n")
with open("test_file.txt",'w') as test:
for path,ocr_number in test_set:
test.write(f"{path},{ocr_number}\n")
with open("valid_file.txt",'w') as valid:
for path,ocr_number in valid_set:
valid.write(f"{path},{ocr_number}\n")
split_data(args.train_language, args.train_font, args.train_dpi, args.directory, args.test_language, args.test_font , args.test_dpi )