diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..287a2f0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + diff --git a/image_classification.py b/image_classification.py index b70d597..5d09e6b 100644 --- a/image_classification.py +++ b/image_classification.py @@ -1,21 +1,23 @@ -"""Этот модуль подключения готовой модели классификации изображений.""" +"""Использование готовой модели для классификации изображений.""" -from transformers import ViTImageProcessor -from transformers import ViTForImageClassification -from PIL import Image import requests +from PIL import Image +from transformers import ViTForImageClassification, ViTImageProcessor -url = 'https://img.freepik.com/premium-photo/a-house-on-a-mountain-with-a-mountain-in-the-background_759095-3394.jpg' +url = ( + "https://img.freepik.com/premium-photo/" + "a-house-on-a-mountain-with-a-mountain-" + "in-the-background_759095-3394.jpg" +) image = Image.open(requests.get(url, stream=True).raw) processor = (ViTImageProcessor - .from_pretrained('google/vit-base-patch16-224')) + .from_pretrained("google/vit-base-patch16-224")) model = (ViTForImageClassification - .from_pretrained('google/vit-base-patch16-224')) + .from_pretrained("google/vit-base-patch16-224")) inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits -# model predicts one of the 1000 ImageNet classes predicted_class_idx = logits.argmax(-1).item() print("Predicted class:", model.config.id2label[predicted_class_idx]) diff --git a/image_classification_fastapi.py b/image_classification_fastapi.py index ea02539..ac5e51b 100644 --- a/image_classification_fastapi.py +++ b/image_classification_fastapi.py @@ -1,15 +1,14 @@ -"""Готовая модель классификации изображений с FastAPI.""" +"""Модель классификации изображений с FastAPI.""" -from PIL import Image import requests from fastapi import FastAPI +from PIL import Image from pydantic import BaseModel -from transformers import ViTImageProcessor -from transformers import ViTForImageClassification +from transformers import ViTForImageClassification, ViTImageProcessor class ImageRequest(BaseModel): - """Класс запроса картинки.""" + """Класс запроса для обработки изображения.""" url: str @@ -17,23 +16,27 @@ class ImageRequest(BaseModel): app = FastAPI() # Процессор для представления изображений в требуемом формате - -processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') +processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") # Модель для классификации изображений - model = (ViTForImageClassification - .from_pretrained('google/vit-base-patch16-224')) + .from_pretrained("google/vit-base-patch16-224")) def load_image(url): - """Получение изображения.""" + """Загрузка изображения из указанного URL-адреса + с помощью библиотеки requests.""" img = Image.open(requests.get(url, stream=True).raw) return img def image_classification(picture): - """Обработка и распознавание изображения.""" + """Обработка и распознавание изображения. + + Принимает изображение, преобразует его в требуемый формат + с помощью процессора, пропускает его через модель, + получает вероятности классов и возвращает предсказанный класс. + """ inputs = processor(images=picture, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits @@ -43,13 +46,21 @@ def image_classification(picture): @app.get("/") def root(): - """Маршрут для корневого URL-адреса.""" + """Маршрут для корневого URL-адреса. + + Возвращает сообщение, указывающее, что это API классификации изображений. + """ return {"message": "Image classification API"} @app.post("/classify-image") def classify_image(request: ImageRequest): - """Classify an image using a pre-trained ViT model.""" + """Классифицирует изображение с помощью готовой модели ViT. + + Принимает запрос с URL-адресом изображения, загружает изображение, + классифицирует его с помощью готовой модели ViT и возвращает результат. + Если изображение не может быть загружено, возвращается сообщение об ошибке. + """ try: loaded_image = load_image(request.url) result = image_classification(loaded_image) diff --git a/image_classification_streamlit.py b/image_classification_streamlit.py index 0c51bda..d7d6d23 100644 --- a/image_classification_streamlit.py +++ b/image_classification_streamlit.py @@ -1,37 +1,45 @@ -#Классификатор картинок +"""Приложение Streamlit для классификации изображений.""" -from transformers import ViTImageProcessor, ViTForImageClassification +import requests import streamlit as st from PIL import Image -import requests +from transformers import ViTForImageClassification, ViTImageProcessor @st.cache_resource -# Загрузка модели для классификации изображений def load_model(): - return ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + """Загрузка модели""" + return (ViTForImageClassification + .from_pretrained("google/vit-base-patch16-224")) @st.cache_resource -# Загрузка процессора для представления изображений в требуемом формате def load_processor(): - return ViTImageProcessor.from_pretrained('google/vit-base-patch16-224') + """Загрузка процессора для обработки изображений.""" + return (ViTImageProcessor + .from_pretrained("google/vit-base-patch16-224")) -# Отображение текстового поля для ввода ссылки на изображение def get_image_link(): + """Ввод URL-адреса с изображением.""" return st.text_input("Введите ссылку на изображение для распознавания") -# Получение изображения и вывод его на экран def load_image(url): + """Загрузка изображения из указанного URL-адреса + с помощью библиотеки requests.""" img = Image.open(requests.get(url, stream=True).raw) st.image(img) return img -# Обработка и распознавание изображения def image_classification(picture): + """Обработка и распознавание изображения. + + Принимает изображение, преобразует его в требуемый формат + с помощью процессора, пропускает его через модель, + получает вероятности классов и возвращает предсказанный класс. + """ inputs = processor(images=picture, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits @@ -39,28 +47,30 @@ def image_classification(picture): return model.config.id2label[predicted_class_idx] -# Вывод результатов на экран def show_results(results): + """Вывод результатов""" st.write(results) processor = load_processor() model = load_model() -st.title('Модель для классификации изображений vit-base-patch16-224') +st.title("Модель для классификации изображений vit-base-patch16-224") link = get_image_link() -result = st.button('Распознать изображение') +result = st.button("Распознать изображение") if result: try: loaded_image = load_image(link) - with st.spinner('Идет обработка... Пожалуйста, подождите...'): + with st.spinner("Идет обработка... Пожалуйста, подождите..."): result = image_classification(loaded_image) - st.markdown(f'Результаты распознавания: :rainbow[{result}]') + st.markdown(f"Результаты распознавания: :rainbow[{result}]") st.snow() - # Обработка исключений, которые приведут к ошибке в случае отсутствия ссылки - # или указания ссылки на объект, который не является изображением except IOError: - st.error('Не удалось найти изображение по указанной ссылке. Попробуйте снова!', icon="😞") + st.error( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!", + icon="😞", + ) diff --git a/test_app.py b/test_app.py index df253a5..024fa8e 100644 --- a/test_app.py +++ b/test_app.py @@ -1,22 +1,48 @@ -"""test_app.py""" -from streamlit.testing.v1 import AppTest +"""Тесты для проверки приложения Streamlit.""" + import time -at = AppTest.from_file("image_classification_streamlit.py", default_timeout=1000).run() +from streamlit.testing.v1 import AppTest + +at = AppTest.from_file("image_classification_streamlit.py", + default_timeout=1000).run() + def test_incorrect_url(): + """Проверка ввода URL-адреса на объект, + который не является изображением.""" at.text_input[0].set_value("https://www.google.com/").run() at.button[0].click().run() - assert at.error[0].value == 'Не удалось найти изображение по указанной ссылке. Попробуйте снова!' + assert at.error[0].value == ( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!" + ) + def test_null_url(): + """Проверка ввода пустого URL-адреса.""" at.text_input[0].set_value("").run() at.button[0].click().run() - assert at.error[0].value == 'Не удалось найти изображение по указанной ссылке. Попробуйте снова!' + assert at.error[0].value == ( + "Не удалось найти изображение по указанной ссылке. " + "Попробуйте снова!" + ) + def test_correct_url(): - at.text_input[0].set_value("https://www.rgo.ru/sites/default/files/styles/head_image_article/public/node/61549/photo-2023-11-08-150058.jpeg").run() + """Проверка ввода корректного URL-адреса на изображение.""" + ( + at.text_input[0] + .set_value( + "https://www.rgo.ru/sites/default/files/" + "styles/head_image_article/public/node/" + "61549/photo-2023-11-08-150058.jpeg" + ) + .run() + ) at.button[0].click().run() time.sleep(5) - assert at.markdown[0].value == 'Результаты распознавания: :rainbow[tabby, tabby cat]' - + assert at.markdown[0].value == ( + "Результаты распознавания: " + ":rainbow[tabby, tabby cat]" + )