forked from SeAH-Besteel-RAG/RAG_before_refactoring
-
Notifications
You must be signed in to change notification settings - Fork 0
/
backend.py
191 lines (146 loc) · 6.6 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import json
import pandas as pd
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.text_splitter import RecursiveCharacterTextSplitter, NLTKTextSplitter
from langchain.docstore.document import Document
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from tika import parser
#Retriever
from langchain.vectorstores.chroma import Chroma
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.bm25 import BM25Retriever
from chromadb.errors import InvalidDimensionException
#### pdf preprocessing
from PDFProcess import PDFParser
def device_check() :
''' for check cuda availability '''
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
return device
#실행 이후에 collection 정리(일단 임시.)
def cleanup_process() :
Chroma().delete_collection()
st.markdown('===========END============')
def document_handler(uploaded_files) :
storage = []
for file in uploaded_files :
if file.name.endswith(".pdf") :
parsed_pdf_file = PDFParser(file)
parsed_page = parsed_pdf_file.parse_pdf()
# create document object
for page in parsed_page :
document = Document(page_content=page['text'], metadata={'source':file.name, 'page':page['page']})
storage.append(document)
#나머지 형식 파일들은 받아서 splitter에서 split한 이후에 document로 생성함.
elif file.name.endswith('.docx') or file.name.endswith('.doc') :
raw = parser.from_buffer(file,xmlContent=False)
document = Document(page_content=raw['content'], metadata={'source':file.name})
storage.append(document)
elif file.name.endswith('.xlsx') or file.name.endswith('.xls') :
raw = parser.from_buffer(file,xmlContent=False)
document = Document(page_content=raw['content'], metadata={'source':file.name})
storage.append(document)
return storage
def extract_text_from_file(uploaded_files):
text_splitter = RecursiveCharacterTextSplitter(separators=['\n\n\n','\n\n','\n',' ','.',''], length_function=len, add_start_index=True)
storage = []
#priority filter
#검색될때 나온거 기반으로 우선순위(rank)
for file in uploaded_files :
raw = parser.from_buffer(file,xmlContent=False)
document = text_splitter.create_documents(texts=[raw['content']],metadatas=[{'source':file.name}])
storage.append(document)
return storage
def documentEnsembleRetreiver(files) :
# documents = extract_text_from_file(files)
documents = document_handler(files)
# Embedding and Store
embedding_function = SentenceTransformerEmbeddings(
model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device':device_check()}, encode_kwargs={'normalize_embeddings':True}
)
#DimensionException 뜨면 기존 문서 지울것.
try:
vectordb = Chroma.from_documents(documents=documents, embedding=embedding_function)
except InvalidDimensionException:
Chroma().delete_collection()
vectordb = Chroma.from_documents(documents=documents, embedding=embedding_function)
score_threshold = 0.2
search_k = 2
chroma_retriever = vectordb.as_retriever(search_type="similarity_score_threshold", search_kwargs={'k': search_k ,'score_threshold': score_threshold})
#BM25 Retriever
bm25_retriever = BM25Retriever.from_documents(documents=documents)
bm25_retriever.k = 2
# embedding_function = OpenAIEmbeddings()
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, chroma_retriever], weights=[0.4, 0.6])
return ensemble_retriever
class StreamHandler(BaseCallbackHandler):
def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ''):
self.container = container
self.text = initial_text
self.run_id_ignore_token = None
self.previous_run_id = None
def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
# Workaround to prevent showing the rephrased question as output
if prompts[0].startswith("Human"):
self.run_id_ignore_token = kwargs.get("run_id")
def on_llm_new_token(self, token: str, **kwargs) -> None:
current_run_id = kwargs.get("run_id", None)
# Check if the current run_id is different from the previous one.
# If yes, then a new question has started, so insert a separator.
if self.previous_run_id and self.previous_run_id != current_run_id:
self.text += "\n\n" # Insert a markdown horizontal line as separator
if self.run_id_ignore_token == current_run_id:
return
self.text += token
self.container.markdown(self.text)
self.previous_run_id = current_run_id
# ======================= 검토항목 로드 및 저장 함수 ===========================
## For : 검토항목 추가 및 질문지 수정
# 검토항목 파일 로드
def load_data():
with open("qr_dic.json", "r", encoding = "utf-8") as file:
data = json.load(file)
return data
# 검토항목 파일 업데이트
def save_data(data, filename= "qr_dic.json"):
with open(filename, "w", encoding = "utf-8") as file:
json.dump(data, file, indent=4)
# =========================== 결과 출력 함수 ===============================
# value 하나에 대한 전처리 함수
def process_value(value):
# Case when value is a dictionary and has only one key-value pair
if isinstance(value, dict) and len(value) == 1:
key, sub_value = list(value.items())[0]
if isinstance(sub_value, str):
return f"{key} {sub_value}"
# Case when value is not a dictionary
return value
# 전체 결과 및 별도표 데이터프레임으로 반환
def total_req(message):
# Process the 'Specification' part of the message
message["Specification"] = process_value(message["Specification"])
if isinstance(message["Specification"], str):
sub_df = pd.DataFrame()
elif isinstance(message["Specification"], dict):
# Specificationd의 값만 뽑아 dataframe으로 변환
key, value = message.popitem()
if len(value) == 1:
sub_df = pd.DataFrame(value)
else:
sub_df = pd.DataFrame({key:value})
message["Specification"] = "하단 표 참조"
temp_df = pd.DataFrame([message])
return temp_df, sub_df
data_2 = {
"Name":"Hardness",
"Reference":"DCF703-08(REV.0)",
"Specification":{
"Hardness":{
"Min":"HRc 43 ↑",
"Max":"HRc 55~62",
}
}
}
data_2