diff --git a/hw1/cs285/infrastructure/utils.py b/hw1/cs285/infrastructure/utils.py index d894480b..51bb3ff5 100644 --- a/hw1/cs285/infrastructure/utils.py +++ b/hw1/cs285/infrastructure/utils.py @@ -73,8 +73,8 @@ def sample_n_trajectories(env, policy, ntraj, max_path_length, render=False, ren Hint1: use sample_trajectory to get each path (i.e. rollout) that goes into paths """ paths = [] - - TODO + for n in range(ntraj): + paths.append(sample_trajectory(env,policy,max_path_length,render,render_mode)) return paths @@ -116,4 +116,4 @@ def convert_listofrollouts(paths, concat_rew=True): ############################################ def get_pathlength(path): - return len(path["reward"]) \ No newline at end of file + return len(path["reward"])