Skip to content

Commit

Permalink
Merge pull request #34 from aws-samples/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
MithilShah authored Aug 17, 2023
2 parents be779f4 + 2b33fb6 commit 409a3a9
Show file tree
Hide file tree
Showing 15 changed files with 372 additions and 28 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,7 @@ dmypy.json
.DS_Store

# vs code
.vscode
.vscode

# venv files
env*
38 changes: 28 additions & 10 deletions kendra_retriever_samples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,46 @@ conda env create -f environment.yml
```

## Running samples
Ensure that the environment variables are set for the aws region, kendra index id and the provider/model used by the sample.
For example, for running the `kendra_chat_flan_xl.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID
and FLAN_XL_ENDPOINT.
Before you run the sample, you need to deploy a Large Language Model (or get an API key if you using Anthropic or OPENAI). The samples in this repository have been tested on models deployed using SageMaker Jumpstart. The model id for the LLMS are specified in the table below.


| Model name | env var name | Jumpstart model id | streamlit provider name |
| -----------| -------- | ------------------ | ----------------- |
| Flan XL | FLAN_XL_ENDPOINT | huggingface-text2text-flan-t5-xl | flanxl |
| Flan XXL | FLAN_XXL_ENDPOINT | huggingface-text2text-flan-t5-xxl | flanxxl |
| Falcon 40B instruct | FALCON_40B_ENDPOINT | huggingface-llm-falcon-40b-instruct-bf16 | falcon40b |
| Llama2 70B instruct | LLAMA_2_ENDPOINT | meta-textgeneration-llama-2-70b-f | llama2 |


after deploying the LLM, set up environment variables for kendra id, aws_region and the endpoint name (or the API key for an external provider)

For example, for running the `kendra_chat_flan_xl.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID and FLAN_XL_ENDPOINT.

You can use commands as below to set the environment variables. Only set the environment variable for the provider that you are using. For example, if you are using Flan-xl only set the FLAN_XXL_ENDPOINT. There is no need to set the other Endpoints and keys.

You can use commands as below to set the environment variables.
```bash
export AWS_REGION="<YOUR-AWS-REGION>"
export KENDRA_INDEX_ID="<YOUR-KENDRA-INDEX-ID>"
export FLAN_XL_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XL>"
export FLAN_XXL_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XXL>"
export OPENAI_API_KEY="<YOUR-OPEN-AI-API-KEY>"
export ANTHROPIC_API_KEY="<YOUR-ANTHROPIC-API-KEY>"
export FLAN_XL_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XL>" # only if you are using FLAN_XL
export FLAN_XXL_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XXL>" # only if you are using FLAN_XXL
export FALCON_40B_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-FALCON>" # only if you are using falcon as the endpoint
export LLAMA_2_ENDPOINT="<YOUR-SAGEMAKER-ENDPOINT-FOR-LLAMA2>" #only if you are using llama2 as the endpoint

export OPENAI_API_KEY="<YOUR-OPEN-AI-API-KEY>" # only if you are using OPENAI as the endpoint
export ANTHROPIC_API_KEY="<YOUR-ANTHROPIC-API-KEY>" # only if you are using Anthropic as the endpoint
```


### Running samples from the streamlit app
The samples directory is bundled with an `app.py` file that can be run as a web app using streamlit.

```bash
streamlit run app.py anthropic
streamlit run app.py llama2
```

The above command will run the `kendra_chat_anthropic` as the LLM chain. In order to run a different chain, pass a different provider, for example for running the `open_ai` chain run this command `streamlit run app.py openai`.
The above command will run the `kendra_chat_llama_2` as the LLM chain. In order to run a different chain, pass a different provider, for example for running the `open_ai` chain run this command `streamlit run app.py openai`. Use the column 'streamlit provider name' from the table above to find out the provider name



### Running samples from the command line
```bash
Expand Down
13 changes: 11 additions & 2 deletions kendra_retriever_samples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import kendra_chat_flan_xl as flanxl
import kendra_chat_flan_xxl as flanxxl
import kendra_chat_open_ai as openai

import kendra_chat_falcon_40b as falcon40b
import kendra_chat_llama_2 as llama2

USER_ICON = "images/user-icon.png"
AI_ICON = "images/ai-icon.png"
Expand All @@ -15,7 +16,9 @@
'openai': 'Open AI',
'anthropic': 'Anthropic',
'flanxl': 'Flan XL',
'flanxxl': 'Flan XXL'
'flanxxl': 'Flan XXL',
'falcon40b': 'Falcon 40B',
'llama2' : 'Llama 2'
}

