-
Notifications
You must be signed in to change notification settings - Fork 1
/
generate.py
67 lines (51 loc) · 1.86 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
from pathlib import Path
from subprocess import PIPE, Popen
import tensorflow as tf
PHP_BINARY = "php"
DATA_DIR = Path("./data/")
CHUNK_SIZE = 65536
def generate():
process = Popen([PHP_BINARY, "generate.php"], stdout=PIPE)
output = b""
while True:
output += process.stdout.read(CHUNK_SIZE)
while True:
eoi = output.find(b"\xff\xd9")
if eoi < 0:
break
soi = output.find(b"\xff\xd8")
eoi += 2
phrase = output[:soi].decode("ascii")
image = output[soi:eoi]
output = output[eoi:]
yield phrase, image
# https://www.tensorflow.org/tutorials/load_data/tfrecord#data_types_for_tftrainexample
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# https://www.tensorflow.org/tutorials/load_data/tfrecord#write_the_tfrecord_file
def captcha_example(image_string: bytes, phrase: str):
image = tf.image.decode_jpeg(image_string)
feature = {
"phrase": _bytes_feature(phrase.encode("ascii")),
"image": _bytes_feature(tf.io.serialize_tensor(image)),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
def write(name: str):
with tf.io.TFRecordWriter(str(DATA_DIR / f"{name}.tfrecords")) as writer:
for phrase, image in generate():
tf_example = captcha_example(image, phrase)
writer.write(tf_example.SerializeToString())
if __name__ == "__main__":
try:
uuid = sys.argv[1]
except IndexError:
from uuid import uuid4
uuid = uuid4()
try:
write(uuid)
except KeyboardInterrupt:
pass