Skip to content

Commit

Permalink
add test for line by line distributed dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
philippguevorguian committed Mar 20, 2024
1 parent 7c9cc13 commit 82ee745
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
114 changes: 114 additions & 0 deletions tests/dataset/distributed_dataset_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import multiprocessing
import glob
import shutil
from datetime import timedelta
import torch
import signal
import random
import numpy
from accelerate import Accelerator, logging, InitProcessGroupKwargs
import hashlib
from datasets.iterable_dataset import IterableDataset
from chemlactica.jsonl_dataset import samples_generator
from chemlactica.utils.utils import (
signal_handler,
)

from accelerate.state import PartialState

distributed_state = PartialState()
torch.manual_seed(42)
random.seed(42)
numpy.random.seed(42)
logger = logging.get_logger("transformers")

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)


def run():
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))
accelerator = Accelerator(
kwargs_handlers=[kwargs], log_with="all", project_dir=None
)

directory_name = ".tmp"
if distributed_state.is_main_process:
if os.path.exists(directory_name):
print(f"test directory '{directory_name}' already exists. Clearing it now")
shutil.rmtree(directory_name)
os.makedirs(directory_name)
print(f"test directory '{directory_name}' created successfully.")
num_files = 5
num_lines = 100

for i in range(num_files):
file_name = os.path.join(directory_name, f"test_file_{i}.jsonl")
with open(file_name, "w") as file:
for j in range(num_lines):
sha3_hash = hashlib.sha3_256(
str.encode(f"test_file_{i}.jsonl - Line {j}")
).hexdigest()
file.write(f"{sha3_hash}\n")
accelerator.wait_for_everyone()
with multiprocessing.Manager() as manager:
shared_jsonl_files = manager.dict()

training_data_files = glob.glob(directory_name + "/*.jsonl")
training_data_files = [os.path.abspath(path) for path in training_data_files]

print(training_data_files)

accelerator.wait_for_everyone()
dataset = IterableDataset.from_generator(
samples_generator,
gen_kwargs={
"files": training_data_files,
"shared_jsonl_files": shared_jsonl_files,
},
)

file_name_mapping = {}
for process_index in range(distributed_state.num_processes):
file_name_mapping[process_index] = f"dataload_proc{process_index}.jsonl"

for example in dataset:
file_to_write = file_name_mapping[distributed_state.process_index]
with open(file_to_write, "a") as f:
f.write(example["text"] + "\n")
accelerator.wait_for_everyone()
file_line_sets = []

# for process_index in range(distributed_state.num_processes):
if distributed_state.is_main_process:
for process_index, file_to_check in file_name_mapping.items():
file_lines = load_file_contents(file_to_check)
file_line_set = set(file_lines)
file_line_sets.append(file_line_set)
print(f"file line set length {len(file_line_set)}")
print(f"file line length {len(file_lines)}")
assert len(file_lines) == len(file_line_set)

num_sets = len(file_line_sets)
for i in range(num_sets):
for j in range(i + 1, num_sets):
set1 = file_line_sets[i]
set2 = file_line_sets[j]
assert set1.isdisjoint(set2)

accelerator.wait_for_everyone()
if distributed_state.is_main_process:
for process_index in file_name_mapping:
file_to_check = file_name_mapping[process_index]
os.remove(file_to_check)


def load_file_contents(file):
with open(file, "r") as f:
lines = [line.strip() for line in f.readlines()]
return lines


if __name__ == "__main__":
run()
20 changes: 20 additions & 0 deletions tests/dataset/test_line_by_line_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
import subprocess
from test_utils import create_train_command


class TestLineByLineDataloader(unittest.TestCase):
def test_line_by_line_dataloader(self):
command = create_train_command(
module="accelerate.commands.launch",
module_args={
"config_file": "chemlactica/config/test_configs/fsdp_config.yaml"
},
script="tests/dataset/distributed_dataset_iter.py",
script_args={},
)

print(f"Running command: {command}")
out = subprocess.run(command, shell=True, capture_output=False)
if out.returncode != 0:
raise Exception(out.stderr.decode())

0 comments on commit 82ee745

Please sign in to comment.