Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: PreProcessor split by token (tiktoken & Hugging Face) #5276

Merged
merged 10 commits into from
Nov 23, 2023
14 changes: 10 additions & 4 deletions haystack/nodes/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from abc import abstractmethod

from transformers import PreTrainedTokenizerBase

from haystack.nodes.base import BaseComponent
from haystack.schema import Document

Expand All @@ -17,10 +19,11 @@ def process(
clean_header_footer: Optional[bool] = False,
clean_empty_lines: Optional[bool] = True,
remove_substrings: Optional[List[str]] = None,
split_by: Literal["word", "sentence", "passage", None] = "word",
split_by: Literal["token", "word", "sentence", "passage", None] = "word",
split_length: Optional[int] = 1000,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = True,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
Expand All @@ -44,10 +47,11 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: Literal["word", "sentence", "passage", None],
split_by: Literal["token", "word", "sentence", "passage", None],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
) -> List[Document]:
raise NotImplementedError

Expand All @@ -57,10 +61,11 @@ def run( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
processed_documents = self.process(
Expand All @@ -83,10 +88,11 @@ def run_batch( # type: ignore
clean_whitespace: Optional[bool] = None,
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
split_by: Literal["word", "sentence", "passage", None] = None,
split_by: Literal["token", "word", "sentence", "passage", None] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
):
return self.run(
Expand Down
112 changes: 84 additions & 28 deletions haystack/nodes/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal
from typing import List, Optional, Generator, Set, Union, Tuple, Dict, Literal, Callable, Any

import logging
import re
Expand All @@ -17,14 +17,18 @@
from haystack.schema import Document
from haystack.lazy_imports import LazyImport

with LazyImport("Run 'pip install transformers'") as transformers_import:
from transformers import PreTrainedTokenizerBase
from transformers import AutoTokenizer

logger = logging.getLogger(__name__)
with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken

logger = logging.getLogger(__name__)

with LazyImport("Run 'pip install farm-haystack[preprocessing]' or 'pip install nltk'") as nltk_import:
import nltk


iso639_to_nltk = {
"ru": "russian",
"sl": "slovene",
Expand Down Expand Up @@ -55,11 +59,12 @@ def __init__(
clean_header_footer: bool = False,
clean_empty_lines: bool = True,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = "word",
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = "word",
split_length: int = 200,
split_overlap: int = 0,
split_respect_sentence_boundary: bool = True,
tokenizer_model_folder: Optional[Union[str, Path]] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = "tiktoken",
language: str = "en",
id_hash_keys: Optional[List[str]] = None,
progress_bar: bool = True,
Expand All @@ -86,6 +91,9 @@ def __init__(
:param split_respect_sentence_boundary: Whether to split in partial sentences if split_by -> `word`. If set
to True, the individual split will always have complete sentences &
the number of words will be <= split_length.
:param tokenizer: Specifies the tokenizer to use if split_by="token". Supported options are "tiktoken"
(for OpenAI's GPT-3.5 and GPT-4) and any HuggingFace tokenizer (e.g. 'bert-base-uncased').
HuggingFace tokenizers can also be passed directly as an PreTrainedTokenizerBase object.
:param language: The language used by "nltk.tokenize.sent_tokenize" in iso639 format.
Available options: "ru","sl","es","sv","tr","cs","da","nl","en","et","fi","fr","de","el","it","no","pl","pt","ml"
:param tokenizer_model_folder: Path to the folder containing the NTLK PunktSentenceTokenizer models, if loading a model from a local path. Leave empty otherwise.
Expand Down Expand Up @@ -124,6 +132,7 @@ def __init__(
self.split_length = split_length
self.split_overlap = split_overlap
self.split_respect_sentence_boundary = split_respect_sentence_boundary
self.tokenizer = tokenizer
self.language = language
self.tokenizer_model_folder = tokenizer_model_folder
self.print_log: Set[str] = set()
Expand All @@ -139,10 +148,11 @@ def process(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""
Expand All @@ -167,6 +177,7 @@ def process(
"split_length": split_length,
"split_overlap": split_overlap,
"split_respect_sentence_boundary": split_respect_sentence_boundary,
"tokenizer": tokenizer,
}

if id_hash_keys is None:
Expand Down Expand Up @@ -219,10 +230,11 @@ def _process_single(
clean_header_footer: Optional[bool] = None,
clean_empty_lines: Optional[bool] = None,
remove_substrings: Optional[List[str]] = None,
split_by: Optional[Literal["word", "sentence", "passage"]] = None,
split_by: Optional[Literal["token", "word", "sentence", "passage"]] = None,
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
if remove_substrings is None:
Expand All @@ -243,6 +255,8 @@ def _process_single(
split_overlap = self.split_overlap
if split_respect_sentence_boundary is None:
split_respect_sentence_boundary = self.split_respect_sentence_boundary
if tokenizer is None:
tokenizer = self.tokenizer

cleaned_document = self.clean(
document=document,
Expand All @@ -258,6 +272,7 @@ def _process_single(
split_length=split_length,
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
tokenizer=tokenizer,
id_hash_keys=id_hash_keys,
)

Expand Down Expand Up @@ -332,10 +347,11 @@ def clean(
def split(
self,
document: Union[dict, Document],
split_by: Optional[Literal["word", "sentence", "passage"]],
split_by: Optional[Literal["token", "word", "sentence", "passage"]],
split_length: int,
split_overlap: int,
split_respect_sentence_boundary: bool,
tokenizer: Optional[Union[str, PreTrainedTokenizerBase]] = None,
id_hash_keys: Optional[List[str]] = None,
) -> List[Document]:
"""Perform document splitting on a single document. This method can split on different units, at different lengths,
Expand All @@ -359,8 +375,10 @@ def split(
if not split_length:
raise Exception("split_length needs be set when using split_by.")

if split_respect_sentence_boundary and split_by != "word":
raise NotImplementedError("'split_respect_sentence_boundary=True' is only compatible with split_by='word'.")
if split_respect_sentence_boundary and split_by not in ["word", "token"]:
raise NotImplementedError(
"'split_respect_sentence_boundary=True' is only compatible with split_by='word' or 'token'."
)

if type(document.content) is not str:
logger.error("Document content is not of type str. Nothing to split.")
Expand All @@ -369,13 +387,17 @@ def split(
text = document.content
headlines = document.meta["headlines"] if "headlines" in document.meta else []

if split_respect_sentence_boundary and split_by == "word":
text_splits, splits_pages, splits_start_idxs = self._split_by_word_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap
if split_respect_sentence_boundary and split_by in ["word", "token"]:

def split_function(text):
return self._split_tokens(text, tokenizer=tokenizer) if split_by == "token" else text.split()

text_splits, splits_pages, splits_start_idxs = self._split_into_units_respecting_sent_boundary(
text=text, split_length=split_length, split_overlap=split_overlap, split_function=split_function
)
else:
# create individual "elements" of passage, sentence, or word
elements, split_at = self._split_into_units(text=text, split_by=split_by)
elements, split_at = self._split_into_units(text=text, split_by=split_by, tokenizer=tokenizer)

# concatenate individual elements based on split_length & split_stride
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
Expand Down Expand Up @@ -467,47 +489,47 @@ def _remove_substring(text: str, substring: str, headlines: List[Dict]) -> Tuple
cleaned_text = text.replace(substring, "")
return cleaned_text, headlines

def _split_by_word_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int
def _split_into_units_respecting_sent_boundary(
self, text: str, split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[int], List[int]]:
"""
Splits the text into parts of split_length words while respecting sentence boundaries.
"""
sentences = self._split_sentences(text)

word_count_slice = 0
unit_count_slice = 0
cur_page = 1
cur_start_idx = 0
splits_pages = []
list_splits = []
splits_start_idxs = []
current_slice: List[str] = []
for sen in sentences:
word_count_sen = len(sen.split())
unit_count_sen = len(split_function(sen))

if word_count_sen > split_length:
if unit_count_sen > split_length:
long_sentence_message = (
"We found one or more sentences whose word count is higher than the split length."
"We found one or more sentences whose split count is higher than the split length."
)
if long_sentence_message not in self.print_log:
self.print_log.add(long_sentence_message)
logger.warning(long_sentence_message)

if word_count_slice + word_count_sen > split_length:
if unit_count_slice + unit_count_sen > split_length:
# Number of words exceeds split_length -> save current slice and start a new one
if current_slice:
list_splits.append(current_slice)
splits_pages.append(cur_page)
splits_start_idxs.append(cur_start_idx)

if split_overlap:
processed_sents, current_slice, word_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap
processed_sents, current_slice, unit_count_slice = self._get_overlap_from_slice(
current_slice, split_length, split_overlap, split_function
)
else:
processed_sents = current_slice
current_slice = []
word_count_slice = 0
unit_count_slice = 0

cur_start_idx += len("".join(processed_sents))

Expand All @@ -522,7 +544,7 @@ def _split_by_word_respecting_sent_boundary(
cur_page += num_page_breaks

current_slice.append(sen)
word_count_slice += word_count_sen
unit_count_slice += unit_count_sen

if current_slice:
list_splits.append(current_slice)
Expand All @@ -539,7 +561,7 @@ def _split_by_word_respecting_sent_boundary(

@staticmethod
def _get_overlap_from_slice(
current_slice: List[str], split_length: int, split_overlap: int
current_slice: List[str], split_length: int, split_overlap: int, split_function: Callable
) -> Tuple[List[str], List[str], int]:
"""
Returns a tuple with the following elements:
Expand All @@ -553,7 +575,7 @@ def _get_overlap_from_slice(
current_slice_copy = deepcopy(current_slice)
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for idx, s in reversed(list(enumerate(current_slice))[1:]):
sen_len = len(s.split())
sen_len = len(split_function(s))
if word_count_overlap < split_overlap and sen_len < split_length:
overlap.append(s)
word_count_overlap += sen_len
Expand All @@ -566,7 +588,7 @@ def _get_overlap_from_slice(

return processed_sents, next_slice, word_count_slice

def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
def _split_into_units(self, text: str, split_by: str, tokenizer: Any) -> Tuple[List[str], str]:
if split_by == "passage":
elements = text.split("\n\n")
split_at = "\n\n"
Expand All @@ -576,8 +598,13 @@ def _split_into_units(self, text: str, split_by: str) -> Tuple[List[str], str]:
elif split_by == "word":
elements = text.split(" ")
split_at = " "
elif split_by == "token":
elements = self._split_tokens(text, tokenizer)
split_at = ""
else:
raise NotImplementedError("PreProcessor only supports 'passage', 'sentence' or 'word' split_by options.")
raise NotImplementedError(
"PreProcessor only supports 'passage', 'sentence', 'word' or 'token' split_by options."
)

return elements, split_at

Expand Down Expand Up @@ -823,6 +850,35 @@ def _split_sentences(self, text: str) -> List[str]:
sentences = sentence_tokenizer.tokenize(text)
return sentences

def _split_tokens(self, text: str, tokenizer: Any) -> List[str]:
if tokenizer == "tiktoken":
tiktoken_import.check()
enc = tiktoken.get_encoding("cl100k_base") # tiktoken is reversible and lossless
integer_tokens = enc.encode(text, disallowed_special=())
elements = [enc.decode_single_token_bytes(token).decode(errors="ignore") for token in integer_tokens]
return elements
if isinstance(tokenizer, str):
transformers_import.check()
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"Could not load tokenizer '{tokenizer}' from HuggingFace model hub. "
f"Please make sure that the tokenizer is correct and exists."
)
if isinstance(tokenizer, PreTrainedTokenizerBase):
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=False)
elements = []
for i in range(l := len(encoded.offset_mapping)):
start_current = encoded.offset_mapping[i][0]
start_next = encoded.offset_mapping[i + 1][0] if i < l - 1 else len(text)
elements.append(text[start_current:start_next])
return elements
raise ValueError(
f"Unsupported tokenizer specification {tokenizer}. "
f"Please provide either the string 'tiktoken' or a HuggingFace tokenizer (PreTrainedTokenizerBase)."
)

def _load_sentence_tokenizer(self, language_name: Optional[str]) -> "nltk.tokenize.punkt.PunktSentenceTokenizer":
# Try to load a custom model from 'tokenizer_model_path'
if self.tokenizer_model_folder is not None:
Expand Down
2 changes: 2 additions & 0 deletions releasenotes/notes/split-by-token-b9a4f954d4077ecc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
features:
- Add `split_length` by token in PreProcessor
Loading