From 375f6d6936db353ec07b3eab4c1f7ecea20bb644 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 26 Nov 2023 01:22:15 +0000 Subject: [PATCH] feat: Updated src/main.py --- src/main.py | 50 +++++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/src/main.py b/src/main.py index 243a31e..c3b5743 100644 --- a/src/main.py +++ b/src/main.py @@ -1,48 +1,44 @@ -from PIL import Image +import numpy as np import torch import torch.nn as nn import torch.optim as optim -from torchvision import datasets, transforms +from cnn import CNN, train +from PIL import Image from torch.utils.data import DataLoader -import numpy as np +from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess transform = transforms.Compose([ - transforms.ToTensor(), +transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) trainset = datasets.MNIST('.', download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) -# Step 2: Define the PyTorch Model -class Net(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(28 * 28, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, 10) + + + + + + + - def forward(self, x): - x = x.view(-1, 28 * 28) - x = nn.functional.relu(self.fc1(x)) - x = nn.functional.relu(self.fc2(x)) - x = self.fc3(x) - return nn.functional.log_softmax(x, dim=1) - -# Step 3: Train the Model + + + + + + + + model = Net() +model = CNN() optimizer = optim.SGD(model.parameters(), lr=0.01) criterion = nn.NLLLoss() # Training loop -epochs = 3 -for epoch in range(epochs): - for images, labels in trainloader: - optimizer.zero_grad() - output = model(images) - loss = criterion(output, labels) - loss.backward() - optimizer.step() +train(model, trainloader, optimizer) + torch.save(model.state_dict(), "mnist_model.pth") \ No newline at end of file