From 3b05496e1edc780d69e144fa4249d4285d539b2f Mon Sep 17 00:00:00 2001 From: Alan Ng <15185920+alanngnet@users.noreply.github.com> Date: Sat, 2 Nov 2024 20:16:02 -0500 Subject: [PATCH] more ce > foc renaming see commit 4f87bad --- tests/test_trainer.py | 4 ++-- tools/train_tune.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 0fc476d..5da73f5 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -106,11 +106,11 @@ def _test_logging(self, device): expected_calls = [ # 1 because the step is logged for csi unittest.mock.call("csi/lr", unittest.mock.ANY, 1), - unittest.mock.call("csi/ce_loss", unittest.mock.ANY, 1), + unittest.mock.call("csi/foc_loss", unittest.mock.ANY, 1), unittest.mock.call("csi/tri_loss", unittest.mock.ANY, 1), unittest.mock.call("csi/total", unittest.mock.ANY, 1), # 0 because the epoch is logged for the rest - unittest.mock.call("csi_test/ce_loss", unittest.mock.ANY, 0), + unittest.mock.call("csi_test/foc_loss", unittest.mock.ANY, 0), unittest.mock.call("csi_test/tri_loss", unittest.mock.ANY, 0), unittest.mock.call("mAP/covers80", unittest.mock.ANY, 0), unittest.mock.call("hit_rate/covers80", unittest.mock.ANY, 0), diff --git a/tools/train_tune.py b/tools/train_tune.py index 8c8b3c7..09ce3da 100644 --- a/tools/train_tune.py +++ b/tools/train_tune.py @@ -118,7 +118,7 @@ def get_final_metrics_from_logs(log_dir, test_name): ea.Reload() # Extract the final validation loss and mAP - val_loss = ea.Scalars("csi_val/ce_loss")[-1].value + val_loss = ea.Scalars("csi_val/foc_loss")[-1].value mAP = ea.Scalars(f"mAP/{test_name}")[-1].value print(f"Final loss {val_loss}, mAP {mAP}")