diff --git a/minatar/environment.py b/minatar/environment.py index 119a90c..e7390ba 100644 --- a/minatar/environment.py +++ b/minatar/environment.py @@ -17,8 +17,9 @@ class Environment: def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True, random_seed = None): env_module = import_module('minatar.environments.'+env_name) + self.random = np.random.RandomState(random_seed) self.env_name = env_name - self.env = env_module.Env(ramping = difficulty_ramping, seed = random_seed) + self.env = env_module.Env(ramping = difficulty_ramping, random_state = self.random) self.n_channels = self.env.state_shape()[2] self.sticky_action_prob = sticky_action_prob self.last_action = 0 @@ -27,7 +28,7 @@ def __init__(self, env_name, sticky_action_prob = 0.1, difficulty_ramping = True # Wrapper for env.act def act(self, a): - if(np.random.rand()