# Check if the user ID is already stored in the session state
Expand All @@ -42,6 +45,12 @@
elif (sys.argv[1] == 'openai'):
st.session_state['llm_app'] = openai
st.session_state['llm_chain'] = openai.build_chain()
elif (sys.argv[1] == 'falcon40b'):
st.session_state['llm_app'] = falcon40b
st.session_state['llm_chain'] = falcon40b.build_chain()
elif (sys.argv[1] == 'llama2'):
st.session_state['llm_app'] = llama2
st.session_state['llm_chain'] = llama2.build_chain()
else:
raise Exception("Unsupported LLM: ", sys.argv[1])
else:
Expand Down
2 changes: 1 addition & 1 deletion kendra_retriever_samples/kendra_chat_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build_chain():

llm = Anthropic(temperature=0, anthropic_api_key=ANTHROPIC_API_KEY, max_tokens_to_sample = 512)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

prompt_template = """
Expand Down
118 changes: 118 additions & 0 deletions kendra_retriever_samples/kendra_chat_falcon_40b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.prompts import PromptTemplate
import sys
import json
import os

class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'

MAX_HISTORY_LENGTH = 5

def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]

class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
prompt = prompt[:1023]
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
print("input_str", input_str)
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
print(response_json)
return response_json[0]["generated_text"]

content_handler = ContentHandler()

llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
"temperature": 0.8,
"max_length": 10000,
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"]
},
content_handler=content_handler
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

prompt_template = """
The following is a friendly conversation between a human and an AI.
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it
does not know.
{context}
Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
if not present in the document.
Solution:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)

qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt":PROMPT})
return qa

def run_chain(chain, prompt: str, history=[]):
return chain({"question": prompt, "chat_history": history})

if __name__ == "__main__":
chat_history = []
qa = build_chain()
print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
for query in sys.stdin:
if (query.strip().lower().startswith("new search:")):
query = query.strip().lower().replace("new search:","")
chat_history = []
elif (len(chat_history) == MAX_HISTORY_LENGTH):
chat_history.pop(0)
result = run_chain(qa, query, chat_history)
chat_history.append((query, result["answer"]))
print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
if 'source_documents' in result:
print(bcolors.OKGREEN + 'Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
print(bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
2 changes: 1 addition & 1 deletion kendra_retriever_samples/kendra_chat_flan_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def transform_output(self, output: bytes) -> str:
content_handler=content_handler
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

prompt_template = """
The following is a friendly conversation between a human and an AI.
Expand Down
6 changes: 3 additions & 3 deletions kendra_retriever_samples/kendra_chat_flan_xxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class ContentHandler(LLMContentHandler):
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
return response_json[0]["generated_text"]
return response_json["generated_texts"][0]

content_handler = ContentHandler()

Expand All @@ -46,7 +46,7 @@ def transform_output(self, output: bytes) -> str:
content_handler=content_handler
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)
retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

prompt_template = """
The following is a friendly conversation between a human and an AI.
Expand Down
118 changes: 118 additions & 0 deletions kendra_retriever_samples/kendra_chat_llama_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
import sys
import json
import os

class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'

MAX_HISTORY_LENGTH = 5

def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["LLAMA_2_ENDPOINT"]

class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs":
[[
#{"role": "system", "content": ""},
{"role": "user", "content": prompt},
]],
**model_kwargs
})
print(input_str)
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))

return response_json[0]['generation']['content']

content_handler = ContentHandler()

llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={"max_new_tokens": 1000, "top_p": 0.9,"temperature":0.6},
endpoint_kwargs={"CustomAttributes":"accept_eula=true"},
content_handler=content_handler,
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

prompt_template = """
The following is a friendly conversation between a human and an AI.
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it
does not know.
{context}
Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
if not present in the document.
Solution:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"],
)

condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)

qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt":PROMPT},
)
return qa

def run_chain(chain, prompt: str, history=[]):
print(prompt)
return chain({"question": prompt, "chat_history": history})

if __name__ == "__main__":
chat_history = []
qa = build_chain()
print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
for query in sys.stdin:
if (query.strip().lower().startswith("new search:")):
query = query.strip().lower().replace("new search:","")
chat_history = []
elif (len(chat_history) == MAX_HISTORY_LENGTH):
chat_history.pop(0)
result = run_chain(qa, query, chat_history)
chat_history.append((query, result["answer"]))
print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
if 'source_documents' in result:
print(bcolors.OKGREEN + 'Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
print(bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
2 changes: 1 addition & 1 deletion kendra_retriever_samples/kendra_chat_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def build_chain():

llm = OpenAI(batch_size=5, temperature=0, max_tokens=300)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)
retriever = AmazonKendraRetriever(index_id=kendra_index_id, region_name=region)

prompt_template = """
The following is a friendly conversation between a human and an AI.
Expand Down
Loading

0 comments on commit 409a3a9

Please sign in to comment.