Skip to content

Commit

Permalink
update upsert to caption as well
Browse files Browse the repository at this point in the history
  • Loading branch information
kenny1G committed Dec 14, 2023
1 parent 8f6f29d commit 8df9aeb
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 10 deletions.
2 changes: 1 addition & 1 deletion streamlit_app/pages/1_Setup_Demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def click_login_button(uid):
auth_url, state = flow.authorization_url()
st.write(
f"""<h1>
Please login using this <a target="_self"
Please login using this <a target="_blank"
href="{auth_url}">url</a></h1>""",
unsafe_allow_html=True,
)
Expand Down
83 changes: 74 additions & 9 deletions streamlit_app/pages/2_Upsert_Images.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import pinecone
import itertools
import pandas as pd
from tqdm import tqdm
from PIL import Image
import streamlit as st
from dotenv import load_dotenv
from datetime import date, timedelta, datetime
from google.auth.transport.requests import Request
from sentence_transformers import SentenceTransformer
from transformers import pipeline

if "credentials" not in st.session_state or "uid" not in st.session_state:
st.warning(
"You are not authenticated yet. Please go to Setup Demo to Authenticate."
"You are not authenticated yet. Please enter your unique ID in Setup Demo to Authenticate."
)
st.stop()

Expand All @@ -31,6 +33,9 @@
def load_model():
return SentenceTransformer("clip-ViT-B-32")

@st.cache_resource
def load_captioner():
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

@st.cache_resource
def get_pinecone_image_index():
Expand All @@ -41,10 +46,11 @@ def get_pinecone_image_index():
)
if im_index_name not in pinecone.list_indexes():
pinecone.create_index(name=im_index_name, dimension=512, metric="cosine")
return pinecone.Index("im_index_name")
return pinecone.Index(im_index_name)


clip_model = load_model()
blip_model = load_captioner()
pinecone_index = get_pinecone_image_index()


Expand Down Expand Up @@ -103,18 +109,22 @@ def get_response_from_medium_api(year, month, day):
return res


@st.cache_data
@st.cache_data(ttl=600)
def get_images_in_date_range(uid, sdate, edate):
date_list = pd.date_range(sdate, edate - timedelta(days=1), freq="d")
media_items_df = pd.DataFrame()
print("get_image_in_date_range, listing media items")
for date in date_list:
progress_text = "Fetching Image Info Data"
progress_bar = st.progress(0, text=progress_text)
for i, date in enumerate(date_list):
items_df, media_items_df = list_of_media_items(
year=date.year,
month=date.month,
day=date.day,
media_items_df=media_items_df,
)
progress_bar.progress((i + 1) / len(date_list), text=progress_text)
progress_bar.empty()
if len(media_items_df) == 0:
return None
else:
Expand Down Expand Up @@ -168,6 +178,37 @@ def embed_images(_images, uid, sdate, edate):
img_embeddings.append(embedding.tolist())
return img_embeddings

@st.cache_data
def caption_images(media_items_df):
img_captions = []
caption_embeddings = []
img_ids = []
errored = []
error_msg = ""


progress_text = "Captioning Images"
progress_bar = st.progress(0, text = progress_text)
for i, burl in tqdm(enumerate(media_items_df['baseUrl'].values), total=len(media_items_df)):
progress_bar.progress(i / len(media_items_df), text=progress_text)

try:
curr_id = media_items_df['id'].values[i]
curr_desc = blip_model(burl)
curr_desc = curr_desc[0]['generated_text']
curr_embedding = clip_model.encode(curr_desc).tolist()

img_ids.append(curr_id)
caption_embeddings.append(curr_embedding)
img_captions.append(curr_desc)
except Exception as e:
errored.append(media_items_df['id'].values[i])
st.warning(f"Failed to caption image {media_items_df['baseUrl'].values[i]}")
print(f"Error occurred: {e}")
return None, None, None
progress_bar.empty()
return img_ids, caption_embeddings, img_captions

def chunks(iterable, batch_size=100):
"""A helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
Expand All @@ -176,13 +217,18 @@ def chunks(iterable, batch_size=100):
yield chunk
chunk = tuple(itertools.islice(it, batch_size))

def upsert_to_pinecone(namespace, media_items_df):
def upsert_to_pinecone(namespace, media_items_df, is_caption=False):
vector = media_items_df.caption_embeddings if is_caption else media_items_df.vector
namespace = namespace + "_captions" if is_caption else namespace

vectors = zip(
media_items_df.id,
media_items_df.vector,
vector,
media_items_df.metadata,
)

pinecone_index.delete(delete_all=True, namespace=namespace)

# Upsert data with 100 vectors per upsert request asynchronously
# - Create pinecone.Index with pool_threads=30 (limits to 30 simultaneous requests)
# - Pass async_req=True to index.upsert()
Expand All @@ -202,11 +248,18 @@ def upsert_to_pinecone(namespace, media_items_df):
st.warning("Error upserting to Pinecone Please try again.")
st.stop()

while index.describe_index_stats()["total_vector_count"] == 0:
while pinecone_index.describe_index_stats()["total_vector_count"] == 0:
print(index.describe_index_stats())

def click_date_range_button(start_date, end_date):
print("Fetching Images")
# Check if the date range is longer than 3 months
if (end_date - start_date).days > 93:
st.warning(
"The date range is longer than 3 months. Please select a shorter range.",
icon="⚠️",
)
return
media_items_df = get_images_in_date_range(uid, start_date, end_date)
if media_items_df is None:
st.warning("No images found in date range")
Expand All @@ -218,13 +271,25 @@ def click_date_range_button(start_date, end_date):
print("Embedding Images")
embeddings = embed_images(images, uid, start_date, end_date)
media_items_df["vector"] = embeddings

print("Generating and Embedding Captions")
caption_ids, caption_embeddings, captions = caption_images(media_items_df)
if caption_ids is None:
st.stop()
return
caption_embeddings_dict = dict(zip(caption_ids, caption_embeddings))
captions_dict = dict(zip(caption_ids, captions))

media_items_df['caption_embeddings'] = media_items_df['id'].map(caption_embeddings_dict)
media_items_df['captions'] = media_items_df['id'].map(captions_dict)

media_items_df["metadata"] = media_items_df.loc[
:, ["year", "month", "day"]
:, ["year", "month", "day", "captions"]
].to_dict("records")
print(image_dict)

print("Upserting to Pinecone")
upsert_to_pinecone(uid, media_items_df)
upsert_to_pinecone(uid, media_items_df, is_caption=True)
st.info(f"Upserted {len(media_items_df)} images from {start_date} to {end_date} ")

st.session_state["media_items_df"] = media_items_df
Expand Down

0 comments on commit 8df9aeb

Please sign in to comment.