Skip to content

Commit

Permalink
some linting things
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixBenning committed Mar 12, 2024
1 parent 7ef70df commit 243e740
Showing 1 changed file with 20 additions and 17 deletions.
37 changes: 20 additions & 17 deletions mnistSimpleCNN/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# imports -------------------------------------------------------------------------#
from ctypes import ArgumentError
import os
import numpy as np
import torch
Expand Down Expand Up @@ -78,11 +79,13 @@ def run(OptClass=lr_scheduled_optimizer(optim.Adam), p_seed=0, p_epochs=10, p_ke

# model selection -------------------------------------------------------------#
if KERNEL_SIZE == 3:
model = ModelM3().to(device)
model:torch.nn.Module = ModelM3().to(device)
elif KERNEL_SIZE == 5:
model = ModelM5().to(device)
model:torch.nn.Module = ModelM5().to(device)
elif KERNEL_SIZE == 7:
model = ModelM7().to(device)
model:torch.nn.Module= ModelM7().to(device)
else:
raise ArgumentError(f"Kernel Size {KERNEL_SIZE} not suppoted")

summary(model, (1, 28, 28))

Expand Down Expand Up @@ -114,9 +117,10 @@ def run(OptClass=lr_scheduled_optimizer(optim.Adam), p_seed=0, p_epochs=10, p_ke
)
for data, target in pbar:
data, target = data.to(device), target.to(device, dtype=torch.int64)
# pylint: disable=cell-var-from-loop
def loss_closure():
optimizer.zero_grad()
output = model(data)
output = model(data) # pylint: disable=not-callable
loss = F.nll_loss(output, target)
train_pred = output.argmax(dim=1, keepdim=True)

Expand Down Expand Up @@ -147,7 +151,7 @@ def loss_closure():
with torch.no_grad():
for data, target in test_pg:
data, target = data.to(device), target.to(device, dtype=torch.int64)
output = model(data)
output = model(data) # pylint: disable=not-callable
test_loss += F.nll_loss(output, target, reduction="sum").item()
pred = output.argmax(dim=1, keepdim=True)
total_pred = np.append(total_pred, pred.cpu().numpy())
Expand All @@ -172,19 +176,18 @@ def loss_closure():
"Accuracy": best_test_accuracy
})

f = open(OUTPUT_FILE, "a")
f.write(
" %3d %12.6f %9.3f %12.6f %9.3f %9.3f\n"
% (
epoch,
train_loss,
train_accuracy,
test_loss,
test_accuracy,
best_test_accuracy,
with open(OUTPUT_FILE, "a") as f:
f.write(
" %3d %12.6f %9.3f %12.6f %9.3f %9.3f\n"
% (
epoch,
train_loss,
train_accuracy,
test_loss,
test_accuracy,
best_test_accuracy,
)
)
)
f.close()

# --------------------------------------------------------------------------#
# update learning rate scheduler #
Expand Down

0 comments on commit 243e740

Please sign in to comment.