Skip to content

Commit

Permalink
Parallel sample_trajectories (doesn't work)
Browse files Browse the repository at this point in the history
  • Loading branch information
frank-lsf committed Sep 27, 2020
1 parent ab39493 commit 3109a76
Showing 1 changed file with 86 additions and 22 deletions.
108 changes: 86 additions & 22 deletions hw2/cs285/infrastructure/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,27 +107,6 @@ def sample_trajectory(
return Path(obs, image_obs, acs, rewards, next_obs, terminals)


def sample_trajectories(
env,
policy,
min_timesteps_per_batch,
max_path_length,
render=False,
render_mode=("rgb_array"),
):
"""Collect rollouts until we have collected min_timesteps_per_batch steps."""
timesteps_this_batch = 0
paths = []
while timesteps_this_batch < min_timesteps_per_batch:
# use sample_trajectory to get each path (i.e. rollout) that goes into paths
path = sample_trajectory(env, policy, max_path_length, render, render_mode)
paths.append(path)
# use get_pathlength to count the timesteps collected in each path
timesteps_this_batch += get_pathlength(path)

return paths, timesteps_this_batch


def sample_n_trajectories(
env, policy, ntraj, max_path_length, render=False, render_mode=("rgb_array")
):
Expand Down Expand Up @@ -222,4 +201,89 @@ def add_noise(data_inp, noiseToSignal=0.01):
+ np.random.normal(0, np.absolute(std_of_noise[j]), (data.shape[0],))
)

return data
return data


############################################
# Parallel sampling
############################################


def sample_trajectories(*args, **kwargs):
return sample_trajectories_parallel(*args, **kwargs)


def sample_trajectories_sequential(
env,
policy,
min_timesteps_per_batch,
max_path_length,
render=False,
render_mode=("rgb_array"),
):
"""Collect rollouts until we have collected min_timesteps_per_batch steps."""
timesteps_this_batch = 0
paths = []
while timesteps_this_batch < min_timesteps_per_batch:
# use sample_trajectory to get each path (i.e. rollout) that goes into paths
path = sample_trajectory(env, policy, max_path_length, render, render_mode)
paths.append(path)
# use get_pathlength to count the timesteps collected in each path
timesteps_this_batch += get_pathlength(path)

return paths, timesteps_this_batch


import concurrent.futures as cf
import os
import torch.multiprocessing as mp


def sample_trajectories_parallel(
env,
policy,
min_timesteps_per_batch,
max_path_length,
render=False,
render_mode=("rgb_array"),
):
"""Collect rollouts until we have collected min_timesteps_per_batch steps."""
# Number of tasks to submit to the executor. This should be larger than
# the number of workers (i.e. CPU count).
n_tasks = os.cpu_count() * 2

timesteps_this_batch = 0
paths = []

mp.set_start_method("spawn")
with cf.ProcessPoolExecutor(mp_context=mp) as executor:
task_args = (
sample_trajectory,
env,
policy,
max_path_length,
render,
render_mode,
)
tasks = set(executor.submit(*task_args) for _ in range(n_tasks))
while True:
done_set, rest_set = cf.wait(tasks, return_when=cf.FIRST_COMPLETED)
(done,) = done_set
if not done.done():
raise done.exception()
path = done.result()
paths.append(path)
timesteps_this_batch += get_pathlength(path)

if timesteps_this_batch >= min_timesteps_per_batch:
# We have collected enough. Cancel the rest.
for task in rest_set:
task.cancel()
break
else:
# Submit a new sample_trajectory task
new_task = executor.submit(*task_args)
rest_set.add(new_task)
tasks = list(rest_set)

return paths, timesteps_this_batch

0 comments on commit 3109a76

Please sign in to comment.