Skip to content

Commit

Permalink
Merge pull request #25 from FloRul/feature/llama-index-migration
Browse files Browse the repository at this point in the history
Feature/llama index migration
  • Loading branch information
FloRul authored Feb 13, 2024
2 parents 67cc69e + a592a5b commit 1d415b0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 140 deletions.
134 changes: 50 additions & 84 deletions lambdas/inference/src/index.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,71 @@
import json
import os
import boto3
from botocore.exceptions import ClientError
from retrieval import Retrieval
from history import History


def get_secret():
try:
response = boto3.client("secretsmanager").get_secret_value(
SecretId=os.environ.get("PGVECTOR_PASSWORD_SECRET_NAME")
)
return response["SecretString"]
except ClientError as e:
raise e


headers = {
HEADERS = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}

PGVECTOR_DRIVER = os.environ.get("PGVECTOR_DRIVER", "psycopg2")
PGVECTOR_HOST = os.environ.get("PGVECTOR_HOST", "localhost")
PGVECTOR_PORT = int(os.environ.get("PGVECTOR_PORT", 5432))
PGVECTOR_DATABASE = os.environ.get("PGVECTOR_DATABASE", "postgres")
PGVECTOR_USER = os.environ.get("PGVECTOR_USER", "postgres")
PGVECTOR_PASSWORD = get_secret()

RELEVANCE_TRESHOLD = os.environ.get("RELEVANCE_TRESHOLD", 0.5)

MODEL_ID = os.environ.get("MODEL_ID", "anthropic.claude-instant-v1")
ACCEPT = "application/json"
CONTENT_TYPE = "application/json"
ENV_VARS = {
"relevance_treshold": os.environ.get("RELEVANCE_TRESHOLD", 0.5),
"model_id": os.environ.get("MODEL_ID", "anthropic.claude-instant-v1"),
"system_prompt": os.environ.get(
"SYSTEM_PROMPT", "Answer in four to five sentences.Answer in french."
),
"enable_history": int(os.environ.get("ENABLE_HISTORY", 1)),
"enable_retrieval": int(os.environ.get("ENABLE_RETRIEVAL", 1)),
"max_tokens": int(os.environ.get("MAX_TOKENS", 100)),
"enable_inference": int(os.environ.get("ENABLE_INFERENCE", 1)),
"top_k": int(os.environ.get("TOP_K", 10)),
"top_p": float(os.environ.get("TOP_P", 0.9)),
"temperature": float(os.environ.get("TEMPERATURE", 0.3)),
}


def prepare_prompt(query: str, docs: list, history: list):
try:
system_prompt = os.environ.get(
"SYSTEM_PROMPT",
"Answer in four to five sentences.Answer in french.",
)
final_prompt = "{}{}\n\nAssistant:"
basic_prompt = f'\n\nHuman: The user sent the following message : "{query}".'
document_prompt = prepare_document_prompt(docs)
history_prompt = prepare_history_prompt(history)
final_prompt = f"{ENV_VARS['system_prompt']}{basic_prompt}\n{document_prompt}\n{history_prompt}\n\nAssistant:"
return final_prompt

basic_prompt = (
f"""\n\nHuman: The user sent the following message : \"{query}\"."""
)

if len(docs) > 0:
docs_context = ".\n".join(map(lambda x: x.page_content, docs))
document_prompt = f"""Here is a set of quotes between <quotes></quotes> XML tags to help you answer: <quotes>{docs_context}</quotes>."""
if len(docs) == 0:
document_prompt = f"""I could not find any relevant quotes to help you answer the user's query."""
def prepare_document_prompt(docs):
if docs:
docs_context = ".\n".join(doc.page_content for doc in docs)
return f"Here is a set of quotes between <quotes></quotes> XML tags to help you answer: <quotes>{docs_context}</quotes>."
return "I could not find any relevant quotes to help you answer the user's query."

