Skip to content

Commit

Permalink
✨ feat(api): combine retrieverqa and chat api
Browse files Browse the repository at this point in the history
  • Loading branch information
hcd233 committed Apr 19, 2024
1 parent 084234e commit 922df2a
Showing 1 changed file with 52 additions and 194 deletions.
246 changes: 52 additions & 194 deletions src/api/router/v1/session.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,21 @@
from datetime import datetime
from json import dumps, loads
from threading import Thread
from typing import Any, Callable, Dict, Tuple
from typing import Any, AsyncGenerator, Callable, Dict, Tuple

from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from sqlalchemy import or_

from langchain.chains.base import Chain
from src.langchain.callback import StreamCallbackHandler, TokenGenerator
from src.langchain.callback import OUTPUT_PARSER_NAME
from src.langchain.chain import init_chat_chain, init_retriever_qa_chain
from src.langchain.embedding import init_embedding
from src.langchain.llm import init_llm
from src.langchain.memory import init_chat_memory, init_history
from src.langchain.prompt import init_chat_prompt, init_retriever_prompt
from src.langchain.retriever import init_retriever
from src.logger import logger
from src.middleware.mysql import session
from src.middleware.mysql.models import (LLMSchema, MessageSchema,
SessionSchema, VectorDbSchema)
from src.middleware.mysql.models import LLMSchema, MessageSchema, SessionSchema, VectorDbSchema
from src.middleware.mysql.models.embeddings import EmbeddingSchema
from src.middleware.redis import r

from ...auth import sk_auth
from ...model.request import ChatRequest, RetrieverQARequest
from ...model.request import ChatRequest
from ...model.response import SSEResponse, StandardResponse

session_router = APIRouter(prefix="/session", tags=["session"])
Expand Down Expand Up @@ -235,7 +223,7 @@ async def delete_session(session_id: int, uid: int = None, info: Tuple[int, int]


@session_router.post("/{session_id}/chat", dependencies=[Depends(sk_auth)])
async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = Depends(sk_auth)) -> StandardResponse | SSEResponse:
async def retriever_qa(session_id: int, request: ChatRequest, info: Tuple[int, int] = Depends(sk_auth)) -> StandardResponse | SSEResponse:
_uid, _ = info

