-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
98 lines (83 loc) · 2.69 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, HttpUrl
from typing import Optional
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Simple RAG Service")
# Pydantic models
class RAGRequest(BaseModel):
url: HttpUrl
query: str
# Initialize models
embeddings = OpenAIEmbeddings(
openai_api_key=os.getenv("OPENAI_API_KEY")
)
llm = ChatOpenAI(
temperature=0,
model_name="gpt-3.5-turbo",
openai_api_key=os.getenv("OPENAI_API_KEY")
)
def process_url_and_query(url: str, query: str) -> str:
"""Process URL content and return answer to query."""
try:
# Load content
loader = WebBaseLoader(url)
documents = loader.load()
# Split text
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)
chunks = text_splitter.split_documents(documents)
# Create FAISS index
vectorstore = FAISS.from_documents(chunks, embeddings)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(),
return_source_documents=True
)
# Get answer
result = qa_chain({"query": query})
return {
"answer": result["result"],
"source_documents": [
{
"content": doc.page_content,
"metadata": doc.metadata
}
for doc in result["source_documents"]
]
}
except Exception as e:
logger.error(f"Error processing URL {url}: {str(e)}")
raise
@app.post("/rag")
async def get_rag_response(request: RAGRequest):
try:
result = process_url_and_query(
str(request.url),
request.query
)
return result
except Exception as e:
logger.error(f"Error in get_rag_response: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 8000)))