Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
added test for rollout_model_env
Browse files Browse the repository at this point in the history
  • Loading branch information
luisenp committed Mar 8, 2021
1 parent 7a0f3dd commit 52c1da4
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/core/test_common_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import omegaconf
import pytest

Expand Down Expand Up @@ -130,3 +131,58 @@ def _check_shapes(train_cap):
assert isinstance(val, replay_buffer.IterableReplayBuffer)

_check_shapes(1500)


class MockModelEnv:
def __init__(self):
self.obs = None

def reset(self, obs0, propagation_method=None, return_as_np=None):
self.obs = obs0
return obs0

def step(self, action, sample=None):
next_obs = self.obs + action[:, :1]
reward = np.ones(next_obs.shape[0])
done = np.zeros(next_obs.shape[0])
self.obs = next_obs
return next_obs, reward, done, {}


class MockAgent:
def __init__(self, length):
self.actions = np.ones((length, 1))

def plan(self, obs):
return self.actions


def test_rollout_model_env():
obs_size = 10
plan_length = 20
num_samples = 5
model_env = MockModelEnv()
obs0 = np.zeros(obs_size)
agent = MockAgent(plan_length)
plan = 0 * agent.plan(obs0) # this should be ignored

# Check rolling out with an agent
obs, rewards, actions = utils.rollout_model_env(
model_env, obs0, plan, agent, num_samples=num_samples
)

assert obs.shape == (plan_length + 1, num_samples, obs_size)
assert rewards.shape == (plan_length, num_samples)
assert actions.shape == (plan_length, 1)

for i, o in enumerate(obs):
assert o.min() == i

# Check rolling out with a given plan
plan = 2 * agent.plan(obs0)
obs, rewards, actions = utils.rollout_model_env(
model_env, obs0, plan, None, num_samples=num_samples
)

for i, o in enumerate(obs):
assert o.min() == 2 * i

0 comments on commit 52c1da4

Please sign in to comment.