Skip to content

Commit

Permalink
Merge pull request #3 from Kawaeee/streamlit-optimization
Browse files Browse the repository at this point in the history
Streamlit optimization
  • Loading branch information
Kawaeee authored May 1, 2021
2 parents a51d573 + e9af098 commit ba9d4fb
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions streamlit_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import streamlit as st
from streamlit.logger import get_logger

import time
import os
Expand All @@ -13,6 +14,10 @@
from torch.nn import functional as F
from torchvision import models, transforms

st_logger = get_logger(__name__)

st.set_option("deprecation.showfileUploaderEncoding", False)

st.set_page_config(
layout="centered",
page_title="Corgi butt or loaf of bread?",
Expand Down Expand Up @@ -89,7 +94,7 @@
)


@st.cache()
@st.cache(allow_output_mutation=True, max_entries=5, ttl=3600)
def initialize_model(device=processing_device):
"""Retrieves the butt_bread trained model and maps it to the CPU by default, can also specify GPU here."""
model = models.resnet152(pretrained=False).to(device)
Expand All @@ -104,8 +109,6 @@ def initialize_model(device=processing_device):

return model


@st.cache()
def predict(img, model):
"""Make a prediction on a single image"""
input_img = img_transformer(img).float()
Expand Down Expand Up @@ -133,15 +136,13 @@ def predict(img, model):

return json_output


@st.cache(suppress_st_warning=True)
def download_model():
"""Download model weight, if model does not exist in Streamlit server."""
if os.path.isfile("buttbread_resnet152_3.h5") == False:
print("Downloading butt_bread model !!")
req = requests.get(model_url_path, allow_redirects=True)
open("buttbread_resnet152_3.h5", "wb").write(req.content)
st.balloons()
return True

return True

Expand All @@ -153,21 +154,23 @@ def download_model():

download_model()
model = initialize_model()

st_logger.info("[INFO] Initialize %s model successfully", "buttbread_resnet152_3.h5", exc_info=0)

st.title("Corgi butt or loaf of bread? 🐕🍞")
st.markdown(version + " " + repo + " " + visitor + " " + follower, unsafe_allow_html=True)

upload_checkbox = st.checkbox("Upload")
processing_mode = st.radio("", ("Upload an image", "Select pre-configured image"))

if upload_checkbox:
processing_mode = "Upload"
img_file = st.file_uploader("Upload An Image", accept_multiple_files=False)
else:
processing_mode = "Select"
if processing_mode == "Upload an image":
img_file = st.file_uploader("Upload an image", accept_multiple_files=False)
elif processing_mode == "Select pre-configured image":
img_labels = st.selectbox("Pick a labels:", labels)

if img_labels == labels[0]:
corgi_list = st.selectbox("Pick your favorite corgi butt image 🐕:", corgi_images_name)
img_file = corgi_images_dict[corgi_list]

elif img_labels == labels[1]:
bread_list = st.selectbox("Pick your favorite loaf of bread image 🍞:", bread_images_name)
img_file = bread_images_dict[bread_list]
Expand All @@ -180,15 +183,18 @@ def download_model():
tmp_format = img.format
img = img.convert("RGB")
img.format = tmp_format
if processing_mode == "Upload":
if processing_mode == "Upload an image":
img.filename = img_file.name
else:
elif processing_mode == "Select pre-configured image":
img.filename = os.path.basename(img_file)

prediction = predict(img, model)

st_logger.info("[INFO] Predict %s image successfully", img.filename, exc_info=0)

except Exception as e:
st.error("ERROR: Unable to predict {} ({}) !!!".format(img_file.name, img_file.type))
st_logger.error("[ERROR] Unable to predict %s (%s) !!!", img_file.name, img_file.type, exc_info=0)
img_file = None
img = None
prediction = None
Expand Down

0 comments on commit ba9d4fb

Please sign in to comment.