diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..287a2f0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/image_classification.py b/image_classification.py index 833e7bc..5d09e6b 100644 --- a/image_classification.py +++ b/image_classification.py @@ -1,16 +1,23 @@ -from transformers import ViTImageProcessor, ViTForImageClassification -from PIL import Image +"""Использование готовой модели для классификации изображений.""" + import requests +from PIL import Image +from transformers import ViTForImageClassification, ViTImageProcessor -url = 'https://img.freepik.com/premium-photo/a-house-on-a-mountain-with-a-mountain-in-the-background_759095-3394.jpg' +url = ( + "https://img.freepik.com/premium-photo/" + "a-house-on-a-mountain-with-a-mountain-" + "in-the-background_759095-3394.jpg" +) image = Image.open(requests.get(url, stream=True).raw) -processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') -model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +processor = (ViTImageProcessor + .from_pretrained("google/vit-base-patch16-224")) +model = (ViTForImageClassification + .from_pretrained("google/vit-base-patch16-224")) inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits -# model predicts one of the 1000 ImageNet classes predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) diff --git a/image_classification_fastapi.py b/image_classification_fastapi.py index b95e49a..ac5e51b 100644 --- a/image_classification_fastapi.py +++ b/image_classification_fastapi.py @@ -1,29 +1,42 @@ -from PIL import Image +"""Модель классификации изображений с FastAPI.""" + import requests from fastapi import FastAPI +from PIL import Image from pydantic import BaseModel -from transformers import ViTImageProcessor, ViTForImageClassification +from transformers import ViTForImageClassification, ViTImageProcessor class ImageRequest(BaseModel): + """Класс запроса для обработки изображения.""" + url: str app = FastAPI() + # Процессор для представления изображений в требуемом формате -processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") + # Модель для классификации изображений -model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +model = (ViTForImageClassification + .from_pretrained("google/vit-base-patch16-224")) -# Получение изображения def load_image(url): + """Загрузка изображения из указанного URL-адреса + с помощью библиотеки requests.""" img = Image.open(requests.get(url, stream=True).raw) return img -# Обработка и распознавание изображения def image_classification(picture): + """Обработка и распознавание изображения. + + Принимает изображение, преобразует его в требуемый формат + с помощью процессора, пропускает его через модель, + получает вероятности классов и возвращает предсказанный класс. + """ inputs = processor(images=picture, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits @@ -31,16 +44,22 @@ def image_classification(picture): return model.config.id2label[predicted_class_idx] -# Маршрут для корневого URL-адреса @app.get("/") def root(): + """Маршрут для корневого URL-адреса. + + Возвращает сообщение, указывающее, что это API классификации изображений. + """ return {"message": "Image classification API"} @app.post("/classify-image") def classify_image(request: ImageRequest): - """ - Classify an image using a pre-trained ViT model. + """Классифицирует изображение с помощью готовой модели ViT. + + Принимает запрос с URL-адресом изображения, загружает изображение, + классифицирует его с помощью готовой модели ViT и возвращает результат. + Если изображение не может быть загружено, возвращается сообщение об ошибке. """ try: loaded_image = load_image(request.url) diff --git a/image_classification_streamlit.py b/image_classification_streamlit.py index 1aef750..d7d6d23 100644 --- a/image_classification_streamlit.py +++ b/image_classification_streamlit.py @@ -1,35 +1,45 @@ -from transformers import ViTImageProcessor, ViTForImageClassification +"""Приложение Streamlit для классификации изображений.""" + +import requests import streamlit as st from PIL import Image -import requests +from transformers import ViTForImageClassification, ViTImageProcessor -@st.cache -# Загрузка модели для классификации изображений +@st.cache_resource def load_model(): - return ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + """Загрузка модели""" + return (ViTForImageClassification + .from_pretrained("google/vit-base-patch16-224")) -@st.cache -# Загрузка процессора для представления изображений в требуемом формате +@st.cache_resource def load_processor(): - return ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') + """Загрузка процессора для обработки изображений.""" + return (ViTImageProcessor + .from_pretrained("google/vit-base-patch16-224")) -# Отображение текстового поля для ввода ссылки на изображение def get_image_link(): + """Ввод URL-адреса с изображением.""" return st.text_input("Введите ссылку на изображение для распознавания") -# Получение изображения и вывод его на экран def load_image(url): + """Загрузка изображения из указанного URL-адреса + с помощью библиотеки requests.""" img = Image.open(requests.get(url, stream=True).raw) st.image(img) return img -# Обработка и распознавание изображения def image_classification(picture): + """Обработка и распознавание изображения. + + Принимает изображение, преобразует его в требуемый формат + с помощью процессора, пропускает его через модель, + получает вероятности классов и возвращает предсказанный класс. + """ inputs = processor(images=picture, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits @@ -37,28 +47,30 @@ def image_classification(picture): return model.config.id2label[predicted_class_idx] -# Вывод результатов на экран def show_results(results): + """Вывод результатов""" st.write(results) processor = load_processor() model = load_model() -st.title('Модель для классификации изображений vit-base-patch16-224') +st.title("Модель для классификации изображений vit-base-patch16-224") link = get_image_link() -result = st.button('Распознать изображение') +result = st.button("Распознать изображение") if result: try: loaded_image = load_image(link) - with st.spinner('Идет обработка... Пожалуйста, подождите...'): + with st.spinner("Идет обработка... Пожалуйста, подождите..."): result = image_classification(loaded_image) - st.markdown(f'Результаты распознавания: :rainbow[{result}]') + st.markdown(f"Результаты распознавания: :rainbow[{result}]") st.snow() - # Обработка исключений, которые приведут к ошибке в случае отсутствия ссылки - # или указания ссылки на объект, который не является изображением except IOError: - st.error(' Не удалось найти изображение по указанной ссылке. Попробуйте снова!', icon="😞") + st.error( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!", + icon="😞", + ) diff --git a/test_app.py b/test_app.py index d12faed..024fa8e 100644 --- a/test_app.py +++ b/test_app.py @@ -1,23 +1,48 @@ +"""Тесты для проверки приложения Streamlit.""" + +import time + from streamlit.testing.v1 import AppTest -at = AppTest.from_file("image_classification_streamlit.py", default_timeout=1000).run() +at = AppTest.from_file("image_classification_streamlit.py", + default_timeout=1000).run() def test_incorrect_url(): - """ - Пользователь вводит некорректную ссылку на изображение - (ссылку на объект не являющийся изображением) - """ + """Проверка ввода URL-адреса на объект, + который не является изображением.""" at.text_input[0].set_value("https://www.google.com/").run() at.button[0].click().run() - assert at.error[0].value == "Не удалось найти изображение по указанной ссылке. Попробуйте снова!" + assert at.error[0].value == ( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!" + ) def test_null_url(): - """ - Пользователь не вводит ссылку на изображение - (оставляет поле для ввода ссылки пустым) - """ + """Проверка ввода пустого URL-адреса.""" at.text_input[0].set_value("").run() at.button[0].click().run() - assert at.error[0].value == "Не удалось найти изображение по указанной ссылке. Попробуйте снова!" + assert at.error[0].value == ( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!" + ) + + +def test_correct_url(): + """Проверка ввода корректного URL-адреса на изображение.""" + ( + at.text_input[0] + .set_value( + "https://www.rgo.ru/sites/default/files/" + "styles/head_image_article/public/node/" + "61549/photo-2023-11-08-150058.jpeg" + ) + .run() + ) + at.button[0].click().run() + time.sleep(5) + assert at.markdown[0].value == ( + "Результаты распознавания: " + ":rainbow[tabby, tabby cat]" + )