Skip to content

Commit 0a265a9

Browse files
Update on "v0 param server (using collectives not object store)"
[ghstack-poisoned]
1 parent 7bd553d commit 0a265a9

File tree

2 files changed

+32
-19
lines changed

2 files changed

+32
-19
lines changed

param_server_weight_updater.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -218,16 +218,18 @@ def _create_trainer_group(
218218

219219
model = "facebook/opt-125m"
220220

221-
ray.init(num_cpus=4, num_gpus=4)
221+
ray.init(num_cpus=5, num_gpus=5)
222222

223-
vllm_master_address, vllm_update_port = get_ip(), get_open_port()
223+
vllm_addresses = [get_ip()] * 2
224+
vllm_ports = [get_open_port() for i in range(2)]
225+
print(vllm_ports)
224226

225227
trainer_workers, parameter_server = _create_trainer_group(
226228
TrainerActor,
227229
vLLMParameterServer,
228230
3,
229-
vllm_master_address,
230-
vllm_update_port,
231+
vllm_addresses,
232+
vllm_ports,
231233
model,
232234
)
233235

@@ -236,19 +238,28 @@ def _create_trainer_group(
236238
handles.append(trainer_worker.train.remote())
237239

238240
model_metadata = ray.get(parameter_server.get_model_metadata.remote())
239-
local_weight_updater = vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata)
241+
local_weight_updaters = [
242+
vLLMHFLocalWeightUpdater(vllm_master_address, vllm_update_port, model_metadata) for
243+
vllm_master_address, vllm_update_port in zip(vllm_addresses, vllm_ports)
244+
]
240245

241246
make_env_parsed = partial(make_env, batch_size=args.batch_size, dataset=args.dataset)
242247
collector = RayCollector(
243-
[make_env_parsed],
248+
[make_env_parsed, make_env_parsed],
244249
policy_factory=make_policy,
245250
frames_per_batch=40,
246251
total_frames=200,
247252
remote_configs=remote_configs,
248253
remote_weight_updater=parameter_server,
249-
collector_kwargs={
250-
"local_weight_updater": local_weight_updater,
251-
},
254+
num_collectors=2,
255+
collector_kwargs=[
256+
{
257+
"local_weight_updater": local_weight_updaters[0],
258+
},
259+
{
260+
"local_weight_updater": local_weight_updaters[1],
261+
}
262+
],
252263
update_after_each_batch=True,
253264
)
254265
print("done collector init")
@@ -258,6 +269,6 @@ def _create_trainer_group(
258269
for i, data in enumerate(collector):
259270
print(tokenizer.decode(data["tokens"][0].squeeze()))
260271
print(tokenizer.decode(data["tokens_response"][0].squeeze()))
261-
if i == 1:
272+
if i == 3:
262273
break
263274
collector.shutdown()

torchrl/collectors/vllm_weight_update.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class WorkerExtension(Worker):
5656
def init_weight_update_group(self, master_address, master_port,
5757
rank_offset, world_size):
5858
from vllm.distributed.parallel_state import get_world_group
59-
rank = get_world_group().rank + rank_offset
59+
# rank = get_world_group().rank + rank_offset
60+
rank = rank_offset
6061
self.model_update_group = stateless_init_process_group(
6162
master_address,
6263
master_port,
@@ -91,10 +92,11 @@ class WorkerExtension:
9192

9293
class vLLMHFLocalWeightUpdater(LocalWeightUpdaterBase):
9394
def __init__(self, master_address, master_port, model_metadata):
95+
print(f"{master_address=}, {master_port=}")
9496
self.master_address = master_address
9597
self.master_port = master_port
9698
self.model_metadata = model_metadata
97-
self.model_update_group = None
99+
self.initialized_group = None
98100

99101
def _get_server_weights(self):
100102
return None
@@ -110,13 +112,13 @@ def _maybe_map_weights(self, server_weights, local_weights):
110112

111113
def _update_local_weights(self, local_weights, mapped_weights):
112114
llm = self.collector.policy["generate"].module
113-
if self.model_update_group is None:
114-
# FIXME: hardcoded
115+
if self.initialized_group is None:
115116
weight_sync_world_size = llm.llm_engine.parallel_config.tensor_parallel_size + 1
116117
llm.collective_rpc(
117118
"init_weight_update_group",
118119
args=(self.master_address, self.master_port, 1, weight_sync_world_size)
119120
)
121+
self.initialized_group = True
120122

121123
for k, (dtype, shape) in self.model_metadata.items():
122124
llm.collective_rpc(
@@ -125,11 +127,11 @@ def _update_local_weights(self, local_weights, mapped_weights):
125127
)
126128

127129
class vLLMRemoteWeightUpdaterBase(RemoteWeightUpdaterBase):
128-
def __init__(self, model, vllm_master_address, vllm_master_port):
130+
def __init__(self, model, vllm_master_addresses, vllm_master_ports):
129131
super().__init__()
130132
from transformers import AutoModel
131-
self.vllm_master_address = vllm_master_address
132-
self.vllm_master_port = vllm_master_port
133+
self.vllm_master_addresses = vllm_master_addresses
134+
self.vllm_master_ports = vllm_master_ports
133135
self.state_dict = AutoModel.from_pretrained(model).cuda().eval().state_dict()
134136
self.state_dict_lock = threading.Lock()
135137
self.vllm_comm_groups = dict()
@@ -160,8 +162,8 @@ def _init_model_update_group(self, worker_id):
160162
vllm_tp_size = 1
161163
weight_sync_world_size = vllm_tp_size + 1
162164
model_update_group = stateless_init_process_group(
163-
self.vllm_master_address,
164-
self.vllm_master_port,
165+
self.vllm_master_addresses[worker_id],
166+
self.vllm_master_ports[worker_id],
165167
0,
166168
weight_sync_world_size,
167169
torch.device("cuda:0"),

0 commit comments

Comments
 (0)