diff --git a/ai/.python-version b/ai/.python-version new file mode 100644 index 000000000..b326afbc9 --- /dev/null +++ b/ai/.python-version @@ -0,0 +1 @@ +3.9.15 diff --git a/ai/README.md b/ai/README.md index ea44a7326..2960e3898 100644 --- a/ai/README.md +++ b/ai/README.md @@ -13,6 +13,18 @@ The server is built with FastAPI. To start the server by running `uvicorn main:a Swaggger Documentation: /docs Chat endpoint: /chat +The storage context is pulled from s3 so the `main.py` script needs to know where to find it and how to authenticate. + +- Auth: + IRSA should work, otherwise you'll need to set the standard AWS env vars: + - `AWS_ACCESS_KEY_ID` + - `AWS_SECRET_ACCESS_KEY` +- Path: + The script expects the AWS path in `PLURAL_AI_INDEX_S3_PATH` in the format `/`. + Defaults to `plural-assets/dagster/plural-ai/vector_store_index` + +To be safe `AWS_DEFAULT_REGION` should be set to the region of the bucket. + ## Running scraper.py The scraper currently incorporates three datasources: diff --git a/ai/main.py b/ai/main.py index 05043562e..2183d15e8 100644 --- a/ai/main.py +++ b/ai/main.py @@ -1,15 +1,34 @@ import os import openai +import asyncio from fastapi import FastAPI, HTTPException from llama_index import StorageContext, load_index_from_storage, ServiceContext, set_global_service_context from llama_index.indices.postprocessor import SentenceEmbeddingOptimizer from llama_index.embeddings import OpenAIEmbedding - +from s3fs import S3FileSystem from pydantic import BaseModel +def load_query_engine(s3_path: str): + storage_context = StorageContext.from_defaults( + # persist_dir format: "/" + persist_dir=s3_path, + fs=S3FileSystem() + ) + index = load_index_from_storage(storage_context) + return index.as_query_engine( + node_postprocessors=[SentenceEmbeddingOptimizer(percentile_cutoff=0.5)], + response_mode="compact", + similarity_cutoff=0.7 + ) + openai.api_key = os.environ["OPENAI_API_KEY"] +PLURAL_AI_INDEX_S3_PATH = os.getenv("PLURAL_AI_INDEX_S3_PATH", "plural-assets/dagster/plural-ai/vector_store_index") app = FastAPI() +embed_model = OpenAIEmbedding(embed_batch_size=10) +service_context = ServiceContext.from_defaults(embed_model=embed_model) +set_global_service_context(service_context) +query_engine = load_query_engine(PLURAL_AI_INDEX_S3_PATH) class QueryRequest(BaseModel): question: str @@ -17,22 +36,20 @@ class QueryRequest(BaseModel): class QueryResponse(BaseModel): answer: str +async def reload_query_engine(): + global query_engine + while True: + await asyncio.sleep(86400) # 86400 seconds in a day + query_engine = load_query_engine(PLURAL_AI_INDEX_S3_PATH) -embed_model = OpenAIEmbedding(embed_batch_size=10) -service_context = ServiceContext.from_defaults(embed_model=embed_model) -set_global_service_context(service_context) - -storage_context = StorageContext.from_defaults(persist_dir="./storage") -index = load_index_from_storage(storage_context) -query_engine = index.as_query_engine( - node_postprocessors=[SentenceEmbeddingOptimizer(percentile_cutoff=0.5)], - response_mode="compact", - similarity_cutoff=0.7 -) +@app.on_event("startup") +async def schedule_reload_query_engine(): + loop = asyncio.get_event_loop() + loop.create_task(reload_query_engine()) @app.get("/") def read_root(): - return {"Hello": "World"} + return {"Plural": "AI"} @app.post("/chat") def query_data(request: QueryRequest): diff --git a/ai/requirements.txt b/ai/requirements.txt index e8ce47a2b..7d810c54d 100644 --- a/ai/requirements.txt +++ b/ai/requirements.txt @@ -56,4 +56,5 @@ yarl==1.9.2 python-graphql-client nltk config -html2text \ No newline at end of file +html2text +s3fs \ No newline at end of file diff --git a/ai/scraper.py b/ai/scraper.py index 8861d39bf..403c40643 100755 --- a/ai/scraper.py +++ b/ai/scraper.py @@ -117,4 +117,4 @@ def scrape_discord(): index = VectorStoreIndex.from_documents(list(chain)) index.storage_context.persist() -print("persisted new vector index") \ No newline at end of file +print("persisted new vector index")