-
Notifications
You must be signed in to change notification settings - Fork 0
/
pdf_knowledge_base.py
81 lines (67 loc) · 2.65 KB
/
pdf_knowledge_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from qdrant import build_qa_model, build_vector_store
from langchain.callbacks import get_openai_callback
def handle_pdf_knowledge_base(llm):
selection = st.sidebar.radio("Go to", ["PDF Upload", "Ask My PDF(s)"])
if selection == "PDF Upload":
page_pdf_upload_and_build_vector_db()
elif selection == "Ask My PDF(s)":
page_ask_my_pdf(llm)
costs = st.session_state.get("costs", [])
st.sidebar.markdown("## Costs")
st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
for cost in costs:
st.sidebar.markdown(f"- ${cost:.5f}")
def page_ask_my_pdf(llm):
st.title("Ask My PDF(s)")
container = st.container()
response_container = st.container()
with container:
query = st.text_input("Query: ", key="input")
if not query:
answer = None
else:
qa = build_qa_model(llm)
if qa:
with st.spinner("ChatGPT is typing ..."):
answer, cost = ask(qa, query)
st.session_state.costs.append(cost)
else:
answer = None
if answer:
with response_container:
st.markdown("## Answer")
st.write(answer)
def page_pdf_upload_and_build_vector_db():
st.title("PDF Upload")
container = st.container()
with container:
pdf_text = get_pdf_text()
if pdf_text:
with st.spinner("Loading PDF ..."):
build_vector_store(pdf_text)
def ask(qa, query):
with get_openai_callback() as cb:
# query / result / source_documents
answer = qa(query)
return answer, cb.total_cost
def get_pdf_text():
uploaded_file = st.file_uploader(
label="Upload your PDF here😇", type="pdf" # アップロードを許可する拡張子 (複数設定可)
)
if uploaded_file:
pdf_reader = PdfReader(uploaded_file)
text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
model_name="text-embedding-ada-002",
# 適切な chunk size は質問対象のPDFによって変わるため調整が必要
# 大きくしすぎると質問回答時に色々な箇所の情報を参照することができない
# 逆に小さすぎると一つのchunkに十分なサイズの文脈が入らない
chunk_size=500,
chunk_overlap=0,
)
return text_splitter.split_text(text)
else:
return None