diff --git a/Dockerfile b/Dockerfile index 4a5e73c0..fb974f0d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,8 @@ FROM $BASE_IMG:$BASE_IMG_TAG AS base ARG RELEASE_TYPE="dev" ARG VERSION="" ARG VERSION_REV="0" +ARG DOWNLOAD_LLAMA_TOKENIZER="" +ARG HF_ACCESS_TOKEN="" # Embed the `git rev-parse HEAD` as a Docker metadata label # Allows for linking container builds to git commits @@ -71,6 +73,9 @@ WORKDIR /workspace # Copy custom entrypoint script COPY ./docker/scripts/entrypoint.sh /workspace/docker/entrypoint.sh +# Copy post build triggers script +COPY ./docker/scripts/post_build_triggers.py /workspace/docker/post_build_triggers.py + FROM base AS nv_ingest_install # Copy the module code COPY setup.py setup.py @@ -124,6 +129,12 @@ RUN --mount=type=cache,target=/opt/conda/pkgs\ && pip install ./api/dist/*.whl \ && pip install ./client/dist/*.whl + +RUN --mount=type=cache,target=/opt/conda/pkgs \ + --mount=type=cache,target=/root/.cache/pip \ + source activate nv_ingest_runtime \ + && python3 /workspace/docker/post_build_triggers.py + RUN rm -rf src FROM nv_ingest_install AS runtime diff --git a/README.md b/README.md index 67a27293..9aad924f 100644 --- a/README.md +++ b/README.md @@ -411,6 +411,12 @@ https://pypi.org/project/pdfservices-sdk/ required if you want to use the Adobe extraction service for PDF decomposition. Please review the [license agreement](https://github.com/adobe/pdfservices-python-sdk?tab=License-1-ov-file) for the pdfservices-sdk before enabling this option. +- **`DOWNLOAD_LLAMA_TOKENIZER` (Built With Llama):**: + - **Description**: The Split task uses the `meta-llama/Llama-3.2-1B` tokenizer, which will be downloaded + from HuggingFace at build time if `DOWNLOAD_LLAMA_TOKENIZER` is set to `True`. Please review the + [license agreement](https://huggingface.co/meta-llama/Llama-3.2-1B) for Llama 3.2 materials before using this. + This is a gated model so you'll need to [request access](https://huggingface.co/meta-llama/Llama-3.2-1B) and + set `HF_ACCESS_TOKEN` to your HuggingFace access token in order to use it. ### Contributing diff --git a/client/src/nv_ingest_client/primitives/tasks/split.py b/client/src/nv_ingest_client/primitives/tasks/split.py index 7bf63dea..12f1b2d7 100644 --- a/client/src/nv_ingest_client/primitives/tasks/split.py +++ b/client/src/nv_ingest_client/primitives/tasks/split.py @@ -8,10 +8,8 @@ import logging from typing import Dict -from typing import Literal -from typing import Optional -from pydantic import BaseModel, field_validator +from pydantic import BaseModel from .task_base import Task @@ -19,18 +17,10 @@ class SplitTaskSchema(BaseModel): - split_by: Optional[str] = "sentence" - split_length: Optional[int] = 10 - split_overlap: Optional[int] = 0 - max_character_length: Optional[int] = 1024 - sentence_window_size: Optional[int] = 0 - - @field_validator("split_by") - def split_by_must_be_valid(cls, v): - valid_criteria = ["page", "size", "word", "sentence"] - if v not in valid_criteria: - raise ValueError(f"split_by must be one of {valid_criteria}") - return v + tokenizer: str = "meta-llama/Llama-3.2-1B" + chunk_size: int = 1024 + chunk_overlap: int = 150 + params: dict = {} class Config: extra = "forbid" @@ -41,25 +31,21 @@ class SplitTask(Task): Object for document splitting task """ - _TypeSplitBy = Literal["word", "sentence", "passage"] - def __init__( self, - split_by: _TypeSplitBy = None, - split_length: int = None, - split_overlap: int = None, - max_character_length: int = None, - sentence_window_size: int = None, + tokenizer: str = "meta-llama/Llama-3.2-1B", + chunk_size: int = 1024, + chunk_overlap: int = 150, + params: dict = {}, ) -> None: """ Setup Split Task Config """ super().__init__() - self._split_by = split_by - self._split_length = split_length - self._split_overlap = split_overlap - self._max_character_length = max_character_length - self._sentence_window_size = sentence_window_size + self._tokenizer = tokenizer + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._params = params def __str__(self) -> str: """ @@ -67,11 +53,11 @@ def __str__(self) -> str: """ info = "" info += "Split Task:\n" - info += f" split_by: {self._split_by}\n" - info += f" split_length: {self._split_length}\n" - info += f" split_overlap: {self._split_overlap}\n" - info += f" split_max_character_length: {self._max_character_length}\n" - info += f" split_sentence_window_size: {self._sentence_window_size}\n" + info += f" tokenizer: {self._tokenizer}\n" + info += f" chunk_size: {self._chunk_size}\n" + info += f" chunk_overlap: {self._chunk_overlap}\n" + for key, value in self._params.items(): + info += f" {key}: {value}\n" return info def to_dict(self) -> Dict: @@ -80,15 +66,13 @@ def to_dict(self) -> Dict: """ split_params = {} - if self._split_by is not None: - split_params["split_by"] = self._split_by - if self._split_length is not None: - split_params["split_length"] = self._split_length - if self._split_overlap is not None: - split_params["split_overlap"] = self._split_overlap - if self._max_character_length is not None: - split_params["max_character_length"] = self._max_character_length - if self._sentence_window_size is not None: - split_params["sentence_window_size"] = self._sentence_window_size + if self._tokenizer is not None: + split_params["tokenizer"] = self._tokenizer + if self._chunk_size is not None: + split_params["chunk_size"] = self._chunk_size + if self._chunk_overlap is not None: + split_params["chunk_overlap"] = self._chunk_overlap + if self._params is not None: + split_params["params"] = self._params return {"type": "split", "task_properties": split_params} diff --git a/docker-compose.yaml b/docker-compose.yaml index 44ed1107..451fc280 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -155,6 +155,9 @@ services: context: ${NV_INGEST_ROOT:-.} dockerfile: "./Dockerfile" target: runtime + args: + DOWNLOAD_LLAMA_TOKENIZER: ${DOWNLOAD_LLAMA_TOKENIZER:-False} + HF_ACCESS_TOKEN: ${HF_ACCESS_TOKEN:-hfaccesstoken} volumes: - ${DATASET_ROOT:-./data}:/workspace/data ports: diff --git a/docker/scripts/post_build_triggers.py b/docker/scripts/post_build_triggers.py new file mode 100644 index 00000000..676139ab --- /dev/null +++ b/docker/scripts/post_build_triggers.py @@ -0,0 +1,9 @@ +import os +from transformers import AutoTokenizer + +if os.getenv("DOWNLOAD_LLAMA_TOKENIZER") == "True": + tokenizer_path = "/workspace/models/tokenizer/" + os.makedirs(tokenizer_path, exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", token=os.getenv("HF_ACCESS_TOKEN")) + tokenizer.save_pretrained(tokenizer_path) diff --git a/src/nv_ingest/modules/transforms/__init__.py b/src/nv_ingest/modules/transforms/__init__.py index 4a39c32a..941cd120 100644 --- a/src/nv_ingest/modules/transforms/__init__.py +++ b/src/nv_ingest/modules/transforms/__init__.py @@ -3,6 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 from .associate_nearby_text import AssociateNearbyTextLoaderFactory -from .nemo_doc_splitter import NemoDocSplitterLoaderFactory +from .text_splitter import TextSplitterLoaderFactory -__all__ = ["NemoDocSplitterLoaderFactory", "AssociateNearbyTextLoaderFactory"] +__all__ = ["TextSplitterLoaderFactory", "AssociateNearbyTextLoaderFactory"] diff --git a/src/nv_ingest/modules/transforms/text_splitter.py b/src/nv_ingest/modules/transforms/text_splitter.py new file mode 100644 index 00000000..a6e09939 --- /dev/null +++ b/src/nv_ingest/modules/transforms/text_splitter.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import copy +import logging +import traceback +import uuid +from typing import Any +from typing import List + +import mrc +import pandas as pd +from transformers import AutoTokenizer +from morpheus.messages import ControlMessage +from morpheus.messages import MessageMeta +from morpheus.utils.control_message_utils import cm_skip_processing_if_failed +from morpheus.utils.module_utils import ModuleLoaderFactory +from morpheus.utils.module_utils import register_module +from mrc.core import operators as ops +from pydantic import BaseModel + +import cudf + +from nv_ingest.schemas.metadata_schema import ContentTypeEnum +from nv_ingest.schemas.text_splitter_schema import TextSplitterSchema +from nv_ingest.util.exception_handlers.decorators import nv_ingest_node_failure_context_manager +from nv_ingest.util.flow_control import filter_by_task +from nv_ingest.util.modules.config_validator import fetch_and_validate_module_config +from nv_ingest.util.tracing import traceable + +logger = logging.getLogger(__name__) + + +def _build_split_documents(row, chunks: List[str]) -> List[dict[str, Any]]: + """Build documents from text chunks""" + documents: List[dict] = [] + + for i, text in enumerate(chunks): + if text is None or not text.strip(): + continue + + metadata = row.metadata if hasattr(row, "metadata") and isinstance(row.metadata, dict) else {} + metadata = copy.deepcopy(metadata) + + metadata["content"] = text + + documents.append({"document_type": ContentTypeEnum.TEXT.value, "metadata": metadata, "uuid": str(uuid.uuid4())}) + + return documents + + +def _split_into_chunks(text, tokenizer, chunk_size=1024, chunk_overlap=20): + # Tokenize the text into token IDs + encoding = tokenizer.encode_plus(text, add_special_tokens=False, return_offsets_mapping=True) + + # Get the token IDs and offsets for splitting + offsets = encoding["offset_mapping"] + + # Split the tokens into chunks of the desired size with the desired overlap + chunks = [offsets[i : i + chunk_size] for i in range(0, len(offsets), chunk_size - chunk_overlap)] + + # Convert token chunks back to text while preserving original spacing and case + text_chunks = [] + for chunk in chunks: + text_chunk = text[chunk[0][0] : chunk[-1][0]] + text_chunks.append(text_chunk) + + return text_chunks + + +MODULE_NAME = "text_splitter" +MODULE_NAMESPACE = "nv_ingest" + +TextSplitterLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE, TextSplitterSchema) + + +@register_module(MODULE_NAME, MODULE_NAMESPACE) +def _text_splitter(builder: mrc.Builder): + """ + A pipeline module that splits documents into smaller parts based on the specified criteria. + """ + + validated_config = fetch_and_validate_module_config(builder, TextSplitterSchema) + + @filter_by_task(["split"]) + @traceable(MODULE_NAME) + @cm_skip_processing_if_failed + @nv_ingest_node_failure_context_manager( + annotation_id=MODULE_NAME, + raise_on_failure=validated_config.raise_on_failure, + ) + def split_and_forward(message: ControlMessage): + try: + # Assume that df is going to have a 'content' column + task_props = message.remove_task("split") + + if isinstance(task_props, BaseModel): + task_props = task_props.model_dump() + + # Validate that all 'content' values are not None + with message.payload().mutable_dataframe() as mdf: + df = mdf.to_pandas() + + # Filter to document type + bool_index = df["document_type"] == ContentTypeEnum.TEXT + df_filtered = df.loc[bool_index] + + if df_filtered.empty: + return message + + # Override parameters if set + tokenizer = task_props.get("tokenizer", validated_config.tokenizer) + chunk_size = task_props.get("chunk_size", validated_config.chunk_size) + chunk_overlap = task_props.get("chunk_overlap", validated_config.chunk_overlap) + params = task_props.get("params", {}) + + hf_access_token = params.get("hf_access_token", None) + split_source_types = params.get("split_source_types", ["TEXT"]) + + logger.debug( + f"Splitting text with tokenizer: {tokenizer}, " + f"chunk_size: {chunk_size} tokens, " + f"chunk_overlap: {chunk_overlap}" + ) + + # Filter to file type + bool_index = pd.json_normalize(df_filtered["metadata"])["source_metadata.source_type"].isin( + split_source_types + ) + df_filtered = df_filtered.loc[bool_index] + + if df_filtered.empty: + return message + + tokenizer_model = AutoTokenizer.from_pretrained(tokenizer, token=hf_access_token) + + split_docs = [] + for _, row in df_filtered.iterrows(): + content = row["metadata"]["content"] if row["metadata"]["content"] is not None else "" + + chunks = _split_into_chunks(content, tokenizer_model, chunk_size, chunk_overlap) + split_docs.extend(_build_split_documents(row, chunks)) + + split_docs_df = pd.DataFrame(split_docs) + + # Return both processed text and other document types + split_docs_df = pd.concat([split_docs_df, df[~bool_index]], axis=0).reset_index(drop=True) + # Update control message with new payload + split_docs_gdf = cudf.from_pandas(split_docs_df) + + message_meta = MessageMeta(df=split_docs_gdf) + message.payload(message_meta) + + return message + except Exception as e: + traceback.print_exc() + raise ValueError(f"Failed to split documents: {e}") + + split_node = builder.make_node("split_and_forward", ops.map(split_and_forward)) + + # Register the input and output of the module + builder.register_module_input("input", split_node) + builder.register_module_output("output", split_node) diff --git a/src/nv_ingest/schemas/__init__.py b/src/nv_ingest/schemas/__init__.py index 43a94055..f3ff4659 100644 --- a/src/nv_ingest/schemas/__init__.py +++ b/src/nv_ingest/schemas/__init__.py @@ -13,13 +13,13 @@ from .message_broker_source_schema import MessageBrokerTaskSourceSchema from .metadata_injector_schema import MetadataInjectorSchema from .metadata_schema import validate_metadata -from .nemo_doc_splitter_schema import DocumentSplitterSchema +from .text_splitter_schema import TextSplitterSchema from .pdf_extractor_schema import PDFExtractorSchema from .task_injection_schema import TaskInjectionSchema from .vdb_task_sink_schema import VdbTaskSinkSchema __all__ = [ - "DocumentSplitterSchema", + "TextSplitterSchema", "ImageCaptionExtractionSchema", "ImageStorageModuleSchema", "IngestJobSchema", diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py index 0bddf825..0341394f 100644 --- a/src/nv_ingest/schemas/ingest_job_schema.py +++ b/src/nv_ingest/schemas/ingest_job_schema.py @@ -8,7 +8,6 @@ from typing import Any from typing import Dict from typing import List -from typing import Literal from typing import Optional from typing import Union @@ -60,16 +59,15 @@ class TracingOptionsSchema(BaseModelNoExt): class IngestTaskSplitSchema(BaseModelNoExt): - split_by: Literal["word", "sentence", "passage"] - split_length: Annotated[int, Field(gt=0)] - split_overlap: Annotated[int, Field(ge=0)] - max_character_length: Optional[Annotated[int, Field(gt=0)]] = None - sentence_window_size: Optional[Annotated[int, Field(ge=0)]] = None - - @field_validator("sentence_window_size") - def check_sentence_window_size(cls, v, values, **kwargs): - if v is not None and v > 0 and values.data["split_by"] != "sentence": - raise ValueError("When using sentence_window_size, split_by must be 'sentence'.") + tokenizer: str + chunk_size: Annotated[int, Field(gt=0)] + chunk_overlap: Annotated[int, Field(ge=0)] + params: dict + + @field_validator("chunk_overlap") + def check_chunk_overlap(cls, v, values, **kwargs): + if v is not None and "chunk_size" in values.data and v >= values.data["chunk_size"]: + raise ValueError("chunk_overlap must be less than chunk_size") return v diff --git a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py index fe5debd6..1ce2e469 100644 --- a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py +++ b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py @@ -18,7 +18,7 @@ from nv_ingest.schemas.message_broker_sink_schema import MessageBrokerTaskSinkSchema from nv_ingest.schemas.message_broker_source_schema import MessageBrokerTaskSourceSchema from nv_ingest.schemas.metadata_injector_schema import MetadataInjectorSchema -from nv_ingest.schemas.nemo_doc_splitter_schema import DocumentSplitterSchema +from nv_ingest.schemas.text_splitter_schema import TextSplitterSchema from nv_ingest.schemas.otel_meter_schema import OpenTelemetryMeterSchema from nv_ingest.schemas.otel_tracer_schema import OpenTelemetryTracerSchema from nv_ingest.schemas.pdf_extractor_schema import PDFExtractorSchema @@ -30,7 +30,7 @@ class PipelineConfigSchema(BaseModel): chart_extractor_module: ChartExtractorSchema = ChartExtractorSchema() - document_splitter_module: DocumentSplitterSchema = DocumentSplitterSchema() + text_splitter_module: TextSplitterSchema = TextSplitterSchema() embedding_storage_module: EmbeddingStorageModuleSchema = EmbeddingStorageModuleSchema() embed_extractions_module: EmbedExtractionsSchema = EmbedExtractionsSchema() image_caption_extraction_module: ImageCaptionExtractionSchema = ImageCaptionExtractionSchema() diff --git a/src/nv_ingest/schemas/nemo_doc_splitter_schema.py b/src/nv_ingest/schemas/nemo_doc_splitter_schema.py deleted file mode 100644 index 58b6a0b4..00000000 --- a/src/nv_ingest/schemas/nemo_doc_splitter_schema.py +++ /dev/null @@ -1,26 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -from typing import Literal -from typing import Optional - -from pydantic import Field, BaseModel, field_validator - -from typing_extensions import Annotated - - -class DocumentSplitterSchema(BaseModel): - split_by: Literal["word", "sentence", "passage"] = "word" - split_length: Annotated[int, Field(gt=0)] = 60 - split_overlap: Annotated[int, Field(ge=0)] = 10 - max_character_length: Optional[Annotated[int, Field(gt=0)]] = 450 - sentence_window_size: Optional[Annotated[int, Field(ge=0)]] = 0 - raise_on_failure: bool = False - - @field_validator("sentence_window_size") - def check_sentence_window_size(cls, v, values, **kwargs): - if v is not None and v > 0 and values.data["split_by"] != "sentence": - raise ValueError("When using sentence_window_size, split_by must be 'sentence'.") - return v diff --git a/src/nv_ingest/schemas/text_splitter_schema.py b/src/nv_ingest/schemas/text_splitter_schema.py new file mode 100644 index 00000000..7eb6d93c --- /dev/null +++ b/src/nv_ingest/schemas/text_splitter_schema.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import Field, BaseModel, field_validator + +from typing_extensions import Annotated + + +class TextSplitterSchema(BaseModel): + tokenizer: str = "meta-llama/Llama-3.2-1B" + chunk_size: Annotated[int, Field(gt=0)] = 1024 + chunk_overlap: Annotated[int, Field(ge=0)] = 150 + raise_on_failure: bool = False + + @field_validator("chunk_overlap") + def check_chunk_overlap(cls, v, values, **kwargs): + if v is not None and "chunk_size" in values.data and v >= values.data["chunk_size"]: + raise ValueError("chunk_overlap must be less than chunk_size") + return v diff --git a/src/nv_ingest/util/pipeline/__init__.py b/src/nv_ingest/util/pipeline/__init__.py index 5754c667..6957b2ea 100644 --- a/src/nv_ingest/util/pipeline/__init__.py +++ b/src/nv_ingest/util/pipeline/__init__.py @@ -17,7 +17,7 @@ add_table_extractor_stage, add_chart_extractor_stage, add_image_caption_stage, - add_nemo_splitter_stage, + add_text_splitter_stage, add_embed_extractions_stage, add_embedding_storage_stage, add_image_storage_stage, @@ -39,7 +39,7 @@ "add_table_extractor_stage", "add_chart_extractor_stage", "add_image_caption_stage", - "add_nemo_splitter_stage", + "add_text_splitter_stage", "add_embed_extractions_stage", "add_embedding_storage_stage", "add_image_storage_stage", diff --git a/src/nv_ingest/util/pipeline/pipeline_builders.py b/src/nv_ingest/util/pipeline/pipeline_builders.py index efeca97f..36624bc1 100644 --- a/src/nv_ingest/util/pipeline/pipeline_builders.py +++ b/src/nv_ingest/util/pipeline/pipeline_builders.py @@ -47,7 +47,7 @@ def setup_ingestion_pipeline( ######################################################################################################## ## Transforms and data synthesis ######################################################################################################## - nemo_splitter_stage = add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config) + text_splitter_stage = add_text_splitter_stage(pipe, morpheus_pipeline_config, ingest_config) embed_extractions_stage = add_embed_extractions_stage(pipe, morpheus_pipeline_config, ingest_config) ######################################################################################################## ## Storage and output @@ -80,8 +80,8 @@ def setup_ingestion_pipeline( pipe.add_edge(image_dedup_stage, image_filter_stage) pipe.add_edge(image_filter_stage, table_extraction_stage) pipe.add_edge(table_extraction_stage, chart_extraction_stage) - pipe.add_edge(chart_extraction_stage, nemo_splitter_stage) - pipe.add_edge(nemo_splitter_stage, image_caption_stage) + pipe.add_edge(chart_extraction_stage, text_splitter_stage) + pipe.add_edge(text_splitter_stage, image_caption_stage) pipe.add_edge(image_caption_stage, embed_extractions_stage) pipe.add_edge(embed_extractions_stage, image_storage_stage) pipe.add_edge(image_storage_stage, embedding_storage_stage) diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 2b3fb2ba..f6cb8745 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -20,7 +20,7 @@ from nv_ingest.modules.telemetry.job_counter import JobCounterLoaderFactory from nv_ingest.modules.telemetry.otel_meter import OpenTelemetryMeterLoaderFactory from nv_ingest.modules.telemetry.otel_tracer import OpenTelemetryTracerLoaderFactory -from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory +from nv_ingest.modules.transforms.text_splitter import TextSplitterLoaderFactory from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage from nv_ingest.stages.extractors.image_extractor_stage import generate_image_extractor_stage from nv_ingest.stages.filters import generate_dedup_stage @@ -359,15 +359,15 @@ def add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, defaul return image_filter_stage -def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): - nemo_splitter_loader = NemoDocSplitterLoaderFactory.get_instance( - module_name="nemo_doc_splitter", +def add_text_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): + text_splitter_loader = TextSplitterLoaderFactory.get_instance( + module_name="text_splitter", module_config=ingest_config.get("text_splitting_module", {}), ) - nemo_splitter_stage = pipe.add_stage( + text_splitter_stage = pipe.add_stage( LinearModulesStage( morpheus_pipeline_config, - nemo_splitter_loader, + text_splitter_loader, input_type=ControlMessage, output_type=ControlMessage, input_port_name="input", @@ -375,7 +375,7 @@ def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): ) ) - return nemo_splitter_stage + return text_splitter_stage def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): diff --git a/tests/nv_ingest/modules/sources/test_message_broker_task_source.py b/tests/nv_ingest/modules/sources/test_message_broker_task_source.py index cdb389e2..3cb9297e 100644 --- a/tests/nv_ingest/modules/sources/test_message_broker_task_source.py +++ b/tests/nv_ingest/modules/sources/test_message_broker_task_source.py @@ -39,9 +39,10 @@ def job_payload(): { "type": "split", "task_properties": { - "split_by": "word", - "split_length": 100, - "split_overlap": 0, + "tokenizer": "intfloat/e5-large-unsupervised", + "chunk_size": 100, + "chunk_overlap": 0, + "params": {}, }, }, { diff --git a/tests/nv_ingest/schemas/test_ingest_job_schema.py b/tests/nv_ingest/schemas/test_ingest_job_schema.py index d0338666..f9d559ca 100644 --- a/tests/nv_ingest/schemas/test_ingest_job_schema.py +++ b/tests/nv_ingest/schemas/test_ingest_job_schema.py @@ -26,11 +26,10 @@ def valid_task_properties(task_type): """Returns valid task properties based on the task type.""" if task_type == TaskTypeEnum.split: return { - "split_by": "sentence", - "split_length": 10, - "split_overlap": 0, - "max_character_length": 100, - "sentence_window_size": None, # This is valid when not required + "tokenizer": "intfloat/e5-large-unsupervised", + "chunk_size": 300, + "chunk_overlap": 0, + "params": {}, } elif task_type == TaskTypeEnum.extract: return {"document_type": "pdf", "method": "OCR", "params": {"language": "en"}} @@ -117,14 +116,14 @@ def test_field_type_correctness(): def test_custom_validator_logic_for_sentence_window_size(): - """Tests custom validator logic related to sentence_window_size in split tasks.""" + """Tests custom validator logic related to chunk_size and chunk_overlap in split tasks.""" task = { "type": "split", "task_properties": { - "split_by": "word", # Incorrect usage of sentence_window_size - "split_length": 10, - "split_overlap": 5, - "sentence_window_size": 5, # Should not be set when split_by is not 'sentence' + "tokanizer": "intfloat/e5-large-unsupervised", + "chunk_size": 200, + "chunk_overlap": 250, # chunk_overlap should always be less than chunk_size + "params": {}, }, } job_data = { @@ -134,7 +133,7 @@ def test_custom_validator_logic_for_sentence_window_size(): } with pytest.raises(ValidationError) as exc_info: validate_ingest_job(job_data) - assert "sentence_window_size" in str(exc_info.value) and "must be 'sentence'" in str(exc_info.value) + assert "chunk_overlap must be less than chunk_size" in str(exc_info.value) def test_multiple_task_types(): @@ -150,9 +149,10 @@ def test_multiple_task_types(): { "type": "split", "task_properties": { - "split_by": "word", - "split_length": 100, - "split_overlap": 0, + "tokenizer": "intfloat/e5-large-unsupervised", + "chunk_size": 100, + "chunk_overlap": 0, + "params": {}, }, }, { @@ -244,9 +244,10 @@ def test_incorrect_property_types(): { "type": "split", "task_properties": { - "split_by": "word", - "split_length": {"not an int": 123}, # Incorrect type (should be int) - "split_overlap": 0, + "tokenizer": "intfloat/e5-large-unsupervised", + "chunk_size": {"not an int": 123}, # Incorrect type (should be int) + "chunk_overlap": 0, + "params": {}, }, } ], @@ -263,8 +264,9 @@ def test_missing_required_fields(): { "type": "split", "task_properties": { - "split_by": "sentence", # Missing split_length - "split_overlap": 0, + "tokenizer": "intfloat/e5-large-unsupervised", # Missing chunk_size + "chunk_overlap": 0, + "params": {}, }, } ], diff --git a/tests/nv_ingest/schemas/test_nemo_doc_splitter_schema.py b/tests/nv_ingest/schemas/test_nemo_doc_splitter_schema.py deleted file mode 100644 index 00f4e606..00000000 --- a/tests/nv_ingest/schemas/test_nemo_doc_splitter_schema.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest -from pydantic import ValidationError - -from nv_ingest.schemas import DocumentSplitterSchema - - -def test_document_splitter_schema_defaults(): - """ - Test the DocumentSplitterSchema with default values. - """ - schema = DocumentSplitterSchema() - assert schema.split_by == "word" - assert schema.split_length == 60 - assert schema.split_overlap == 10 - assert schema.max_character_length == 450 - assert schema.sentence_window_size == 0 - assert schema.raise_on_failure is False - - -@pytest.mark.parametrize("invalid_value", [-1, 0]) -def test_document_splitter_schema_invalid_split_length(invalid_value): - """ - Test DocumentSplitterSchema with invalid split_length values. - """ - with pytest.raises(ValidationError): - DocumentSplitterSchema(split_length=invalid_value) - - -@pytest.mark.parametrize( - "split_by, sentence_window_size, is_valid", - [ - ("sentence", 5, True), # Valid use of sentence_window_size - ( - "word", - 0, - True, - ), # Valid when split_by is not 'sentence' but sentence_window_size is 0 - ( - "word", - 5, - False, - ), # Invalid because sentence_window_size > 0 requires split_by to be 'sentence' - ], -) -def test_document_splitter_schema_sentence_window_size_validation(split_by, sentence_window_size, is_valid): - """ - Parametrized test for validating the sentence_window_size logic in DocumentSplitterSchema. - """ - if is_valid: - schema = DocumentSplitterSchema(split_by=split_by, sentence_window_size=sentence_window_size) - assert schema.sentence_window_size == sentence_window_size - assert schema.split_by == split_by - else: - with pytest.raises(ValidationError) as excinfo: - DocumentSplitterSchema(split_by=split_by, sentence_window_size=sentence_window_size) - assert "split_by must be 'sentence'" in str(excinfo.value) - - -def test_document_splitter_schema_optional_fields_none(): - """ - Test DocumentSplitterSchema with optional fields set to None. - """ - schema = DocumentSplitterSchema(max_character_length=None, sentence_window_size=None) - assert schema.max_character_length is None - assert schema.sentence_window_size is None diff --git a/tests/nv_ingest/schemas/test_text_splitter_schema.py b/tests/nv_ingest/schemas/test_text_splitter_schema.py new file mode 100644 index 00000000..8cfc9d97 --- /dev/null +++ b/tests/nv_ingest/schemas/test_text_splitter_schema.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pydantic import ValidationError + +from nv_ingest.schemas import TextSplitterSchema + + +def test_text_splitter_schema_defaults(): + """ + Test the TextSplitterSchema with default values. + """ + schema = TextSplitterSchema() + assert schema.tokenizer == "meta-llama/Llama-3.2-1B" + assert schema.chunk_size == 1024 + assert schema.chunk_overlap == 150 + assert schema.raise_on_failure is False + + +def test_text_splitter_schema_custom_values(): + """ + Test the TextSplitterSchema with custom values. + """ + tokenizer = "meta-llama/Llama-3.2-1B" + chunk_size = 500 + chunk_overlap = 10 + schema = TextSplitterSchema( + tokenizer=tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap, raise_on_failure=True + ) + assert schema.tokenizer == tokenizer + assert schema.chunk_size == chunk_size + assert schema.chunk_overlap == chunk_overlap + assert schema.raise_on_failure is True + + +@pytest.mark.parametrize("invalid_value", [50, 5.5]) +def test_text_splitter_schema_invalid_tokenizer(invalid_value): + """ + Test TextSplitterSchema with invalid tokenizer values. + """ + with pytest.raises(ValidationError): + TextSplitterSchema(tokenizer=invalid_value) + + +@pytest.mark.parametrize("invalid_value", [-1, 0]) +def test_text_splitter_schema_invalid_chunk_size(invalid_value): + """ + Test TextSplitterSchema with invalid chunk_size values. + """ + with pytest.raises(ValidationError): + TextSplitterSchema(chunk_size=invalid_value) + + +@pytest.mark.parametrize("invalid_value", [-1, "a"]) +def test_text_splitter_schema_invalid_chunk_overlap(invalid_value): + """ + Test TextSplitterSchema with invalid chunk_overlap values. + """ + with pytest.raises(ValidationError): + TextSplitterSchema(chunk_overlap=invalid_value) + + +@pytest.mark.parametrize( + "chunk_size, chunk_overlap, is_valid", + [ + (300, 50, True), + (150, 0, True), + (100, 100, False), + (50, 200, False), + ], +) +def test_text_splitter_schema_chunk_overlap_validation(chunk_size, chunk_overlap, is_valid): + """ + Parametrized test for validating the chunk_overlap logic in TextSplitterSchema. + """ + if is_valid: + schema = TextSplitterSchema(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + assert schema.chunk_size == chunk_size + assert schema.chunk_overlap == chunk_overlap + else: + with pytest.raises(ValidationError) as excinfo: + TextSplitterSchema(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + assert "chunk_overlap must be less than chunk_size" in str(excinfo.value) diff --git a/tests/nv_ingest_client/cli/util/test_click.py b/tests/nv_ingest_client/cli/util/test_click.py index f4f15ebd..a8ed1b0a 100644 --- a/tests/nv_ingest_client/cli/util/test_click.py +++ b/tests/nv_ingest_client/cli/util/test_click.py @@ -93,7 +93,7 @@ def test_debug_print_click_options(mock_pprint): def test_validate_task_with_valid_split(): """Test with valid split task options.""" - value = ['split:{"split_by": "page", "split_length": 10}'] + value = ['split:{"tokenizer": "intfloat/e5-large-unsupervised", "chunk_size": 300}'] result = click_validate_task(None, None, value) assert "split" in result diff --git a/tests/nv_ingest_client/client/test_client.py b/tests/nv_ingest_client/client/test_client.py index 05c90425..5a45936c 100644 --- a/tests/nv_ingest_client/client/test_client.py +++ b/tests/nv_ingest_client/client/test_client.py @@ -276,7 +276,7 @@ def test_correct_storage_of_job_details(nv_ingest_client): def test_successful_task_creation(nv_ingest_client_with_jobs): job_id = "12345678-1234-5678-1234-567812345678" task_type = TaskType.SPLIT - task_params = {"split_by": "sentence"} + task_params = {"tokenizer": "intfloat/e5-large-unsupervised"} # Assuming task_factory and task creation are implemented nv_ingest_client_with_jobs.create_task(job_id, task_type, task_params) @@ -288,7 +288,9 @@ def test_successful_task_creation(nv_ingest_client_with_jobs): def test_non_existent_job(nv_ingest_client): with pytest.raises(ValueError): - nv_ingest_client.create_task("nonexistent_job_id", TaskType.SPLIT, {"split_by": "sentence"}) + nv_ingest_client.create_task( + "nonexistent_job_id", TaskType.SPLIT, {"tokenizer": "intfloat/e5-large-unsupervised"} + ) def test_add_task_post_submission(nv_ingest_client_with_jobs): @@ -297,13 +299,13 @@ def test_add_task_post_submission(nv_ingest_client_with_jobs): nv_ingest_client_with_jobs._job_states[job_id].state = JobStateEnum.PROCESSING with pytest.raises(ValueError): - nv_ingest_client_with_jobs.create_task(job_id, TaskType.SPLIT, {"split_by": "sentence"}) + nv_ingest_client_with_jobs.create_task(job_id, TaskType.SPLIT, {"tokenizer": "intfloat/e5-large-unsupervised"}) def test_parameter_validation(nv_ingest_client_with_jobs): job_id = "12345678-1234-5678-1234-567812345678" task_type = TaskType.SPLIT - task_params = {"split_by": "sentence", "split_length": 128} + task_params = {"tokenizer": "intfloat/e5-large-unsupervised", "chunk_size": 128} nv_ingest_client_with_jobs.create_task(job_id, task_type, task_params) job_state = nv_ingest_client_with_jobs._job_states[job_id] @@ -580,8 +582,8 @@ def test_create_jobs_for_batch_duplicate_task(nv_ingest_client, mock_create_job_ files = ["file1.pdf"] duplicate_tasks = { - "split": SplitTask(split_by="sentence"), - "store": SplitTask(split_by="sentence"), # Duplicate task + "split": SplitTask(tokenizer="intfloat/e5-large-unsupervised"), + "store": SplitTask(tokenizer="intfloat/e5-large-unsupervised"), # Duplicate task } with pytest.raises(ValueError, match="Duplicate task detected"): diff --git a/tests/nv_ingest_client/client/test_interface.py b/tests/nv_ingest_client/client/test_interface.py index 83a556b8..8d8e810b 100644 --- a/tests/nv_ingest_client/client/test_interface.py +++ b/tests/nv_ingest_client/client/test_interface.py @@ -155,12 +155,12 @@ def test_split_task_no_args(ingestor): def test_split_task_some_args(ingestor): - ingestor.split(split_by="word", split_length=42) + ingestor.split(tokenizer="intfloat/e5-large-unsupervised", chunk_size=42) task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] assert isinstance(task, SplitTask) - assert task._split_by == "word" - assert task._split_length == 42 + assert task._tokenizer == "intfloat/e5-large-unsupervised" + assert task._chunk_size == 42 def test_store_task_no_args(ingestor): diff --git a/tests/nv_ingest_client/primitives/tasks/test_split.py b/tests/nv_ingest_client/primitives/tasks/test_split.py index 3fb0dbeb..533ef0c5 100644 --- a/tests/nv_ingest_client/primitives/tasks/test_split.py +++ b/tests/nv_ingest_client/primitives/tasks/test_split.py @@ -10,31 +10,24 @@ def test_split_task_initialization(): task = SplitTask( - split_by="word", - split_length=100, - split_overlap=10, - max_character_length=1000, - sentence_window_size=5, + tokenizer="meta-llama/Llama-3.2-1B", + chunk_size=1024, + chunk_overlap=0, + params={}, ) - assert task._split_by == "word" - assert task._split_length == 100 - assert task._split_overlap == 10 - assert task._max_character_length == 1000 - assert task._sentence_window_size == 5 + assert task._tokenizer == "meta-llama/Llama-3.2-1B" + assert task._chunk_size == 1024 + assert task._chunk_overlap == 0 + assert task._params == {} # String Representation Tests def test_split_task_str_representation(): - task = SplitTask(split_by="sentence", split_length=50, split_overlap=5) + task = SplitTask(tokenizer="intfloat/e5-large-unsupervised", chunk_size=50, chunk_overlap=5) expected_str = ( - "Split Task:\n" - " split_by: sentence\n" - " split_length: 50\n" - " split_overlap: 5\n" - " split_max_character_length: None\n" - " split_sentence_window_size: None\n" + "Split Task:\n" " tokenizer: intfloat/e5-large-unsupervised\n" " chunk_size: 50\n" " chunk_overlap: 5\n" ) assert str(task) == expected_str @@ -43,42 +36,37 @@ def test_split_task_str_representation(): @pytest.mark.parametrize( - "split_by, split_length, split_overlap, max_character_length, sentence_window_size", + "tokenizer, chunk_size, chunk_overlap, params", [ - ("word", 100, 10, 1000, 5), - ("sentence", 50, 5, None, None), - ("passage", None, None, 1500, 3), - (None, None, None, None, None), # Test default parameters + ("intfloat/e5-large-unsupervised", 100, 10, {}), + ("microsoft/deberta-large", 50, 5, None), + ("meta-llama/Llama-3.2-1B", 1024, 0, {"hf_access_token": "TOKEN"}), ], ) def test_split_task_to_dict( - split_by, - split_length, - split_overlap, - max_character_length, - sentence_window_size, + tokenizer, + chunk_size, + chunk_overlap, + params, ): task = SplitTask( - split_by=split_by, - split_length=split_length, - split_overlap=split_overlap, - max_character_length=max_character_length, - sentence_window_size=sentence_window_size, + tokenizer=tokenizer, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + params=params, ) expected_dict = {"type": "split", "task_properties": {}} # Only add properties to expected_dict if they are not None - if split_by is not None: - expected_dict["task_properties"]["split_by"] = split_by - if split_length is not None: - expected_dict["task_properties"]["split_length"] = split_length - if split_overlap is not None: - expected_dict["task_properties"]["split_overlap"] = split_overlap - if max_character_length is not None: - expected_dict["task_properties"]["max_character_length"] = max_character_length - if sentence_window_size is not None: - expected_dict["task_properties"]["sentence_window_size"] = sentence_window_size + if tokenizer is not None: + expected_dict["task_properties"]["tokenizer"] = tokenizer + if chunk_size is not None: + expected_dict["task_properties"]["chunk_size"] = chunk_size + if chunk_overlap is not None: + expected_dict["task_properties"]["chunk_overlap"] = chunk_overlap + if params is not None: + expected_dict["task_properties"]["params"] = params assert task.to_dict() == expected_dict, "The to_dict method did not return the expected dictionary representation" @@ -89,14 +77,20 @@ def test_split_task_to_dict( def test_split_task_default_params(): task = SplitTask() expected_str_contains = [ - "split_by: None", - "split_length: None", - "split_overlap: None", - "split_max_character_length: None", - "split_sentence_window_size: None", + "tokenizer: meta-llama/Llama-3.2-1B", + "chunk_size: 1024", + "chunk_overlap: 150", ] for expected_part in expected_str_contains: assert expected_part in str(task) - expected_dict = {"type": "split", "task_properties": {}} + expected_dict = { + "type": "split", + "task_properties": { + "tokenizer": "meta-llama/Llama-3.2-1B", + "chunk_size": 1024, + "chunk_overlap": 150, + "params": {}, + }, + } assert task.to_dict() == expected_dict