diff --git a/tests/core/test_common_utils.py b/tests/core/test_common_utils.py index 2d1aff0c..e4355fc5 100644 --- a/tests/core/test_common_utils.py +++ b/tests/core/test_common_utils.py @@ -1,3 +1,4 @@ +import numpy as np import omegaconf import pytest @@ -130,3 +131,58 @@ def _check_shapes(train_cap): assert isinstance(val, replay_buffer.IterableReplayBuffer) _check_shapes(1500) + + +class MockModelEnv: + def __init__(self): + self.obs = None + + def reset(self, obs0, propagation_method=None, return_as_np=None): + self.obs = obs0 + return obs0 + + def step(self, action, sample=None): + next_obs = self.obs + action[:, :1] + reward = np.ones(next_obs.shape[0]) + done = np.zeros(next_obs.shape[0]) + self.obs = next_obs + return next_obs, reward, done, {} + + +class MockAgent: + def __init__(self, length): + self.actions = np.ones((length, 1)) + + def plan(self, obs): + return self.actions + + +def test_rollout_model_env(): + obs_size = 10 + plan_length = 20 + num_samples = 5 + model_env = MockModelEnv() + obs0 = np.zeros(obs_size) + agent = MockAgent(plan_length) + plan = 0 * agent.plan(obs0) # this should be ignored + + # Check rolling out with an agent + obs, rewards, actions = utils.rollout_model_env( + model_env, obs0, plan, agent, num_samples=num_samples + ) + + assert obs.shape == (plan_length + 1, num_samples, obs_size) + assert rewards.shape == (plan_length, num_samples) + assert actions.shape == (plan_length, 1) + + for i, o in enumerate(obs): + assert o.min() == i + + # Check rolling out with a given plan + plan = 2 * agent.plan(obs0) + obs, rewards, actions = utils.rollout_model_env( + model_env, obs0, plan, None, num_samples=num_samples + ) + + for i, o in enumerate(obs): + assert o.min() == 2 * i