diff --git a/test/test_transformer.py b/test/test_transformer.py index 899f68c55..cee785ac0 100644 --- a/test/test_transformer.py +++ b/test/test_transformer.py @@ -1,6 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import torch from pippy import annotate_split_points, Pipe, SplitPoint +import torch.distributed.checkpoint as dcp +import tempfile d_hid = 16 @@ -66,6 +68,49 @@ def get_layers(module): return layers +def pipe_to_sd(pipe): + sd = {} + for stage_idx in range(pipe.num_stages): + stage_mod = pipe.get_stage_module(stage_idx) + sd[f"stage_{stage_idx}"] = stage_mod + return sd + +with tempfile.TemporaryDirectory() as tmpdir: + #Simulate saving the pipe + # Option 1: + # for stage_idx in range(pipe.num_stages): + # print(f"Saving pipeline stage {stage_idx}") + # stage_mod = pipe.get_stage_module(stage_idx) + # dcp.save( + # {f"stage_{stage_idx}": stage_mod}, + # checkpoint_id=f"{tmpdir}_{stage_idx}" + # ) + # Option 2: + sd = pipe_to_sd(pipe) + dcp.save(state_dict, checkpoint_id=tmpdir) + + + #Simulate loading the pipe + # Option 1: + # for stage_idx in range(pipe.num_stages): + # print(f"Loading pipeline stage {stage_idx}") + # stage_mod = pipe.get_stage_module(stage_idx) + # dcp.load( + # {f"stage_{stage_idx}": stage_mod}, + # checkpoint_id=f"{tmpdir}_{stage_idx}" + # ) + + #Option 2: + new_pipe = Pipe.from_tracing( + transformer, + 1, + (x,), + ) + sd = pipe_to_sd(new_pipe) + dcp.load(sd, checkpoint_id=tmpdir) + +pipe = new_pipe + # Collect all layers in pipe layers = [] for stage_idx in range(pipe.num_stages):