Skip to content

Commit

Permalink
Update Model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
seq-to-mind authored Sep 25, 2023
1 parent 79ea10b commit 09ef49a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,12 @@ def forward(self, batch, eval_mode=False):
""" Back Translation Training """
""" here the reward should have the dim (batch_size, ) """
with torch.no_grad():
# cyclic_advantage = self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_sampling, batch_sample_input_text) - \
# self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_greedy, batch_sample_input_text)

cyclic_advantage = self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_sampling, batch_sample_input_text)
# calculate self-critical advantage
cyclic_advantage = self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_sampling, batch_sample_input_text) - \
self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_greedy, batch_sample_input_text)
cyclic_advantage = torch.clamp(cyclic_advantage, min=0.0) # optional: add the clamp operation when advantage < 0

# cyclic_advantage = self.agent.cyclic_generation(batch_target_style_list, trans_sen_text_sampling, batch_sample_input_text)

""" Style Discriminator Training """
if trans_sen_text_greedy is not None:
Expand Down

0 comments on commit 09ef49a

Please sign in to comment.