-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
65 lines (55 loc) · 2.12 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
from torchvision import transforms
import torch
from PIL import Image
import torch.nn as nn
import math
import io
from models import Generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the image transform
transform = transforms.Compose([
transforms.ToTensor()
])
# Define the Streamlit app
def app():
st.title("Super Resolution App")
st.write("Upload an image and generate a super resolution image.")
# Upload an image
file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"])
model_choice = st.selectbox("Select a model", ["SRGAN 1000", "SRGAN 50"])
if model_choice == "SRGAN 50":
model_path = "netG3_epoch50.pt"
elif model_choice == "SRGAN 1000":
model_path = "netG2_epoch1000.pt"
netG = Generator(4)
netG.load_state_dict(torch.load('models/' + model_path, map_location=device))
netG.to(device)
# When an image is uploaded
if file is not None:
# Read the uploaded image
img = Image.open(file).convert('RGB')
img = img.resize((img.size[0]//2, img.size[1]//2))
# Display the original image
st.image(img, caption="Original Image", use_column_width=True)
# Preprocess the image
img = transform(img).unsqueeze(0)
img = img.to(device)
# Generate the super resolution image
with torch.no_grad():
sr_img = netG(img)
# Convert the output tensor to a PIL Image and save the result
sr_img = transforms.ToPILImage()(sr_img[0].cpu())
sr_img_bytes = io.BytesIO()
sr_img.save(sr_img_bytes, format='JPEG')
# Display the super resolution image
st.image(sr_img, caption="Super Resolution Image", use_column_width=True)
# Add a download button to download the super resolution image
st.download_button(
label="Download Super Resolution Image",
data=sr_img_bytes.getvalue(),
file_name="super_resolution_image.jpg",
mime="image/jpeg"
)
if __name__ == '__main__':
app()