Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
braun-steven committed Jun 12, 2024
1 parent 77dfb46 commit 0213fe0
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def train(args, model: Union[Einet, EinetMixture], device, train_loader, optimiz
epoch,
batch_idx * len(data),
len(train_loader.dataset),
loss.item() / args.batch_size,
loss.item(),
acc_term,
)
)
Expand Down Expand Up @@ -122,10 +122,6 @@ def test(model, device, loader, tag):
test_loss += -1 * lls.sum()
test_losses += lls.squeeze().cpu().tolist()

# Else compute negative log likelihoods
test_loss += -1 * outputs.sum()
test_losses += outputs.squeeze().cpu().tolist()

if args.classification:
_, predicted = outputs.max(1)
total += target.size(0)
Expand Down Expand Up @@ -161,6 +157,9 @@ def test(model, device, loader, tag):
elif args.dist == "normal":
leaf_type = Normal
leaf_kwargs = {}
elif args.dist == "normal_rat":
leaf_type = RatNormal
leaf_kwargs = {"min_sigma": args.min_sigma, "max_sigma": args.max_sigma}
elif args.dist == "categorical":
leaf_type = Categorical
leaf_kwargs = {"num_bins": n_bins}
Expand All @@ -178,7 +177,7 @@ def test(model, device, loader, tag):
num_classes=num_classes,
leaf_type=leaf_type,
leaf_kwargs=leaf_kwargs,
layer_type="linsum",
layer_type=args.layer,
dropout=0.0,
)

Expand Down Expand Up @@ -218,15 +217,19 @@ def test(model, device, loader, tag):
if args.train:
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
lr_scheduler.step()
# lr_scheduler.step()

torch.save(model.state_dict(), os.path.join(result_dir, "model.pth"))
test(model, device, train_loader, "Train")
test(model, device, val_loader, "Val")
test(model, device, test_loader, "Test")

else:
model.load_state_dict(torch.load(os.path.join(result_dir, "model.pth")))

# test(model, device, test_loader, "Train")
# test(model, device, test_loader, "Test")
test(model, device, train_loader, "Train")
test(model, device, val_loader, "Val")
test(model, device, test_loader, "Test")

# Don't sample when doing classification
if not args.classification:
Expand Down

0 comments on commit 0213fe0

Please sign in to comment.