Skip to content

Commit 308c696

Browse files
authored
Readme 30k fix (#65)
* fixes * fix * nota bene * link to supmat places section * python -> python3 fix * eval2_gpu fix
1 parent 4d6e17b commit 308c696

File tree

4 files changed

+45
-36
lines changed

4 files changed

+45
-36
lines changed

README.md

+16-9
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ Then download models for _perceptual loss_:
162162
163163
164164
## Places
165+
166+
⚠️ NB: FID/SSIM/LPIPS metric values for Places that we see in LaMa paper are computed on 30000 images that we produce in evaluation section below.
167+
For more details on evaluation data check [[Section 3. Dataset splits in Supplementary](https://ashukha.com/projects/lama_21/lama_supmat_2021.pdf#subsection.3.1)] ⚠️
168+
165169
On the host machine:
166170
167171
# Download data from http://places2.csail.mit.edu/download.html
@@ -170,18 +174,20 @@ On the host machine:
170174
wget http://data.csail.mit.edu/places/places365/val_large.tar
171175
wget http://data.csail.mit.edu/places/places365/test_large.tar
172176
173-
# Unpack and etc.
177+
# Unpack train/test/val data and create .yaml config for it
174178
bash fetch_data/places_standard_train_prepare.sh
175179
bash fetch_data/places_standard_test_val_prepare.sh
176-
bash fetch_data/places_standard_evaluation_prepare_data.sh
177180
178181
# Sample images for test and viz at the end of epoch
179182
bash fetch_data/places_standard_test_val_sample.sh
180183
bash fetch_data/places_standard_test_val_gen_masks.sh
181184
182185
# Run training
183-
# You can change bs with data.batch_size=10
184-
python bin/train.py -cn lama-fourier location=places_standard
186+
python3 bin/train.py -cn lama-fourier location=places_standard
187+
188+
# To evaluate trained model and report metrics as in our paper
189+
# we need to sample previously unseen 30k images and generate masks for them
190+
bash fetch_data/places_standard_evaluation_prepare_data.sh
185191
186192
# Infer model on thick/thin/medium masks in 256 and 512 and run evaluation
187193
# like this:
@@ -191,9 +197,10 @@ On the host machine:
191197
outdir=$(pwd)/inference/random_thick_512 model.checkpoint=last.ckpt
192198
193199
python3 bin/evaluate_predicts.py \
194-
$(pwd)/configs/eval_2gpu.yaml \
200+
$(pwd)/configs/eval2_gpu.yaml \
195201
$(pwd)/places_standard_dataset/evaluation/random_thick_512/ \
196-
$(pwd)/inference/random_thick_512 $(pwd)/inference/random_thick_512_metrics.csv
202+
$(pwd)/inference/random_thick_512 \
203+
$(pwd)/inference/random_thick_512_metrics.csv
197204
198205
199206
@@ -216,7 +223,7 @@ On the host machine:
216223
bash fetch_data/celebahq_gen_masks.sh
217224
218225
# Run training
219-
python bin/train.py -cn lama-fourier-celeba data.batch_size=10
226+
python3 bin/train.py -cn lama-fourier-celeba data.batch_size=10
220227
221228
# Infer model on thick/thin/medium masks in 256 and run evaluation
222229
# like this:
@@ -335,7 +342,7 @@ On the host machine:
335342
336343
337344
# Run training
338-
python bin/train.py -cn lama-fourier location=my_dataset data.batch_size=10
345+
python3 bin/train.py -cn lama-fourier location=my_dataset data.batch_size=10
339346
340347
# Evaluation: LaMa training procedure picks best few models according to
341348
# scores on my_dataset/val/
@@ -353,7 +360,7 @@ On the host machine:
353360
354361
# metrics calculation:
355362
python3 bin/evaluate_predicts.py \
356-
$(pwd)/configs/eval_2gpu.yaml \
363+
$(pwd)/configs/eval2_gpu.yaml \
357364
$(pwd)/my_dataset/eval/random_<size>_512/ \
358365
$(pwd)/inference/my_dataset/random_<size>_512 \
359366
$(pwd)/inference/my_dataset/random_<size>_512_metrics.csv

fetch_data/eval_sampler.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
import os
22
import random
33

4-
5-
val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/'
4+
val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/'
5+
list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/eval_random_files.txt'
66
val_files = [val_files_path + image for image in os.listdir(val_files_path)]
77

8-
print(f'found {len(val_files)} images in {val_files_path}')
8+
print(f'Sampling 30000 images out of {len(val_files)} images in {val_files_path}' + \
9+
f'and put their paths to {list_of_random_val_files}')
910

10-
random.shuffle(val_files)
11-
val_files_random = val_files[0:2000]
11+
print('In our paper we evaluate trained models on these 30k sampled (mask,image) pairs in our paper (check Sup. mat.)')
1212

13-
list_of_random_val_files = os.path.abspath('.') \
14-
+ '/places_standard_dataset/original/eval_random_files.txt'
13+
random.shuffle(val_files)
14+
val_files_random = val_files[0:30000]
1515

16-
print(f'copying 2000 random images to {list_of_random_val_files}')
1716
with open(list_of_random_val_files, 'w') as fw:
1817
for filename in val_files_random:
1918
fw.write(filename+'\n')

fetch_data/places_standard_evaluation_prepare_data.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mkdir -p places_standard_dataset/evaluation/random_thick_256/
77
mkdir -p places_standard_dataset/evaluation/random_thin_256/
88
mkdir -p places_standard_dataset/evaluation/random_medium_256/
99

10-
# 1. sample 2000 new images
10+
# 1. sample 30000 new images
1111
OUT=$(python3 fetch_data/eval_sampler.py)
1212
echo ${OUT}
1313

fetch_data/sampler.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,40 @@
11
import os
22
import random
33

4-
test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/'
5-
test_files = [test_files_path + image for image in os.listdir(test_files_path)]
6-
print(f'found {len(test_files)} images in {test_files_path}')
4+
test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/'
5+
list_of_random_test_files = os.path.abspath('.') + '/places_standard_dataset/original/test_random_files.txt'
76

8-
random.shuffle(test_files)
9-
test_files_random = test_files[0:2000]
10-
#print(test_files_random[0:10])
7+
test_files = [
8+
test_files_path + image for image in os.listdir(test_files_path)
9+
]
1110

12-
list_of_random_test_files = os.path.abspath('.') \
13-
+ '/places_standard_dataset/original/test_random_files.txt'
11+
print(f'Sampling 2000 images out of {len(test_files)} images in {test_files_path}' + \
12+
f'and put their paths to {list_of_random_test_files}')
13+
print('Our training procedure will pick best checkpoints according to metrics, computed on these images.')
1414

15-
print(f'copying 100 random images to {list_of_random_test_files}')
15+
random.shuffle(test_files)
16+
test_files_random = test_files[0:2000]
1617
with open(list_of_random_test_files, 'w') as fw:
1718
for filename in test_files_random:
1819
fw.write(filename+'\n')
1920
print('...done')
2021

21-
# ----------------------------------------------------------------------------------
2222

23+
# --------------------------------
2324

24-
val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/'
25-
val_files = [val_files_path + image for image in os.listdir(val_files_path)]
26-
print(f'found {len(val_files)} images in {val_files_path}')
25+
val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/'
26+
list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/val_random_files.txt'
2727

28-
random.shuffle(val_files)
29-
val_files_random = val_files[0:100]
28+
val_files = [
29+
val_files_path + image for image in os.listdir(val_files_path)
30+
]
3031

31-
list_of_random_val_files = os.path.abspath('.') \
32-
+ '/places_standard_dataset/original/val_random_files.txt'
32+
print(f'Sampling 100 images out of {len(val_files)} in {val_files_path} ' + \
33+
f'and put their paths to {list_of_random_val_files}')
34+
print('We use these images for visual check up of evolution of inpainting algorithm epoch to epoch' )
3335

34-
print(f'copying 100 random images to {list_of_random_val_files}')
36+
random.shuffle(val_files)
37+
val_files_random = val_files[0:100]
3538
with open(list_of_random_val_files, 'w') as fw:
3639
for filename in val_files_random:
3740
fw.write(filename+'\n')

0 commit comments

Comments
 (0)