Skip to content

Commit

Permalink
chore: move split_text to abstract class
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Jan 3, 2024
1 parent 7378350 commit 0641f6f
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class Embedding(ABC):
_type_: _description_
"""

model: EmbeddingModel

@abstractmethod
def embed(self, texts: Iterable[str], batch_size: int = 256, parallel: int = None) -> List[np.ndarray]:
raise NotImplementedError
Expand Down Expand Up @@ -447,6 +449,18 @@ def query_embed(self, query: str) -> Iterable[np.ndarray]:
query_embedding = self.embed([query])
return query_embedding

def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
"""Splits text into chunks based on the tokenizer encoding size.
Args:
text (str): The text to split.
chunk_size (Optional[int], optional): Maximum size of chunks based on the tokenizer encoding.
chunk_overlap (Optional[int], optional): Allowed overlap in characters between chunks.
Returns:
List[str]: The list of strings.
"""
return self.model.split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

class FlagEmbedding(Embedding):
"""
Expand Down Expand Up @@ -543,19 +557,6 @@ def embed(
embeddings, _ = batch
yield from normalize(embeddings[:, 0]).astype(np.float32)

def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
"""Splits text into chunks based on the tokenizer encoding size.
Args:
text (str): The text to split.
chunk_size (Optional[int], optional): Maximum size of chunks based on the tokenizer encoding.
chunk_overlap (Optional[int], optional): Allowed overlap in characters between chunks.
Returns:
List[str]: The list of strings.
"""
return self.model.split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Expand Down Expand Up @@ -683,19 +684,6 @@ def embed(
embeddings, attn_mask = batch
yield from normalize(self.mean_pooling(embeddings, attn_mask)).astype(np.float32)

def split_text(self, text: str, chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None) -> List[str]:
"""Splits text into chunks based on the tokenizer encoding size.
Args:
text (str): The text to split.
chunk_size (Optional[int], optional): Maximum size of chunks based on the tokenizer encoding.
chunk_overlap (Optional[int], optional): Allowed overlap in characters between chunks.
Returns:
List[str]: The list of strings.
"""
return self.model.split_text(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)

@classmethod
def list_supported_models(cls) -> List[Dict[str, Union[str, Union[int, float]]]]:
"""
Expand Down

0 comments on commit 0641f6f

Please sign in to comment.