-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
48 lines (42 loc) · 1.49 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from streamlit_chat import message
import chromadb
from typing import Final
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
DB_DIR: Final = "embeddings/chroma/"
client_settings = chromadb.config.Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=DB_DIR,
anonymized_telemetry=False,
)
def make_chain(api_key, db_dir=DB_DIR):
if api_key:
model = ChatOpenAI(
model_name="gpt-3.5-turbo", temperature=0.2, openai_api_key=api_key
)
embeddings = OpenAIEmbeddings(
openai_api_key=api_key,
model="text-embedding-ada-002",
)
vectordb = Chroma(
persist_directory=db_dir,
embedding_function=embeddings,
collection_name="Yonsin_Annual_Report_2023_1-25_pages",
client_settings=client_settings,
)
# expose this index in a retriever interface
retriever = vectordb.as_retriever(
search_type="similarity", search_kwargs={"k": 5}
)
return ConversationalRetrievalChain.from_llm(
model, retriever=retriever, return_source_documents=True
)
def get_text(samp_select):
if samp_select:
input_text = st.text_input("Query: ", samp_select, key="input")
else:
input_text = st.text_input("Query: ", "", key="input")
return input_text