diff --git a/merlin/models/tokenizers/__init__.py b/merlin/models/tokenizers/__init__.py new file mode 100644 index 0000000000..d955878e32 --- /dev/null +++ b/merlin/models/tokenizers/__init__.py @@ -0,0 +1,2 @@ +from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer # noqa: F401 +from merlin.models.tokenizers.tokenizer import Tokenizer # noqa: F401 diff --git a/merlin/models/tokenizers/sentencepiece.py b/merlin/models/tokenizers/sentencepiece.py new file mode 100644 index 0000000000..436d1484b8 --- /dev/null +++ b/merlin/models/tokenizers/sentencepiece.py @@ -0,0 +1,56 @@ +from typing import List + +from merlin.models.tokenizers.tokenizer import Tokenizer + + +class SentencePieceTokenizer(Tokenizer): + """Tokenizer using SentencePiece [1]. + + References + ---------- + [1] https://github.com/google/sentencepiece + """ + + def __init__(self, *, processor: "SentencePieceTrainer") -> None: # noqa: F821 + require_sentencepiece() + + self.processor = processor + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + self.pad_id = self.processor.pad_id() + + def encode( + self, + string: str, + bos: bool = False, + eos: bool = False, + max_length: int = -1, + pad: bool = False, + ) -> List[int]: + tokens = self.processor.encode(string) + if bos: + tokens = [self.bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + if pad and len(tokens) < max_length: + tokens += [self.pad_id] * (max_length - len(tokens)) + + return tokens + + def decode(self, tokens: List[int]) -> str: + return self.processor.decode(tokens) + + @property + def vocab_size(self) -> int: + return self.processor.vocab_size() + + +def require_sentencepiece() -> None: + try: + from sentencepiece import SentencePieceProcessor, SentencePieceTrainer # noqa: F401 + except ImportError: + raise ImportError( + "This requires `sentencepiece`. Install it with `pip install sentencepiece`." + ) diff --git a/merlin/models/tokenizers/tokenizer.py b/merlin/models/tokenizers/tokenizer.py new file mode 100644 index 0000000000..ae493b4341 --- /dev/null +++ b/merlin/models/tokenizers/tokenizer.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import List + + +class Tokenizer(ABC): + """ + Base class for all tokenizers. + """ + + def __call__(self, string: str): + return self.encode(string) + + @abstractmethod + def decode(self, tokens: List[int]): + ... + + @abstractmethod + def encode(self, string: str): + ... diff --git a/merlin/models/torch/blocks/tokenizer.py b/merlin/models/torch/blocks/tokenizer.py new file mode 100644 index 0000000000..42600280ad --- /dev/null +++ b/merlin/models/torch/blocks/tokenizer.py @@ -0,0 +1,65 @@ +import os +from pathlib import Path +from typing import Optional + +import torch +from sentencepiece import SentencePieceProcessor, SentencePieceTrainer + + +class SentencePieceTokenizer: + """Tokenizer for LLaMA. + + Example usage + ------------- + >> tokenizer_path = Path("llama/tokenizer.model") + >> tokenizer = SentencePieceTokenizer(tokenizer_path) + >> tokenizer.encode("Hello, my name is", bos=True, eos=False) + tensor([ 1, 15043, 29892, 590, 1024, 338], dtype=torch.int32) + """ + + def __init__(self, model_path: Path) -> None: + try: + import sentencepiece # noqa: F401 + except ImportError: + raise ImportError( + "`sentencepiece` is required to use this feature. " + "Install it with `pip install sentencepiece`." + ) + + self.processor = SentencePieceProcessor(model_file=str(model_path)) + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + self.pad_id = self.processor.pad_id() + + @property + def vocab_size(self) -> int: + return self.processor.vocab_size() + + def encode( + self, + string: str, + bos: bool = True, + eos: bool = False, + max_length: int = -1, + pad: bool = False, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + tokens = self.processor.encode(string) + if bos: + tokens = [self.bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + if pad and len(tokens) < max_length: + tokens += [self.pad_id] * (max_length - len(tokens)) + + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tokens: torch.Tensor) -> str: + return self.processor.decode(tokens.tolist()) + + @staticmethod + def train(input: str, destination: str, vocab_size=32000) -> None: + model_prefix = os.path.join(destination, "tokenizer") + SentencePieceTrainer.Train(input=input, model_prefix=model_prefix, vocab_size=vocab_size) diff --git a/merlin/models/torch/tokenizers/__init__.py b/merlin/models/torch/tokenizers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/merlin/models/torch/tokenizers/llama.py b/merlin/models/torch/tokenizers/llama.py new file mode 100644 index 0000000000..f7d854d650 --- /dev/null +++ b/merlin/models/torch/tokenizers/llama.py @@ -0,0 +1,40 @@ +from pathlib import Path +from typing import Optional, Union + +import torch + +from merlin.models.tokenizers.sentencepiece import SentencePieceTokenizer, require_sentencepiece + + +class LlamaTokenizer(SentencePieceTokenizer): + def __init__(self, path: Union[str, Path]) -> None: + require_sentencepiece() + + from sentencepiece import SentencePieceProcessor + + if isinstance(path, Path): + path = str(path) + processor = SentencePieceProcessor(model_file=str(path)) + + super().__init__(processor=processor) + + def endode( + self, + string: str, + bos: bool = True, + eos: bool = False, + max_length: int = -1, + pad: bool = False, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + tokens = super().encode( + string=string, + bos=bos, + eos=eos, + max_length=max_length, + pad=pad, + ) + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tokens: torch.Tensor) -> str: + return self.processor.decode(tokens.tolist())