From 862f91bac17a1793e416f2d2442aed33cca81cf2 Mon Sep 17 00:00:00 2001 From: zindzigriffin Date: Fri, 1 Jul 2022 19:06:29 -0700 Subject: [PATCH] Update rl_trainer.py --- hw1/cs285/infrastructure/rl_trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/hw1/cs285/infrastructure/rl_trainer.py b/hw1/cs285/infrastructure/rl_trainer.py index bb27972e..0f88d0cc 100644 --- a/hw1/cs285/infrastructure/rl_trainer.py +++ b/hw1/cs285/infrastructure/rl_trainer.py @@ -161,6 +161,10 @@ def collect_training_trajectories( # ``` return loaded_paths, 0, None ``` # (2) collect `self.params['batch_size']` transitions + if itr == 0: + with open(load_initial_expertdata, "rb") as f: + loaded_paths = pickle.load(f) + return loaded_paths, 0, None # TODO collect `batch_size` samples to be used for training # HINT1: use sample_trajectories from utils @@ -187,12 +191,12 @@ def train_agent(self): # TODO sample some data from the data buffer # HINT1: use the agent's sample function # HINT2: how much data = self.params['train_batch_size'] - ob_batch, ac_batch, re_batch, next_ob_batch, terminal_batch = TODO + ob_batch, ac_batch, re_batch, next_ob_batch, terminal_batch = self.agent.sample(self.params['train_batch_size']) # TODO use the sampled data to train an agent # HINT: use the agent's train function # HINT: keep the agent's training log for debugging - train_log = TODO + train_log = self.agent.train(ob_batch, ac_batch, re_batch, next_ob_batch, terminal_batch) all_logs.append(train_log) return all_logs @@ -202,7 +206,8 @@ def do_relabel_with_expert(self, expert_policy, paths): # TODO relabel collected obsevations (from our policy) with labels from an expert policy # HINT: query the policy (using the get_action function) with paths[i]["observation"] # and replace paths[i]["action"] with these expert labels - + for i in range(len(paths)): + paths[i]["action"] = expert_policy.get_action(paths[i]["observation"]).detach().numpy() return paths ####################################