-
Notifications
You must be signed in to change notification settings - Fork 20
/
worker.py
98 lines (77 loc) · 4.37 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.llms import HuggingFaceHub
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes, DecodingMethods
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watson_machine_learning.foundation_models import Model
# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
Watsonx_API = "Your WatsonX API"
Project_id= "Your Project ID"
# Function to initialize the language model and its embeddings
def init_llm():
global llm_hub, embeddings
params = {
GenParams.MAX_NEW_TOKENS: 250, # The maximum number of tokens that the model can generate in a single run.
GenParams.MIN_NEW_TOKENS: 1, # The minimum number of tokens that the model should generate in a single run.
GenParams.DECODING_METHOD: DecodingMethods.SAMPLE, # The method used by the model for decoding/generating new tokens. In this case, it uses the sampling method.
GenParams.TEMPERATURE: 0.1, # A parameter that controls the randomness of the token generation. A lower value makes the generation more deterministic, while a higher value introduces more randomness.
GenParams.TOP_K: 50, # The top K parameter restricts the token generation to the K most likely tokens at each step, which can help to focus the generation and avoid irrelevant tokens.
GenParams.TOP_P: 1 # The top P parameter, also known as nucleus sampling, restricts the token generation to a subset of tokens that have a cumulative probability of at most P, helping to balance between diversity and quality of the generated text.
}
credentials = {
'url': "https://us-south.ml.cloud.ibm.com",
'apikey' : Watsonx_API
}
LLAMA2_model = Model(
model_id= 'meta-llama/llama-2-70b-chat',
credentials=credentials,
params=params,
project_id=Project_id)
llm_hub = WatsonxLLM(model=LLAMA2_model)
#Initialize embeddings using a pre-trained model to represent the text data.
embeddings = # create object of Hugging Face Instruct Embeddings with (model_name, model_kwargs={"device": DEVICE} )
# Function to process a PDF document
def process_document(document_path):
global conversation_retrieval_chain
# Load the document
loader = # ---> use PyPDFLoader and document_path from the function input parameter <---
documents = loader.load()
# Split the document into chunks, set chunk_size=1024, and chunk_overlap=64. assign it to variable text_splitter
text_splitter = # ---> use Recursive Character TextSplitter and specify the input parameters <---
texts = text_splitter.split_documents(documents)
# Create an embeddings database using Chroma from the split text chunks.
db = Chroma.from_documents(texts, embedding=embeddings)
# Build the QA chain, which utilizes the LLM and retriever for answering questions.
conversation_retrieval_chain = RetrievalQA.from_chain_type(
llm=llm_hub,
chain_type="stuff",
retriever= db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
return_source_documents=False
)
# Function to process a user prompt
def process_prompt(prompt):
global conversation_retrieval_chain
global chat_history
# Pass the prompt and the chat history to the conversation_retrieval_chain object
output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history})
answer = output["result"]
# Update the chat history
# TODO: Append the prompt and the bot's response to the chat history using chat_history.append and pass `prompt` `answer` as arguments
# --> write your code here <--
# Return the model's response
return result['answer']
# Initialize the language model
init_llm()