Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 25, 2023
1 parent a09133f commit 46ccf85
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,37 @@
import torch
import torch.nn as nn
import torch.optim as optim
from cnn import CNN
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from cnn import CNN

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


# Step 3: Train the Model
model = CNN()
optimizer = optim.SGD(model.parameters(), lr=0.01)
Expand All @@ -46,4 +48,4 @@ def forward(self, x):
loss.backward()
optimizer.step()

torch.save(model.state_dict(), "mnist_model.pth")
torch.save(model.state_dict(), "mnist_model.pth")

0 comments on commit 46ccf85

Please sign in to comment.