diff --git a/scripts/reinforce/config_scaffold.yaml b/scripts/reinforce/config_scaffold.yaml index d16e9b33..25a5182e 100644 --- a/scripts/reinforce/config_scaffold.yaml +++ b/scripts/reinforce/config_scaffold.yaml @@ -29,4 +29,3 @@ model: gru # gru, lstm, or gpt2 lr: 0.0001 eps: 1.0e-08 weight_decay: 0.0 - diff --git a/scripts/reinforce/reinforce.py b/scripts/reinforce/reinforce.py index dda05f06..53ce86d0 100644 --- a/scripts/reinforce/reinforce.py +++ b/scripts/reinforce/reinforce.py @@ -290,7 +290,7 @@ def compute_loss(data, model): agent_log_prob = get_log_prob(data, model) agent_likelihood = (agent_log_prob * mask).sum(-1) reward = data.get(("next", "reward")).squeeze(-1).sum(-1) - loss = - agent_likelihood * reward + loss = -agent_likelihood * reward return data, loss