Skip to content

Commit 27d3680

Browse files
committed
[Feature] Local and Remote WeightUpdaters
ghstack-source-id: 2962530f87b596d038e3a13a934ea09064af2964 Pull Request resolved: #2848
1 parent 49a8a42 commit 27d3680

File tree

11 files changed

+911
-185
lines changed

11 files changed

+911
-185
lines changed

Diff for: docs/source/reference/collectors.rst

+70
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,76 @@ try to limit the cases where a deepcopy will be executed. The following chart sh
117117

118118
Policy copy decision tree in Collectors.
119119

120+
Weight Synchronization in Distributed Environments
121+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
122+
In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the
123+
latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible
124+
mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios.
125+
126+
Local and Remote Weight Updaters
127+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
128+
129+
The weight synchronization process is facilitated by two main components: :class:`~torchrl.collectors.LocalWeightUpdaterBase`
130+
and :class:`~torchrl.collectors.RemoteWeightUpdaterBase`. These base classes provide a structured interface for
131+
implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs.
132+
133+
- :class:`~torchrl.collectors.LocalWeightUpdaterBase`: This component is responsible for updating the policy weights on
134+
the local inference worker. It is particularly useful when the training and inference occur on the same machine but on
135+
different devices. Users can extend this class to define how weights are fetched from a server and applied locally.
136+
It is also the extension point for collectors where the workers need to ask for weight updates (in contrast with
137+
situations where the server decides when to update the worker policies).
138+
- :class:`~torchrl.collectors.RemoteWeightUpdaterBase`: This component handles the distribution of policy weights to
139+
remote inference workers. It is essential in distributed systems where multiple workers need to be kept in sync with
140+
the central policy. Users can extend this class to implement custom logic for synchronizing weights across a network of
141+
devices or processes.
142+
143+
Extending the Updater Classes
144+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145+
146+
To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations.
147+
This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware
148+
setups. By implementing the abstract methods in these base classes, users can define how weights are retrieved,
149+
transformed, and applied, ensuring seamless integration with their existing infrastructure.
150+
151+
Default Implementations
152+
~~~~~~~~~~~~~~~~~~~~~~~
153+
154+
For common scenarios, the API provides default implementations of these updaters, such as
155+
:class:`~torchrl.collectors.VanillaLocalWeightUpdater`, :class:`~torchrl.collectors.MultiProcessedRemoteWeightUpdate`,
156+
:class:`~torchrl.collectors.RayRemoteWeightUpdater`, :class:`~torchrl.collectors.RPCRemoteWeightUpdater`, and
157+
:class:`~torchrl.collectors.DistributedRemoteWeightUpdater`.
158+
These implementations cover a range of typical deployment configurations, from single-device setups to large-scale
159+
distributed systems.
160+
161+
Practical Considerations
162+
~~~~~~~~~~~~~~~~~~~~~~~~
163+
164+
When designing a system that leverages this API, consider the following:
165+
166+
- Network Latency: In distributed environments, network latency can impact the speed of weight updates. Ensure that your
167+
implementation accounts for potential delays and optimizes data transfer where possible.
168+
- Consistency: Ensure that all workers receive the updated weights in a timely manner to maintain consistency across
169+
the system. This is particularly important in reinforcement learning scenarios where stale weights can lead to
170+
suboptimal policy performance.
171+
- Scalability: As your system grows, the weight synchronization mechanism should scale efficiently. Consider the
172+
overhead of broadcasting weights to a large number of workers and optimize the process to minimize bottlenecks.
173+
174+
By leveraging the API, users can achieve robust and efficient weight synchronization across a variety of deployment
175+
scenarios, ensuring that their policies remain up-to-date and performant.
176+
177+
.. currentmodule:: torchrl.collectors
178+
179+
.. autosummary::
180+
:toctree: generated/
181+
:template: rl_template.rst
182+
183+
LocalWeightUpdaterBase
184+
RemoteWeightUpdaterBase
185+
VanillaLocalWeightUpdater
186+
MultiProcessedRemoteWeightUpdate
187+
RayRemoteWeightUpdater
188+
DistributedRemoteWeightUpdater
189+
RPCRemoteWeightUpdater
120190

