diff --git a/lib/model.py b/lib/model.py index 8a25f46..d7b2b2e 100644 --- a/lib/model.py +++ b/lib/model.py @@ -315,7 +315,7 @@ def forward_d(self): def backward_g(self): """ Backpropagate through netG """ - self.err_g_adv = self.l_adv(self.feat_fake, self.feat_real) + self.err_g_adv = self.l_adv(self.netd(self.input)[1], self.netd(self.fake)[1]) self.err_g_con = self.l_con(self.fake, self.input) self.err_g_enc = self.l_enc(self.latent_o, self.latent_i) self.err_g = self.err_g_adv * self.opt.w_adv + \