-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement `Tokenizer` and `MMapDataset` * Support concatenation of `MemMapDataset`s * Add script for generating memmap file * minor improvements * smaller test fixtures * add test * Add validation final array * remove duplicate job * clean up progress * add "files" word * Add `Tokenizer.vocab_size()` method * add "-j/--workers" argument * clean up
- Loading branch information
Showing
21 changed files
with
581 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
test_fixtures/*.json.gz binary |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from dolma.data.tokenizer import Tokenizer | ||
|
||
TEST_MODEL = "gpt2" | ||
|
||
LOREM_IPSUM_1 = """ | ||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, | ||
sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. | ||
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip | ||
ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit | ||
esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat | ||
non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. | ||
""" | ||
|
||
LOREM_IPSUM_2 = """ | ||
Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque | ||
laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi | ||
architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia | ||
voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores | ||
eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem | ||
ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius | ||
modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem. | ||
Ut enim ad minima veniam, quis nostrum exercitationem ullam corporis suscipit | ||
laboriosam, nisi ut aliquid ex ea commodi consequatur? Quis autem vel eum iure | ||
reprehenderit qui in ea voluptate velit esse quam nihil molestiae consequatur, | ||
vel illum qui dolorem eum fugiat quo voluptas nulla pariatur? | ||
""" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def pretrained_tokenizer_name() -> str: | ||
return TEST_MODEL | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def tokenizer() -> Tokenizer: | ||
return Tokenizer.from_pretrained(TEST_MODEL) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def eos_token_id(tokenizer: Tokenizer) -> int: | ||
return tokenizer.eos_token_id | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def lorem_ipsum() -> str: | ||
return LOREM_IPSUM_1.replace("\n", " ").strip() | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def lorem_ipsum_docs() -> List[str]: | ||
return [text.replace("\n", " ").strip() for text in (LOREM_IPSUM_1, LOREM_IPSUM_2)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from os import PathLike | ||
from typing import Union | ||
|
||
__all__ = ["PathOrStr"] | ||
|
||
|
||
PathOrStr = Union[str, PathLike] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .memmap_dataset import MemMapDataset | ||
from .tokenizer import Tokenizer, TruncationDirection | ||
|
||
__all__ = ["MemMapDataset", "Tokenizer", "TruncationDirection"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import List, Optional, Tuple, cast | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
from ..aliases import PathOrStr | ||
|
||
__all__ = ["MemMapDataset"] | ||
|
||
|
||
class MemMapDataset(Dataset[torch.LongTensor]): | ||
""" | ||
A PyTorch :class:`~torch.utils.data.Dataset` backed by one or more numpy memory-mapped arrays | ||
of token IDs. Token IDs are chunked together into contiguous blocks of ``chunk_size`` | ||
to create instances. | ||
If the length of a memory-mapped array is not a multiple of ``chunk_size`` the | ||
remainder of the tokens will be ignored. | ||
No special tokens are added to the input IDs so it's assumed that if you want | ||
EOS tokens between documents, for example, those will already by in the memory-mapped array. | ||
:param paths: Paths to memory-mapped token arrays. | ||
:param chunk_size: The number of tokens to chunk together into a single instance. | ||
Generally this should correspond to your model's maximum input length. | ||
:param memmap_dtype: The numpy datatype of the memory-mapped array. | ||
""" | ||
|
||
def __init__(self, *paths: PathOrStr, chunk_size: int = 1024, memmap_dtype=np.uint16): | ||
if not paths: | ||
raise ValueError("At least one path is required") | ||
self._memmap_paths = paths | ||
self._chunk_size = chunk_size | ||
self._mmaps: Optional[List[np.memmap]] = None | ||
self._mmap_offsets: Optional[List[Tuple[int, int]]] = None | ||
self._num_instances: Optional[int] = None | ||
self.dtype = memmap_dtype | ||
|
||
@property | ||
def memmaps(self) -> List[np.memmap]: | ||
if self._mmaps is None: | ||
self._mmaps = [] | ||
for path in self._memmap_paths: | ||
mmap = np.memmap(path, mode="r", dtype=self.dtype) | ||
self._mmaps.append(mmap) | ||
return self._mmaps | ||
|
||
@property | ||
def offsets(self) -> List[Tuple[int, int]]: | ||
if self._mmap_offsets is None: | ||
start_offset = 0 | ||
self._mmap_offsets = [] | ||
for mmap in self.memmaps: | ||
length = mmap.shape[0] // self._chunk_size | ||
end_offset = start_offset + length | ||
self._mmap_offsets.append((start_offset, end_offset)) | ||
start_offset += length | ||
return self._mmap_offsets | ||
|
||
def __len__(self) -> int: | ||
if self._num_instances is None: | ||
self._num_instances = self.offsets[-1][1] | ||
return self._num_instances | ||
|
||
def __getitem__(self, index: int) -> torch.LongTensor: | ||
pos_index = index if index >= 0 else len(self) + index | ||
|
||
# The index of the memmap array within 'self.memmaps' | ||
memmap_index: Optional[int] = None | ||
# The 'index' relative to the corresponding memmap array. | ||
memmap_local_index: Optional[int] = None | ||
for i, (offset_start, offset_end) in enumerate(self.offsets): | ||
if offset_start <= pos_index < offset_end: | ||
memmap_index = i | ||
memmap_local_index = pos_index - offset_start | ||
|
||
if memmap_index is None or memmap_local_index is None: | ||
raise IndexError(f"{index} is out of bounds for dataset of size {len(self)}") | ||
|
||
memmap = self.memmaps[memmap_index] | ||
index_start = memmap_local_index * self._chunk_size | ||
index_stop = (memmap_local_index + 1) * self._chunk_size | ||
data = memmap[index_start:index_stop].astype(np.int_) | ||
return cast(torch.LongTensor, torch.tensor(data, dtype=torch.long)) | ||
|
||
def __add__(self, other: "MemMapDataset") -> "MemMapDataset": | ||
""" | ||
Concatenate one :class:`MemMapDataset` with another. | ||
""" | ||
if not isinstance(other, MemMapDataset): | ||
raise NotImplementedError(f"Expected another MemMapDataset but got {type(other)}") | ||
return MemMapDataset( | ||
*(self._memmap_paths + other._memmap_paths), chunk_size=self._chunk_size, memmap_dtype=self.dtype | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from contextlib import contextmanager | ||
from typing import Generator, List, Optional, Union | ||
|
||
from tokenizers import Tokenizer as BaseTokenizer | ||
|
||
from ..util import StrEnum | ||
|
||
__all__ = ["Tokenizer", "TruncationDirection"] | ||
|
||
|
||
class TruncationDirection(StrEnum): | ||
right = "right" | ||
left = "left" | ||
|
||
|
||
class Tokenizer: | ||
""" | ||
A :class:`Tokenizer` is a light-weight wrapper around :class:`tokenizers.Tokenizer`. | ||
:param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. | ||
:param eos_token_id: The EOS token ID. If not set we default to using the last token | ||
in the vocabulary, which is usually correct for GPT tokenizers. | ||
:param truncate_to: Truncate when tokenizer to this number of token IDs. | ||
:param truncate_direction: The direction to truncate in. "right" means truncate the tokens | ||
on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, | ||
this setting has no effect. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
base_tokenizer: BaseTokenizer, | ||
eos_token_id: Optional[int] = None, | ||
truncate_to: Optional[int] = None, | ||
truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, | ||
): | ||
self.base_tokenizer = base_tokenizer | ||
self.eos_token_id = eos_token_id if eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 | ||
self.truncate_to = truncate_to | ||
self.truncate_direction = TruncationDirection(truncate_direction) | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
return self.base_tokenizer.get_vocab_size() | ||
|
||
@classmethod | ||
def from_pretrained(cls, identifier: str, **kwargs) -> "Tokenizer": | ||
""" | ||
Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. | ||
:param identifier: The identifier of a model on the Hub that contains a | ||
``tokenizer.json`` file. | ||
""" | ||
base_tokenizer = BaseTokenizer.from_pretrained(identifier) | ||
return cls(base_tokenizer, **kwargs) | ||
|
||
def add_special_tokens(self, input_ids: List[int]) -> List[int]: | ||
""" | ||
Add special tokens in-place (if not already present) to the given token IDs. | ||
""" | ||
if not input_ids or input_ids[-1] != self.eos_token_id: | ||
input_ids.append(self.eos_token_id) | ||
return input_ids | ||
|
||
def num_special_tokens_to_add(self, is_pair: bool = False) -> int: | ||
return 2 if is_pair else 1 | ||
|
||
@contextmanager | ||
def _truncation( | ||
self, truncate_to: Optional[int], direction: Union[str, TruncationDirection] = TruncationDirection.right | ||
) -> Generator["Tokenizer", None, None]: | ||
""" | ||
A context manager to temporarily enable/disable truncation. | ||
""" | ||
truncation = self.base_tokenizer.truncation | ||
|
||
try: | ||
if truncate_to is not None: | ||
self.base_tokenizer.enable_truncation(truncate_to, direction=str(direction)) | ||
else: | ||
self.base_tokenizer.no_truncation() | ||
yield self | ||
finally: | ||
if truncation is None: | ||
self.base_tokenizer.no_truncation() | ||
else: | ||
self.base_tokenizer.enable_truncation(**truncation) | ||
|
||
def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: | ||
""" | ||
Encode a string into token IDs. | ||
""" | ||
truncate_to = self.truncate_to | ||
if truncate_to is not None and add_special_tokens: | ||
truncate_to -= self.num_special_tokens_to_add(False) | ||
|
||
with self._truncation(truncate_to, direction=self.truncate_direction): | ||
input_ids = self.base_tokenizer.encode(input).ids | ||
|
||
if add_special_tokens: | ||
input_ids = self.add_special_tokens(input_ids) | ||
|
||
return input_ids | ||
|
||
def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: | ||
""" | ||
Encode a batch of strings into token IDs. | ||
""" | ||
truncate_to = self.truncate_to | ||
if truncate_to is not None and add_special_tokens: | ||
truncate_to -= self.num_special_tokens_to_add(False) | ||
|
||
with self._truncation(truncate_to, direction=self.truncate_direction): | ||
batch_encoding = self.base_tokenizer.encode_batch(inputs) | ||
|
||
all_input_ids = [] | ||
for encoding in batch_encoding: | ||
input_ids = encoding.ids | ||
if add_special_tokens: | ||
input_ids = self.add_special_tokens(input_ids) | ||
all_input_ids.append(input_ids) | ||
|
||
return all_input_ids | ||
|
||
def decode(self, token_ids: List[int]) -> str: | ||
return self.base_tokenizer.decode(token_ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from enum import Enum | ||
|
||
__all__ = ["StrEnum"] | ||
|
||
|
||
class StrEnum(str, Enum): | ||
""" | ||
This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. | ||
We include this here for compatibility with older version of Python. | ||
""" | ||
|
||
def __str__(self) -> str: | ||
return self.value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
numpy | ||
torch | ||
tokenizers | ||
click | ||
rich |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.