Skip to content

Commit

Permalink
code improvement #2
Browse files Browse the repository at this point in the history
  • Loading branch information
Bulrush3 committed May 16, 2024
1 parent 629b537 commit 4232530
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@ streamlit run image_classification_streamlit.py
uvicorn image_classification_fastapi:app
```
## P.S. Идеи по улучшению качества кода
1) Контекстный менеджер для работы с изображением with. Использование его гарантирует, что ресурсы будут корректно закрыты после использования. Таким образом, код автоматически закроет файл изображения после завершения работы с ним. (файл: image_classification.py)
1) Контекстный менеджер для работы с изображением with. Использование его гарантирует, что ресурсы будут корректно закрыты после использования. Таким образом, код автоматически закроет файл изображения после завершения работы с ним. (файл: image_classification.py)

2) Использованы строки f-формата для форматирования строк с переменными, что делает код более ясным и легко читаемым. Добавлены комментарии для каждого этапа. (файл: image_classification.py)
10 changes: 9 additions & 1 deletion image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor

# URL изображения
url = (
"https://img.freepik.com/premium-photo/"
"a-house-on-a-mountain-with-a-mountain-"
Expand All @@ -12,13 +13,20 @@
# Используем контекстный менеджер для открытия изображения
with Image.open(requests.get(url, stream=True).raw) as image:

# Инициализация процессора и модели
processor = (ViTImageProcessor
.from_pretrained("google/vit-base-patch16-224"))
model = (ViTForImageClassification
.from_pretrained("google/vit-base-patch16-224"))

# Подготовка изображения для модели
inputs = processor(images=image, return_tensors="pt")

# Получение предсказания от модели
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

# Вывод предсказанного класса
predicted_class = model.config.id2label[predicted_class_idx]
print(f"Predicted class: {predicted_class}")

0 comments on commit 4232530

Please sign in to comment.