diff --git a/.gitattributes b/.gitattributes index 865da2ca2d..e6436790e4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -14,6 +14,7 @@ *.ico binary *.jpeg binary *.mp3 binary +*.mp4 binary *.zip binary *.bin binary diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 4b155e5dc1..b82468eed0 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -60,6 +60,10 @@ iflytek_api_secret: "YOUR_API_SECRET" metagpt_tti_url: "YOUR_MODEL_URL" +omniparse: + api_key: "YOUR_API_KEY" + base_url: "YOUR_BASE_URL" + models: # "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo # api_type: "openai" # or azure / ollama / groq etc. diff --git a/examples/data/omniparse/test01.docx b/examples/data/omniparse/test01.docx new file mode 100644 index 0000000000..7b62517992 Binary files /dev/null and b/examples/data/omniparse/test01.docx differ diff --git a/examples/data/omniparse/test02.pdf b/examples/data/omniparse/test02.pdf new file mode 100644 index 0000000000..8cd15877f1 Binary files /dev/null and b/examples/data/omniparse/test02.pdf differ diff --git a/examples/data/omniparse/test03.mp4 b/examples/data/omniparse/test03.mp4 new file mode 100644 index 0000000000..54746f45dc Binary files /dev/null and b/examples/data/omniparse/test03.mp4 differ diff --git a/examples/data/omniparse/test04.mp3 b/examples/data/omniparse/test04.mp3 new file mode 100644 index 0000000000..2c8e149d8a Binary files /dev/null and b/examples/data/omniparse/test04.mp3 differ diff --git a/examples/rag/omniparse.py b/examples/rag/omniparse.py new file mode 100644 index 0000000000..b9159dae52 --- /dev/null +++ b/examples/rag/omniparse.py @@ -0,0 +1,64 @@ +import asyncio + +from metagpt.config2 import config +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.logs import logger +from metagpt.rag.parsers import OmniParse +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType +from metagpt.utils.omniparse_client import OmniParseClient + +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" + + +async def omniparse_client_example(): + client = OmniParseClient(base_url=config.omniparse.base_url) + + # docx + with open(TEST_DOCX, "rb") as f: + file_input = f.read() + document_parse_ret = await client.parse_document(file_input=file_input, bytes_filename="test_01.docx") + logger.info(document_parse_ret) + + # pdf + pdf_parse_ret = await client.parse_pdf(file_input=TEST_PDF) + logger.info(pdf_parse_ret) + + # video + video_parse_ret = await client.parse_video(file_input=TEST_VIDEO) + logger.info(video_parse_ret) + + # audio + audio_parse_ret = await client.parse_audio(file_input=TEST_AUDIO) + logger.info(audio_parse_ret) + + +async def omniparse_example(): + parser = OmniParse( + api_key=config.omniparse.api_key, + base_url=config.omniparse.base_url, + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ParseResultType.MD, + max_timeout=120, + num_workers=3, + ), + ) + ret = parser.load_data(file_path=TEST_PDF) + logger.info(ret) + + file_paths = [TEST_DOCX, TEST_PDF] + parser.parse_type = OmniParseType.DOCUMENT + ret = await parser.aload_data(file_path=file_paths) + logger.info(ret) + + +async def main(): + await omniparse_client_example() + await omniparse_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/rag_bm.py b/examples/rag/rag_bm.py similarity index 100% rename from examples/rag_bm.py rename to examples/rag/rag_bm.py diff --git a/examples/rag_pipeline.py b/examples/rag/rag_pipeline.py similarity index 100% rename from examples/rag_pipeline.py rename to examples/rag/rag_pipeline.py diff --git a/examples/rag_search.py b/examples/rag/rag_search.py similarity index 88% rename from examples/rag_search.py rename to examples/rag/rag_search.py index 258c5ba60f..3b0e047f81 100644 --- a/examples/rag_search.py +++ b/examples/rag/rag_search.py @@ -2,7 +2,7 @@ import asyncio -from examples.rag_pipeline import DOC_PATH, QUESTION +from examples.rag.rag_pipeline import DOC_PATH, QUESTION from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.roles import Sales diff --git a/metagpt/config2.py b/metagpt/config2.py index fff1799a73..27b228b336 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -13,6 +13,7 @@ from metagpt.configs.browser_config import BrowserConfig from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.file_parser_config import OmniParseConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -51,6 +52,9 @@ class Config(CLIParams, YamlModel): # RAG Embedding embedding: EmbeddingConfig = EmbeddingConfig() + # omniparse + omniparse: OmniParseConfig = OmniParseConfig() + # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/configs/file_parser_config.py b/metagpt/configs/file_parser_config.py new file mode 100644 index 0000000000..39742c8a4d --- /dev/null +++ b/metagpt/configs/file_parser_config.py @@ -0,0 +1,6 @@ +from metagpt.utils.yaml_model import YamlModel + + +class OmniParseConfig(YamlModel): + api_key: str = "" + base_url: str = "" diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index c237dcf69c..a03e0149cc 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -14,6 +14,7 @@ from llama_index.core.node_parser import SentenceSplitter from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.readers.base import BaseReader from llama_index.core.response_synthesizers import ( BaseSynthesizer, get_response_synthesizer, @@ -28,6 +29,7 @@ TransformComponent, ) +from metagpt.config2 import config from metagpt.rag.factories import ( get_index, get_rag_embedding, @@ -36,6 +38,7 @@ get_retriever, ) from metagpt.rag.interface import NoEmbedding, RAGObject +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever from metagpt.rag.schema import ( @@ -44,6 +47,9 @@ BaseRetrieverConfig, BM25RetrieverConfig, ObjectNode, + OmniParseOptions, + OmniParseType, + ParseResultType, ) from metagpt.utils.common import import_class @@ -100,7 +106,10 @@ def from_docs( if not input_dir and not input_files: raise ValueError("Must provide either `input_dir` or `input_files`.") - documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data() + file_extractor = cls._get_file_extractor() + documents = SimpleDirectoryReader( + input_dir=input_dir, input_files=input_files, file_extractor=file_extractor + ).load_data() cls._fix_document_metadata(documents) transformations = transformations or cls._default_transformations() @@ -301,3 +310,23 @@ def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @staticmethod + def _get_file_extractor() -> dict[str:BaseReader]: + """ + Get the file extractor. + Currently, only PDF use OmniParse. Other document types use the built-in reader from llama_index. + + Returns: + dict[file_type: BaseReader] + """ + file_extractor: dict[str:BaseReader] = {} + if config.omniparse.base_url: + pdf_parser = OmniParse( + api_key=config.omniparse.api_key, + base_url=config.omniparse.base_url, + parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD), + ) + file_extractor[".pdf"] = pdf_parser + + return file_extractor diff --git a/metagpt/rag/parsers/__init__.py b/metagpt/rag/parsers/__init__.py new file mode 100644 index 0000000000..03ac0de3ab --- /dev/null +++ b/metagpt/rag/parsers/__init__.py @@ -0,0 +1,3 @@ +from metagpt.rag.parsers.omniparse import OmniParse + +__all__ = ["OmniParse"] diff --git a/metagpt/rag/parsers/omniparse.py b/metagpt/rag/parsers/omniparse.py new file mode 100644 index 0000000000..ec08e38f15 --- /dev/null +++ b/metagpt/rag/parsers/omniparse.py @@ -0,0 +1,139 @@ +import asyncio +from fileinput import FileInput +from pathlib import Path +from typing import List, Optional, Union + +from llama_index.core import Document +from llama_index.core.async_utils import run_jobs +from llama_index.core.readers.base import BaseReader + +from metagpt.logs import logger +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType +from metagpt.utils.async_helper import NestAsyncio +from metagpt.utils.omniparse_client import OmniParseClient + + +class OmniParse(BaseReader): + """OmniParse""" + + def __init__( + self, api_key: str = None, base_url: str = "http://localhost:8000", parse_options: OmniParseOptions = None + ): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: OmniParse Base URL for the API. + parse_options: Optional settings for OmniParse. Default is OmniParseOptions with default values. + """ + self.parse_options = parse_options or OmniParseOptions() + self.omniparse_client = OmniParseClient(api_key, base_url, max_timeout=self.parse_options.max_timeout) + + @property + def parse_type(self): + return self.parse_options.parse_type + + @property + def result_type(self): + return self.parse_options.result_type + + @parse_type.setter + def parse_type(self, parse_type: Union[str, OmniParseType]): + if isinstance(parse_type, str): + parse_type = OmniParseType(parse_type) + self.parse_options.parse_type = parse_type + + @result_type.setter + def result_type(self, result_type: Union[str, ParseResultType]): + if isinstance(result_type, str): + result_type = ParseResultType(result_type) + self.parse_options.result_type = result_type + + async def _aload_data( + self, + file_path: Union[str, bytes, Path], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Returns: + List[Document] + """ + try: + if self.parse_type == OmniParseType.PDF: + # pdf parse + parsed_result = await self.omniparse_client.parse_pdf(file_path) + else: + # other parse use omniparse_client.parse_document + # For compatible byte data, additional filename is required + extra_info = extra_info or {} + filename = extra_info.get("filename") + parsed_result = await self.omniparse_client.parse_document(file_path, bytes_filename=filename) + + # Get the specified structured data based on result_type + content = getattr(parsed_result, self.result_type) + docs = [ + Document( + text=content, + metadata=extra_info or {}, + ) + ] + except Exception as e: + logger.error(f"OMNI Parse Error: {e}") + docs = [] + + return docs + + async def aload_data( + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls _aload_data for processing. + + Returns: + List[Document] + """ + docs = [] + if isinstance(file_path, (str, bytes, Path)): + # Processing single file + docs = await self._aload_data(file_path, extra_info) + elif isinstance(file_path, list): + # Concurrently process multiple files + parse_jobs = [self._aload_data(file_item, extra_info) for file_item in file_path] + doc_ret_list = await run_jobs(jobs=parse_jobs, workers=self.parse_options.num_workers) + docs = [doc for docs in doc_ret_list for doc in docs] + return docs + + def load_data( + self, + file_path: Union[List[FileInput], FileInput], + extra_info: Optional[dict] = None, + ) -> List[Document]: + """ + Load data from the input file_path. + + Args: + file_path: File path or file byte data. + extra_info: Optional dictionary containing additional information. + + Notes: + This method ultimately calls aload_data for processing. + + Returns: + List[Document] + """ + NestAsyncio.apply_once() # Ensure compatibility with nested async calls + return asyncio.run(self.aload_data(file_path, extra_info)) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 618880a22f..a8a10f90e0 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,7 +1,7 @@ """RAG schemas.""" - +from enum import Enum from pathlib import Path -from typing import Any, ClassVar, Literal, Optional, Union +from typing import Any, ClassVar, List, Literal, Optional, Union from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding @@ -214,3 +214,51 @@ def get_obj_metadata(obj: RAGObject) -> dict: ) return metadata.model_dump() + + +class OmniParseType(str, Enum): + """OmniParseType""" + + PDF = "PDF" + DOCUMENT = "DOCUMENT" + + +class ParseResultType(str, Enum): + """The result type for the parser.""" + + TXT = "text" + MD = "markdown" + JSON = "json" + + +class OmniParseOptions(BaseModel): + """OmniParse Options config""" + + result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type") + parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type") + max_timeout: Optional[int] = Field(default=120, description="Maximum timeout for OmniParse service requests") + num_workers: int = Field( + default=5, + gt=0, + lt=10, + description="Number of concurrent requests for multiple files", + ) + + +class OminParseImage(BaseModel): + image: str = Field(default="", description="image str bytes") + image_name: str = Field(default="", description="image name") + image_info: Optional[dict] = Field(default={}, description="image info") + + +class OmniParsedResult(BaseModel): + markdown: str = Field(default="", description="markdown text") + text: str = Field(default="", description="plain text") + images: Optional[List[OminParseImage]] = Field(default=[], description="images") + metadata: Optional[dict] = Field(default={}, description="metadata") + + @model_validator(mode="before") + def set_markdown(cls, values): + if not values.get("markdown"): + values["markdown"] = values.get("text") + return values diff --git a/metagpt/utils/omniparse_client.py b/metagpt/utils/omniparse_client.py new file mode 100644 index 0000000000..e7c5a3d445 --- /dev/null +++ b/metagpt/utils/omniparse_client.py @@ -0,0 +1,239 @@ +import mimetypes +import os +from pathlib import Path +from typing import Union + +import httpx + +from metagpt.rag.schema import OmniParsedResult +from metagpt.utils.common import aread_bin + + +class OmniParseClient: + """ + OmniParse Server Client + This client interacts with the OmniParse server to parse different types of media, documents. + + OmniParse API Documentation: https://docs.cognitivelab.in/api + + Attributes: + ALLOWED_DOCUMENT_EXTENSIONS (set): A set of supported document file extensions. + ALLOWED_AUDIO_EXTENSIONS (set): A set of supported audio file extensions. + ALLOWED_VIDEO_EXTENSIONS (set): A set of supported video file extensions. + """ + + ALLOWED_DOCUMENT_EXTENSIONS = {".pdf", ".ppt", ".pptx", ".doc", ".docx"} + ALLOWED_AUDIO_EXTENSIONS = {".mp3", ".wav", ".aac"} + ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".mkv", ".avi", ".mov"} + + def __init__(self, api_key: str = None, base_url: str = "http://localhost:8000", max_timeout: int = 120): + """ + Args: + api_key: Default None, can be used for authentication later. + base_url: Base URL for the API. + max_timeout: Maximum request timeout in seconds. + """ + self.api_key = api_key + self.base_url = base_url + self.max_timeout = max_timeout + + self.parse_media_endpoint = "/parse_media" + self.parse_website_endpoint = "/parse_website" + self.parse_document_endpoint = "/parse_document" + + async def _request_parse( + self, + endpoint: str, + method: str = "POST", + files: dict = None, + params: dict = None, + data: dict = None, + json: dict = None, + headers: dict = None, + **kwargs, + ) -> dict: + """ + Request OmniParse API to parse a document. + + Args: + endpoint (str): API endpoint. + method (str, optional): HTTP method to use. Default is "POST". + files (dict, optional): Files to include in the request. + params (dict, optional): Query string parameters. + data (dict, optional): Form data to include in the request body. + json (dict, optional): JSON data to include in the request body. + headers (dict, optional): HTTP headers to include in the request. + **kwargs: Additional keyword arguments for httpx.AsyncClient.request() + + Returns: + dict: JSON response data. + """ + url = f"{self.base_url}{endpoint}" + method = method.upper() + headers = headers or {} + _headers = {"Authorization": f"Bearer {self.api_key}"} if self.api_key else {} + headers.update(**_headers) + async with httpx.AsyncClient() as client: + response = await client.request( + url=url, + method=method, + files=files, + params=params, + json=json, + data=data, + headers=headers, + timeout=self.max_timeout, + **kwargs, + ) + response.raise_for_status() + return response.json() + + async def parse_document(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> OmniParsedResult: + """ + Parse document-type data (supports ".pdf", ".ppt", ".pptx", ".doc", ".docx"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the document parsing. + """ + self.verify_file_ext(file_input, self.ALLOWED_DOCUMENT_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + resp = await self._request_parse(self.parse_document_endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_pdf(self, file_input: Union[str, bytes, Path]) -> OmniParsedResult: + """ + Parse pdf document. + + Args: + file_input: File path or file byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + OmniParsedResult: The result of the pdf parsing. + """ + self.verify_file_ext(file_input, {".pdf"}) + # parse_pdf supports parsing by accepting only the byte data of the file. + file_info = await self.get_file_info(file_input, only_bytes=True) + endpoint = f"{self.parse_document_endpoint}/pdf" + resp = await self._request_parse(endpoint=endpoint, files={"file": file_info}) + data = OmniParsedResult(**resp) + return data + + async def parse_video(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse video-type data (supports ".mp4", ".mkv", ".avi", ".mov"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_VIDEO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/video", files={"file": file_info}) + + async def parse_audio(self, file_input: Union[str, bytes, Path], bytes_filename: str = None) -> dict: + """ + Parse audio-type data (supports ".mp3", ".wav", ".aac"). + + Args: + file_input: File path or file byte data. + bytes_filename: Filename for byte data, useful for determining MIME type for the HTTP request. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + dict: JSON response data. + """ + self.verify_file_ext(file_input, self.ALLOWED_AUDIO_EXTENSIONS, bytes_filename) + file_info = await self.get_file_info(file_input, bytes_filename) + return await self._request_parse(f"{self.parse_media_endpoint}/audio", files={"file": file_info}) + + @staticmethod + def verify_file_ext(file_input: Union[str, bytes, Path], allowed_file_extensions: set, bytes_filename: str = None): + """ + Verify the file extension. + + Args: + file_input: File path or file byte data. + allowed_file_extensions: Set of allowed file extensions. + bytes_filename: Filename to use for verification when `file_input` is byte data. + + Raises: + ValueError: If the file extension is not allowed. + + Returns: + """ + verify_file_path = None + if isinstance(file_input, (str, Path)): + verify_file_path = str(file_input) + elif isinstance(file_input, bytes) and bytes_filename: + verify_file_path = bytes_filename + + if not verify_file_path: + # Do not verify if only byte data is provided + return + + file_ext = os.path.splitext(verify_file_path)[1].lower() + if file_ext not in allowed_file_extensions: + raise ValueError(f"Not allowed {file_ext} File extension must be one of {allowed_file_extensions}") + + @staticmethod + async def get_file_info( + file_input: Union[str, bytes, Path], + bytes_filename: str = None, + only_bytes: bool = False, + ) -> Union[bytes, tuple]: + """ + Get file information. + + Args: + file_input: File path or file byte data. + bytes_filename: Filename to use when uploading byte data, useful for determining MIME type. + only_bytes: Whether to return only byte data. Default is False, which returns a tuple. + + Raises: + ValueError: If bytes_filename is not provided when file_input is bytes or if file_input is not a valid type. + + Notes: + Since `parse_document`,`parse_video`, `parse_audio` supports parsing various file types, + the MIME type of the file must be specified when uploading. + + Returns: [bytes, tuple] + Returns bytes if only_bytes is True, otherwise returns a tuple (filename, file_bytes, mime_type). + """ + if isinstance(file_input, (str, Path)): + filename = os.path.basename(str(file_input)) + file_bytes = await aread_bin(file_input) + + if only_bytes: + return file_bytes + + mime_type = mimetypes.guess_type(file_input)[0] + return filename, file_bytes, mime_type + elif isinstance(file_input, bytes): + if only_bytes: + return file_input + if not bytes_filename: + raise ValueError("bytes_filename must be set when passing bytes") + + mime_type = mimetypes.guess_type(bytes_filename)[0] + return bytes_filename, file_input, mime_type + else: + raise ValueError("file_input must be a string (file path) or bytes.") diff --git a/tests/metagpt/rag/engines/test_simple.py b/tests/metagpt/rag/engines/test_simple.py index 8c7a15be2e..a10fcbe639 100644 --- a/tests/metagpt/rag/engines/test_simple.py +++ b/tests/metagpt/rag/engines/test_simple.py @@ -7,6 +7,7 @@ from llama_index.core.schema import Document, NodeWithScore, TextNode from metagpt.rag.engines import SimpleEngine +from metagpt.rag.parsers import OmniParse from metagpt.rag.retrievers import SimpleHybridRetriever from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode @@ -37,6 +38,10 @@ def mock_get_rankers(self, mocker): def mock_get_response_synthesizer(self, mocker): return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer") + @pytest.fixture + def mock_get_file_extractor(self, mocker): + return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor") + def test_from_docs( self, mocker, @@ -44,6 +49,7 @@ def test_from_docs( mock_get_retriever, mock_get_rankers, mock_get_response_synthesizer, + mock_get_file_extractor, ): # Mock mock_simple_directory_reader.return_value.load_data.return_value = [ @@ -53,6 +59,8 @@ def test_from_docs( mock_get_retriever.return_value = mocker.MagicMock() mock_get_rankers.return_value = [mocker.MagicMock()] mock_get_response_synthesizer.return_value = mocker.MagicMock() + file_extractor = mocker.MagicMock() + mock_get_file_extractor.return_value = file_extractor # Setup input_dir = "test_dir" @@ -75,7 +83,9 @@ def test_from_docs( ) # Assert - mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files) + mock_simple_directory_reader.assert_called_once_with( + input_dir=input_dir, input_files=input_files, file_extractor=file_extractor + ) mock_get_retriever.assert_called_once() mock_get_rankers.assert_called_once() mock_get_response_synthesizer.assert_called_once_with(llm=llm) @@ -298,3 +308,17 @@ def __eq__(self, other): # Assert assert "obj" in node.node.metadata assert node.node.metadata["obj"] == expected_obj + + def test_get_file_extractor(self, mocker): + # mock no omniparse config + mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True) + mock_omniparse_config.base_url = "" + + file_extractor = SimpleEngine._get_file_extractor() + assert file_extractor == {} + + # mock have omniparse config + mock_omniparse_config.base_url = "http://localhost:8000" + file_extractor = SimpleEngine._get_file_extractor() + assert ".pdf" in file_extractor + assert isinstance(file_extractor[".pdf"], OmniParse) diff --git a/tests/metagpt/rag/parser/test_omniparse.py b/tests/metagpt/rag/parser/test_omniparse.py new file mode 100644 index 0000000000..d2b533d061 --- /dev/null +++ b/tests/metagpt/rag/parser/test_omniparse.py @@ -0,0 +1,118 @@ +import pytest +from llama_index.core import Document + +from metagpt.const import EXAMPLE_DATA_PATH +from metagpt.rag.parsers import OmniParse +from metagpt.rag.schema import ( + OmniParsedResult, + OmniParseOptions, + OmniParseType, + ParseResultType, +) +from metagpt.utils.omniparse_client import OmniParseClient + +# test data +TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx" +TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf" +TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4" +TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3" + + +class TestOmniParseClient: + parse_client = OmniParseClient() + + @pytest.fixture + def mock_request_parse(self, mocker): + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_parse_pdf(self, mock_request_parse): + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + parse_ret = await self.parse_client.parse_pdf(TEST_PDF) + assert parse_ret == mock_parsed_ret + + @pytest.mark.asyncio + async def test_parse_document(self, mock_request_parse): + mock_content = "#test title\ntest_parse_document" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + with open(TEST_DOCX, "rb") as f: + file_bytes = f.read() + + with pytest.raises(ValueError): + # bytes data must provide bytes_filename + await self.parse_client.parse_document(file_bytes) + + parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx") + assert parse_ret == mock_parsed_ret + + @pytest.mark.asyncio + async def test_parse_video(self, mock_request_parse): + mock_content = "#test title\ntest_parse_video" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + with pytest.raises(ValueError): + # Wrong file extension test + await self.parse_client.parse_video(TEST_DOCX) + + parse_ret = await self.parse_client.parse_video(TEST_VIDEO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + + @pytest.mark.asyncio + async def test_parse_audio(self, mock_request_parse): + mock_content = "#test title\ntest_parse_audio" + mock_request_parse.return_value = { + "text": mock_content, + "metadata": {}, + } + parse_ret = await self.parse_client.parse_audio(TEST_AUDIO) + assert "text" in parse_ret and "metadata" in parse_ret + assert parse_ret["text"] == mock_content + + +class TestOmniParse: + @pytest.fixture + def mock_omniparse(self): + parser = OmniParse( + parse_options=OmniParseOptions( + parse_type=OmniParseType.PDF, + result_type=ParseResultType.MD, + max_timeout=120, + num_workers=3, + ) + ) + return parser + + @pytest.fixture + def mock_request_parse(self, mocker): + return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse") + + @pytest.mark.asyncio + async def test_load_data(self, mock_omniparse, mock_request_parse): + # mock + mock_content = "#test title\ntest content" + mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content) + mock_request_parse.return_value = mock_parsed_ret.model_dump() + + # single file + documents = mock_omniparse.load_data(file_path=TEST_PDF) + doc = documents[0] + assert isinstance(doc, Document) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown + + # multi files + file_paths = [TEST_DOCX, TEST_PDF] + mock_omniparse.parse_type = OmniParseType.DOCUMENT + documents = await mock_omniparse.aload_data(file_path=file_paths) + doc = documents[0] + + # assert + assert isinstance(doc, Document) + assert len(documents) == len(file_paths) + assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown