diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..cac78559d6 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3.10-slim-buster + +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 +ENV CARGO_HOME=/root/.cargo +ENV PATH=$CARGO_HOME/bin:$PATH + +RUN apt update && apt install -y \ + curl \ + build-essential \ + && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ + && cargo --version \ + && rustc --version + +WORKDIR /app + +COPY requirements.txt . + +RUN pip install --no-cache-dir -r requirements.txt + +COPY . . + +EXPOSE 20213 + +CMD ["uvicorn", "webserver.main:app", "--host", "0.0.0.0", "--port", "20213"] diff --git a/LICENSE b/LICENSE index 9e841e7a26..435e5e2180 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ MIT License Copyright (c) Microsoft Corporation. + Copyright (c) KylinMountain. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 35f2eb2a1e..eff0cc6e3f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,110 @@ +# GraphRAG customized by KylinMountain +- I have added websever to support streaming output immediately. +- I have fixed error when using local embedding service like LM Studio +- I have fixed index error after prompt tune +- I have fixed the strategy not loaded when setting entity extraction using NLTK. +- I have added advice question api +- I have added reference link to the entity、report or relationship refered in output, you can access it. +- Support any desktop application or web application compatible with OpenAI SDK. +- Support docker deploy. you can get the docker kylinmountain/graphrag-server:0.3.1 + +# GraphRAG 定制版 +- 我添加了Web服务器,以支持真即时流式输出。 +- 我修复了使用本地嵌入服务(如LM Studio)时的错误。 +- 我修复了提示调整后索引错误的问题。 +- 我修复了在使用NLTK设置实体提取时策略未加载的问题。 +- 我添加了建议问题API。 +- 我添加了实体或者关系等链接到输出中,你可以直接点击访问参考实体、关系、数据源或者报告。 +- 支持任意兼容OpenAI大模型桌面应用或者Web应用UI接入。 +- 增加Docker构建,最新版本0.3.1, kylinmountain/graphrag-server:0.3.1 + +![image](https://github.com/user-attachments/assets/c251d434-4925-4012-88e7-f3b2ff40471f) + + +![image](https://github.com/user-attachments/assets/ab7a8d2e-aeec-4a0c-afb9-97086b9c7b2a) + +# 如何安装How to install +你可以使用Docker安装,也可以拉取本项目使用。You can install by docker or pull this repo. +## 拉取源码 Pull the source code +- 克隆本项目 Clone the repo +``` +git clone https://github.com/KylinMountain/graphrag.git +cd graphrag +``` +- 建立虚拟环境 Create virtual env +``` +conda create -n graphrag python=3.10 +conda activate graphrag +``` +- 安装poetry Install poetry +``` +curl -sSL https://install.python-poetry.org | python3 - +``` +- 安装依赖 Install dependencies +``` +poetry install +pip install -r webserver/requirements.txt +``` +或者 or +``` +pip install -r requirements.txt +``` +- 初始化GraphRAG Initialize GraphRAG +``` +poetry run poe index --init --root . +# 或者 or +python -m graphrag.index --init --root . +``` +- 创建input文件夹 Create Input Foler +``` +mkdir input +``` +- 配置settings.yaml Config settings.yaml +按照GraphRAG官方配置文档配置 [GraphRAG Configuration](https://microsoft.github.io/graphrag/posts/config/json_yaml/) +- 配置webserver Config webserver + +你可能需要配置以下设置,但默认即可支持本地运行。 You may need config the following item, but you can use the default param. +```yaml + server_host: str = "http://localhost" + server_port: int = 20213 + data: str = ( + "./output" + ) + lancedb_uri: str = ( + "./lancedb" + ) +``` +- 启动web serevr +```bash +python -m webserver.main +``` +更多的参考配置,可以访问[公众号文章](https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MzI0OTAzNTEwMw==&action=getalbum&album_id=3429606151455670272&uin=&key=&devicetype=iMac+MacBookPro17%2C1+OSX+OSX+14.4+build(23E214)&version=13080710&lang=zh_CN&nettype=WIFI&ascene=0&fontScale=100)和[B站视频](https://www.bilibili.com/video/BV113v8e6EZn) + +## 使用Docker安装 Install by docker +- 拉取镜像 pull the docker image +``` +docker pull kylinmountain/graphrag-server:0.3.1 +``` +启动 Start +在启动前 你可以创建output、input、prompts和settings.yaml等目录或者文件 +Before start, you can create output、input、prompts and settings.yaml etc. +``` +docker run -v ./output:/app/output \ + -v ./input:/app/input \ + -v ./prompts:/app/prompts \ + -v ./settings.yaml:/app/settings.yaml \ + -v ./lancedb:/app/lancedb -p 20213:20213 kylinmountain/graphrag-server:0.3.1 + +``` +- 索引 Index +``` +docker run kylinmountain/graphrag-server:0.3.1 python -m graphrag.index --root . +``` + + + + +------- # GraphRAG 👉 [Use the GraphRAG Accelerator solution](https://github.com/Azure-Samples/graphrag-accelerator)
diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py index 6c6405174c..3684105a49 100644 --- a/graphrag/query/structured_search/drift_search/action.py +++ b/graphrag/query/structured_search/drift_search/action.py @@ -7,6 +7,7 @@ import logging from typing import Any +from graphrag.llm.openai.utils import try_parse_json_object from graphrag.query.llm.text_utils import num_tokens log = logging.getLogger(__name__) @@ -71,7 +72,7 @@ async def asearch(self, search_engine: Any, global_query: str, scorer: Any = Non ) try: - response = json.loads(search_result.response) + _, response = try_parse_json_object(search_result.response) except json.JSONDecodeError: error_message = "Failed to parse search response" log.exception("%s: %s", error_message, search_result.response) @@ -198,7 +199,7 @@ def from_primer_response( # If response is a string, attempt to parse as JSON if isinstance(response, str): try: - parsed_response = json.loads(response) + _, parsed_response = try_parse_json_object(response) if isinstance(parsed_response, dict): return cls.from_primer_response(query, parsed_response) error_message = "Parsed response must be a dictionary." diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 1a3d7b27df..064070afc0 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -14,6 +14,7 @@ from tqdm.asyncio import tqdm_asyncio from graphrag.config.models.drift_config import DRIFTSearchConfig +from graphrag.llm.openai.utils import try_parse_json_object from graphrag.model import CommunityReport from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.chat_openai import ChatOpenAI @@ -139,7 +140,7 @@ async def decompose_query( messages, response_format={"type": "json_object"} ) - parsed_response = json.loads(response) + _, parsed_response = try_parse_json_object(response) token_ct = num_tokens(prompt + response, self.token_encoder) return parsed_response, token_ct diff --git a/graphrag/query/structured_search/global_search/map_system_prompt.py b/graphrag/query/structured_search/global_search/map_system_prompt.py index db1a649df3..39a2718405 100644 --- a/graphrag/query/structured_search/global_search/map_system_prompt.py +++ b/graphrag/query/structured_search/global_search/map_system_prompt.py @@ -23,22 +23,21 @@ The response should be JSON formatted as follows: {{ "points": [ - {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}, - {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}} + {{"description": "Description of point 1 [^Data:Reports(report id)][^Data:Reports(report id)]", "score": score_value}}, + {{"description": "Description of point 2 [^Data:Reports(report id)][^Data:Reports(report id)]", "score": score_value}} ] }} The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". Points supported by data should list the relevant reports as references as follows: -"This is an example sentence supported by data references [Data: Reports (report ids)]" +"This is an example sentence supported by data references [^Data:Reports(report id)][^Data:Reports(report id)]" **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" - -where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [^Data:Reports(2)] [^Data:Reports(7)] [^Data:Reports(34)] [^Data:Reports(46)] [^Data:Reports(64,+more)]. He is also CEO of company X [^Data:Reports(1)] [^Data:Reports(3)]" +where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. Do not include information where the supporting evidence for it is not provided. @@ -80,3 +79,4 @@ ] }} """ + diff --git a/graphrag/query/structured_search/global_search/reduce_system_prompt.py b/graphrag/query/structured_search/global_search/reduce_system_prompt.py index 701717817c..6b3a591be7 100644 --- a/graphrag/query/structured_search/global_search/reduce_system_prompt.py +++ b/graphrag/query/structured_search/global_search/reduce_system_prompt.py @@ -25,11 +25,11 @@ The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. -**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. +**References should be listed with a single record ID per citation**, with each citation containing only one record ID. For example, [^Data:Relationships(38)] [^Data:Relationships(55)], instead of [^Data:Relationships(38, 55)]. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [^Data:Reports(2)] [^Data:Reports(7)] [^Data:Reports(34)] [^Data:Reports(46)] [^Data:Reports(64,+more)]. He is also CEO of company X [^Data:Reports(1)] [^Data:Reports(3)]" where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. @@ -60,11 +60,11 @@ The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. -**Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. +**References should be listed with a single record ID per citation**, with each citation containing only one record ID. For example, [^Data:Relationships(38)] [^Data:Relationships(55)], instead of [^Data:Relationships(38, 55)]. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]" +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [^Data:Reports(2)] [^Data:Reports(7)] [^Data:Reports(34)] [^Data:Reports(46)] [^Data:Reports(64,+more)]. He is also CEO of company X [^Data:Reports(1)] [^Data:Reports(3)]" where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. diff --git a/graphrag/query/structured_search/local_search/system_prompt.py b/graphrag/query/structured_search/local_search/system_prompt.py index 70b1d12fc3..73253dba53 100644 --- a/graphrag/query/structured_search/local_search/system_prompt.py +++ b/graphrag/query/structured_search/local_search/system_prompt.py @@ -17,13 +17,15 @@ Points supported by data should list their data references as follows: -"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." +"This is an example sentence supported by multiple data references [^Data:(record id)] [^Data:(record id)]." -Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. +The should be one of Entities, Relationships, Claims, Sources, Reports. + +**References should be listed with a single record ID per citation**, with each citation containing only one record ID. For example, [^Data:Relationships(38)] [^Data:Relationships(55)], instead of [^Data:Relationships(38, 55)]. For example: -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." +"Person X is the owner of Company Y and subject to many allegations of wrongdoing [^Data:Sources(15)] [^Data:Sources(16)] [^Data:Reports(1)] [^Data:Entities(5)] [^Data:Entities(7)] [^Data:Relationships(23)] [^Data:Claims(2)] [^Data:Claims(7)] [^Data:Claims(34)] [^Data:Claims(46)] [^Data:Claims(64,+more)]." where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. @@ -39,31 +41,4 @@ {context_data} - ----Goal--- - -Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. - -If you don't know the answer, just say so. Do not make anything up. - -Points supported by data should list their data references as follows: - -"This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." - -Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. - -For example: - -"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]." - -where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. - -Do not include information where the supporting evidence for it is not provided. - - ----Target response length and format--- - -{response_type} - -Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. """ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..a115cfbdfd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,118 @@ +aiofiles==24.1.0 ; python_version >= "3.10" and python_version < "3.13" +aiolimiter==1.1.0 ; python_version >= "3.10" and python_version < "3.13" +annotated-types==0.7.0 ; python_version >= "3.10" and python_version < "3.13" +anyio==4.6.2.post1 ; python_version >= "3.10" and python_version < "3.13" +anytree==2.12.1 ; python_version >= "3.10" and python_version < "3.13" +asttokens==2.4.1 ; python_version >= "3.10" and python_version < "3.13" +attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13" +autograd==1.7.0 ; python_version >= "3.10" and python_version < "3.13" +azure-common==1.1.28 ; python_version >= "3.10" and python_version < "3.13" +azure-core==1.32.0 ; python_version >= "3.10" and python_version < "3.13" +azure-identity==1.19.0 ; python_version >= "3.10" and python_version < "3.13" +azure-search-documents==11.5.2 ; python_version >= "3.10" and python_version < "3.13" +azure-storage-blob==12.23.1 ; python_version >= "3.10" and python_version < "3.13" +beartype==0.18.5 ; python_version >= "3.10" and python_version < "3.13" +cachetools==5.5.0 ; python_version >= "3.10" and python_version < "3.13" +certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13" +cffi==1.17.1 ; python_version >= "3.10" and python_version < "3.13" and platform_python_implementation != "PyPy" +charset-normalizer==3.4.0 ; python_version >= "3.10" and python_version < "3.13" +click==8.1.7 ; python_version >= "3.10" and python_version < "3.13" +colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and platform_system == "Windows" +contourpy==1.3.0 ; python_version >= "3.10" and python_version < "3.13" +cryptography==43.0.3 ; python_version >= "3.10" and python_version < "3.13" +cycler==0.12.1 ; python_version >= "3.10" and python_version < "3.13" +datashaper==0.0.49 ; python_version >= "3.10" and python_version < "3.13" +decorator==5.1.1 ; python_version >= "3.10" and python_version < "3.13" +deprecation==2.1.0 ; python_version >= "3.10" and python_version < "3.13" +devtools==0.12.2 ; python_version >= "3.10" and python_version < "3.13" +diskcache==5.6.3 ; python_version >= "3.10" and python_version < "3.13" +distro==1.9.0 ; python_version >= "3.10" and python_version < "3.13" +environs==11.0.0 ; python_version >= "3.10" and python_version < "3.13" +exceptiongroup==1.2.2 ; python_version >= "3.10" and python_version < "3.11" +executing==2.1.0 ; python_version >= "3.10" and python_version < "3.13" +fonttools==4.54.1 ; python_version >= "3.10" and python_version < "3.13" +future==1.0.0 ; python_version >= "3.10" and python_version < "3.13" +gensim==4.3.3 ; python_version >= "3.10" and python_version < "3.13" +graspologic-native==1.2.1 ; python_version >= "3.10" and python_version < "3.13" +graspologic==3.4.1 ; python_version >= "3.10" and python_version < "3.13" +h11==0.14.0 ; python_version >= "3.10" and python_version < "3.13" +httpcore==1.0.6 ; python_version >= "3.10" and python_version < "3.13" +httpx==0.27.2 ; python_version >= "3.10" and python_version < "3.13" +hyppo==0.4.0 ; python_version >= "3.10" and python_version < "3.13" +idna==3.10 ; python_version >= "3.10" and python_version < "3.13" +isodate==0.7.2 ; python_version >= "3.10" and python_version < "3.13" +jiter==0.7.0 ; python_version >= "3.10" and python_version < "3.13" +joblib==1.4.2 ; python_version >= "3.10" and python_version < "3.13" +json-repair==0.30.1 ; python_version >= "3.10" and python_version < "3.13" +jsonschema-specifications==2024.10.1 ; python_version >= "3.10" and python_version < "3.13" +jsonschema==4.23.0 ; python_version >= "3.10" and python_version < "3.13" +kiwisolver==1.4.7 ; python_version >= "3.10" and python_version < "3.13" +lancedb==0.13.0 ; python_version >= "3.10" and python_version < "3.13" +llvmlite==0.43.0 ; python_version >= "3.10" and python_version < "3.13" +markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.13" +marshmallow==3.23.1 ; python_version >= "3.10" and python_version < "3.13" +matplotlib==3.9.2 ; python_version >= "3.10" and python_version < "3.13" +mdurl==0.1.2 ; python_version >= "3.10" and python_version < "3.13" +msal-extensions==1.2.0 ; python_version >= "3.10" and python_version < "3.13" +msal==1.31.0 ; python_version >= "3.10" and python_version < "3.13" +networkx==3.4.2 ; python_version >= "3.10" and python_version < "3.13" +nltk==3.9.1 ; python_version >= "3.10" and python_version < "3.13" +numba==0.60.0 ; python_version >= "3.10" and python_version < "3.13" +numpy==1.26.4 ; python_version >= "3.10" and python_version < "3.13" +openai==1.54.0 ; python_version >= "3.10" and python_version < "3.13" +overrides==7.7.0 ; python_version >= "3.10" and python_version < "3.13" +packaging==24.1 ; python_version >= "3.10" and python_version < "3.13" +pandas==2.2.3 ; python_version >= "3.10" and python_version < "3.13" +patsy==0.5.6 ; python_version >= "3.10" and python_version < "3.13" +pillow==11.0.0 ; python_version >= "3.10" and python_version < "3.13" +portalocker==2.10.1 ; python_version >= "3.10" and python_version < "3.13" +pot==0.9.4 ; python_version >= "3.10" and python_version < "3.13" +py==1.11.0 ; python_version >= "3.10" and python_version < "3.13" +pyaml-env==1.2.1 ; python_version >= "3.10" and python_version < "3.13" +pyarrow==15.0.2 ; python_version >= "3.10" and python_version < "3.13" +pycparser==2.22 ; python_version >= "3.10" and python_version < "3.13" and platform_python_implementation != "PyPy" +pydantic-core==2.23.4 ; python_version >= "3.10" and python_version < "3.13" +pydantic==2.9.2 ; python_version >= "3.10" and python_version < "3.13" +pygments==2.18.0 ; python_version >= "3.10" and python_version < "3.13" +pyjwt[crypto]==2.9.0 ; python_version >= "3.10" and python_version < "3.13" +pylance==0.17.0 ; python_version >= "3.10" and python_version < "3.13" +pynndescent==0.5.13 ; python_version >= "3.10" and python_version < "3.13" +pyparsing==3.2.0 ; python_version >= "3.10" and python_version < "3.13" +python-dateutil==2.9.0.post0 ; python_version >= "3.10" and python_version < "3.13" +python-dotenv==1.0.1 ; python_version >= "3.10" and python_version < "3.13" +pytz==2024.2 ; python_version >= "3.10" and python_version < "3.13" +pywin32==308 ; python_version >= "3.10" and python_version < "3.13" and platform_system == "Windows" +pyyaml==6.0.2 ; python_version >= "3.10" and python_version < "3.13" +referencing==0.35.1 ; python_version >= "3.10" and python_version < "3.13" +regex==2024.9.11 ; python_version >= "3.10" and python_version < "3.13" +requests==2.32.3 ; python_version >= "3.10" and python_version < "3.13" +retry==0.9.2 ; python_version >= "3.10" and python_version < "3.13" +rich==13.9.4 ; python_version >= "3.10" and python_version < "3.13" +rpds-py==0.20.1 ; python_version >= "3.10" and python_version < "3.13" +scikit-learn==1.5.2 ; python_version >= "3.10" and python_version < "3.13" +scipy==1.12.0 ; python_version >= "3.10" and python_version < "3.13" +seaborn==0.13.2 ; python_version >= "3.10" and python_version < "3.13" +shellingham==1.5.4 ; python_version >= "3.10" and python_version < "3.13" +six==1.16.0 ; python_version >= "3.10" and python_version < "3.13" +smart-open==7.0.5 ; python_version >= "3.10" and python_version < "3.13" +sniffio==1.3.1 ; python_version >= "3.10" and python_version < "3.13" +statsmodels==0.14.4 ; python_version >= "3.10" and python_version < "3.13" +tenacity==9.0.0 ; python_version >= "3.10" and python_version < "3.13" +threadpoolctl==3.5.0 ; python_version >= "3.10" and python_version < "3.13" +tiktoken==0.7.0 ; python_version >= "3.10" and python_version < "3.13" +tqdm==4.66.6 ; python_version >= "3.10" and python_version < "3.13" +typer==0.12.5 ; python_version >= "3.10" and python_version < "3.13" +typing-extensions==4.12.2 ; python_version >= "3.10" and python_version < "3.13" +tzdata==2024.2 ; python_version >= "3.10" and python_version < "3.13" +umap-learn==0.5.7 ; python_version >= "3.10" and python_version < "3.13" +urllib3==2.2.3 ; python_version >= "3.10" and python_version < "3.13" +wrapt==1.16.0 ; python_version >= "3.10" and python_version < "3.13" + +#graphrag web server requirements +pydantic_settings==2.3.4 +uvicorn~=0.30.4 +fastapi~=0.103.0 +jinja2~=3.1.4 +aiocache~=0.12.2 +milvus-model~=0.2.5 +python-multipart~=0.0.9 \ No newline at end of file diff --git a/scripts/neo4jvisualization.py b/scripts/neo4jvisualization.py new file mode 100644 index 0000000000..40fbfd5120 --- /dev/null +++ b/scripts/neo4jvisualization.py @@ -0,0 +1,167 @@ +import os +import time + +import pandas as pd +from neo4j import GraphDatabase + +NEO4J_URI = "neo4j://localhost" # or neo4j+s://xxxx.databases.neo4j.io +NEO4J_USERNAME = "neo4j" +NEO4J_PASSWORD = "password" +NEO4J_DATABASE = "neo4j" + +# Create a Neo4j driver +driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + +GRAPHRAG_FOLDER = os.path.join("..", "output", "20240802-112937", "artifacts") + + +def batched_import(statement, df, batch_size=1000): + """ + Import a dataframe into Neo4j using a batched approach. + Parameters: statement is the Cypher query to execute, df is the dataframe to import, and batch_size is the number of rows to import in each batch. + """ + total = len(df) + start_s = time.time() + for start in range(0, total, batch_size): + batch = df.iloc[start: min(start + batch_size, total)] + result = driver.execute_query("UNWIND $rows AS value " + statement, + rows=batch.to_dict('records'), + database_=NEO4J_DATABASE) + print(result.summary.counters) + print(f'{total} rows in {time.time() - start_s} s.') + return total + + +# create constraints, idempotent operation + +statements = """ +create constraint chunk_id if not exists for (c:__Chunk__) require c.id is unique; +create constraint document_id if not exists for (d:__Document__) require d.id is unique; +create constraint entity_id if not exists for (c:__Community__) require c.community is unique; +create constraint entity_id if not exists for (e:__Entity__) require e.id is unique; +create constraint entity_title if not exists for (e:__Entity__) require e.name is unique; +create constraint entity_title if not exists for (e:__Covariate__) require e.title is unique; +create constraint related_id if not exists for ()-[rel:RELATED]->() require rel.id is unique; +""".split(";") + +for statement in statements: + if len((statement or "").strip()) > 0: + print(statement) + driver.execute_query(statement) + +doc_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_documents.parquet"), columns=["id", "title"]) +doc_df.head(2) + +# import documents +statement = """ +MERGE (d:__Document__ {id:value.id}) +SET d += value {.title} +""" + +batched_import(statement, doc_df) + +text_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_text_units.parquet"), + columns=["id", "text", "n_tokens", "document_ids"]) +text_df.head(2) + +statement = """ +MERGE (c:__Chunk__ {id:value.id}) +SET c += value {.text, .n_tokens} +WITH c, value +UNWIND value.document_ids AS document +MATCH (d:__Document__ {id:document}) +MERGE (c)-[:PART_OF]->(d) +""" + +batched_import(statement, text_df) + +entity_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_entities.parquet"), + columns=["name", "type", "description", "human_readable_id", "id", "description_embedding", + "text_unit_ids"]) +entity_df.head(2) + +entity_statement = """ +MERGE (e:__Entity__ {id:value.id}) +SET e += value {.human_readable_id, .description, name:replace(value.name,'"','')} +WITH e, value +CALL db.create.setNodeVectorProperty(e, "description_embedding", value.description_embedding) +CALL apoc.create.addLabels(e, case when coalesce(value.type,"") = "" then [] else [apoc.text.upperCamelCase(replace(value.type,'"',''))] end) yield node +UNWIND value.text_unit_ids AS text_unit +MATCH (c:__Chunk__ {id:text_unit}) +MERGE (c)-[:HAS_ENTITY]->(e) +""" + +batched_import(entity_statement, entity_df) + +rel_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_relationships.parquet"), + columns=["source", "target", "id", "rank", "weight", "human_readable_id", "description", + "text_unit_ids"]) +rel_df.head(2) + +rel_statement = """ + MATCH (source:__Entity__ {name:replace(value.source,'"','')}) + MATCH (target:__Entity__ {name:replace(value.target,'"','')}) + // not necessary to merge on id as there is only one relationship per pair + MERGE (source)-[rel:RELATED {id: value.id}]->(target) + SET rel += value {.rank, .weight, .human_readable_id, .description, .text_unit_ids} + RETURN count(*) as createdRels +""" + +batched_import(rel_statement, rel_df) + +community_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_communities.parquet"), + columns=["id", "level", "title", "text_unit_ids", "relationship_ids"]) + +community_df.head(2) + +statement = """ +MERGE (c:__Community__ {community:value.id}) +SET c += value {.level, .title} +/* +UNWIND value.text_unit_ids as text_unit_id +MATCH (t:__Chunk__ {id:text_unit_id}) +MERGE (c)-[:HAS_CHUNK]->(t) +WITH distinct c, value +*/ +WITH * +UNWIND value.relationship_ids as rel_id +MATCH (start:__Entity__)-[:RELATED {id:rel_id}]->(end:__Entity__) +MERGE (start)-[:IN_COMMUNITY]->(c) +MERGE (end)-[:IN_COMMUNITY]->(c) +RETURn count(distinct c) as createdCommunities +""" + +batched_import(statement, community_df) + +community_report_df = pd.read_parquet(os.path.join(GRAPHRAG_FOLDER, "create_final_community_reports.parquet"), + columns=["id", "community", "level", "title", "summary", "findings", "rank", + "rank_explanation", "full_content"]) +community_report_df.head(2) + +# community_df['findings'][0] + +# import communities +community_statement = """MATCH (c:__Community__ {community: value.community}) +SET c += value {.level, .title, .rank, .rank_explanation, .full_content, .summary} +WITH c, value +UNWIND range(0, size(value.findings)-1) AS finding_idx +WITH c, value, finding_idx, value.findings[finding_idx] as finding +MERGE (c)-[:HAS_FINDING]->(f:Finding {id: finding_idx}) +SET f += finding""" +batched_import(community_statement, community_report_df) + +# cov_df = pd.read_parquet(f'{GRAPHRAG_FOLDER}/create_final_covariates.parquet'), +# # columns=["id","text_unit_id"]) +# cov_df.head(2) +# # Subject id do not match entity ids +# +# +# # import covariates +# cov_statement = """ +# MERGE (c:__Covariate__ {id:value.id}) +# SET c += apoc.map.clean(value, ["text_unit_id", "document_ids", "n_tokens"], [NULL, ""]) +# WITH c, value +# MATCH (ch:__Chunk__ {id: value.text_unit_id}) +# MERGE (ch)-[:HAS_COVARIATE]->(c) +# """ +# batched_import(cov_statement, cov_df) diff --git a/webserver/__init__.py b/webserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/webserver/configs/__init__.py b/webserver/configs/__init__.py new file mode 100644 index 0000000000..84a6cc56d6 --- /dev/null +++ b/webserver/configs/__init__.py @@ -0,0 +1,3 @@ +from .settings import settings + +__all__ = ["settings"] diff --git a/webserver/configs/settings.py b/webserver/configs/settings.py new file mode 100644 index 0000000000..6bbc19d18d --- /dev/null +++ b/webserver/configs/settings.py @@ -0,0 +1,74 @@ +import os + +import yaml +from azure.identity import get_bearer_token_provider, DefaultAzureCredential +from dotenv import load_dotenv +from pydantic_settings import BaseSettings + +from graphrag.config import LLMParameters, TextEmbeddingConfig, LocalSearchConfig, GlobalSearchConfig, LLMType +from graphrag.query.llm.oai import OpenaiApiType + + +class Settings(BaseSettings): + server_port: int = 20213 + website_address: str = f"http://127.0.0.1:{server_port}" + cors_allowed_origins: list = ["*"] # Edit the list to restrict access. + data: str = ( + "./output" + ) + llm: LLMParameters + embeddings: TextEmbeddingConfig + global_search: GlobalSearchConfig + local_search: LocalSearchConfig + encoding_model: str = "cl100k_base" + + def is_azure_client(self): + return self.llm.type == LLMType.AzureOpenAIChat or settings.llm.type == LLMType.AzureOpenAI + + def get_api_type(self): + return OpenaiApiType.AzureOpenAI if self.is_azure_client() else OpenaiApiType.OpenAI + + def azure_ad_token_provider(self): + is_azure_client = ( + settings.llm.type == LLMType.AzureOpenAIChat + or settings.llm.type == LLMType.AzureOpenAI + ) + + audience = ( + settings.llm.audience + if settings.llm.audience + else "https://cognitiveservices.azure.com/.default" + ) + + if is_azure_client and not settings.llm.api_key: + return get_bearer_token_provider(DefaultAzureCredential(), audience) + else: + return None + + +def load_settings_from_yaml(file_path: str) -> Settings: + with open(file_path, 'r') as file: + config = yaml.safe_load(file) + llm_config = config['llm'] + embeddings_config = config['embeddings'] + global_search_config = config['global_search'] + local_search_config = config['local_search'] + encoding_model = config['encoding_model'] + + # Manually setting the API keys from environment variables if specified + load_dotenv() + llm_params = LLMParameters(**llm_config) + llm_params.api_key = os.environ.get("GRAPHRAG_API_KEY", llm_config['api_key']) + text_embedding = TextEmbeddingConfig(**embeddings_config) + text_embedding.llm.api_key = os.environ.get("GRAPHRAG_API_KEY", embeddings_config['llm']['api_key']) + + return Settings( + llm=llm_params, + embeddings=text_embedding, + global_search=GlobalSearchConfig(**global_search_config if global_search_config else {}), + local_search=LocalSearchConfig(**local_search_config if local_search_config else {}), + encoding_model=encoding_model + ) + + +settings = load_settings_from_yaml("settings.yaml") diff --git a/webserver/gtypes/__init__.py b/webserver/gtypes/__init__.py new file mode 100644 index 0000000000..01ec5cea18 --- /dev/null +++ b/webserver/gtypes/__init__.py @@ -0,0 +1,3 @@ +from .chat_request import CompletionCreateParamsBase as ChatCompletionRequest +from .chat_result import TypedFuture, QuestionGenResult +from .chat_request import ChatQuestionGen, Model, ModelList diff --git a/webserver/gtypes/chat_request.py b/webserver/gtypes/chat_request.py new file mode 100644 index 0000000000..56c33f429f --- /dev/null +++ b/webserver/gtypes/chat_request.py @@ -0,0 +1,70 @@ +from typing import Optional, Dict, List, Union, Literal, Any + +from pydantic import BaseModel + + +class ChatCompletionMessageParam(BaseModel): + content: str + role: str = "user" + + +class ResponseFormat(BaseModel): + type: str + + +class ChatCompletionStreamOptionsParam(BaseModel): + enable: bool + + +class ChatCompletionToolParam(BaseModel): + name: str + description: str + + +class CompletionCreateParamsBase(BaseModel): + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = None + logit_bias: Optional[Dict[str, int]] = None + logprobs: Optional[bool] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + parallel_tool_calls: bool = False + presence_penalty: Optional[float] = None + response_format: ResponseFormat = ResponseFormat(type="text") + seed: Optional[int] = None + service_tier: Optional[Literal["auto", "default"]] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + stream_options: Optional[Dict] = None + temperature: Optional[float] = 0.0 + tools: List[ChatCompletionToolParam] = None + top_logprobs: Optional[int] = None + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + def llm_chat_params(self) -> dict[str, Any]: + return { + "temperature": self.temperature, + "seed": self.seed + } + + +class ChatQuestionGen(BaseModel): + messages: List[ChatCompletionMessageParam] + model: str + max_tokens: Optional[int] = None + temperature: Optional[float] = 0.0 + n: Optional[int] = None + + +class Model(BaseModel): + id: str + object: Literal["model"] + created: int + owned_by: str + + +class ModelList(BaseModel): + object: Literal["list"] = "list" + data: List[Model] diff --git a/webserver/gtypes/chat_result.py b/webserver/gtypes/chat_result.py new file mode 100644 index 0000000000..6bc0889d23 --- /dev/null +++ b/webserver/gtypes/chat_result.py @@ -0,0 +1,15 @@ +import asyncio +from typing import Any + +from pydantic import BaseModel + + +class TypedFuture(asyncio.Future): + pass + + +class QuestionGenResult(BaseModel): + questions: list[str] + completion_time: float + llm_calls: int + prompt_tokens: int diff --git a/webserver/main.py b/webserver/main.py new file mode 100644 index 0000000000..3524e01841 --- /dev/null +++ b/webserver/main.py @@ -0,0 +1,290 @@ +import logging +import os +import time +import uuid + +import tiktoken +from fastapi import FastAPI, HTTPException +from fastapi.encoders import jsonable_encoder +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse +from fastapi.staticfiles import StaticFiles +from jinja2 import Template +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta + +from graphrag.query.context_builder.conversation_history import ConversationHistory +from graphrag.query.llm.oai import ChatOpenAI, OpenAIEmbedding +from graphrag.query.question_gen.local_gen import LocalQuestionGen +from graphrag.query.structured_search.base import BaseSearch +from graphrag.query.structured_search.drift_search.search import DRIFTSearch +from graphrag.query.structured_search.global_search.search import GlobalSearch +from graphrag.query.structured_search.local_search.search import LocalSearch +from webserver import gtypes +from webserver import search +from webserver import utils +from webserver.configs import settings +from webserver.search.localsearch import build_drift_search_context +from webserver.utils import consts + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_allowed_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.mount("/static", StaticFiles(directory="webserver/static"), name="static") + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +llm = ChatOpenAI( + api_key=settings.llm.api_key, + api_base=settings.llm.api_base, + model=settings.llm.model, + api_type=settings.get_api_type(), + max_retries=settings.llm.max_retries, + azure_ad_token_provider=settings.azure_ad_token_provider(), + deployment_name=settings.llm.deployment_name, + api_version=settings.llm.api_version, + organization=settings.llm.organization, + request_timeout=settings.llm.request_timeout, +) + +text_embedder = OpenAIEmbedding( + api_key=settings.embeddings.llm.api_key, + api_base=settings.embeddings.llm.api_base, + api_type=settings.get_api_type(), + api_version=settings.embeddings.llm.api_version, + model=settings.embeddings.llm.model, + max_retries=settings.embeddings.llm.max_retries, + max_tokens=settings.embeddings.llm.max_tokens, + azure_ad_token_provider=settings.azure_ad_token_provider(), + deployment_name=settings.embeddings.llm.deployment_name, + organization=settings.embeddings.llm.organization, + encoding_name=settings.encoding_model, + request_timeout=settings.embeddings.llm.request_timeout, +) + +token_encoder = tiktoken.get_encoding("cl100k_base") + +local_search: LocalSearch +global_search: GlobalSearch +question_gen: LocalQuestionGen + + +@app.on_event("startup") +async def startup_event(): + global local_search + global global_search + global question_gen + local_search = await search.build_local_search_engine(llm, token_encoder=token_encoder) + global_search = await search.build_global_search_engine(llm, token_encoder=token_encoder) + question_gen = await search.build_local_question_gen(llm, token_encoder=token_encoder) + + +@app.get("/") +async def index(): + html_file_path = os.path.join("webserver", "templates", "index.html") + with open(html_file_path, "r", encoding="utf-8") as file: + html_content = file.read() + return HTMLResponse(content=html_content) + + +async def initialize_search(request: gtypes.ChatCompletionRequest, search: BaseSearch, index: str = None): + search.context_builder = await switch_context(index=index) + search.llm_params.update(request.llm_chat_params()) + return search + + +async def handle_sync_response(request, search, conversation_history, drift_search: bool = False): + result = await search.asearch(request.messages[-1].content, conversation_history=conversation_history) + if drift_search: + response = result.response + # context_data = _reformat_context_data(result.context_data) # type: ignore + response = response["nodes"][0]["answer"] + else: + response = result.response + + reference = utils.get_reference(response) + if reference: + response += f"\n{utils.generate_ref_links(reference, request.model)}" + from openai.types.chat.chat_completion import Choice + completion = ChatCompletion( + id=f"chatcmpl-{uuid.uuid4().hex}", + created=int(time.time()), + model=request.model, + object="chat.completion", + choices=[ + Choice( + index=0, + finish_reason="stop", + message=ChatCompletionMessage( + role="assistant", + content=response + ) + ) + ], + usage=CompletionUsage( + completion_tokens=-1, + prompt_tokens=result.prompt_tokens, + total_tokens=-1 + ) + ) + return JSONResponse(content=jsonable_encoder(completion)) + + +async def handle_stream_response(request, search, conversation_history): + async def wrapper_astream_search(): + token_index = 0 + chat_id = f"chatcmpl-{uuid.uuid4().hex}" + full_response = "" + async for token in search.astream_search(request.messages[-1].content, conversation_history): # 调用原始的生成器 + if token_index == 0: + token_index += 1 + continue + + chunk = ChatCompletionChunk( + id=chat_id, + created=int(time.time()), + model=request.model, + object="chat.completion.chunk", + choices=[ + Choice( + index=token_index - 1, + finish_reason=None, + delta=ChoiceDelta( + role="assistant", + content=token + ) + ) + ] + ) + yield f"data: {chunk.json()}\n\n" + token_index += 1 + full_response += token + + content = "" + reference = utils.get_reference(full_response) + if reference: + content = f"\n{utils.generate_ref_links(reference, request.model)}" + finish_reason = 'stop' + chunk = ChatCompletionChunk( + id=chat_id, + created=int(time.time()), + model=request.model, + object="chat.completion.chunk", + choices=[ + Choice( + index=token_index, + finish_reason=finish_reason, + delta=ChoiceDelta( + role="assistant", + # content=result.context_data["entities"].head().to_string() + content=content + ) + ), + ], + ) + yield f"data: {chunk.json()}\n\n" + yield f"data: [DONE]\n\n" + + return StreamingResponse(wrapper_astream_search(), media_type="text/event-stream") + + +@app.post("/v1/chat/completions") +async def chat_completions(request: gtypes.ChatCompletionRequest): + if not local_search or not global_search: + logger.error("graphrag search engines is not initialized") + raise HTTPException(status_code=500, detail="graphrag search engines is not initialized") + + try: + history = request.messages[:-1] + conversation_history = ConversationHistory.from_list([message.dict() for message in history]) + + if request.model == consts.INDEX_GLOBAL: + search_engine = await initialize_search(request, global_search, request.model) + elif request.model == consts.INDEX_LOCAL: + search_engine = await initialize_search(request, local_search, request.model) + elif request.model == consts.INDEX_DRIFT: + context_builder = await build_drift_search_context(llm, settings.data, text_embedder) + drift_search = await search.build_drift_search_engine(llm, context_builder=context_builder, token_encoder=token_encoder) + search_engine = await initialize_search(request, drift_search, request.model) + # due to dirt search engine doesn't support streaming search yet. + return await handle_sync_response(request, search_engine, conversation_history, drift_search=True) + else: + raise NotImplementedError(f"model {request.model} is not supported") + + if not request.stream: + return await handle_sync_response(request, search_engine, conversation_history) + else: + return await handle_stream_response(request, search_engine, conversation_history) + except Exception as e: + logger.error(msg=f"chat_completions error: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/advice_questions", response_model=gtypes.QuestionGenResult) +async def get_advice_question(request: gtypes.ChatQuestionGen): + if request.model == consts.INDEX_LOCAL: + local_context = await switch_context(index=request.model) + question_gen.context_builder = local_context + else: + raise NotImplementedError(f"model {request.model} is not supported") + question_history = [message.content for message in request.messages if message.role == "user"] + candidate_questions = await question_gen.agenerate( + question_history=question_history, context_data=None, question_count=5 + ) + # the original generated question is "- what about xxx?" + questions: list[str] = [question.removeprefix("-").strip() for question in candidate_questions.response] + resp = gtypes.QuestionGenResult(questions=questions, + completion_time=candidate_questions.completion_time, + llm_calls=candidate_questions.llm_calls, + prompt_tokens=candidate_questions.prompt_tokens) + return resp + + +@app.get("/v1/models", response_model=gtypes.ModelList) +async def list_models(): + models: list[gtypes.Model] = [ + gtypes.Model(id=consts.INDEX_LOCAL, object="model", created=1644752340, owned_by="graphrag"), + gtypes.Model(id=consts.INDEX_GLOBAL, object="model", created=1644752340, owned_by="graphrag"), + gtypes.Model(id=consts.INDEX_DRIFT, object="model", created=1644752340, owned_by="graphrag")] + return gtypes.ModelList(data=models) + + +@app.get("/v1/references/{index_id}/{datatype}/{id}", response_class=HTMLResponse) +async def get_reference(index_id: str, datatype: str, id: int): + if not os.path.exists(settings.data): + raise HTTPException(status_code=404, detail=f"{index_id} not found") + if datatype not in ["entities", "claims", "sources", "reports", "relationships"]: + raise HTTPException(status_code=404, detail=f"{datatype} not found") + + data = await search.get_index_data(settings.data, datatype, id) + html_file_path = os.path.join("webserver", "templates", f"{datatype}_template.html") + with open(html_file_path, 'r') as file: + html_content = file.read() + template = Template(html_content) + html_content = template.render(data=data) + return HTMLResponse(content=html_content) + + +async def switch_context(index: str): + if index == consts.INDEX_GLOBAL: + context_builder = await search.load_global_context(settings.data, token_encoder) + elif index == consts.INDEX_LOCAL: + context_builder = await search.load_local_context(settings.data, text_embedder, token_encoder) + elif index == consts.INDEX_DRIFT: + context_builder = await build_drift_search_context(llm, settings.data, text_embedder) + else: + raise NotImplementedError(f"{index} is not supported") + return context_builder + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, port=settings.server_port) diff --git a/webserver/search/__init__.py b/webserver/search/__init__.py new file mode 100644 index 0000000000..1a749bd987 --- /dev/null +++ b/webserver/search/__init__.py @@ -0,0 +1,3 @@ +from .localsearch import build_local_question_gen, build_local_search_engine, load_local_context, build_drift_search_engine +from .globalsearch import build_global_search_engine, load_global_context +from .indexdata import get_index_data diff --git a/webserver/search/globalsearch.py b/webserver/search/globalsearch.py new file mode 100644 index 0000000000..3c10f8f6b6 --- /dev/null +++ b/webserver/search/globalsearch.py @@ -0,0 +1,80 @@ +import pandas as pd +import tiktoken + +from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback +from graphrag.query.context_builder.builders import GlobalContextBuilder +from graphrag.query.indexer_adapters import read_indexer_entities, read_indexer_reports +from graphrag.query.llm.base import BaseLLM +from graphrag.query.structured_search.global_search.community_context import ( + GlobalCommunityContext, +) +from graphrag.query.structured_search.global_search.search import GlobalSearch + +from webserver.configs import settings +from webserver.utils import consts + + +async def load_global_context(input_dir: str, + token_encoder: tiktoken.Encoding | None = None) -> GlobalContextBuilder: + final_nodes = pd.read_parquet(f"{input_dir}/{consts.ENTITY_TABLE}.parquet") + final_community_reports = pd.read_parquet(f"{input_dir}/{consts.COMMUNITY_REPORT_TABLE}.parquet") + final_entities = pd.read_parquet(f"{input_dir}/{consts.ENTITY_EMBEDDING_TABLE}.parquet") + + reports = read_indexer_reports(final_community_reports, final_nodes, consts.COMMUNITY_LEVEL) + entities = read_indexer_entities(final_nodes, final_entities, consts.COMMUNITY_LEVEL) + + context_builder = GlobalCommunityContext( + community_reports=reports, + entities=entities, # default to None if you don't want to use community weights for ranking + token_encoder=token_encoder, + ) + return context_builder + + +async def build_global_search_engine(llm: BaseLLM, context_builder=None, callback: GlobalSearchLLMCallback = None, + token_encoder: tiktoken.Encoding | None = None) -> GlobalSearch: + context_builder_params = { + "use_community_summary": False, + "shuffle_data": True, + "include_community_rank": True, + "min_community_rank": 0, + "community_rank_name": "rank", + "include_community_weight": True, + "community_weight_name": "occurrence weight", + "normalize_community_weight": True, + "max_tokens": settings.global_search.max_tokens, + "context_name": "Reports", + } + + map_llm_params = { + "max_tokens": settings.global_search.map_max_tokens, + "temperature": settings.global_search.temperature, + "top_p": settings.global_search.top_p, + "n": settings.global_search.n, + "response_format": {"type": "json_object"}, + } + + reduce_llm_params = { + "max_tokens": settings.global_search.reduce_max_tokens, + "temperature": settings.global_search.temperature, + "top_p": settings.global_search.top_p, + "n": settings.global_search.n, + } + + search_engine = GlobalSearch( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + max_data_tokens=settings.global_search.data_max_tokens, + map_llm_params=map_llm_params, + reduce_llm_params=reduce_llm_params, + allow_general_knowledge=False, + json_mode=settings.llm.model_supports_json, # set this to False if your LLM model does not support JSON mode. + context_builder_params=context_builder_params, + concurrent_coroutines=settings.global_search.concurrency, + callbacks=[callback] if callback else None, + # free form text describing the response type and format, can be anything, + # e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report + response_type="multiple paragraphs", + ) + return search_engine diff --git a/webserver/search/indexdata.py b/webserver/search/indexdata.py new file mode 100644 index 0000000000..79cafe9429 --- /dev/null +++ b/webserver/search/indexdata.py @@ -0,0 +1,81 @@ +import os +from typing import Optional + +import pandas as pd + +from graphrag.model import Relationship, Covariate, Entity, CommunityReport, TextUnit +from graphrag.query.indexer_adapters import read_indexer_relationships, read_indexer_covariates, read_indexer_entities, \ + read_indexer_reports, read_indexer_text_units +from ..utils import consts + + +async def get_index_data(input_dir: str, datatype: str, id: Optional[int] = None): + if datatype == "entities": + return await get_entity(input_dir, id) + elif datatype == "claims": + return await get_claim(input_dir, id) + elif datatype == "sources": + return await get_source(input_dir, id) + elif datatype == "reports": + return await get_report(input_dir, id) + elif datatype == "relationships": + return await get_relationship(input_dir, id) + else: + raise ValueError(f"Unknown datatype: {datatype}") + + +async def get_entity(input_dir: str, row_id: Optional[int] = None) -> Entity: + entity_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_TABLE}.parquet") + entity_embedding_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_EMBEDDING_TABLE}.parquet") + + entities = read_indexer_entities(entity_df, entity_embedding_df, consts.COMMUNITY_LEVEL) + # TODO optimize performance using like database or dict in memory + for entity in entities: + if int(entity.short_id) == row_id: + return entity + raise ValueError(f"Not Found entity id {row_id}") + + +async def get_claim(input_dir: str, row_id: Optional[int] = None) -> Covariate: + covariate_file = f"{input_dir}/{consts.COVARIATE_TABLE}.parquet" + if os.path.exists(covariate_file): + covariate_df = pd.read_parquet(covariate_file) + claims = read_indexer_covariates(covariate_df) + else: + raise ValueError(f"No claims {input_dir} of id {row_id} found") + # TODO optimize performance using like database or dict in memory + for claim in claims: + if int(claim.short_id) == row_id: + return claim + raise ValueError(f"Not Found claim id {row_id}") + + +async def get_source(input_dir: str, row_id: Optional[int] = None) -> TextUnit: + text_unit_df = pd.read_parquet(f"{input_dir}/{consts.TEXT_UNIT_TABLE}.parquet") + text_units = read_indexer_text_units(text_unit_df) + # TODO optimize performance using like database or dict in memory + for text_unit in text_units: + if int(text_unit.short_id) == row_id: + return text_unit + raise ValueError(f"Not Found source id {row_id}") + + +async def get_report(input_dir: str, row_id: Optional[int] = None) -> CommunityReport: + entity_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_TABLE}.parquet") + report_df = pd.read_parquet(f"{input_dir}/{consts.COMMUNITY_REPORT_TABLE}.parquet") + reports = read_indexer_reports(report_df, entity_df, consts.COMMUNITY_LEVEL) + # TODO optimize performance using like database or dict in memory + for report in reports: + if int(report.short_id) == row_id: + return report + raise ValueError(f"Not Found report id {row_id}") + + +async def get_relationship(input_dir: str, row_id: Optional[int] = None) -> Relationship: + relationship_df = pd.read_parquet(f"{input_dir}/{consts.RELATIONSHIP_TABLE}.parquet") + relationships = read_indexer_relationships(relationship_df) + # TODO optimize performance using like database or dict in memory + for relationship in relationships: + if int(relationship.short_id) == row_id: + return relationship + raise ValueError(f"Not Found relationship id {row_id}") diff --git a/webserver/search/localsearch.py b/webserver/search/localsearch.py new file mode 100644 index 0000000000..2812e5ef71 --- /dev/null +++ b/webserver/search/localsearch.py @@ -0,0 +1,225 @@ +import logging +import os +from pathlib import Path + +import pandas as pd +import tiktoken + +from graphrag.query.context_builder.builders import LocalContextBuilder +from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey +from graphrag.query.indexer_adapters import ( + read_indexer_covariates, + read_indexer_entities, + read_indexer_relationships, + read_indexer_reports, + read_indexer_text_units, read_indexer_report_embeddings, +) +from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding +from graphrag.query.question_gen.local_gen import LocalQuestionGen +from graphrag.query.structured_search.drift_search.drift_context import DRIFTSearchContextBuilder +from graphrag.query.structured_search.drift_search.search import DRIFTSearch +from graphrag.query.structured_search.local_search.mixed_context import ( + LocalSearchMixedContext, +) +from graphrag.query.structured_search.local_search.search import LocalSearch +from graphrag.vector_stores import VectorStoreType, VectorStoreFactory, BaseVectorStore +from webserver.configs import settings +from webserver.utils import consts + +logger = logging.getLogger(__name__) + + +async def load_local_context(input_dir: str, embedder: BaseTextEmbedding, + token_encoder: tiktoken.Encoding | None = None) -> LocalContextBuilder: + # read nodes table to get community and degree data + entity_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_TABLE}.parquet") + entity_embedding_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_EMBEDDING_TABLE}.parquet") + + entities = read_indexer_entities(entity_df, entity_embedding_df, consts.COMMUNITY_LEVEL) + + vector_store_type = settings.embeddings.vector_store.get("type", VectorStoreType.LanceDB) # type: ignore + vector_store_args = settings.embeddings.vector_store + if vector_store_type == VectorStoreType.LanceDB: + db_uri = settings.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(settings.data).parent.resolve() / db_uri + vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore + + description_embedding_store = _get_embedding_store( + config_args=vector_store_args, # type: ignore + container_suffix="entity-description", + ) + + relationship_df = pd.read_parquet(f"{input_dir}/{consts.RELATIONSHIP_TABLE}.parquet") + relationships = read_indexer_relationships(relationship_df) + + covariate_file = f"{input_dir}/{consts.COVARIATE_TABLE}.parquet" + if os.path.exists(covariate_file): + covariate_df = pd.read_parquet(covariate_file) + claims = read_indexer_covariates(covariate_df) + covariates = {"claims": claims} + else: + covariates = None + + report_df = pd.read_parquet(f"{input_dir}/{consts.COMMUNITY_REPORT_TABLE}.parquet") + reports = read_indexer_reports(report_df, entity_df, consts.COMMUNITY_LEVEL) + + text_unit_df = pd.read_parquet(f"{input_dir}/{consts.TEXT_UNIT_TABLE}.parquet") + text_units = read_indexer_text_units(text_unit_df) + + context_builder = LocalSearchMixedContext( + community_reports=reports, + text_units=text_units, + entities=entities, + relationships=relationships, + covariates=covariates, + entity_text_embeddings=description_embedding_store, + embedding_vectorstore_key=EntityVectorStoreKey.ID, + # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE + text_embedder=embedder, + token_encoder=token_encoder, + ) + return context_builder + + +def _get_embedding_store( + config_args: dict, + container_suffix: str, +) -> BaseVectorStore: + """Get the embedding description store.""" + vector_store_type = config_args["type"] + collection_name = ( + f"{config_args.get('container_name', 'default')}-{container_suffix}" + ) + embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, + kwargs={**config_args, "collection_name": collection_name}, + ) + embedding_store.connect(**config_args) + return embedding_store + + +async def build_local_question_gen(llm: BaseLLM, context_builder: LocalContextBuilder = None, + token_encoder: tiktoken.Encoding | None = None) -> LocalQuestionGen: + local_context_params = { + "text_unit_prop": settings.local_search.text_unit_prop, + "community_prop": settings.local_search.community_prop, + "conversation_history_max_turns": settings.local_search.conversation_history_max_turns, + "conversation_history_user_turns_only": True, + "top_k_mapped_entities": settings.local_search.top_k_entities, + "top_k_relationships": settings.local_search.top_k_relationships, + "include_entity_rank": True, + "include_relationship_weight": True, + "include_community_rank": False, + "return_candidate_context": False, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, + # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids + "max_tokens": settings.local_search.max_tokens, + } + + llm_params = { + "max_tokens": settings.local_search.llm_max_tokens, + "temperature": settings.local_search.temperature, + "top_p": settings.local_search.top_p, + "n": settings.local_search.n, + } + + question_generator = LocalQuestionGen( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=local_context_params, + ) + + return question_generator + + +async def build_local_search_engine(llm: BaseLLM, context_builder: LocalContextBuilder = None, + token_encoder: tiktoken.Encoding | None = None) -> LocalSearch: + local_context_params = { + "text_unit_prop": settings.local_search.text_unit_prop, + "community_prop": settings.local_search.community_prop, + "conversation_history_max_turns": settings.local_search.conversation_history_max_turns, + "conversation_history_user_turns_only": True, + "top_k_mapped_entities": settings.local_search.top_k_entities, + "top_k_relationships": settings.local_search.top_k_relationships, + "include_entity_rank": True, + "include_relationship_weight": True, + "include_community_rank": False, + "return_candidate_context": False, + "embedding_vectorstore_key": EntityVectorStoreKey.ID, + # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids + "max_tokens": settings.local_search.max_tokens, + } + + llm_params = { + "max_tokens": settings.local_search.llm_max_tokens, + "temperature": settings.local_search.temperature, + "top_p": settings.local_search.top_p, + "n": settings.local_search.n, + } + + search_engine = LocalSearch( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + llm_params=llm_params, + context_builder_params=local_context_params, + response_type="multiple paragraphs", + ) + return search_engine + + +async def build_drift_search_context(llm: BaseLLM, input_dir: str, + embedder: BaseTextEmbedding) -> DRIFTSearchContextBuilder: + # read nodes table to get community and degree data + entity_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_TABLE}.parquet") + entity_embedding_df = pd.read_parquet(f"{input_dir}/{consts.ENTITY_EMBEDDING_TABLE}.parquet") + entities = read_indexer_entities(entity_df, entity_embedding_df, consts.COMMUNITY_LEVEL) + + relationship_df = pd.read_parquet(f"{input_dir}/{consts.RELATIONSHIP_TABLE}.parquet") + relationships = read_indexer_relationships(relationship_df) + + report_df = pd.read_parquet(f"{input_dir}/{consts.COMMUNITY_REPORT_TABLE}.parquet") + reports = read_indexer_reports(report_df, entity_df, consts.COMMUNITY_LEVEL) + + text_unit_df = pd.read_parquet(f"{input_dir}/{consts.TEXT_UNIT_TABLE}.parquet") + text_units = read_indexer_text_units(text_unit_df) + + # vector store + vector_store_type = settings.embeddings.vector_store.get("type", VectorStoreType.LanceDB) # type: ignore + vector_store_args = settings.embeddings.vector_store + if vector_store_type == VectorStoreType.LanceDB: + db_uri = settings.embeddings.vector_store["db_uri"] # type: ignore + lancedb_dir = Path(settings.data).parent.resolve() / db_uri + vector_store_args["db_uri"] = str(lancedb_dir) # type: ignore + + description_embedding_store = _get_embedding_store( + config_args=vector_store_args, # type: ignore + container_suffix="entity-description", + ) + + full_content_embedding_store = _get_embedding_store( + config_args=vector_store_args, # type: ignore + container_suffix="community-full_content", + ) + read_indexer_report_embeddings(reports, full_content_embedding_store) + + return DRIFTSearchContextBuilder( + chat_llm=llm, + text_embedder=embedder, + entities=entities, + relationships=relationships, + reports=reports, + entity_text_embeddings=description_embedding_store, + text_units=text_units, + ) + + +async def build_drift_search_engine(llm: BaseLLM, context_builder: DRIFTSearchContextBuilder = None, + token_encoder: tiktoken.Encoding | None = None) -> DRIFTSearch: + return DRIFTSearch( + llm=llm, + context_builder=context_builder, + token_encoder=token_encoder, + ) diff --git a/webserver/static/wechat_qr_code.png b/webserver/static/wechat_qr_code.png new file mode 100644 index 0000000000..970b419fd2 Binary files /dev/null and b/webserver/static/wechat_qr_code.png differ diff --git a/webserver/templates/claims_template.html b/webserver/templates/claims_template.html new file mode 100644 index 0000000000..65de42d26c --- /dev/null +++ b/webserver/templates/claims_template.html @@ -0,0 +1,36 @@ + + + + Covariate Information + + + +

Claims Information

+ + + + + + + + +
AttributeValue
Subject ID{{ data.subject_id }}
Subject Type{{ data.subject_type }}
Covariate Type{{ data.covariate_type }}
Text Unit IDs{{ data.text_unit_ids }}
Document IDs{{ data.document_ids }}
Attributes{{ data.attributes }}
+ + \ No newline at end of file diff --git a/webserver/templates/entities_template.html b/webserver/templates/entities_template.html new file mode 100644 index 0000000000..c07190c622 --- /dev/null +++ b/webserver/templates/entities_template.html @@ -0,0 +1,41 @@ + + + + Entity Information + + + +

Entity Information

+ + + + + + + + + + + + + +
AttributeValue
Name{{ data.title }}
Type{{ data.type }}
Description{{ data.description }}
Description EmbeddingLength: {{ data.description_embedding | length }}
Name Embedding{{ data.name_embedding }}
Graph Embedding{{ data.graph_embedding }}
Community IDs{{ data.community_ids }}
Text Unit IDs{{ data.text_unit_ids }}
Document IDs{{ data.document_ids }}
Rank{{ data.rank }}
Attributes{{ data.attributes }}
+ + \ No newline at end of file diff --git a/webserver/templates/index.html b/webserver/templates/index.html new file mode 100644 index 0000000000..e4299bb574 --- /dev/null +++ b/webserver/templates/index.html @@ -0,0 +1,95 @@ + + + + GraphRAG Web Server Homepage + + + +
+

GraphRAG Web Server

+
+

邮箱: kose2livs@gmail.com

+

公众号: 深入LLM Agent应用开发

+ 微信公众号二维码 +
+
+

项目介绍:

+
    +
  • 1. 输入输出兼容OpenAI SDK
  • +
  • 2. 支持真流式输出,响应快速方便接入各种UI支持
  • +
  • 3. 支持自动获取GraphRAG索引文件
  • +
  • 4. 支持一键可视化到Neo4j
  • +
  • 5. [TODO] 集成huggingface embedding
  • +
  • 6. [TODO] 支持PDF输入
  • +
+
+ +
+ + \ No newline at end of file diff --git a/webserver/templates/relationships_template.html b/webserver/templates/relationships_template.html new file mode 100644 index 0000000000..554bcb61f3 --- /dev/null +++ b/webserver/templates/relationships_template.html @@ -0,0 +1,38 @@ + + + + Relationship Information + + + +

Relationship Information

+ + + + + + + + + + +
AttributeValue
Source{{ data.source }}
Target{{ data.target }}
Weight{{ data.weight }}
Description{{ data.description }}
Description Embedding{{ data.description_embedding }}
Text Unit IDs{{ data.text_unit_ids }}
Document IDs{{ data.document_ids }}
Attributes{{ data.attributes }}
+ + \ No newline at end of file diff --git a/webserver/templates/reports_template.html b/webserver/templates/reports_template.html new file mode 100644 index 0000000000..a8c8f70c4b --- /dev/null +++ b/webserver/templates/reports_template.html @@ -0,0 +1,38 @@ + + + + Community Report Information + + + +

Community Report Information

+ + + + + + + + + + +
AttributeValue
Name{{ data.title }}
Community ID{{ data.community_id }}
Summary{{ data.summary }}
Full Content{{ data.full_content }}
Rank{{ data.rank }}
Summary Embedding{{ data.summary_embedding }}
Full Content Embedding{{ data.full_content_embedding }}
Attributes{{ data.attributes }}
+ + \ No newline at end of file diff --git a/webserver/templates/sources_template.html b/webserver/templates/sources_template.html new file mode 100644 index 0000000000..82c8bbf39c --- /dev/null +++ b/webserver/templates/sources_template.html @@ -0,0 +1,38 @@ + + + + Source Information + + + +

Text Unit Information

+ + + + + + + + + + +
AttributeValue
Text{{ data.text }}
Text Embedding{{ data.text_embedding }}
Entity IDs{{ data.entity_ids }}
Relationship IDs{{ data.relationship_ids }}
Covariate IDs{{ data.covariate_ids }}
Number of Tokens{{ data.n_tokens }}
Document IDs{{ data.document_ids }}
Attributes{{ data.attributes }}
+ + \ No newline at end of file diff --git a/webserver/utils/__init__.py b/webserver/utils/__init__.py new file mode 100644 index 0000000000..7524dab6c1 --- /dev/null +++ b/webserver/utils/__init__.py @@ -0,0 +1 @@ +from .refer import get_reference, generate_ref_links \ No newline at end of file diff --git a/webserver/utils/consts.py b/webserver/utils/consts.py new file mode 100644 index 0000000000..93da76e784 --- /dev/null +++ b/webserver/utils/consts.py @@ -0,0 +1,16 @@ +# parquet files generated from indexing pipeline +COMMUNITY_REPORT_TABLE = "create_final_community_reports" +ENTITY_TABLE = "create_final_nodes" +ENTITY_EMBEDDING_TABLE = "create_final_entities" + +# community level in the Leiden community hierarchy from which we will load the community reports +# higher value means we use reports from more fine-grained communities (at the cost of higher computation cost) +COMMUNITY_LEVEL = 2 + +RELATIONSHIP_TABLE = "create_final_relationships" +COVARIATE_TABLE = "create_final_covariates" +TEXT_UNIT_TABLE = "create_final_text_units" + +INDEX_LOCAL = "local" +INDEX_GLOBAL = "global" +INDEX_DRIFT = "drift" diff --git a/webserver/utils/refer.py b/webserver/utils/refer.py new file mode 100644 index 0000000000..5f066570cb --- /dev/null +++ b/webserver/utils/refer.py @@ -0,0 +1,28 @@ +import re +from collections import defaultdict +from typing import Set, Dict + +from webserver.configs import settings + +pattern = re.compile(r'\[\^Data:(\w+?)\((\d+(?:,\d+)*)\)\]') + + +def get_reference(text: str) -> dict: + data_dict = defaultdict(set) + for match in pattern.finditer(text): + key = match.group(1).lower() + value = match.group(2) + + ids = value.replace(" ", "").split(',') + data_dict[key].update(ids) + + return dict(data_dict) + + +def generate_ref_links(data: Dict[str, Set[int]], index_id: str) -> str: + base_url = f"{settings.website_address}/v1/references" + lines = [] + for key, values in data.items(): + for value in values: + lines.append(f'[^Data:{key.capitalize()}({value})]: [{key.capitalize()}: {value}]({base_url}/{index_id}/{key}/{value})') + return "\n".join(lines)