Skip to content

Commit 19f764c

Browse files
authored
Merge pull request #10 from KB-iGOT/main
main
2 parents 11e467e + 305a91c commit 19f764c

File tree

4 files changed

+129
-10
lines changed

4 files changed

+129
-10
lines changed

app/routers/course.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import logging
1+
import time
22
from fastapi import APIRouter, HTTPException
33
from ..logger import logger
44
from ..models import ImageResponse, ImageVariationResponse
5-
from ..services.course import generate_course_summary, generate_image_prompt, generate_image
5+
from ..services.course import generate_course_summary, generate_image_prompt, generate_image, generate_public_url
66
from ..services.image_variation import generate_image_variations
77

88
router = APIRouter(
@@ -13,9 +13,11 @@
1313
@router.get("/course/{course_id}", response_model=ImageResponse,summary= "Create thumbnail from course title, description, and table of contents")
1414
def generate_course_image(course_id: str):
1515
try:
16-
final_summary = generate_course_summary(course_id)
16+
content, final_summary = generate_course_summary(course_id)
1717
image_prompt = generate_image_prompt(final_summary)
18-
image_url = generate_image(image_prompt)
18+
image_data = generate_image(image_prompt)
19+
image_url = generate_public_url(content, image_data)
20+
1921
return {
2022
"final_summary": final_summary,
2123
"image_prompt": image_prompt,
@@ -28,8 +30,10 @@ def generate_course_image(course_id: str):
2830
@router.get("/variations/course/{course_id}", response_model=ImageVariationResponse,summary= "Generate thumbnail variations from an existing course thumbnail")
2931
def generate_course_image_variations(course_id: str):
3032
try:
33+
start_time = time.time()
3134
logger.info(f"Course ID : {course_id}")
3235
image_urls = generate_image_variations(course_id)
36+
print("Time took to process the request and return response is {} sec".format(time.time() - start_time))
3337
return ImageVariationResponse(images=image_urls)
3438
except Exception as e:
3539
logger.exception("Error while generating the image variations")

app/services/course.py

+52-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
1+
from PIL import Image
2+
import base64
3+
import io
14
import os
25
from dotenv import load_dotenv
36
from typing import Dict
7+
import urllib.parse
48
import requests
59
from langchain_openai import ChatOpenAI
610
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
7-
from langchain.chains.combine_documents import create_stuff_documents_chain
8-
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
911
from langchain_core.output_parsers import StrOutputParser
12+
import openai
13+
from ..utils import get_extension_from_mimetype
1014

1115
# Load environment variables from .env file
1216
load_dotenv()
1317
KB_API_HOST = os.environ["KB_API_HOST"]
1418

19+
from ..libs.storage import GCPStorage
20+
storage = GCPStorage()
21+
STORAGE_THUMBNAIL_FOLDER=os.environ["STORAGE_THUMBNAIL_FOLDER"]
22+
STORAGE_PROXY_PATH=os.environ["STORAGE_PROXY_PATH"]
23+
1524
llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")
1625
# Define prompt
1726
prompt = ChatPromptTemplate.from_messages(
@@ -107,7 +116,7 @@ def generate_course_summary(course_id: str):
107116
"description": course_details["result"]["content"]["description"],
108117
"toc": formatted_toc
109118
})
110-
return result
119+
return course_details["result"]["content"], result
111120

112121
def generate_image_prompt(summary: str):
113122
prompt = PromptTemplate.from_template(prompt_template)
@@ -125,12 +134,49 @@ def generate_image_prompt(summary: str):
125134
def generate_image(image_prompt: str):
126135
############################
127136
# Generate a thumbnail image
128-
image_url = DallEAPIWrapper(model="dall-e-3", size="1792x1024").run(f"""
137+
prompt = f"""
129138
Do not print any text on image, just use it AS-IS:
130139
{image_prompt}
131140
132141
Guidelines:
133142
- Please ensure that the image does not include any text or human imagery.
134143
- Generate image without map of india.
135-
""")
136-
return image_url
144+
"""
145+
response = openai.images.generate(
146+
prompt=prompt,
147+
model="dall-e-3",
148+
size="1024x1024",
149+
quality="standard",
150+
n=1,
151+
response_format="b64_json"
152+
)
153+
return response.data[0].b64_json
154+
155+
# Compress and convert image to JPEG
156+
def compress_image(image_data):
157+
# Decode the base64 image data
158+
image_bytes = base64.b64decode(image_data)
159+
160+
# Load the image into a PIL Image object
161+
image = Image.open(io.BytesIO(image_bytes))
162+
163+
# Convert to JPEG and compress the image
164+
compressed_buffer = io.BytesIO()
165+
image.save(compressed_buffer, format="JPEG", quality=80) # Adjust quality if needed
166+
compressed_image_data = compressed_buffer.getvalue()
167+
return compressed_image_data
168+
169+
def generate_public_url(content, image_data, mime_type = None):
170+
171+
if mime_type is None:
172+
mime_type = "image/jpeg"
173+
174+
compressed_image_data = compress_image(image_data)
175+
176+
extension = get_extension_from_mimetype(mime_type)
177+
filename = f"{content["name"]}.{extension}"
178+
filepath = os.path.join(STORAGE_THUMBNAIL_FOLDER, content["identifier"], filename)
179+
storage.write_file(filepath, compressed_image_data, mime_type)
180+
# image_urls.append(storage.public_url(filepath))
181+
public_url = urllib.parse.urljoin(KB_API_HOST, os.path.join(STORAGE_PROXY_PATH, content["identifier"], filename))
182+
return public_url

demo_app/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Image Variation Generator Streamlit App
2+
3+
## Prerequisites
4+
5+
Before you begin, ensure you have the following installed on your system:
6+
7+
- Thumbnail generation API
8+
- Streamlit library installed: `pip install streamlit`
9+
10+
## Start the Streamlit app:
11+
```
12+
streamlit run frontend.py
13+
```

demo_app/frontend.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import requests
2+
import streamlit as st
3+
4+
def fetch_image_variations(course_id):
5+
"""Fetches image variations for a given course ID from the API.
6+
7+
Args:
8+
course_id (str): The ID of the course.
9+
10+
Returns:
11+
list[str]: A list of image URLs.
12+
13+
Raises:
14+
requests.exceptions.RequestException: If there's an error making the API request.
15+
"""
16+
17+
url = f"http://localhost:8000/v1/image/variations/course/{course_id}"
18+
headers = {"accept": "application/json"}
19+
20+
try:
21+
response = requests.get(url, headers=headers)
22+
response.raise_for_status()
23+
return response.json()["images"]
24+
except requests.exceptions.RequestException as e:
25+
raise e
26+
27+
def display_images(images):
28+
"""Displays a list of image URLs in a 4-column grid.
29+
30+
Args:
31+
images (list[str]): A list of image URLs.
32+
"""
33+
34+
cols = st.columns(4)
35+
for i, image_url in enumerate(images):
36+
cols[i].image(image_url)
37+
38+
def main():
39+
"""Main function for the Streamlit application."""
40+
41+
st.title("Image Variation Generator")
42+
43+
course_id = st.text_input("Enter Course ID:")
44+
45+
generate_button = st.button("Generate Images")
46+
if generate_button:
47+
with st.spinner("Processing..."):
48+
try:
49+
images = fetch_image_variations(course_id)
50+
st.success("Images generated successfully!")
51+
display_images(images)
52+
except requests.exceptions.RequestException as e:
53+
st.error(f"Error fetching images: {e}")
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)