From 96effb16763edd7b3747dd0f12aac395f279789f Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 16 Apr 2024 17:07:10 +0200 Subject: [PATCH] fix loss --- scripts/a2c/a2c.py | 4 +++- scripts/ppo/ppo.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/a2c/a2c.py b/scripts/a2c/a2c.py index a0ddc0e7..dc689872 100644 --- a/scripts/a2c/a2c.py +++ b/scripts/a2c/a2c.py @@ -319,7 +319,9 @@ def create_env_fn(): loss = loss_module(batch) loss = loss.named_apply( lambda name, value: ( - (value * mask).mean() if name.startswith("loss_") else value + (value * mask).mean() + if name.startswith("loss_") + else value # (value * mask).sum(-1).mean(-1) if name.startswith("loss_") else value ), batch_size=[], diff --git a/scripts/ppo/ppo.py b/scripts/ppo/ppo.py index 37c641e9..7b3b47e1 100644 --- a/scripts/ppo/ppo.py +++ b/scripts/ppo/ppo.py @@ -360,7 +360,9 @@ def create_env_fn(): loss = loss_module(batch) loss = loss.named_apply( lambda name, value: ( - value * mask).mean() if name.startswith("loss_") else value + (value * mask).mean() + if name.startswith("loss_") + else value # (value * mask).sum(-1).mean(-1) if name.startswith("loss_") else value ), batch_size=[],