diff --git a/image_classification_streamlit.py b/image_classification_streamlit.py index d7d6d23..07cce2f 100644 --- a/image_classification_streamlit.py +++ b/image_classification_streamlit.py @@ -2,10 +2,26 @@ import requests import streamlit as st -from PIL import Image +import translators as ts +from PIL import Image, UnidentifiedImageError +from requests.exceptions import MissingSchema from transformers import ViTForImageClassification, ViTImageProcessor +class MissingSourceError(Exception): + """Класс представляет ошибку, + возникающую при отсутствии + источника изображения.""" + pass + + +class TwoSourcesError(Exception): + """Класс представляет ошибку, + возникающую при указании + двух источников изображений.""" + pass + + @st.cache_resource def load_model(): """Загрузка модели""" @@ -25,11 +41,21 @@ def get_image_link(): return st.text_input("Введите ссылку на изображение для распознавания") -def load_image(url): +def get_image_file(): + """Загрузка файла с изображением.""" + return st.file_uploader("Или загрузите изображение из файла") + + +def load_image_from_url(url): """Загрузка изображения из указанного URL-адреса с помощью библиотеки requests.""" img = Image.open(requests.get(url, stream=True).raw) - st.image(img) + return img + + +def load_image_from_file(file): + """Загрузка изображения из файла.""" + img = Image.open(file) return img @@ -57,20 +83,58 @@ def show_results(results): st.title("Модель для классификации изображений vit-base-patch16-224") -link = get_image_link() +image_link = get_image_link() +image_file = get_image_file() result = st.button("Распознать изображение") if result: try: - loaded_image = load_image(link) + loaded_image = "" + if image_link != "" and image_file is not None: + raise TwoSourcesError + elif image_link != "": + loaded_image = load_image_from_url(image_link) + elif image_file is not None: + loaded_image = load_image_from_file(image_file) + else: + raise MissingSourceError + st.image(loaded_image) with st.spinner("Идет обработка... Пожалуйста, подождите..."): - result = image_classification(loaded_image) - st.markdown(f"Результаты распознавания: :rainbow[{result}]") - st.snow() - except IOError: + result = image_classification(loaded_image).split(',')[0] + st.markdown( + f"Результаты распознавания: " + f":rainbow[{ts.translate_text(result, + translator="bing", + from_language="en", + to_language="ru")}]" + ) + except MissingSourceError: + st.error( + "Вы не предоставили источник " + "для загрузки изображения. " + "Загрузите файл с изображением или укажите ссылку " + "и попробуйте снова!", + icon="😞", + ) + except MissingSchema: + st.error( + "Некорректная ссылка! " + "Укажите корректную ссылку " + "и попробуйте снова!", + icon="😞", + ) + except UnidentifiedImageError: + st.error( + "Ваша ссылка или файл не содержат изображения. " + "Предоставьте корректную ссылку или файл " + "и попробуйте снова!", + icon="😞", + ) + except TwoSourcesError: st.error( - "Не удалось найти изображение по указанной ссылке. " - "Попробуйте снова!", + "Вы указали два источника. " + "Удалите один из источников " + "и попробуйте снова!", icon="😞", ) diff --git a/requirements.txt b/requirements.txt index 527bb23..a9ce732 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ transformers[torch] requests Pillow pydantic -streamlit \ No newline at end of file +streamlit +translators \ No newline at end of file