basic_prompt = f"""{basic_prompt}\n{document_prompt}"""

if len(history) > 0:
history_context = ".\n".join(
map(
lambda x: f"""Human:{x['HumanMessage']}\nAssistant:{x['AssistantMessage']}""",
history,
)
)
history_prompt = f"""Here is the history of the previous messages history between <history></history> XML tags: <history>{history_context}</history>."""
basic_prompt = f"""{basic_prompt}\n{history_prompt}"""

final_prompt = final_prompt.format(system_prompt, basic_prompt)
return final_prompt
except Exception as e:
print(f"Error while preparing prompt : {e}")
raise e
def prepare_history_prompt(history):
if history:
history_context = ".\n".join(
f"Human:{x['HumanMessage']}\nAssistant:{x['AssistantMessage']}"
for x in history
)
return f"Here is the history of the previous messages history between <history></history> XML tags: <history>{history_context}</history>."
return ""


def invoke_model(prompt: str, max_tokens: int, temperature: float, top_p: float):
def invoke_model(prompt: str):
body = json.dumps(
{
"prompt": prompt,
"max_tokens_to_sample": max_tokens,
"temperature": temperature,
"top_p": top_p,
"max_tokens_to_sample": ENV_VARS["max_tokens"],
"temperature": ENV_VARS["temperature"],
"top_p": ENV_VARS["top_p"],
}
)
try:
response = boto3.client("bedrock-runtime").invoke_model(
body=body, modelId=MODEL_ID, accept=ACCEPT, contentType=CONTENT_TYPE
body=body,
modelId=ENV_VARS["model_id"],
accept="application/json",
contentType="application/json",
)
body = response["body"].read().decode("utf-8")
json_body = json.loads(body)
Expand All @@ -96,15 +77,6 @@ def invoke_model(prompt: str, max_tokens: int, temperature: float, top_p: float)

def lambda_handler(event, context):
response = "this is a dummy response"

enable_history = int(os.environ.get("ENABLE_HISTORY", 1))
enable_retrieval = int(os.environ.get("ENABLE_RETRIEVAL", 1))
max_tokens_to_sample = int(os.environ.get("MAX_TOKENS", 100))
enable_inference = int(os.environ.get("ENABLE_INFERENCE", 1))
top_k = int(os.environ.get("TOP_K", 10))
top_p = float(os.environ.get("TOP_P", 0.9))
temperature = float(os.environ.get("TEMPERATURE", 0.3))

history = History(event["queryStringParameters"]["sessionId"])
embedding_collection_name = event["queryStringParameters"]["collectionName"]

Expand All @@ -113,42 +85,36 @@ def lambda_handler(event, context):
docs = []
chat_history = []

if enable_inference != 0:
if enable_retrieval != 0:
if ENV_VARS["enable_inference"]:
if ENV_VARS["enable_retrieval"]:
retrieval = Retrieval(
driver=PGVECTOR_DRIVER,
host=PGVECTOR_HOST,
port=PGVECTOR_PORT,
database=PGVECTOR_DATABASE,
user=PGVECTOR_USER,
password=PGVECTOR_PASSWORD,
collection_name=embedding_collection_name,
relevance_treshold=RELEVANCE_TRESHOLD,
relevance_treshold=ENV_VARS["relevance_treshold"],
)
docs = retrieval.fetch_documents(query=query, top_k=top_k)
docs = retrieval.fetch_documents(query=query, top_k=ENV_VARS["top_k"])

if enable_history != 0:
if ENV_VARS["enable_history"]:
chat_history = json.loads(history.get(limit=10))

# prepare the prompt
prompt = prepare_prompt(query, docs, chat_history)
response = invoke_model(prompt, max_tokens_to_sample, temperature, top_p)
response = invoke_model(prompt)

if enable_history != 0:
if ENV_VARS["enable_history"]:
history.add(
human_message=query, assistant_message=response, prompt=prompt
)

