Skip to content

Commit

Permalink
Allow multi-proc env to be rendered.
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 6, 2019
1 parent 9d4ad1c commit 20baac2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 9 additions & 4 deletions lib/RLTrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,23 +213,28 @@ def test(self, model_epoch: int = 0, should_render: bool = True):

del train_provider

test_env = SubprocVecEnv([make_env(test_provider, i) for i in range(self.n_envs)])
test_env = DummyVecEnv([make_env(test_provider, i) for i in range(1)])

model_path = path.join('data', 'agents', f'{self.study_name}__{model_epoch}.pkl')
model = self.Model.load(model_path, env=test_env)

self.logger.info(f'Testing model ({self.study_name}__{model_epoch})')

zero_completed_obs = np.zeros((self.n_envs,) + test_env.observation_space.shape)
zero_completed_obs[0, :] = test_env.reset()

state = None
obs, rewards = test_env.reset(), []
rewards = []

for _ in range(len(test_provider.data_frame)):
action, state = model.predict(obs, state=state)
action, state = model.predict(zero_completed_obs, state=state)
obs, reward, _, __ = test_env.step(action)

zero_completed_obs[0, :] = obs

rewards.append(reward)

if should_render and self.n_envs == 1:
if should_render:
test_env.render(mode='human')

self.logger.info(
Expand Down
2 changes: 1 addition & 1 deletion optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def optimize_code(params):


if __name__ == '__main__':
n_process = multiprocessing.cpu_count() - 4
n_process = multiprocessing.cpu_count()
params = {}

processes = []
Expand Down

0 comments on commit 20baac2

Please sign in to comment.