@@ -56,7 +56,8 @@ class WorkerExtension(Worker):
56
56
def init_weight_update_group (self , master_address , master_port ,
57
57
rank_offset , world_size ):
58
58
from vllm .distributed .parallel_state import get_world_group
59
- rank = get_world_group ().rank + rank_offset
59
+ # rank = get_world_group().rank + rank_offset
60
+ rank = rank_offset
60
61
self .model_update_group = stateless_init_process_group (
61
62
master_address ,
62
63
master_port ,
@@ -91,10 +92,11 @@ class WorkerExtension:
91
92
92
93
class vLLMHFLocalWeightUpdater (LocalWeightUpdaterBase ):
93
94
def __init__ (self , master_address , master_port , model_metadata ):
95
+ print (f"{ master_address = } , { master_port = } " )
94
96
self .master_address = master_address
95
97
self .master_port = master_port
96
98
self .model_metadata = model_metadata
97
- self .model_update_group = None
99
+ self .initialized_group = None
98
100
99
101
def _get_server_weights (self ):
100
102
return None
@@ -110,13 +112,13 @@ def _maybe_map_weights(self, server_weights, local_weights):
110
112
111
113
def _update_local_weights (self , local_weights , mapped_weights ):
112
114
llm = self .collector .policy ["generate" ].module
113
- if self .model_update_group is None :
114
- # FIXME: hardcoded
115
+ if self .initialized_group is None :
115
116
weight_sync_world_size = llm .llm_engine .parallel_config .tensor_parallel_size + 1
116
117
llm .collective_rpc (
117
118
"init_weight_update_group" ,
118
119
args = (self .master_address , self .master_port , 1 , weight_sync_world_size )
119
120
)
121
+ self .initialized_group = True
120
122
121
123
for k , (dtype , shape ) in self .model_metadata .items ():
122
124
llm .collective_rpc (
@@ -125,11 +127,11 @@ def _update_local_weights(self, local_weights, mapped_weights):
125
127
)
126
128
127
129
class vLLMRemoteWeightUpdaterBase (RemoteWeightUpdaterBase ):
128
- def __init__ (self , model , vllm_master_address , vllm_master_port ):
130
+ def __init__ (self , model , vllm_master_addresses , vllm_master_ports ):
129
131
super ().__init__ ()
130
132
from transformers import AutoModel
131
- self .vllm_master_address = vllm_master_address
132
- self .vllm_master_port = vllm_master_port
133
+ self .vllm_master_addresses = vllm_master_addresses
134
+ self .vllm_master_ports = vllm_master_ports
133
135
self .state_dict = AutoModel .from_pretrained (model ).cuda ().eval ().state_dict ()
134
136
self .state_dict_lock = threading .Lock ()
135
137
self .vllm_comm_groups = dict ()
@@ -160,8 +162,8 @@ def _init_model_update_group(self, worker_id):
160
162
vllm_tp_size = 1
161
163
weight_sync_world_size = vllm_tp_size + 1
162
164
model_update_group = stateless_init_process_group (
163
- self .vllm_master_address ,
164
- self .vllm_master_port ,
165
+ self .vllm_master_addresses [ worker_id ] ,
166
+ self .vllm_master_ports [ worker_id ] ,
165
167
0 ,
166
168
weight_sync_world_size ,
167
169
torch .device ("cuda:0" ),
0 commit comments