return {
"statusCode": 200,
"body": response,
"headers": headers,
"headers": HEADERS,
"isBase64Encoded": False,
}
except Exception as e:
print(e)
return {
"statusCode": 500,
"body": json.dumps(e),
"headers": headers,
"headers": HEADERS,
}
2 changes: 1 addition & 1 deletion lambdas/inference/src/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ langchain-community
langchain
psycopg2-binary
pgvector
boto3>=1.28.85
boto3
76 changes: 21 additions & 55 deletions lambdas/inference/src/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,44 @@
import boto3
from langchain_community.embeddings import BedrockEmbeddings

MODEL_ID = "anthropic.claude-instant-v1"
ACCEPT = "application/json"
CONTENT_TYPE = "application/json"
def get_secret():
try:
response = boto3.client("secretsmanager").get_secret_value(
SecretId=os.environ.get("PGVECTOR_PASSWORD_SECRET_NAME")
)
return response["SecretString"]
except ClientError as e:
raise e


class Retrieval:
def __init__(
self,
driver,
host,
port,
database,
user,
password,
collection_name,
relevance_treshold,
):
self._relevance_treshold = relevance_treshold
PGVECTOR_DRIVER = os.environ.get("PGVECTOR_DRIVER", "psycopg2")
PGVECTOR_HOST = os.environ.get("PGVECTOR_HOST", "localhost")
PGVECTOR_PORT = int(os.environ.get("PGVECTOR_PORT", 5432))
PGVECTOR_DATABASE = os.environ.get("PGVECTOR_DATABASE", "postgres")
PGVECTOR_USER = os.environ.get("PGVECTOR_USER", "postgres")
PGVECTOR_PASSWORD = get_secret()
self._vector_store = PGVector(
connection_string=PGVector.connection_string_from_db_params(
driver=driver,
host=host,
port=port,
database=database,
user=user,
password=password,
driver=PGVECTOR_DRIVER,
host=PGVECTOR_HOST,
port=PGVECTOR_PORT,
database=PGVECTOR_DATABASE,
user=PGVECTOR_USER,
password=PGVECTOR_PASSWORD,
),
collection_name=collection_name,
embedding_function=BedrockEmbeddings(
client=boto3.client("bedrock-runtime")
),
)

def _get_secret(self):
try:
response = boto3.client("secretsmanager").get_secret_value(
SecretId=os.environ.get("PGVECTOR_PASSWORD_SECRET_NAME")
)
return response["SecretString"]
except ClientError as e:
raise e

def fetch_documents(self, query: str, top_k: int = 10):
try:
docs = self._vector_store.similarity_search_with_relevance_scores(
Expand All @@ -58,34 +54,4 @@ def fetch_documents(self, query: str, top_k: int = 10):
return [x[0] for x in docs if x[1] > self._relevance_treshold]
except Exception as e:
print(f"Error while retrieving documents : {e}")
raise e


# # Retrieve more documents with higher diversity
# # Useful if your dataset has many similar documents
# vectorstore.as_retriever(
# search_type="mmr",
# search_kwargs={"k": 6, "lambda_mult": 0.25}
# )

# # Fetch more documents for the MMR algorithm to consider
# # But only return the top 5
# vectorstore.as_retriever(
# search_type="mmr",
# search_kwargs={"k": 5, "fetch_k": 50}
# )

# # Only retrieve documents that have a relevance score
# # Above a certain threshold
# vectorstore.as_retriever(
# search_type="similarity_score_threshold",
# search_kwargs={"score_threshold": 0.8}
# )

# # Only get the single most similar document from the dataset
# vectorstore.as_retriever(search_kwargs={"k": 1})

# # Use a filter to only retrieve documents from a specific paper
# docsearch.as_retriever(
# search_kwargs={"filter": {"paper_title": "GPT-4 Technical Report"}}
# )
raise e

0 comments on commit 1d415b0

Please sign in to comment.