Skip to content

Commit

Permalink
image to story generator api
Browse files Browse the repository at this point in the history
  • Loading branch information
emon5122 committed Dec 31, 2023
1 parent dcf5ceb commit 6d87b3e
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 0 deletions.
Binary file added data.pth
Binary file not shown.
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
einops==0.7.0
fastapi==0.108.0
google_search_results==2.4.2
huggingface-hub==0.20.1
langchain==0.0.352
nltk==3.8.1
openai==1.6.1
Pillow==10.1.0
psycopg2-binary==2.9.9
python-dotenv==1.0.0
python-jose==3.3.0
python-multipart==0.0.6
torch==2.1.2
transformers==4.36.2
uvicorn==0.25.0
wikipedia==1.4.0
39 changes: 39 additions & 0 deletions src/api/endpoints/story_teller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import io

import torch
from fastapi import APIRouter, UploadFile, status
from PIL import Image
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BlipForConditionalGeneration,
BlipProcessor,
)

from api.validators.story import Story

router = APIRouter(prefix="/story-teller", tags=["chatbot"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
img_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to(device)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2", torch_dtype=torch.float32, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)


@router.post("/", status_code=status.HTTP_200_OK, response_model=Story)
async def Chat(file: UploadFile):
image_content = await file.read()
raw_image = Image.open(io.BytesIO(image_content)).convert("RGB")
inputs = processor(raw_image, return_tensors="pt").to(device)
out = img_model.generate(**inputs, max_length=50).to(device)
val = processor.decode(out[0], skip_special_tokens=True)
text = f"Write a story on the following scenerio:-> {val}"
inputs = tokenizer(text, return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_length=200).to(device)
text = tokenizer.batch_decode(outputs)
return {"text": text[0]}
5 changes: 5 additions & 0 deletions src/api/validators/story.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel


class Story(BaseModel):
text: str
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from api.endpoints.chat import router as chat_router
from api.endpoints.data_loader import router as data_loader_router
from api.endpoints.story_teller import router as story_teller_router
from api.models import response as response_model
from api.models import user as user_model
from core.database import engine
Expand All @@ -27,3 +28,4 @@
app.include_router(business_idea_generator_router)
app.include_router(chat_router)
app.include_router(data_loader_router)
app.include_router(story_teller_router)

0 comments on commit 6d87b3e

Please sign in to comment.