Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Oct 24, 2023
1 parent fb2817e commit e5641e5
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class MNISTTrainer:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
self.optimizer = None
self.criterion = nn.NLLLoss()
self.epochs = 3

def load_data(self):
"""Load and preprocess MNIST data."""
trainset = datasets.MNIST('.', download=True, train=True, transform=self.transform)
trainset = datasets.MNIST(
".", download=True, train=True, transform=self.transform
)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
return trainloader

def define_model(self):
"""Define the PyTorch Model."""

class Net(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -54,9 +57,10 @@ def save_model(self, model):
"""Save the trained model."""
torch.save(model.state_dict(), "mnist_model.pth")


# Create an instance of MNISTTrainer and call the methods in the correct order
trainer = MNISTTrainer()
trainloader = trainer.load_data()
model = trainer.define_model()
trainer.train_model(model, trainloader)
trainer.save_model(model)
trainer.save_model(model)

0 comments on commit e5641e5

Please sign in to comment.