-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
180 lines (149 loc) · 6.14 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import streamlit as st
import json
from typing import List, Dict
import numpy as np
from openai import OpenAI
from tenacity import retry, stop_after_attempt, wait_random_exponential
import os
from datetime import datetime
import time
import pickle
import urllib.parse
class EmbeddingProcessor:
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""Initialize the embedding processor with OpenAI credentials."""
self.client = OpenAI(api_key=api_key)
self.model = model
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def get_embedding(self, text: str) -> List[float]:
"""Get embedding for a single text using OpenAI's API with retry logic."""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return response.data[0].embedding
def search_similar(self, query: str, embedding_data: Dict, top_k: int = 5) -> List[Dict]:
"""Search for similar documents using cosine similarity."""
query_embedding = self.get_embedding(query)
similarities = np.dot(embedding_data['embeddings'], query_embedding) / (
np.linalg.norm(embedding_data['embeddings'], axis=1) * np.linalg.norm(query_embedding)
)
top_indices = np.argsort(similarities)[-top_k:][::-1]
results = []
for idx in top_indices:
doc = embedding_data['documents'][idx]
result = {
'poster_number': doc['poster_number'],
'title': doc['title'],
'abstract': doc['abstract'],
'authors': doc['authors'],
'similarity': float(similarities[idx]),
'session_info': {
'session_name': doc['session_info']['session_name'],
'location': doc['session_info']['location'],
'time': doc['session_info']['time'],
'date': doc['session_info']['date']
}
}
results.append(result)
return results
def create_google_search_url(title: str) -> str:
"""Create a Google search URL for a given title"""
encoded_title = urllib.parse.quote(title)
return f"https://www.google.com/search?q={encoded_title}"
def main():
st.set_page_config(
page_title="Poster Search",
page_icon="🔍",
layout="wide"
)
# Minimal styling without background colors
st.markdown("""
<style>
.stMarkdown {
font-size: 1.2rem;
}
.session-title {
font-size: 1.2rem;
font-weight: 600;
margin-bottom: 1rem;
}
.session-item {
margin: 0.5rem 0;
font-size: 1.1rem;
}
footer {
visibility: visible;
position: relative;
clear: both;
margin-top: 50px;
padding: 20px;
text-align: right;
font-size: 14px;
}
</style>
""", unsafe_allow_html=True)
# Main title and description
st.title("🔍 NeurIPS 2024 Poster Search")
st.markdown("Search through conference posters using natural language queries.")
# Load embeddings
@st.cache_resource
def load_embeddings(pkl_path):
with open(pkl_path, 'rb') as f:
return pickle.load(f)
try:
embedding_data = load_embeddings("embeddings/poster_embeddings.pkl")
except Exception as e:
st.error(f"Error loading embeddings: {str(e)}")
return
# Search interface with better layout
col1, col2 = st.columns([3, 1])
with col1:
query = st.text_input("Enter your search query:", placeholder="e.g., machine learning in healthcare")
with col2:
num_results = st.slider("Number of results:", min_value=1, max_value=20, value=5)
# Initialize processor with environment API key
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
st.error("OpenAI API key not found in environment variables.")
return
processor = EmbeddingProcessor(api_key=api_key)
if query:
try:
with st.spinner("Searching..."):
results = processor.search_similar(query, embedding_data, num_results)
st.header(f"🎯 Top {len(results)} Results")
for i, result in enumerate(results, 1):
google_url = create_google_search_url(result['title'])
st.markdown("---")
st.markdown(f"### [{result['title']}]({google_url})")
st.markdown(f"*Similarity Score: {result['similarity']:.3f}*")
col1, col2 = st.columns([2, 1])
with col1:
st.markdown(f"**Poster Number:** {result['poster_number']}")
st.markdown(f"**Authors:** {result['authors']}")
st.markdown("**Abstract:**")
st.markdown(result['abstract'])
with col2:
session = result['session_info']
st.markdown("### 📍 Session Information")
st.markdown(f" **🏷️ Session:** {session['session_name']}")
st.markdown(f" **📍 Location:** {session['location']}")
st.markdown(f" **🕒 Time:** {session['time']}")
st.markdown(f" **📅 Date:** {session['date']}")
except Exception as e:
st.error(f"Error during search: {str(e)}")
###################### Footer ######################
st.markdown("") # Add some space
st.markdown("") # Add more space
st.markdown("---") # Horizontal line
st.markdown(
"""
<footer>
made with ❤️ by <a href="https://www.linkedin.com/in/chenml/" target="_blank">Mei Chen</a>
</footer>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()