diff --git a/train.py b/train.py index 6ef4a61..70595e0 100644 --- a/train.py +++ b/train.py @@ -419,7 +419,7 @@ def main(): d0 = IMFPatchDiscriminator(ndf=config.discriminator.ndf) d1 = MultiScalePatchDiscriminator(input_nc=3, ndf=64, n_layers=3, num_D=3) - discriminator = d1 if config.training.use_multiscale_discriminator else do + discriminator = d1 if config.training.use_multiscale_discriminator else d0 add_gradient_hooks(discriminator)