Skip to content

Commit

Permalink
[Test] Add end-to-end tests (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Jan 25, 2023
1 parent 54af148 commit f6b8608
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 1 deletion.
5 changes: 5 additions & 0 deletions ci/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@ echo "Running unit tests..."
# -r: redirect the output of local rank 1 to None so that
# only local rank 0's output is printed to the console.
torchrun --nproc_per_node 2 -r 1:1 -m pytest tests

echo "Downloading test data..."
bash benchmark/download_benchmark_dataset.sh
echo "Running end-to-end tests..."
python3 -m pytest -s tests/end2end.py
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def pytest_collection_modifyitems(items):
items.sort(key=lambda item: item.name)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(scope="session")
def init_dist(request):
"""Initialize the distributed group once in the entire test session."""
torch.manual_seed(9999)
Expand Down
67 changes: 67 additions & 0 deletions tests/end2end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
End-to-end tests
"""

import os
import pytest

from slapo.model_dialect import get_dialect_cls


def parse_log(impl, log_file):
with open(log_file, "r", encoding="utf-8") as f:
text = f.read()

if impl in {"slapo-megatron", "megatron"}:
parser = get_dialect_cls("log_parser", "megatron")
_, samples_per_sec, _, error_code = parser.parse_log(log_file)
elif impl in {"slapo-deepspeed", "deepspeed"}:
parser = get_dialect_cls("log_parser", "deepspeed")
_, samples_per_sec, _, error_code = parser.parse_log(log_file)
else:
raise RuntimeError("Please provide correct `impl`")
return (error_code, samples_per_sec, text)


# fmt: off
@pytest.mark.parametrize("model,impl,n_gpu,batch_size,seq_len,ckpt_ratio", [
("wideresnet-250M", "slapo-megatron", "1", "48", "512", "0.34"),
("wideresnet-250M", "slapo-deepspeed", "4", "256", "512", "0.67"),
("bert-large-uncased", "slapo-megatron", "2", "10", "512", "0"),
("bert-large-uncased", "slapo-deepspeed", "2", "28", "512", "0"),
("EleutherAI/gpt-neo-125M", "slapo-megatron", "2", "1", "512", "1.0"),
("t5-base", "slapo-megatron", "4", "8", "1024", "0.67"),
])
# fmt: on
def test_end2end(model, impl, n_gpu, batch_size, seq_len, ckpt_ratio):
print(f"Running {impl} on {model} with {n_gpu} GPU", flush=True)
if impl == "deepspeed":
ckpt_ratio = "1.0"
elif impl == "megatron":
ckpt_ratio = "full"
cmd = f"python3 benchmark/bench_single_node.py {impl}"
cmd += f" --model {model} --gpus {n_gpu} --seq-len {seq_len}"
if "t5" in model:
cmd += " --seq-len-dec 512"
cmd += f" --batch-size {batch_size}"
cmd += f" --gradient-checkpoint {ckpt_ratio}"
cmd += " > run_script.log 2>&1"
print(cmd, flush=True)
os.system(cmd)
print("\n", flush=True)
error_code, samples_per_sec, text = parse_log(impl, "log.txt")
print(f"\tThroughput: {samples_per_sec:.2f}")
if error_code == 1:
print("oom")
print(text)
elif error_code == 2:
print("fail")
print(text)
assert error_code == 0


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f6b8608

Please sign in to comment.