From 0213fe0397e303a78dfa6941312bb580adadf793 Mon Sep 17 00:00:00 2001 From: Steven Braun Date: Wed, 12 Jun 2024 11:42:05 +0200 Subject: [PATCH] Update main.py --- main.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 8d7f1b3..703bfbf 100644 --- a/main.py +++ b/main.py @@ -89,7 +89,7 @@ def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimiz epoch, batch_idx * len(data), len(train_loader.dataset), - loss.item() / args.batch_size, + loss.item(), acc_term, ) ) @@ -122,10 +122,6 @@ def test(model, device, loader, tag): test_loss += -1 * lls.sum() test_losses += lls.squeeze().cpu().tolist() - # Else compute negative log likelihoods - test_loss += -1 * outputs.sum() - test_losses += outputs.squeeze().cpu().tolist() - if args.classification: _, predicted = outputs.max(1) total += target.size(0) @@ -161,6 +157,9 @@ def test(model, device, loader, tag): elif args.dist == "normal": leaf_type = Normal leaf_kwargs = {} + elif args.dist == "normal_rat": + leaf_type = RatNormal + leaf_kwargs = {"min_sigma": args.min_sigma, "max_sigma": args.max_sigma} elif args.dist == "categorical": leaf_type = Categorical leaf_kwargs = {"num_bins": n_bins} @@ -178,7 +177,7 @@ def test(model, device, loader, tag): num_classes=num_classes, leaf_type=leaf_type, leaf_kwargs=leaf_kwargs, - layer_type="linsum", + layer_type=args.layer, dropout=0.0, ) @@ -218,15 +217,19 @@ def test(model, device, loader, tag): if args.train: for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) - lr_scheduler.step() + # lr_scheduler.step() torch.save(model.state_dict(), os.path.join(result_dir, "model.pth")) + test(model, device, train_loader, "Train") + test(model, device, val_loader, "Val") + test(model, device, test_loader, "Test") else: model.load_state_dict(torch.load(os.path.join(result_dir, "model.pth"))) - # test(model, device, test_loader, "Train") - # test(model, device, test_loader, "Test") + test(model, device, train_loader, "Train") + test(model, device, val_loader, "Val") + test(model, device, test_loader, "Test") # Don't sample when doing classification if not args.classification: