-
Notifications
You must be signed in to change notification settings - Fork 20
/
worker_huggingFace.py
83 lines (65 loc) · 3.09 KB
/
worker_huggingFace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFaceHub
# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
# Function to initialize the language model and its embeddings
def init_llm():
global llm_hub, embeddings
# Set up the environment variable for HuggingFace and initialize the desired model.
os.environ["HUGGINGFACEHUB_API_TOKEN"] = "YOUR API KEY"
# repo name for the model
model_id = "tiiuae/falcon-7b-instruct"
# load the model into the HuggingFaceHub
llm_hub = HuggingFaceHub(repo_id=model_id, model_kwargs={"temperature": 0.1, "max_new_tokens": 600, "max_length": 600})
#Initialize embeddings using a pre-trained model to represent the text data.
embeddings = HuggingFaceInstructEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device": DEVICE}
)
# Function to process a PDF document
def process_document(document_path):
global conversation_retrieval_chain
# Load the document
loader = PyPDFLoader(document_path)
documents = loader.load()
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
texts = text_splitter.split_documents(documents)
# Create an embeddings database using Chroma from the split text chunks.
db = Chroma.from_documents(texts, embedding=embeddings)
# --> Build the QA chain, which utilizes the LLM and retriever for answering questions.
# By default, the vectorstore retriever uses similarity search.
# If the underlying vectorstore support maximum marginal relevance search, you can specify that as the search type (search_type="mmr").
# You can also specify search kwargs like k to use when doing retrieval. k represent how many search results send to llm
conversation_retrieval_chain = RetrievalQA.from_chain_type(
llm=llm_hub,
chain_type="stuff",
retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
return_source_documents=False,
input_key = "question"
# chain_type_kwargs={"prompt": prompt} # if you are using prompt template, you need to uncomment this part
)
# Function to process a user prompt
def process_prompt(prompt):
global conversation_retrieval_chain
global chat_history
# Query the model
output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
answer = output["result"]
# Update the chat history
chat_history.append((prompt, answer))
# Return the model's response
return answer
# Initialize the language model
init_llm()