-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sweep: refactor the training loop from a script to a class in main.py #6
Comments
Here's the PR! #163.💎 Sweep Pro: I'm using GPT-4. You have unlimited GPT-4 tickets. (tracking ID:
c863ad18ba )Actions (click)
Sandbox Execution ✓Here are the sandbox execution logs prior to making any changes: Sandbox logs for
|
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
import numpy as np | |
# Step 1: Load MNIST Data and Preprocess | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
trainset = datasets.MNIST('.', download=True, train=True, transform=transform) | |
trainloader = DataLoader(trainset, batch_size=64, shuffle=True) | |
# Step 2: Define the PyTorch Model | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.fc1 = nn.Linear(28 * 28, 128) | |
self.fc2 = nn.Linear(128, 64) | |
self.fc3 = nn.Linear(64, 10) | |
def forward(self, x): | |
x = x.view(-1, 28 * 28) | |
x = nn.functional.relu(self.fc1(x)) | |
x = nn.functional.relu(self.fc2(x)) | |
x = self.fc3(x) | |
return nn.functional.log_softmax(x, dim=1) | |
# Step 3: Train the Model | |
model = Net() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
criterion = nn.NLLLoss() | |
# Training loop | |
epochs = 3 | |
for epoch in range(epochs): | |
for images, labels in trainloader: | |
optimizer.zero_grad() | |
output = model(images) | |
loss = criterion(output, labels) | |
loss.backward() | |
optimizer.step() |
Lines 1 to 27 in 7284908
from fastapi import FastAPI, UploadFile, File | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from main import Net # Importing Net class from main.py | |
# Load the model | |
model = Net() | |
model.load_state_dict(torch.load("mnist_model.pth")) | |
model.eval() | |
# Transform used for preprocessing the image | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
app = FastAPI() | |
@app.post("/predict/") | |
async def predict(file: UploadFile = File(...)): | |
image = Image.open(file.file).convert("L") | |
image = transform(image) | |
image = image.unsqueeze(0) # Add batch dimension | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted = torch.max(output.data, 1) |
Lines 1 to 15 in 7284908
annotated-types==0.6.0 | |
anyio==3.7.1 | |
astroid==3.0.1 | |
certifi==2022.12.7 | |
charset-normalizer==2.1.1 | |
click==8.1.7 | |
dill==0.3.7 | |
exceptiongroup==1.1.3 | |
fastapi==0.104.0 | |
filelock==3.9.0 | |
fsspec==2023.4.0 | |
h11==0.14.0 | |
idna==3.4 | |
iniconfig==2.0.0 | |
isort==5.12.0 |
Step 2: ⌨️ Coding
- Modify
src/main.py
✓ 35f37cd
Modify src/main.py with contents:
• Create a new class named `Trainer` at the top of the file.
• Move the model, optimizer, and loss function initialization from the global scope into the `__init__` method of the `Trainer` class. The `__init__` method should take the learning rate and the path to save the model as parameters, and use them to initialize the optimizer and save the model respectively.
• Move the training loop from the global scope into a new method of the `Trainer` class named `train`. This method should take the number of epochs as a parameter.
• Move the model saving code from the global scope into a new method of the `Trainer` class named `save_model`.
• Remove the global code for model initialization, training, and saving, as it is now encapsulated in the `Trainer` class.
• At the end of the file, create an instance of the `Trainer` class, and call its `train` and `save_model` methods to train and save the model.
- Ran sandbox for
src/main.py
. ✓ 52cb808
Sandbox logs for https://github.com/sweepai/evals/commit/52cb808b1b8bc91787d586dd565b782319ec01c4
trunk fmt src/main.py || exit 0
1/2 ✓✔ Formatted src/main.py Re-checking autofixed files... Checked 1 file ✔ No issues Run trunk upgrade to upgrade 1 linter
trunk check --fix --filter=-ruff --print-failures src/main.py
2/2 ✓Checked 1 file ✔ No issues
- Modify
src/api.py
✓ c1333ec
Modify src/api.py with contents:
• Import the `Trainer` class from main.py at the top of the file.
• Replace the global model loading code with the creation of a `Trainer` instance and a call to its `load_model` method.
• Replace the global model usage in the `predict` function with a call to a method of the `Trainer` instance that returns the model.
- Ran sandbox for
src/api.py
. ✓ b31fc72
Sandbox logs for https://github.com/sweepai/evals/commit/b31fc724747dc0cd6fdae265be435aa36b06ab4a
trunk fmt src/api.py || exit 0
1/2 ✓✔ Formatted src/api.py Re-checking autofixed files... ✔ Formatted src/api.py Re-checking autofixed files... Checked 1 file ✔ No issues Run trunk upgrade to upgrade 1 linter
trunk check --fix --filter=-ruff --print-failures src/api.py
2/2 ✓Checked 1 file ✔ No issues
Step 3: 🔁 Code Review
I have finished reviewing the code for completeness. I did not find errors for sweep/refactor-training-loop_27
.
🎉 Latest improvements to Sweep:
- Sweep now uses the
rope
library to refactor Python! Check out Large Language Models are Bad at Refactoring Code. To have Sweep refactor your code, trysweep: Refactor <your_file>.py
! - Sweep finds and fixes bugs in your code by writing unit tests! Check out Having GPT-4 Iterate on Unit Tests like a Human.
💡 To recreate the pull request edit the issue title or description. To tweak the pull request, leave a comment on the pull request.
Join Our Discord
Details
No response
Checklist
src/main.py
✓ 35f37cdsrc/main.py
. ✓ 52cb808src/api.py
✓ c1333ecsrc/api.py
. ✓ b31fc72The text was updated successfully, but these errors were encountered: