Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Autoshard] Auto-parallelism solver #96

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open

Conversation

chhzh123
Copy link
Contributor

@chhzh123 chhzh123 commented May 22, 2023

Description

This PR introduces an auto-parallelism solver that finds the optimal sharding scheme for a given model. Basically, it models the parallelism scheme of each tensor as the combination of "R" and "S" specs, where R represents "replicated", and S denotes "sharded". We can explicitly calculate the computation and resharding cost (#95) of each operator, and sum all the costs together to form an optimization problem.

The problem is then encoded as a program synthesis problem and solved by a z3 solver using counter-example guided synthesis. Detailed process can be found in the solver.py file.

A sample output of a two-layer MLP is shown below. It shows how the solver finds the optimal scheme step by step. The cost of each scheme is also dumped for users to better reason about the tradeoff.

$ python3 tests/test_autoshard.py 
[2023-05-22 06:19:31,530][INFO][solver.py:339:dump_fx_node] 
 name      op             target                                    shape            dtype
--------  -------------  ----------------------------------------  ---------------  -------------
x         placeholder    x                                         [8, 1024, 1024]  torch.float32
fc1       call_module    <class 'torch.nn.modules.linear.Linear'>  [8, 1024, 4096]  torch.float32
|-weight                                                           [4096, 1024]     torch.float32
|-bias                                                             [4096]           torch.float32
gelu      call_function  <built-in function gelu>                  [8, 1024, 4096]  torch.float32
fc2       call_module    <class 'torch.nn.modules.linear.Linear'>  [8, 1024, 1024]  torch.float32
|-weight                                                           [1024, 4096]     torch.float32
|-bias                                                             [1024]           torch.float32
output    output         output                                    [8, 1024, 1024]  torch.float32 

[2023-05-22 06:19:31,553][INFO][solver.py:583:solve] =================== Iter 0 ===================
[2023-05-22 06:19:31,556][INFO][solver.py:594:solve] [fc1_1 = 0, fc1_0 = 2, fc2_0 = 1, fc2_1 = 2]
[2023-05-22 06:19:31,563][INFO][solver.py:517:calculate_new_cost] 
 Name    InSpec    OutSpec    Cost
------  --------  ---------  -------
fc1     SRxRR     SR         0
|-x     RR        SR         0
fc2     RSxSR     RR         1048576
|-gelu  SR        RS         458752
output  RR        RR         0
Total                        1507328 

[2023-05-22 06:19:31,563][INFO][solver.py:583:solve] =================== Iter 1 ===================
[2023-05-22 06:19:31,564][INFO][solver.py:594:solve] [fc1_1 = 1, fc1_0 = 0, fc2_0 = 1, fc2_1 = 2]
[2023-05-22 06:19:31,571][INFO][solver.py:517:calculate_new_cost] 
 Name    InSpec    OutSpec    Cost
------  --------  ---------  -------
fc1     RRxRS     RS         0
|-x     RR        RR         0
fc2     RSxSR     RR         1048576
|-gelu  RS        RS         0
output  RR        RR         0
Total                        1048576 

[2023-05-22 06:19:31,571][INFO][solver.py:583:solve] =================== Iter 2 ===================
[2023-05-22 06:19:31,573][INFO][solver.py:594:solve] [fc1_1 = 1, fc1_0 = 0, fc2_0 = 0, fc2_1 = 1]
[2023-05-22 06:19:31,579][INFO][solver.py:517:calculate_new_cost] 
 Name    InSpec    OutSpec    Cost
------  --------  ---------  ------
fc1     RRxRS     RS         0
|-x     RR        RR         0
fc2     RRxRS     RS         0
|-gelu  RS        RR         524288
output  RS        RR         131072
Total                        655360 

[2023-05-22 06:19:31,580][INFO][solver.py:583:solve] =================== Iter 3 ===================
[2023-05-22 06:19:31,581][INFO][solver.py:594:solve] [fc1_1 = 1, fc1_0 = 0, fc2_0 = 2, fc2_1 = 0]
[2023-05-22 06:19:31,588][INFO][solver.py:517:calculate_new_cost] 
 Name    InSpec    OutSpec    Cost
------  --------  ---------  ------
fc1     RRxRS     RS         0
|-x     RR        RR         0
fc2     SRxRR     SR         0
|-gelu  RS        SR         458752
output  SR        RR         131072
Total                        589824 

[2023-05-22 06:19:31,588][INFO][solver.py:583:solve] =================== Iter 4 ===================
[2023-05-22 06:19:31,589][INFO][solver.py:594:solve] [fc1_1 = 0, fc1_0 = 2, fc2_0 = 2, fc2_1 = 0]
[2023-05-22 06:19:31,596][INFO][solver.py:517:calculate_new_cost] 
 Name    InSpec    OutSpec    Cost
------  --------  ---------  ------
fc1     SRxRR     SR         0
|-x     RR        SR         0
fc2     SRxRR     SR         0
|-gelu  SR        SR         0
output  SR        RR         131072
Total                        131072 

[2023-05-22 06:19:31,596][INFO][solver.py:583:solve] =================== Iter 5 ===================
[2023-05-22 06:19:31,597][INFO][solver.py:590:solve] Cannot find better solutions

Best solution:
sch["fc1"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR")
sch["fc2"].sync(mode="fwd_post", sync_op_or_fn="SR->RR")

Checklist

  • Support attention module
  • Add HF model tests

The autosharder is still in early shape and requires more rigorous testing, but I would like to first gather more suggestions on the interface and code organizations.

cc @comaniac @zarzen @whbldhwj

@chhzh123
Copy link
Contributor Author

To enable testing the HF model, #94 needs to be merged first

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant