-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
71 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
import os | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains import create_retrieval_chain | ||
|
||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.runnables import RunnablePassthrough | ||
from langchain.prompts import PromptTemplate | ||
|
@@ -74,26 +73,8 @@ | |
user_input = right_column.text_input("Name") | ||
openai_api_key = os.environ["OPENAI_API_KEY"] | ||
|
||
## DATABASE | ||
|
||
#COLLECTION_NAME = "strolr_test" | ||
#CONNECTION_STRING = PGVector.connection_string_from_db_params( | ||
# driver=os.environ.get("PGVECTOR_DRIVER", "psycopg2"), | ||
# host=os.environ.get("PGVECTOR_HOST", "vectordb.cfowaqqqovp0.us-east-2.rds.amazonaws.com"), | ||
# port=int(os.environ.get("PGVECTOR_PORT", "5432")), | ||
# database=os.environ.get("PGVECTOR_DATABASE", "postgres"), | ||
# user=os.environ.get("PGVECTOR_USER", "postgres"), | ||
# password=os.environ.get("PGVECTOR_PASSWORD", "temporary"), | ||
#) | ||
|
||
#conn = psycopg2.connect( | ||
# host="vectordb.cfowaqqqovp0.us-east-2.rds.amazonaws.com", | ||
# database="postgres", | ||
# user="postgres", | ||
# password="temporary") | ||
|
||
|
||
# Create a string for downloadable chat history | ||
# Create a string for downloadable CHAT HISTORY | ||
if user_input != '': | ||
chat_hist_download = user_input + '\'s chat history on ' + str(today) + '\n' | ||
username_hist = user_input | ||
|
@@ -132,15 +113,12 @@ def format_response(responses): | |
|
||
|
||
@st.cache_resource | ||
#def load_chain_with_sources(): | ||
#CHAIN | ||
def load_chain_with_sources(): | ||
|
||
embeddings = OpenAIEmbeddings() | ||
# store = PGVector( | ||
# collection_name=COLLECTION_NAME, | ||
# connection_string=CONNECTION_STRING, | ||
# embedding_function=embeddings, | ||
# ) | ||
|
||
# CONNECT TO RDS | ||
connection = "postgresql+psycopg://langchain:[email protected]:5432/postgres" | ||
collection_name = "strolr_docs" | ||
store = PGVector( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,16 @@ | |
from langchain_postgres.vectorstores import PGVector | ||
from langchain_openai import OpenAIEmbeddings | ||
import pickle | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_openai import ChatOpenAI | ||
|
||
from langchain_openai import OpenAI | ||
from openai import OpenAI | ||
from langchain_core.documents import Document | ||
from langchain_postgres import PGVector | ||
from langchain_postgres.vectorstores import PGVector | ||
|
||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains import create_retrieval_chain | ||
connection = "postgresql+psycopg://langchain:[email protected]:5432/postgres" | ||
collection_name = "strolr_docs" | ||
embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | ||
|
@@ -21,11 +27,65 @@ | |
) | ||
#query = "What should I do if I have a post-partum depression?" | ||
query = "Is it safe for my unborn baby if I eat raw fish during pregnancy?" | ||
similar = vector_store.similarity_search_with_score(query, k=3) | ||
vector = embeddings.embed_query(query) | ||
#similar = vector_store.similarity_search_with_score(query, k=3) | ||
|
||
def load_chain_with_sources(): | ||
|
||
embeddings = OpenAIEmbeddings() | ||
|
||
# CONNECT TO RDS | ||
connection = "postgresql+psycopg://langchain:[email protected]:5432/postgres" | ||
collection_name = "strolr_docs" | ||
store = PGVector( | ||
embeddings=embeddings, | ||
collection_name=collection_name, | ||
connection=connection, | ||
use_jsonb=True,) | ||
retriever = store.as_retriever(search_type="similarity_score_threshold", search_kwargs = {"k":3, "score_threshold":0.8}) | ||
llm = ChatOpenAI(temperature = 0.8, model = "gpt-4o-mini") | ||
|
||
|
||
|
||
# Create memory 'chat_history' | ||
#memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages = True) | ||
#memory = ConversationBufferWindowMemory(k=1, memory_key="chat_history", output_key='answer', return_messages = True) | ||
|
||
# Create system prompt | ||
template = """ | ||
You are acting as a friendly clinician who is speaking to a patient. | ||
Do not say you are an AI. Don't say you're a clinician ever to the user. | ||
The patient is looking for information related to pregnancy. | ||
This patient has below a proficient health literacy level based on the National Assessment of Adult Literacy. Please adjust your response accordingly. | ||
This patient reads at a 6th grade reading level. Please adjust your response accordingly. | ||
Only provide the answer to questions you can find answers to in the database. If the information is not in the database, just apologize and say that you do not know the answer. | ||
Never provide resources if they are not relevant to the user's question. If applicable, highlight the text you referenced from the original source. If no sources are relevant for a user's question, never include any resources in your response. | ||
Don't try to make up an answer. | ||
Never give a response in any language besides the English language even if the user requests it. | ||
If the question is not related to pregnancy or childcare, politely inform them that you are tuned to only answer questions about pregnancy and childcare. | ||
If the answer is not in the {context}, say that you don't know in a kind way or give them a suggestion on a different question to ask. | ||
Do your best to understand typos, casing, and framing of questions. | ||
Do not return sources if you responded with I don't know. | ||
{context} | ||
""" | ||
|
||
# Create the Conversational Chain | ||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", template), | ||
("human", "{input}"), | ||
] | ||
) | ||
question_answer_chain = create_stuff_documents_chain(llm, prompt) | ||
# Set up the RAG chain | ||
chain = create_retrieval_chain(retriever, question_answer_chain) | ||
|
||
# Invoke the RAG chain with the question | ||
return chain | ||
|
||
|
||
for doc in similar: | ||
print(doc) | ||
print('\n') | ||
chain = load_chain_with_sources() | ||
|
||
similar[0][0] | ||
formatted_query = {'input': query} | ||
result = chain.invoke(formatted_query) | ||
print(result) |