diff --git a/imagenet-classification/create_train_dir.py b/imagenet-classification/create_train_dir.py index c5e076f3c..d6620c730 100644 --- a/imagenet-classification/create_train_dir.py +++ b/imagenet-classification/create_train_dir.py @@ -29,14 +29,11 @@ dst_dir = args.outdir with tarfile.open(source_tar_file) as tf: - tar_tmp_dir = dst_dir + '/' + 'tmpdir' - tf.extractall(tar_tmp_dir) - -for tar_file in tqdm.tqdm(os.listdir(tar_tmp_dir)): - name, ext = os.path.splitext(os.path.basename(tar_file)) - category_dir = dst_dir + '/' + name - os.mkdir(category_dir) - with tarfile.open(tar_tmp_dir + '/' + tar_file) as tf: - tf.extractall(category_dir) - -shutil.rmtree(tar_tmp_dir) + for tar_file_info in tqdm.tqdm(tf.getmembers()): + fullname = tar_file_info.name + name, ext = os.path.splitext(os.path.basename(fullname)) + category_dir = dst_dir + '/' + name + os.mkdir(category_dir) + fileobj = tf.extractfile(tar_file_info) + with tarfile.open(fileobj=fileobj) as tf_class: + tf_class.extractall(category_dir)