-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit_ui.py
150 lines (117 loc) · 4.76 KB
/
streamlit_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import base64
import logging
import os
import re
from io import BytesIO
from urllib.parse import urlparse
import sqlite3
import numpy as np
import requests
import streamlit as st
from dotenv import load_dotenv
from PIL import Image
from ml_scripts.MongoDBDataset import get_vocab, save_vocab
import polars as pl
logging.getLogger().setLevel(logging.DEBUG)
# Define the address of your TensorFlow Serving container
TF_SERVING_URL = "http://localhost:8510/v1/models/flower:predict"
load_dotenv(override=True)
def submit():
"""Submits the url."""
st.session_state.image_url = st.session_state.url_widget
st.session_state.url_widget = ""
@st.cache_data
def process_image(image_url: str) -> np.ndarray:
"""Converts an image to a numpy array.
Args:
image_url: The URL of the image to convert.
Returns:
np.ndarray: A numpy array of the image.
"""
logging.info("Loading image and creating array.")
parsed_url = urlparse(image_url)
# Check if the URL is a data URL
if parsed_url.scheme == "data":
# Extract base64-encoded image data from the URL
image_data = re.sub("^data:image/.+;base64,", "", image_url)
image_binary = base64.b64decode(image_data)
image = Image.open(BytesIO(image_binary))
else:
# Open the image from a regular URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Resize or preprocess the image as needed
image_array = np.array(image)
return image_array
# Function to make a prediction using TensorFlow Serving
@st.cache_data
def make_prediction(image_array: np.ndarray, category_flower_mapping: dict) -> str:
"""Makes a prediction using TensorFlow Serving.
Args:
image_array: A numpy array of the image to classify.
category_flower_mapping: A dictionary mapping category to flower type.
Returns:
str: The predicted class of the image.
"""
logging.info("making prediction")
payload = {"instances": [image_array.tolist()]}
response = requests.post(TF_SERVING_URL, json=payload)
predictions = response.json()["predictions"]
predicted_class = category_flower_mapping[np.argmax(predictions[0])]
return predicted_class
def check_if_correct(conn: sqlite3.Connection, predicted_class: str):
"""Checks if the prediction is correct, and writes to the database.
Args:
conn: A connection to the SQLite database.
image_url: The URL of the image that was classified.
predicted_class: The predicted class of the image.
"""
# display result
with st.form("Results"):
st.image(st.session_state.image_url, caption=f"Predicted class: {predicted_class}")
correct_prediction = st.checkbox(
"Is this correct? Then please check this box.", value=False
)
submitted = st.form_submit_button("Submit")
if submitted:
st.write("Thank you for your feedback! Here's a flower for you 🌸")
# write feedback to database if not previously contained for this url
insert_statement = """
INSERT OR REPLACE INTO feedback (image_url, predicted_class, correct_prediction)
VALUES (?, ?, ?)
"""
c = conn.cursor()
c.execute(
insert_statement,
(st.session_state.image_url, predicted_class, correct_prediction),
)
conn.commit()
def setup():
"""Creates the database table if it doesn't already exist."""
if not os.path.exists("flower.db"):
save_vocab()
st.session_state.category_flower_mapping = get_vocab("flower.db", reverse=True)
st.session_state.conn = sqlite3.connect("flower.db", check_same_thread=False)
# load schema from feedback_schema.sql file
with open("feedback_schema.sql") as feedback_schema_file:
create_table_statement = feedback_schema_file.read()
c = st.session_state.conn.cursor()
c.execute(create_table_statement)
st.session_state.conn.commit()
if "image_url" not in st.session_state:
st.session_state.image_url = ""
# set flag that setup is not needed anymore
st.session_state.setup = True
def main():
"""Builds the UI for the app."""
st.title("Flower Classifier")
st.write("This app uses a TensorFlow model to classify images of flowers.")
if "setup" not in st.session_state:
setup()
st.text_input("Enter the URL of an image:", key="url_widget", value="", on_change=submit)
if st.session_state.image_url:
image_array = process_image(st.session_state.image_url)
predicted_class = make_prediction(image_array, st.session_state.category_flower_mapping)
check_if_correct(st.session_state.conn, predicted_class)
if __name__ == "__main__":
main()