diff --git a/data.pth b/data.pth new file mode 100644 index 0000000..ff50464 Binary files /dev/null and b/data.pth differ diff --git a/requirements.txt b/requirements.txt index b279a05..ea3e451 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,16 @@ +einops==0.7.0 fastapi==0.108.0 google_search_results==2.4.2 huggingface-hub==0.20.1 langchain==0.0.352 nltk==3.8.1 openai==1.6.1 +Pillow==10.1.0 psycopg2-binary==2.9.9 python-dotenv==1.0.0 python-jose==3.3.0 +python-multipart==0.0.6 torch==2.1.2 +transformers==4.36.2 uvicorn==0.25.0 wikipedia==1.4.0 \ No newline at end of file diff --git a/src/api/endpoints/story_teller.py b/src/api/endpoints/story_teller.py new file mode 100644 index 0000000..bb0ef9d --- /dev/null +++ b/src/api/endpoints/story_teller.py @@ -0,0 +1,39 @@ +import io + +import torch +from fastapi import APIRouter, UploadFile, status +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BlipForConditionalGeneration, + BlipProcessor, +) + +from api.validators.story import Story + +router = APIRouter(prefix="/story-teller", tags=["chatbot"]) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") +img_model = BlipForConditionalGeneration.from_pretrained( + "Salesforce/blip-image-captioning-large" +).to(device) +model = AutoModelForCausalLM.from_pretrained( + "microsoft/phi-2", torch_dtype=torch.float32, trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + + +@router.post("/", status_code=status.HTTP_200_OK, response_model=Story) +async def Chat(file: UploadFile): + image_content = await file.read() + raw_image = Image.open(io.BytesIO(image_content)).convert("RGB") + inputs = processor(raw_image, return_tensors="pt").to(device) + out = img_model.generate(**inputs, max_length=50).to(device) + val = processor.decode(out[0], skip_special_tokens=True) + text = f"Write a story on the following scenerio:-> {val}" + inputs = tokenizer(text, return_tensors="pt", return_attention_mask=False) + outputs = model.generate(**inputs, max_length=200).to(device) + text = tokenizer.batch_decode(outputs) + return {"text": text[0]} diff --git a/src/api/validators/story.py b/src/api/validators/story.py new file mode 100644 index 0000000..c59baee --- /dev/null +++ b/src/api/validators/story.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class Story(BaseModel): + text: str diff --git a/src/main.py b/src/main.py index 9fc3676..58667ae 100644 --- a/src/main.py +++ b/src/main.py @@ -7,6 +7,7 @@ ) from api.endpoints.chat import router as chat_router from api.endpoints.data_loader import router as data_loader_router +from api.endpoints.story_teller import router as story_teller_router from api.models import response as response_model from api.models import user as user_model from core.database import engine @@ -27,3 +28,4 @@ app.include_router(business_idea_generator_router) app.include_router(chat_router) app.include_router(data_loader_router) +app.include_router(story_teller_router)