diff --git a/predict.py b/predict.py index 46ff358..baefd51 100644 --- a/predict.py +++ b/predict.py @@ -8,7 +8,7 @@ from config import main_config from models import gmcnn_gan -from utils import training_utils +from utils import training_utils, constants log = training_utils.get_logger() @@ -49,6 +49,9 @@ def main(): parser.add_argument('--save_to', default='predicted.jpg', help='The save path of predicted image') + parser.add_argument('--exp_name', + requiered=True, + help='name of the experiment') args = parser.parse_args() @@ -59,7 +62,7 @@ def main(): img_width=config.training.img_width, num_channels=config.training.num_channels, warm_up_generator=False, - config=config) + config=config,output_paths=constants.OutputPaths(args.exp_name)) log.info('Loading GMCNN model...') gmcnn_model.load() log.info('GMCNN model successfully loaded.')