121191
Collectors and replay buffers interoperability
122192
----------------------------------------------

Diff for: test/test_distributed.py

+39-31
Original file line numberDiff line numberDiff line change
@@ -390,25 +390,29 @@ def _test_distributed_collector_updatepolicy(
390390
update_interval=update_interval,
391391
**cls.distributed_kwargs(),
392392
)
393-
total = 0
394-
first_batch = None
395-
last_batch = None
396-
for i, data in enumerate(collector):
397-
total += data.numel()
398-
assert data.numel() == frames_per_batch
399-
if i == 0:
400-
first_batch = data
401-
policy.weight.data += 1
402-
elif total == total_frames - frames_per_batch:
403-
last_batch = data
404-
assert (first_batch["action"] == 1).all(), first_batch["action"]
405-
if update_interval == 1:
406-
assert (last_batch["action"] == 2).all(), last_batch["action"]
407-
else:
408-
assert (last_batch["action"] == 1).all(), last_batch["action"]
409-
collector.shutdown()
410-
assert total == total_frames
411-
queue.put("passed")
393+
try:
394+
395+
total = 0
396+
first_batch = None
397+
last_batch = None
398+
for i, data in enumerate(collector):
399+
total += data.numel()
400+
assert data.numel() == frames_per_batch
401+
if i == 0:
402+
first_batch = data
403+
policy.weight.data += 1
404+
elif total == total_frames - frames_per_batch:
405+
last_batch = data
406+
assert (first_batch["action"] == 1).all(), first_batch["action"]
407+
if update_interval == 1:
408+
assert (last_batch["action"] == 2).all(), last_batch["action"]
409+
else:
410+
assert (last_batch["action"] == 1).all(), last_batch["action"]
411+
assert total == total_frames
412+
queue.put("passed")
413+
finally:
414+
collector.shutdown()
415+
queue.put("not passed")
412416

413417
@pytest.mark.parametrize(
414418
"collector_class",
@@ -490,12 +494,14 @@ def test_distributed_collector_sync(self, sync, frames_per_batch=200):
490494
sync=sync,
491495
**self.distributed_kwargs(),
492496
)
493-
total = 0
494-
for data in collector:
495-
total += data.numel()
496-
assert data.numel() == frames_per_batch
497-
collector.shutdown()
498-
assert total == 200
497+
try:
498+
total = 0
499+
for data in collector:
500+
total += data.numel()
501+
assert data.numel() == frames_per_batch
502+
assert total == 200
503+
finally:
504+
collector.shutdown()
499505

500506
@pytest.mark.parametrize(
501507
"collector_class",
@@ -517,12 +523,14 @@ def test_distributed_collector_class(self, collector_class):
517523
frames_per_batch=frames_per_batch,
518524
**self.distributed_kwargs(),
519525
)
520-
total = 0
521-
for data in collector:
522-
total += data.numel()
523-
assert data.numel() == frames_per_batch
524-
collector.shutdown()
525-
assert total == 200
526+
try:
527+
total = 0
528+
for data in collector:
529+
total += data.numel()
530+
assert data.numel() == frames_per_batch
531+
assert total == 200
532+
finally:
533+
collector.shutdown()
526534

527535
@pytest.mark.parametrize(
528536
"collector_class",

Diff for: torchrl/collectors/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,21 @@
1212
MultiSyncDataCollector,
1313
SyncDataCollector,
1414
)
15+
from .weight_update import (
16+
LocalWeightUpdaterBase,
17+
MultiProcessedRemoteWeightUpdate,
18+
RayRemoteWeightUpdater,
19+
RemoteWeightUpdaterBase,
20+
VanillaLocalWeightUpdater,
21+
)
1522

1623
__all__ = [
1724
"RandomPolicy",
25+
"LocalWeightUpdaterBase",
26+
"RemoteWeightUpdaterBase",
27+
"VanillaLocalWeightUpdater",
28+
"RayRemoteWeightUpdater",
29+
"MultiProcessedRemoteWeightUpdate",
1830
"aSyncDataCollector",
1931
"DataCollectorBase",
2032
"MultiaSyncDataCollector",

0 commit comments

Comments
 (0)