diff --git a/exceptions.py b/exceptions.py new file mode 100644 index 0000000..83cc6b2 --- /dev/null +++ b/exceptions.py @@ -0,0 +1,11 @@ +"Определение пользовательских исключений." + + +class MissingSourceError(Exception): + """Класс представляет ошибку, возникающую при отсутствии источника изображения.""" + pass + + +class TwoSourcesError(Exception): + """Класс представляет ошибку, возникающую при указании двух источников изображений.""" + pass diff --git a/image_classification_streamlit.py b/main.py similarity index 51% rename from image_classification_streamlit.py rename to main.py index abcf3b5..f160f52 100644 --- a/image_classification_streamlit.py +++ b/main.py @@ -1,39 +1,11 @@ -"""Приложение Streamlit для классификации изображений.""" +"Файл для запуска приложения streamlit." -import requests import streamlit as st -import translators as ts -from PIL import Image, UnidentifiedImageError +from PIL import UnidentifiedImageError from requests.exceptions import MissingSchema -from transformers import ViTForImageClassification, ViTImageProcessor - - -class MissingSourceError(Exception): - """Класс представляет ошибку, - возникающую при отсутствии - источника изображения.""" - pass - - -class TwoSourcesError(Exception): - """Класс представляет ошибку, - возникающую при указании - двух источников изображений.""" - pass - - -@st.cache_resource -def load_model(): - """Загрузка модели""" - return (ViTForImageClassification - .from_pretrained("google/vit-base-patch16-224")) - - -@st.cache_resource -def load_processor(): - """Загрузка процессора для обработки изображений.""" - return (ViTImageProcessor - .from_pretrained("google/vit-base-patch16-224")) +from utils import load_image_from_url, load_image_from_file, translate_text +from model import load_model, load_processor, image_classification +from exceptions import MissingSourceError, TwoSourcesError def get_image_link(): @@ -46,33 +18,6 @@ def get_image_file(): return st.file_uploader("Или загрузите изображение из файла") -def load_image_from_url(url): - """Загрузка изображения из указанного URL-адреса - с помощью библиотеки requests.""" - img = Image.open(requests.get(url, stream=True).raw) - return img - - -def load_image_from_file(file): - """Загрузка изображения из файла.""" - img = Image.open(file) - return img - - -def image_classification(picture): - """Обработка и распознавание изображения. - - Принимает изображение, преобразует его в требуемый формат - с помощью процессора, пропускает его через модель, - получает вероятности классов и возвращает предсказанный класс. - """ - inputs = processor(images=picture, return_tensors="pt") - outputs = model(**inputs) - logits = outputs.logits - predicted_class_idx = logits.argmax(-1).item() - return model.config.id2label[predicted_class_idx] - - def show_results(results): """Вывод результатов""" st.write(results) @@ -101,11 +46,8 @@ def show_results(results): raise MissingSourceError st.image(loaded_image) with st.spinner("Идет обработка... Пожалуйста, подождите..."): - result = image_classification(loaded_image) - translated_result = ts.translate_text(result, - translator="bing", - from_language="en", - to_language="ru") + result = image_classification(loaded_image, processor, model) + translated_result = translate_text(result, "en", "ru") st.markdown(f"Результаты распознавания: {translated_result}") except MissingSourceError: st.error( diff --git a/model.py b/model.py new file mode 100644 index 0000000..12825b4 --- /dev/null +++ b/model.py @@ -0,0 +1,22 @@ +"Загрузка и работа с моделью и процессором." + +from transformers import ViTForImageClassification, ViTImageProcessor + + +def load_model(): + """Загрузка модели.""" + return ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + +def load_processor(): + """Загрузка процессора для обработки изображений.""" + return ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") + + +def image_classification(picture, processor, model): + """Обработка и распознавание изображения.""" + inputs = processor(images=picture, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + predicted_class_idx = logits.argmax(-1).item() + return model.config.id2label[predicted_class_idx] diff --git a/test_app.py b/test_app.py index 23671a1..c8ba9c8 100644 --- a/test_app.py +++ b/test_app.py @@ -1,20 +1,26 @@ """Тесты для проверки приложения Streamlit.""" import time - -from streamlit.testing.v1 import AppTest -from image_classification_streamlit import image_classification -from PIL import UnidentifiedImageError -from PIL import Image +import pytest import io +from PIL import Image, UnidentifiedImageError +from streamlit.testing.v1 import AppTest +from model import image_classification, load_processor, load_model + +# Создаем объект AppTest для тестирования приложения Streamlit +at = AppTest.from_file("main.py", default_timeout=1000).run() + -at = AppTest.from_file("image_classification_streamlit.py", - default_timeout=1000).run() +@pytest.fixture(scope="module") +def processor_and_model(): + """Фикстура для загрузки процессора и модели.""" + processor = load_processor() + model = load_model() + return processor, model def test_no_image_url(): - """Проверка ввода URL-адреса на объект, - который не является изображением""" + """Проверка ввода URL-адреса на объект, который не является изображением""" at.text_input[0].set_value("https://www.google.com/").run() at.button[0].click().run() assert at.error[0].value == ( @@ -48,7 +54,7 @@ def test_correct_url(): .run() ) at.button[0].click().run() - time.sleep(5) + time.sleep(5) # Добавляем ожидание 5 секунд assert at.markdown[0].value == ( "Результаты распознавания: табби, полосатый кот" ) @@ -65,13 +71,14 @@ def test_incorrect_url(): ) -def test_correct_image_file(): +def test_correct_image_file(processor_and_model): """Проверка загрузки изображения через файл.""" + processor, model = processor_and_model with open("test_image.jpg", "rb") as file: test_image_bytes = file.read() test_image = Image.open(io.BytesIO(test_image_bytes)) try: - result = image_classification(test_image) + result = image_classification(test_image, processor, model) assert result == "Egyptian cat" except UnidentifiedImageError: assert False, "Ошибка при обработке изображения" diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..16a2920 --- /dev/null +++ b/utils.py @@ -0,0 +1,22 @@ +"Вспомогательные функции для обработки изображений и работы с моделью." + +import requests +from PIL import Image +import translators as ts + + +def load_image_from_url(url): + """Загрузка изображения из указанного URL-адреса.""" + img = Image.open(requests.get(url, stream=True).raw) + return img + + +def load_image_from_file(file): + """Загрузка изображения из файла.""" + img = Image.open(file) + return img + + +def translate_text(text, from_language, to_language): + """Перевод текста с одного языка на другой.""" + return ts.translate_text(text, translator="bing", from_language=from_language, to_language=to_language)