Skip to content

Commit

Permalink
chore: docstrings update
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 3, 2024
1 parent a154a80 commit 7378350
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
32 changes: 26 additions & 6 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def onnx_embed(self, documents: List[str]) -> Tuple[np.ndarray, np.ndarray]:
embeddings = model_output[0]
return embeddings, attention_mask

def split_text(self, text: str) -> List[str]:
return self.splitter.split_text(text)
def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
return self.splitter.split_text(text=text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)


class EmbeddingWorker(Worker):
Expand Down Expand Up @@ -543,8 +543,18 @@ def embed(
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)

def split_text(self, text: str) -> List[str]:
return self.model.split_text(text)
def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
"""Splits text into chunks based on the tokenizer encoding size.
Args:
text (str): The text to split.
chunk_size (Optional[int], optional): Maximum size of chunks based on the tokenizer encoding.
chunk_overlap (Optional[int], optional): Allowed overlap in characters between chunks.
Returns:
List[str]: The list of strings.
"""
return self.model.split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down Expand Up @@ -673,8 +683,18 @@ def embed(
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

def split_text(self, text: str) -> List[str]:
return self.model.split_text(text)
def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
"""Splits text into chunks based on the tokenizer encoding size.
Args:
text (str): The text to split.
chunk_size (Optional[int], optional): Maximum size of chunks based on the tokenizer encoding.
chunk_overlap (Optional[int], optional): Allowed overlap in characters between chunks.
Returns:
List[str]: The list of strings.
"""
return self.model.split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
Expand Down
40 changes: 28 additions & 12 deletions fastembed/splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Custom implementation based on Langchain's text splitter
# Custom implementation based on the Langchain text splitter
# Reference: https://python.langchain.com/docs/modules/data_connection/document_transformers/recursive_text_splitter

import logging
Expand All @@ -15,6 +15,7 @@

logger = logging.getLogger(__name__)


def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> List[str]:
# Now that we have the separator, split the text
if separator:
Expand Down Expand Up @@ -76,20 +77,29 @@ def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
else:
return text

def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
def _merge_splits(
self,
splits: Iterable[str],
separator: str,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)

chunk_size = chunk_size or self._chunk_size
chunk_overlap = chunk_overlap or self._chunk_overlap

docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
if total > self._chunk_size:
if total + _len + (separator_len if len(current_doc) > 0 else 0) > chunk_size:
if total > chunk_size:
logger.warning(
f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}"
f"Created a chunk of size {total}, " f"which is longer than the specified {chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
Expand All @@ -98,8 +108,8 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size and total > 0
while total > chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0) > chunk_size and total > 0
):
total -= self._length_function(current_doc[0]) + (separator_len if len(current_doc) > 1 else 0)
current_doc = current_doc[1:]
Expand Down Expand Up @@ -146,8 +156,14 @@ def _tokenizer_length(text: str) -> int:
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex

def _split_text(self, text: str, separators: List[str]) -> List[str]:
def _split_text(
self, text: str, separators: List[str], chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None
) -> List[str]:
"""Split incoming text and return chunks."""

chunk_size = chunk_size or self._chunk_size
chunk_overlap = chunk_overlap or self._chunk_overlap

final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
Expand All @@ -169,11 +185,11 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]:
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
if self._length_function(s) < chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, chunk_size, chunk_overlap)
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
Expand All @@ -186,5 +202,5 @@ def _split_text(self, text: str, separators: List[str]) -> List[str]:
final_chunks.extend(merged_text)
return final_chunks

def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
return self._split_text(text, self._separators, chunk_size, chunk_overlap)
1 change: 0 additions & 1 deletion tests/test_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from fastembed.embedding import DefaultEmbedding


@pytest.mark.parametrize(["chunk_size", "chunk_overlap"], [[500, 50], [1000, 100]])
def test_embedding(chunk_size: int, chunk_overlap: int):
is_ubuntu_ci = os.getenv("IS_UBUNTU_CI")
Expand Down

0 comments on commit 7378350

Please sign in to comment.