redis_lock = f"chat_lock:uid:{_uid}"
Expand Down Expand Up @@ -282,190 +270,60 @@ async def chat(session_id: int, request: ChatRequest, info: Tuple[int, int] = De
conn.commit()
logger.debug(f"Bind LLM: {request.llm_name} to Session: {session_id}")

try:
token_generator = TokenGenerator(redis_lock=redis_lock)
callback = StreamCallbackHandler(token_generator)

llm: ChatOpenAI = init_llm(
llm_type=_llm.llm_type,
llm_name=_llm.llm_name,
base_url=_llm.base_url,
api_key=_llm.api_key,
temperature=request.temperature,
max_tokens=_llm.max_tokens,
callbacks=[callback],
)
history = init_history(session_id=session_id)

memory = init_chat_memory(
history=history,
request_type=_llm.request_type,
user_name=_llm.user_name,
ai_name=_llm.ai_name,
k=8,
)
prompt = init_chat_prompt(
sys_prompt=_llm.sys_prompt,
request_type=_llm.request_type,
sys_name=_llm.sys_name,
user_name=_llm.user_name,
ai_name=_llm.ai_name,
)

chain = init_chat_chain(
llm=llm,
prompt=prompt,
memory=memory,
callbacks=[callback],
)
except Exception as e:
logger.error(f"Init langchain modules failed: {e}")
return StandardResponse(code=1, status="error", message="Chat init failed")

r.delete(f"session:{session_id}")
r.delete(f"uid:{_uid}:sessions")

Thread(target=chain.invoke, args=(request.message,)).start()
return StreamingResponse(token_generator, media_type="text/event-stream")


@session_router.post("/{session_id}/retriever-qa", dependencies=[Depends(sk_auth)])
async def retriever_qa(session_id: int, request: RetrieverQARequest, info: Tuple[int, int] = Depends(sk_auth)) -> StandardResponse | SSEResponse:
_uid, _ = info

redis_lock = f"chat_lock:uid:{_uid}"
if r.exists(redis_lock):
return StandardResponse(code=1, status="error", message="You are chatting, please wait a moment")
r.set(redis_lock, "lock", ex=30)

with session() as conn:
if not conn.is_active:
conn.rollback()
conn.close()
else:
conn.commit()

query = (
conn.query(SessionSchema.session_id, LLMSchema.llm_name)
.filter(SessionSchema.session_id == session_id)
.filter(SessionSchema.uid == _uid)
.join(LLMSchema, isouter=True)
.filter(or_(SessionSchema.delete_at.is_(None), datetime.now() < SessionSchema.delete_at))
)

result = query.first()
if not result:
r.delete(redis_lock)
return StandardResponse(code=1, status="error", message="Session not exist")

_, llm_name = result
if llm_name:
request.llm_name = llm_name
logger.debug(f"Use bind LLM: {llm_name}")
query = (
conn.query(LLMSchema)
.filter(LLMSchema.llm_name == request.llm_name)
.filter(or_(LLMSchema.delete_at.is_(None), datetime.now() < LLMSchema.delete_at))
)
_llm: LLMSchema | None = query.first()
if not _llm:
r.delete(redis_lock)
return StandardResponse(code=1, status="error", message="LLM not exist")

if not llm_name:
conn.query(SessionSchema).filter(SessionSchema.session_id == session_id).update({SessionSchema.llm_id: _llm.llm_id})
conn.commit()
logger.debug(f"Bind LLM: {request.llm_name} to Session: {session_id}")

query = (
conn.query(VectorDbSchema.embedding_id, VectorDbSchema.db_size)
.filter(VectorDbSchema.vector_db_id == request.vector_db_id)
.filter(or_(VectorDbSchema.delete_at.is_(None), datetime.now() < VectorDbSchema.delete_at))
)
result = query.first()
if not result:
return StandardResponse(code=1, status="error", message="Vector DB not exist")

(embedding_id, db_size) = result

if db_size == 0:
return StandardResponse(code=1, status="error", message="Vector DB is empty, please upload data first")

query = (
conn.query(EmbeddingSchema)
.filter(EmbeddingSchema.embedding_id == embedding_id)
.filter(or_(EmbeddingSchema.delete_at.is_(None), datetime.now() < EmbeddingSchema.delete_at))
)
_embedding: EmbeddingSchema | None = query.first()
if not _embedding:
return StandardResponse(code=1, status="error", message="Embedding not exist")

try:
token_generator = TokenGenerator(redis_lock=redis_lock)
callback = StreamCallbackHandler(token_generator)

llm: ChatOpenAI = init_llm(
llm_type=_llm.llm_type,
llm_name=_llm.llm_name,
base_url=_llm.base_url,
api_key=_llm.api_key,
temperature=request.temperature,
max_tokens=_llm.max_tokens,
callbacks=[callback],
)
if not llm:
return StandardResponse(code=1, status="error", message="LLM init failed")

history = init_history(session_id=session_id)

prompt = init_retriever_prompt(
sys_prompt=_llm.sys_prompt,
request_type=_llm.request_type,
sys_name=_llm.sys_name,
user_name=_llm.user_name,
ai_name=_llm.ai_name,
)

embedding: OpenAIEmbeddings = init_embedding(
_embedding.embedding_type,
embedding_name=_embedding.embedding_name,
api_key=_embedding.api_key,
base_url=_embedding.base_url,
chunk_size=_embedding.chunk_size,
if request.vector_db_id:
with session() as conn:
query = (
conn.query(VectorDbSchema.embedding_id, VectorDbSchema.db_size)
.filter(VectorDbSchema.vector_db_id == request.vector_db_id)
.filter(or_(VectorDbSchema.delete_at.is_(None), datetime.now() < VectorDbSchema.delete_at))
)
result = query.first()
if not result:
return StandardResponse(code=1, status="error", message="Vector DB not exist")

if not embedding:
return StandardResponse(code=1, status="error", message="Embedding init failed")
(embedding_id, db_size) = result

retriever: VectorStoreRetriever = init_retriever(
vector_db_id=request.vector_db_id,
embedding=embedding,
)
if db_size == 0:
return StandardResponse(code=1, status="error", message="Vector DB is empty, please upload data first")

chain = init_retriever_qa_chain(
llm=llm,
prompt=prompt,
retriever=retriever,
callbacks=[callback],
query = (
conn.query(EmbeddingSchema)
.filter(EmbeddingSchema.embedding_id == embedding_id)
.filter(or_(EmbeddingSchema.delete_at.is_(None), datetime.now() < EmbeddingSchema.delete_at))
)

except Exception as e:
logger.error(f"Init langchain modules failed: {e}")
return StandardResponse(code=1, status="error", message="Chat init failed")
_embedding: EmbeddingSchema | None = query.first()
if not _embedding:
return StandardResponse(code=1, status="error", message="Embedding not exist")

chain_func = init_retriever_qa_chain
chain_kwargs = {
"llm_schema": _llm,
"embedding_schema": _embedding,
"temperature": request.temperature,
"session_id": session_id,
"vector_db_id": request.vector_db_id,
}
else:
chain_func = init_chat_chain
chain_kwargs = {
"llm_schema": _llm,
"temperature": request.temperature,
"session_id": session_id,
}
try:
chain = chain_func(**chain_kwargs)
except Exception as e:
logger.exception(f"Init langchain modules failed: {e}")
return StandardResponse(code=1, status="error", message="Chat init failed")

r.delete(f"session:{session_id}")
r.delete(f"uid:{_uid}:sessions")

def _chat_after_save_history(chain: Chain, user_prompt: str, history: BaseChatMessageHistory):
output = chain.invoke(user_prompt)

docs = "```\n" + "\n```\n---\n```\n".join([doc.page_content for doc in output["context"]]) + "\n```"

user_input = f"context: {docs}\n---\nquestion: {user_prompt}"
llm_output = output["answer"]

history.add_message(HumanMessage(content=user_input))
history.add_message(AIMessage(content=llm_output))
async def _filter_event_stream() -> AsyncGenerator[str, None]:
async for event in chain.astream_events({"user_prompt": request.message}, version="v1", include_names=[OUTPUT_PARSER_NAME]):
if event["event"] not in ["on_parser_stream"]:
continue
yield f"data: {dumps(event, ensure_ascii=False)}\n\n"
r.delete(redis_lock)

Thread(target=_chat_after_save_history, args=(chain, request.message, history)).start()
return StreamingResponse(token_generator, media_type="text/event-stream")
return StreamingResponse(_filter_event_stream(), media_type="text/event-stream")

0 comments on commit 922df2a

Please sign in to comment.