diff --git a/hw2/cs285/infrastructure/utils.py b/hw2/cs285/infrastructure/utils.py index 35286be5..00e15b36 100644 --- a/hw2/cs285/infrastructure/utils.py +++ b/hw2/cs285/infrastructure/utils.py @@ -234,11 +234,22 @@ def sample_trajectories_sequential( return paths, timesteps_this_batch -import concurrent.futures as cf +import copy import os import torch.multiprocessing as mp +def mp_worker(result_queue, env, policy, max_path_length, render, render_mode): + # if policy.logits_na: + # policy.logits_na = copy.deepcopy(policy.logits_na) + # else: + # raise False + + while True: + result = sample_trajectory(env, copy.deepcopy(policy), max_path_length, render, render_mode) + result_queue.put(result) + + def sample_trajectories_parallel( env, policy, @@ -250,40 +261,31 @@ def sample_trajectories_parallel( """Collect rollouts until we have collected min_timesteps_per_batch steps.""" # Number of tasks to submit to the executor. This should be larger than # the number of workers (i.e. CPU count). - n_tasks = os.cpu_count() * 2 - timesteps_this_batch = 0 paths = [] - mp.set_start_method("spawn") - with cf.ProcessPoolExecutor(mp_context=mp) as executor: - task_args = ( - sample_trajectory, - env, - policy, - max_path_length, - render, - render_mode, + ctx = mp.get_context("spawn") + + def launch_worker(): + proc = ctx.Process( + target=mp_worker, + args=(result_queue, env, policy, max_path_length, render, render_mode), ) - tasks = set(executor.submit(*task_args) for _ in range(n_tasks)) - while True: - done_set, rest_set = cf.wait(tasks, return_when=cf.FIRST_COMPLETED) - (done,) = done_set - if not done.done(): - raise done.exception() - path = done.result() - paths.append(path) - timesteps_this_batch += get_pathlength(path) - - if timesteps_this_batch >= min_timesteps_per_batch: - # We have collected enough. Cancel the rest. - for task in rest_set: - task.cancel() - break - else: - # Submit a new sample_trajectory task - new_task = executor.submit(*task_args) - rest_set.add(new_task) - tasks = list(rest_set) + proc.start() + return proc + + # mp.set_start_method("spawn", force=True) + result_queue = ctx.Queue(1) + # processes = [launch_worker() for _ in range(os.cpu_count())] + processes = [launch_worker() for _ in range(1)] + + while True: + path = result_queue.get() + paths.append(path) + timesteps_this_batch += get_pathlength(path) + if timesteps_this_batch >= min_timesteps_per_batch: + # We have collected enough. Kill the workers. + for proc in processes: + proc.kill() return paths, timesteps_this_batch