From 0b3949d818f7911f1b7af1cbb0d507fbe6f28ea8 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Wed, 3 Jun 2020 13:17:53 -0500 Subject: [PATCH] fix dump_to_lmdb for py3(also added saving all keys in the lmdb mimicing vilbert) --- scripts/dump_to_lmdb.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/scripts/dump_to_lmdb.py b/scripts/dump_to_lmdb.py index 6a0976bc..f08a03dd 100644 --- a/scripts/dump_to_lmdb.py +++ b/scripts/dump_to_lmdb.py @@ -42,7 +42,7 @@ def __init__(self, db_path, fn_list=None): def __getitem__(self, index): env = self.env with env.begin(write=False) as txn: - byteflow = txn.get(self.keys[index]) + byteflow = txn.get(self.keys[index].encode()) # load image imgbuf = byteflow @@ -117,7 +117,7 @@ def __init__(self, root, loader, extension, fn_list=None): if fn_list: samples = [os.path.join(root, str(_)+extension) for _ in fn_list] else: - samples = make_dataset(self.root, extention) + samples = make_dataset(self.root, extension) self.loader = loader self.extension = extension @@ -161,14 +161,16 @@ def folder2lmdb(dpath, fn_list, write_frequency=5000): txn = db.begin(write=True) - tsvfile = open(args.output_file, 'ab') + tsvfile = open(args.output_file, 'a') writer = csv.DictWriter(tsvfile, delimiter='\t', fieldnames=FIELDNAMES) - names = [] + names = [] + all_keys = [] for idx, data in enumerate(tqdm.tqdm(data_loader)): # print(type(data), data) name, byte, npz = data[0] if npz is not None: - txn.put(name, byte) + txn.put(name.encode(), byte) + all_keys.append(name) names.append({'image_id': name, 'status': str(npz is not None)}) if idx % write_frequency == 0: print("[%d/%d]" % (idx, len(data_loader))) @@ -181,7 +183,8 @@ def folder2lmdb(dpath, fn_list, write_frequency=5000): names = [] tsvfile.flush() print('writing finished') - + # write all keys + txn.put("keys".encode(), pickle.dumps(all_keys)) # finish iterating through dataset txn.commit() for name in names: