Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

np.memmap memory leak and correct val sampling #16

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def parse_args(base_parser, args, namespace):
parser.add_argument('--results_base_folder', default="./exps", type=str)
parser.add_argument('--grad_clip', default=0.0, type=float) # default value is 1.0 in NanoGPT
# Dataset params
parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'arxiv', "arxiv2000", "arxiv+wiki", 'openwebtext2'])
parser.add_argument('--dataset', default='slimpajama', choices=['slimpajama', 'wikitext', "shakespeare-char", 'openwebtext2'])
parser.add_argument('--vocab_size', default=50304, type=int)
parser.add_argument('--data_in_ram', action='store_true') # force the data to RAM, mostly useless except for openwebtext2
# Model params
Expand Down
115 changes: 0 additions & 115 deletions src/data/arxiv.py

This file was deleted.

5 changes: 1 addition & 4 deletions src/data/openwebtext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,5 @@ def process(example):
idx += len(arr_batch)
arr.flush()

train_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(OWT2_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')

return {'train': train_data, 'val': val_data}
return {'train': os.path.join(OWT2_DATA_PATH, 'train.bin'), 'val': os.path.join(OWT2_DATA_PATH, 'val.bin')}

3 changes: 1 addition & 2 deletions src/data/shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ def get_shakespeare_data():
mem[:] = x_test

# at this point we know that the binfile was properly created so we load it
return {"train": np.memmap(train_path, dtype=np.uint16, mode="r"),
"val": np.memmap(test_path, dtype=np.uint16, mode="r")}
return {"train": train_path, "val": test_path}
71 changes: 4 additions & 67 deletions src/data/slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


SPJ_DATA_PATH = os.path.join(os.path.dirname(__file__), "datasets/slimpajama6B/")
SPJ_CHUNK_1_DATA_PATH = os.path.join(SPJ_DATA_PATH, "chunk1")


tknzr = tiktoken.get_encoding("gpt2")
Expand Down Expand Up @@ -60,69 +59,7 @@ def process(example):
idx += len(arr_batch)
arr.flush()

train_data = np.memmap(
os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r"
)
val_data = np.memmap(
os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r"
)

return {"train": train_data, "val": val_data}


def get_slimpajama_chunk1(num_proc=40):
if not os.path.exists(os.path.join(SPJ_CHUNK_1_DATA_PATH, "train.bin")):
os.makedirs(SPJ_DATA_PATH, exist_ok=True)
dataset = load_dataset("cerebras/SlimPajama-627B", split="train/chunk1")

split_dataset = dataset["train"].train_test_split(
test_size=0.0005, seed=2357, shuffle=True
)
split_dataset["val"] = split_dataset.pop("test")

def process(example):
ids = tknzr.encode_ordinary(
example["text"]
) # encode_ordinary ignores any special tokens
ids.append(
tknzr.eot_token
) # add the end of text token, e.g. 50256 for gpt2 bpe
out = {"ids": ids, "len": len(ids)}
return out

# tokenize the dataset
tokenized = split_dataset.map(
process,
remove_columns=["text"],
desc="tokenizing the splits",
num_proc=num_proc,
)

# concatenate all the ids in each dataset into one large file we can use for training
for split, dset in tokenized.items():
arr_len = np.sum(dset["len"])
filename = os.path.join(SPJ_DATA_PATH, f"{split}.bin")
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
total_batches = min(1024, len(dset))

idx = 0
for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
# Batch together samples for faster write
batch = dset.shard(
num_shards=total_batches, index=batch_idx, contiguous=True
).with_format("numpy")
arr_batch = np.concatenate(batch["ids"])
# Write into mmap
arr[idx : idx + len(arr_batch)] = arr_batch
idx += len(arr_batch)
arr.flush()

train_data = np.memmap(
os.path.join(SPJ_DATA_PATH, "train.bin"), dtype=np.uint16, mode="r"
)
val_data = np.memmap(
os.path.join(SPJ_DATA_PATH, "val.bin"), dtype=np.uint16, mode="r"
)

return {"train": train_data, "val": val_data}
return {
"train": os.path.join(SPJ_DATA_PATH, "train.bin"),
"val": os.path.join(SPJ_DATA_PATH, "val.bin"),
}
41 changes: 16 additions & 25 deletions src/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,47 @@

from .shakespeare import get_shakespeare_data
from .wikitext import get_wikitext_data
from .arxiv import get_arxiv_2000, get_arxiv_full
from .openwebtext2 import get_openwebtext2_data
from .slimpajama import get_slimpajama_data


def get_dataset(args) -> Dict[str, np.ndarray]:
""" Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
contained in its own python file. The expected format at the moment is a dictionary of np.memmap
containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data. """
if args.dataset == 'wikitext':
"""Fetch the right dataset given by the args.dataset parameter. The logic for each dataset is
contained in its own python file. The expected format at the moment is a dictionary of np.memmap
containing two keys: 'train' and 'val', corresponding to the tokenized training and validation data.
This just returns a dictionary of the paths to the np.memmap objects, and does not load the data into memory.
"""
if args.dataset == "wikitext":
return get_wikitext_data()
if args.dataset == "shakespeare-char":
return get_shakespeare_data()
if args.dataset == "arxiv2000":
return get_arxiv_2000()
if args.dataset == "arxiv":
return get_arxiv_full()
if args.dataset == "arxiv+wiki":
arxiv_data = get_arxiv_full()
wiki_data = get_wikitext_data()
train_data = np.concatenate((arxiv_data['train'], wiki_data['train']))
val_data = np.concatenate((arxiv_data['val'], wiki_data['val']))
return {'train': train_data, 'val': val_data}
if args.dataset == 'openwebtext2':
if args.dataset == "openwebtext2":
return get_openwebtext2_data()
if args.dataset == "slimpajama":
return get_slimpajama_data()
else:
raise NotImplementedError(f"Unknow dataset key '{args.dataset}'")


class Dataset(torch.utils.data.Dataset):
def __init__(self, data, sequence_length):
def __init__(self, data_path, sequence_length):
super().__init__()
self.data = data
self.data_path = data_path
self.sequence_length = sequence_length

def __len__(self):
total_length = len(self.data)
data = np.memmap(self.data_path, dtype=np.uint16, mode="r")
total_length = len(data)
# chunk the data into sequences of length `sequence_length`
# NOTE: we discard the last remainding sequence if it's not of length `sequence_length`
# NOTE: we discard the last remaining sequence if it's not of length `sequence_length`
return (total_length - 1) // self.sequence_length

def __getitem__(self, idx):
data = np.memmap(self.data_path, dtype=np.uint16, mode="r")
seq_length = self.sequence_length
idx = idx * seq_length
x = torch.from_numpy((self.data[idx : idx + seq_length]).astype(np.int64))

y = torch.from_numpy(
(self.data[idx + 1 : idx + 1 + seq_length]).astype(np.int64)
)
x = torch.from_numpy((data[idx : idx + seq_length]).astype(np.int64))
y = torch.from_numpy((data[idx + 1 : idx + 1 + seq_length]).astype(np.int64))
return x, y


Expand Down
5 changes: 1 addition & 4 deletions src/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,4 @@ def get_wikitext_data():
eval_tokenized.tofile(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'))
print("completed the tokenization process!")

train_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(WIKITEXT_DATA_PATH, 'val.bin'), dtype=np.uint16, mode='r')

return {'train': train_data, 'val': val_data}
return {'train': os.path.join(WIKITEXT_DATA_PATH, 'train.bin'), 'val': os.path.join(WIKITEXT_DATA_PATH, 'val.bin')}
4 changes: 2 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def main(args):
if args.data_in_ram:
data = {'train': np.array(data['train']), 'val': np.array(data['val'])}

print(f"Num training tokens: {len(data['train'])}")
print(f"Num validation tokens: {len(data['val'])}")
print(f"Num training tokens: {len(np.memmap(data['train'], dtype=np.uint16, mode='r'))}")
print(f"Num validation tokens: {len(np.memmap(data['val'], dtype=np.uint16, mode='r'))}")

model = get_model(args).to(args.device) # todo: take care of initializing the model if args.use_pretrained != 'none'

Expand Down
Loading