Skip to content

Commit d66b1e7

Browse files
committed
Update code.
1 parent 163f0f7 commit d66b1e7

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

pytorch_fid/fid_score.py

+9-20
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"""
3434
import os
3535
import pathlib
36-
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
36+
from argparse import Namespace
3737

3838
import numpy as np
3939
import torch
@@ -51,24 +51,6 @@ def tqdm(x):
5151

5252
from pytorch_fid.inception import InceptionV3
5353

54-
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
55-
parser.add_argument('--batch-size', type=int, default=50,
56-
help='Batch size to use')
57-
parser.add_argument('--num-workers', type=int,
58-
help=('Number of processes to use for data loading. '
59-
'Defaults to `min(8, num_cpus)`'))
60-
parser.add_argument('--device', type=str, default=None,
61-
help='Device to use. Like cuda, cuda:0 or cpu')
62-
parser.add_argument('--dims', type=int, default=2048,
63-
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
64-
help=('Dimensionality of Inception features to use. '
65-
'By default, uses pool3 features'))
66-
parser.add_argument('--save-stats', action='store_true',
67-
help=('Generate an npz archive from a directory of samples. '
68-
'The first path is used as input and the second as output.'))
69-
parser.add_argument('--path', type=str, nargs=2, default=[r"C:\Users\PC\Desktop\output", "pytorch_fid/imagenet.npz"],
70-
help=('Paths to the generated images or '
71-
'to .npz statistic files'))
7254

7355
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
7456
'tif', 'tiff', 'webp'}
@@ -286,7 +268,14 @@ def save_fid_stats(paths, batch_size, device, dims, num_workers=1):
286268

287269

288270
def main(path=None, dataset_name="imagenet_compatible"):
289-
args = parser.parse_args()
271+
args = Namespace(
272+
batch_size=50,
273+
num_workers=None, # You can manually set this value if needed
274+
device=None,
275+
dims=2048,
276+
save_stats=False,
277+
path=[r"C:\Users\PC\Desktop\output", "pytorch_fid/imagenet.npz"]
278+
)
290279

291280
if path is not None:
292281
args.path[0] = path

0 commit comments

Comments
 (0)