diff --git a/slapo/schedule.py b/slapo/schedule.py index be586b6e..67bb7108 100644 --- a/slapo/schedule.py +++ b/slapo/schedule.py @@ -193,9 +193,15 @@ def __getitem__(self, full_path): curr_sch = self for token in self.tokenize_module_path(full_path): + sub_tokens = token.split(".") + if len(sub_tokens) == 2 and sub_tokens[0] in curr_sch.child: + # If this token is in the format of "layer.0" and "layer" is a child of curr_sch, + # then "layer" is nn.Sequential. In this case, we have to first get the nn.Sequential module first. + curr_sch = curr_sch.child[sub_tokens[0]] + token = sub_tokens[1] if token not in curr_sch.child: raise KeyError( - f"The schedule of '{full_path}' is not a child of {curr_sch.name}" + f"The schedule of '{full_path}' ({token}) is not a child of {curr_sch.name}" ) curr_sch = curr_sch.child[token] if not curr_sch: diff --git a/tests/test_trace.py b/tests/test_trace.py index 99b5390c..11ba798c 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -20,39 +20,37 @@ def generate_concrete_args(model, input_names): return concrete_args -@pytest.mark.skip(reason="Need update") def test_hf_bert(): """Test tracing HF bert model.""" from transformers import AutoConfig, BertModel config = AutoConfig.from_pretrained("bert-base-uncased") model = BertModel(config) - input_names = list(model.dummy_inputs.keys()) # only has "input_ids" - input_names += ["attention_mask", "token_type_ids"] # "position_ids" - concrete_args = generate_concrete_args(model, input_names) - sch = slapo.create_schedule( - model, - world_size=1, - rank=0, - tracer="huggingface", - concrete_args=concrete_args, - ) + sch = slapo.create_schedule(model) # The original module list. - assert isinstance(sch.get_module("encoder"), torch.nn.Module) + assert isinstance(sch["encoder"].mod, torch.nn.Module) + assert isinstance(sch["encoder.layer.0"].mod, torch.nn.Module) + assert isinstance(sch["encoder.layer.0.attention"].mod, torch.nn.Module) - # Traced layers. - assert isinstance(sch.get_module("encoder.layer.0"), fx.GraphModule) - assert isinstance(sch.get_module("encoder.layer.0.attention"), fx.GraphModule) - # self will be renamed to self_m because it is a Python preserved keyword. - assert isinstance( - sch.get_module("encoder.layer.0.attention.self_m"), fx.GraphModule - ) - assert isinstance(sch.get_module("encoder.layer.0.intermediate"), fx.GraphModule) - assert isinstance(sch.get_module("encoder.layer.0.output"), fx.GraphModule) + sub_sch = sch["encoder.layer.0.attention"] + input_names = ["hidden_states", "attention_mask"] + sig = inspect.signature(sub_sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sub_sch.trace(tracer="pytorch", concrete_args=concrete_args) + + # Only the traced submodules are graph modules. + assert isinstance(sch["encoder.layer.0.attention"].mod, fx.GraphModule) + assert isinstance(sch["encoder.layer.0.attention.self"].mod, fx.GraphModule) + assert isinstance(sch["encoder.layer.0.attention.output"].mod, fx.GraphModule) + + # Other modules remain the same. + assert isinstance(sch["encoder.layer.0.intermediate"].mod, torch.nn.Module) + assert isinstance(sch["encoder.layer.0.output"].mod, torch.nn.Module) -@pytest.mark.skip(reason="Need update") def test_hf_gpt_neo(): """Test tracing HF gpt-neo model.""" from transformers import AutoConfig, GPTNeoModel @@ -60,54 +58,48 @@ def test_hf_gpt_neo(): config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-125M") model = GPTNeoModel(config) - input_names = list(model.dummy_inputs.keys()) - input_names += ["attention_mask", "position_ids"] - concrete_args = generate_concrete_args(model, input_names) - sch = slapo.create_schedule( - model, - world_size=1, - rank=0, - tracer="huggingface", - concrete_args=concrete_args, - ) + sch = slapo.create_schedule(model) # The original module list. - assert isinstance(sch.get_module("h"), torch.nn.Module) + assert isinstance(sch["h.0"].mod, torch.nn.Module) # Traced layers. - assert isinstance(sch.get_module("h.0"), fx.GraphModule) - assert isinstance(sch.get_module("h.0.attn"), fx.GraphModule) - assert isinstance(sch.get_module("h.0.attn.attention"), fx.GraphModule) - assert isinstance(sch.get_module("h.0.mlp"), fx.GraphModule) + sub_sch = sch["h.0"] + input_names = ["hidden_states", "attention_mask"] + sig = inspect.signature(sub_sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sub_sch.trace(tracer="pytorch", concrete_args=concrete_args) + assert isinstance(sch["h.0"].mod, fx.GraphModule) + assert isinstance(sch["h.0.attn"].mod, fx.GraphModule) + assert isinstance(sch["h.0.mlp"].mod, fx.GraphModule) + + # Attention submodule cannot be traced. + assert isinstance(sch["h.0.attn.attention"].mod, torch.nn.Module) -@pytest.mark.skip(reason="Need update") def test_torchvision_wideresnet(): """Test tracing torchvision wideresnet model.""" from torchvision.models.resnet import Bottleneck, ResNet model = ResNet(Bottleneck, [6, 8, 4, 6], width_per_group=128) concrete_args = generate_concrete_args(model, ["x"]) - sch = slapo.create_schedule( - model, - world_size=1, - rank=0, - tracer="pytorch", - concrete_args=concrete_args, - ) - - assert isinstance(sch.get_module("layer1"), fx.GraphModule) + sch = slapo.create_schedule(model) + sch.trace(tracer="pytorch", concrete_args=concrete_args) + + assert isinstance(sch["layer1"].mod, fx.GraphModule) for idx in range(6): assert isinstance(sch.get_module(f"layer1.{idx}"), fx.GraphModule) - # Should not trace leaf. - assert not isinstance(sch.get_module("layer1.0.conv1"), fx.GraphModule) + # Should not trace leaf. + assert not isinstance(sch[f"layer1.{idx}.conv1"].mod, fx.GraphModule) # Should have "layer1.0.downsample" instead of "layer1.0.downsample.0" - assert isinstance(sch.get_module("layer1.0.downsample"), fx.GraphModule) + assert isinstance(sch["layer1.0.downsample"].mod, fx.GraphModule) # Should not trace leaf. - assert not isinstance(sch.get_module("layer1.0.downsample.0"), fx.GraphModule) + assert not isinstance(sch["layer1.0.downsample.0"].mod, fx.GraphModule) if __name__ == "__main__":