-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch split task to token based splitting #283
Open
ChrisJar
wants to merge
29
commits into
NVIDIA:main
Choose a base branch
from
ChrisJar:token-split
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
918740b
Switch split task to token based splitting
ChrisJar f45d7c9
Merge branch 'main' into token-split
c12df54
Move tokenizer out of loop
c88a4a2
Fix CLI
0cacc9e
Add chunk_overlap parameter
2e495b2
Merge in main
d8a2de9
Fix broken tests
a21d065
Add chunk overlap
bc2bf48
Rename nemo document splitter to text splitter
cd18083
Temp fix
081dc4e
Merge remote-tracking branch 'upstream/main' into token-split
bab7f3d
Address reviews
125bb38
Merge remote-tracking branch 'upstream/main' into token-split
e09723f
Merge remote-tracking branch 'upstream/main' into token-split
289998c
Merge remote-tracking branch 'upstream/main' into token-split
e86c539
Merge remote-tracking branch 'upstream/main' into token-split
2f4b979
Change default chunk_size to 1024
b052942
Change default chunk_overlap to 20
317a426
Pass huggingface access token as param
e250eff
Add llama license notice
3c2f246
Add support for filtering by file type
cf6d125
Merge remote-tracking branch 'upstream/main' into token-split
6e32a9c
Fix offset mapping
0825feb
Merge remote-tracking branch 'upstream/main' into token-split
67b7051
Add built with llama
b6829ed
Merge upstream/main into token-split
7634030
Change default chunk_overlap to 150
66eb642
Merge remote-tracking branch 'upstream/main' into token-split
d6401bc
Add option to predownload llama tokenizer
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import os | ||
from transformers import AutoTokenizer | ||
|
||
if os.getenv("DOWNLOAD_LLAMA_TOKENIZER") == "True": | ||
tokenizer_path = "/workspace/models/tokenizer/" | ||
os.makedirs(tokenizer_path, exist_ok=True) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", token=os.getenv("HF_ACCESS_TOKEN")) | ||
tokenizer.save_pretrained(tokenizer_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import copy | ||
import logging | ||
import traceback | ||
import uuid | ||
from typing import Any | ||
from typing import List | ||
|
||
import mrc | ||
import pandas as pd | ||
from transformers import AutoTokenizer | ||
from morpheus.messages import ControlMessage | ||
from morpheus.messages import MessageMeta | ||
from morpheus.utils.control_message_utils import cm_skip_processing_if_failed | ||
from morpheus.utils.module_utils import ModuleLoaderFactory | ||
from morpheus.utils.module_utils import register_module | ||
from mrc.core import operators as ops | ||
from pydantic import BaseModel | ||
|
||
import cudf | ||
|
||
from nv_ingest.schemas.metadata_schema import ContentTypeEnum | ||
from nv_ingest.schemas.text_splitter_schema import TextSplitterSchema | ||
from nv_ingest.util.exception_handlers.decorators import nv_ingest_node_failure_context_manager | ||
from nv_ingest.util.flow_control import filter_by_task | ||
from nv_ingest.util.modules.config_validator import fetch_and_validate_module_config | ||
from nv_ingest.util.tracing import traceable | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _build_split_documents(row, chunks: List[str]) -> List[dict[str, Any]]: | ||
"""Build documents from text chunks""" | ||
documents: List[dict] = [] | ||
|
||
for i, text in enumerate(chunks): | ||
if text is None or not text.strip(): | ||
continue | ||
|
||
metadata = row.metadata if hasattr(row, "metadata") and isinstance(row.metadata, dict) else {} | ||
metadata = copy.deepcopy(metadata) | ||
|
||
metadata["content"] = text | ||
|
||
documents.append({"document_type": ContentTypeEnum.TEXT.value, "metadata": metadata, "uuid": str(uuid.uuid4())}) | ||
|
||
return documents | ||
|
||
|
||
def _split_into_chunks(text, tokenizer, chunk_size=1024, chunk_overlap=20): | ||
# Tokenize the text into token IDs | ||
encoding = tokenizer.encode_plus(text, add_special_tokens=False, return_offsets_mapping=True) | ||
|
||
# Get the token IDs and offsets for splitting | ||
offsets = encoding["offset_mapping"] | ||
|
||
# Split the tokens into chunks of the desired size with the desired overlap | ||
chunks = [offsets[i : i + chunk_size] for i in range(0, len(offsets), chunk_size - chunk_overlap)] | ||
|
||
# Convert token chunks back to text while preserving original spacing and case | ||
text_chunks = [] | ||
for chunk in chunks: | ||
text_chunk = text[chunk[0][0] : chunk[-1][0]] | ||
text_chunks.append(text_chunk) | ||
|
||
return text_chunks | ||
|
||
|
||
MODULE_NAME = "text_splitter" | ||
MODULE_NAMESPACE = "nv_ingest" | ||
|
||
TextSplitterLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE, TextSplitterSchema) | ||
|
||
|
||
@register_module(MODULE_NAME, MODULE_NAMESPACE) | ||
def _text_splitter(builder: mrc.Builder): | ||
""" | ||
A pipeline module that splits documents into smaller parts based on the specified criteria. | ||
""" | ||
|
||
validated_config = fetch_and_validate_module_config(builder, TextSplitterSchema) | ||
|
||
@filter_by_task(["split"]) | ||
@traceable(MODULE_NAME) | ||
@cm_skip_processing_if_failed | ||
@nv_ingest_node_failure_context_manager( | ||
annotation_id=MODULE_NAME, | ||
raise_on_failure=validated_config.raise_on_failure, | ||
) | ||
def split_and_forward(message: ControlMessage): | ||
try: | ||
# Assume that df is going to have a 'content' column | ||
task_props = message.remove_task("split") | ||
|
||
if isinstance(task_props, BaseModel): | ||
task_props = task_props.model_dump() | ||
|
||
# Validate that all 'content' values are not None | ||
with message.payload().mutable_dataframe() as mdf: | ||
df = mdf.to_pandas() | ||
|
||
# Filter to document type | ||
bool_index = df["document_type"] == ContentTypeEnum.TEXT | ||
df_filtered = df.loc[bool_index] | ||
|
||
if df_filtered.empty: | ||
return message | ||
|
||
# Override parameters if set | ||
tokenizer = task_props.get("tokenizer", validated_config.tokenizer) | ||
chunk_size = task_props.get("chunk_size", validated_config.chunk_size) | ||
chunk_overlap = task_props.get("chunk_overlap", validated_config.chunk_overlap) | ||
params = task_props.get("params", {}) | ||
|
||
hf_access_token = params.get("hf_access_token", None) | ||
split_source_types = params.get("split_source_types", ["TEXT"]) | ||
|
||
logger.debug( | ||
f"Splitting text with tokenizer: {tokenizer}, " | ||
f"chunk_size: {chunk_size} tokens, " | ||
f"chunk_overlap: {chunk_overlap}" | ||
) | ||
|
||
# Filter to file type | ||
bool_index = pd.json_normalize(df_filtered["metadata"])["source_metadata.source_type"].isin( | ||
split_source_types | ||
) | ||
df_filtered = df_filtered.loc[bool_index] | ||
|
||
if df_filtered.empty: | ||
return message | ||
|
||
tokenizer_model = AutoTokenizer.from_pretrained(tokenizer, token=hf_access_token) | ||
|
||
split_docs = [] | ||
for _, row in df_filtered.iterrows(): | ||
content = row["metadata"]["content"] if row["metadata"]["content"] is not None else "" | ||
|
||
chunks = _split_into_chunks(content, tokenizer_model, chunk_size, chunk_overlap) | ||
split_docs.extend(_build_split_documents(row, chunks)) | ||
|
||
split_docs_df = pd.DataFrame(split_docs) | ||
|
||
# Return both processed text and other document types | ||
split_docs_df = pd.concat([split_docs_df, df[~bool_index]], axis=0).reset_index(drop=True) | ||
# Update control message with new payload | ||
split_docs_gdf = cudf.from_pandas(split_docs_df) | ||
|
||
message_meta = MessageMeta(df=split_docs_gdf) | ||
message.payload(message_meta) | ||
|
||
return message | ||
except Exception as e: | ||
traceback.print_exc() | ||
raise ValueError(f"Failed to split documents: {e}") | ||
|
||
split_node = builder.make_node("split_and_forward", ops.map(split_and_forward)) | ||
|
||
# Register the input and output of the module | ||
builder.register_module_input("input", split_node) | ||
builder.register_module_output("output", split_node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matching documents should be collected and processed all at once as opposed to iterating rows; and we should actually make this a multiprocessing stage so we're able to use the worker pool for CPU bound tasks. But we can hold of if this needs to go in sooner.