Skip to content

Commit

Permalink
[Test] Add tracer unit tests (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Jan 25, 2023
1 parent 681f2b8 commit 36c0f46
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 52 deletions.
8 changes: 7 additions & 1 deletion slapo/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,15 @@ def __getitem__(self, full_path):

curr_sch = self
for token in self.tokenize_module_path(full_path):
sub_tokens = token.split(".")
if len(sub_tokens) == 2 and sub_tokens[0] in curr_sch.child:
# If this token is in the format of "layer.0" and "layer" is a child of curr_sch,
# then "layer" is nn.Sequential. In this case, we have to first get the nn.Sequential module first.
curr_sch = curr_sch.child[sub_tokens[0]]
token = sub_tokens[1]
if token not in curr_sch.child:
raise KeyError(
f"The schedule of '{full_path}' is not a child of {curr_sch.name}"
f"The schedule of '{full_path}' ({token}) is not a child of {curr_sch.name}"
)
curr_sch = curr_sch.child[token]
if not curr_sch:
Expand Down
94 changes: 43 additions & 51 deletions tests/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,94 +20,86 @@ def generate_concrete_args(model, input_names):
return concrete_args


@pytest.mark.skip(reason="Need update")
def test_hf_bert():
"""Test tracing HF bert model."""
from transformers import AutoConfig, BertModel

config = AutoConfig.from_pretrained("bert-base-uncased")
model = BertModel(config)
input_names = list(model.dummy_inputs.keys()) # only has "input_ids"
input_names += ["attention_mask", "token_type_ids"] # "position_ids"
concrete_args = generate_concrete_args(model, input_names)
sch = slapo.create_schedule(
model,
world_size=1,
rank=0,
tracer="huggingface",
concrete_args=concrete_args,
)
sch = slapo.create_schedule(model)

# The original module list.
assert isinstance(sch.get_module("encoder"), torch.nn.Module)
assert isinstance(sch["encoder"].mod, torch.nn.Module)
assert isinstance(sch["encoder.layer.0"].mod, torch.nn.Module)
assert isinstance(sch["encoder.layer.0.attention"].mod, torch.nn.Module)

# Traced layers.
assert isinstance(sch.get_module("encoder.layer.0"), fx.GraphModule)
assert isinstance(sch.get_module("encoder.layer.0.attention"), fx.GraphModule)
# self will be renamed to self_m because it is a Python preserved keyword.
assert isinstance(
sch.get_module("encoder.layer.0.attention.self_m"), fx.GraphModule
)
assert isinstance(sch.get_module("encoder.layer.0.intermediate"), fx.GraphModule)
assert isinstance(sch.get_module("encoder.layer.0.output"), fx.GraphModule)
sub_sch = sch["encoder.layer.0.attention"]
input_names = ["hidden_states", "attention_mask"]
sig = inspect.signature(sub_sch.mod.forward)
concrete_args = {
p.name: p.default for p in sig.parameters.values() if p.name not in input_names
}
sub_sch.trace(tracer="pytorch", concrete_args=concrete_args)

# Only the traced submodules are graph modules.
assert isinstance(sch["encoder.layer.0.attention"].mod, fx.GraphModule)
assert isinstance(sch["encoder.layer.0.attention.self"].mod, fx.GraphModule)
assert isinstance(sch["encoder.layer.0.attention.output"].mod, fx.GraphModule)

# Other modules remain the same.
assert isinstance(sch["encoder.layer.0.intermediate"].mod, torch.nn.Module)
assert isinstance(sch["encoder.layer.0.output"].mod, torch.nn.Module)


@pytest.mark.skip(reason="Need update")
def test_hf_gpt_neo():
"""Test tracing HF gpt-neo model."""
from transformers import AutoConfig, GPTNeoModel

config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-125M")
model = GPTNeoModel(config)

input_names = list(model.dummy_inputs.keys())
input_names += ["attention_mask", "position_ids"]
concrete_args = generate_concrete_args(model, input_names)
sch = slapo.create_schedule(
model,
world_size=1,
rank=0,
tracer="huggingface",
concrete_args=concrete_args,
)
sch = slapo.create_schedule(model)

# The original module list.
assert isinstance(sch.get_module("h"), torch.nn.Module)
assert isinstance(sch["h.0"].mod, torch.nn.Module)

# Traced layers.
assert isinstance(sch.get_module("h.0"), fx.GraphModule)
assert isinstance(sch.get_module("h.0.attn"), fx.GraphModule)
assert isinstance(sch.get_module("h.0.attn.attention"), fx.GraphModule)
assert isinstance(sch.get_module("h.0.mlp"), fx.GraphModule)
sub_sch = sch["h.0"]
input_names = ["hidden_states", "attention_mask"]
sig = inspect.signature(sub_sch.mod.forward)
concrete_args = {
p.name: p.default for p in sig.parameters.values() if p.name not in input_names
}
sub_sch.trace(tracer="pytorch", concrete_args=concrete_args)
assert isinstance(sch["h.0"].mod, fx.GraphModule)
assert isinstance(sch["h.0.attn"].mod, fx.GraphModule)
assert isinstance(sch["h.0.mlp"].mod, fx.GraphModule)

# Attention submodule cannot be traced.
assert isinstance(sch["h.0.attn.attention"].mod, torch.nn.Module)


@pytest.mark.skip(reason="Need update")
def test_torchvision_wideresnet():
"""Test tracing torchvision wideresnet model."""
from torchvision.models.resnet import Bottleneck, ResNet

model = ResNet(Bottleneck, [6, 8, 4, 6], width_per_group=128)
concrete_args = generate_concrete_args(model, ["x"])
sch = slapo.create_schedule(
model,
world_size=1,
rank=0,
tracer="pytorch",
concrete_args=concrete_args,
)

assert isinstance(sch.get_module("layer1"), fx.GraphModule)
sch = slapo.create_schedule(model)
sch.trace(tracer="pytorch", concrete_args=concrete_args)

assert isinstance(sch["layer1"].mod, fx.GraphModule)
for idx in range(6):
assert isinstance(sch.get_module(f"layer1.{idx}"), fx.GraphModule)

# Should not trace leaf.
assert not isinstance(sch.get_module("layer1.0.conv1"), fx.GraphModule)
# Should not trace leaf.
assert not isinstance(sch[f"layer1.{idx}.conv1"].mod, fx.GraphModule)

# Should have "layer1.0.downsample" instead of "layer1.0.downsample.0"
assert isinstance(sch.get_module("layer1.0.downsample"), fx.GraphModule)
assert isinstance(sch["layer1.0.downsample"].mod, fx.GraphModule)

# Should not trace leaf.
assert not isinstance(sch.get_module("layer1.0.downsample.0"), fx.GraphModule)
assert not isinstance(sch["layer1.0.downsample.0"].mod, fx.GraphModule)


if __name__ == "__main__":
Expand Down

0 comments on commit 36c0f46

Please sign in to comment.