Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote gym env wrapper #522

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
13 changes: 9 additions & 4 deletions parl/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ class RemoteGymEnv(object):
env = RemoteGymEnv(env_name='HalfCheetah-v1')

Attributes:
env_name: Mujoco gym environment name
env_name: gym environment name

Public Functions: the same as gym (mainly action_space, observation_space, reset, step, seed ...)

Note:
``RemoteGymEnv`` defines a remote environment wrapper for machines that do not support mujoco,
support both Continuous action space and Discrete action space environments
``RemoteGymEnv`` defines a remote environment wrapper for running the environment remotely and
enables large-scale parallel collection of environmental data.

Support both Continuous action space and Discrete action space environments.

"""

Expand Down Expand Up @@ -71,7 +73,10 @@ def __init__(self, observation_space, low, high, shape=None):

self.env = gym.make(env_name)
self._max_episode_steps = int(self.env._max_episode_steps)
self._elapsed_steps = int(self.env._elapsed_steps)
try:
self._elapsed_steps = int(self.env._elapsed_steps)
except:
logger.info('object has no attribute _elspaed_steps')
rical730 marked this conversation as resolved.
Show resolved Hide resolved

self.observation_space = ObservationSpace(
self.env.observation_space, self.env.observation_space.low,
Expand Down
33 changes: 21 additions & 12 deletions parl/utils/remote_gym_env_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,35 @@
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
from parl.utils import logger
from parl.utils import logger, get_free_tcp_port
from env_utils import RemoteGymEnv
import gym
from gym.spaces import Box, Discrete


def float_equal(x1, x2):
if np.abs(x1 - x2) < 1e-6:
return True
else:
return False


# Test RemoteGymEnv
# for both discrete and continuous action space
# for both discrete and continuous action space environment
class TestRemoteEnv(unittest.TestCase):
def tearDown(self):
disconnect()

def test_discrete_env_wrapper(self):
logger.info("Running: test discrete_env_wrapper")
master = Master(port=8267)
port = get_free_tcp_port()
rical730 marked this conversation as resolved.
Show resolved Hide resolved
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
woker1 = Worker('localhost:8267', 1)
woker1 = Worker('localhost:{}'.format(port), 1)

parl.connect('localhost:8267')
parl.connect('localhost:{}'.format(port))
logger.info("Running: test discrete_env_wrapper: 1")

env = RemoteGymEnv(env_name='MountainCar-v0')
Expand All @@ -51,8 +59,8 @@ def test_discrete_env_wrapper(self):
observation_space = env.observation_space
obs_space_high = observation_space.high
obs_space_low = observation_space.low
self.assertEqual(obs_space_high[0], 0.6)
self.assertEqual(obs_space_low[0], -1.2)
self.assertTrue(float_equal(obs_space_high[1], 0.07))
self.assertTrue(float_equal(obs_space_low[0], -1.2))

action_space = env.action_space
act_dim = action_space.n
Expand All @@ -72,13 +80,14 @@ def test_discrete_env_wrapper(self):

def test_continuous_env_wrapper(self):
logger.info("Running: test continuous_env_wrapper")
master = Master(port=8268)
port = get_free_tcp_port()
master = Master(port=port)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
woker1 = Worker('localhost:8268', 1)
woker1 = Worker('localhost:{}'.format(port), 1)

parl.connect('localhost:8268')
parl.connect('localhost:{}'.format(port))
logger.info("Running: test continuous_env_wrapper: 1")

env = RemoteGymEnv(env_name='Pendulum-v0')
Expand All @@ -89,8 +98,8 @@ def test_continuous_env_wrapper(self):
observation_space = env.observation_space
obs_space_high = observation_space.high
obs_space_low = observation_space.low
self.assertEqual(obs_space_high[1], 1.)
self.assertEqual(obs_space_low[1], -1.)
self.assertTrue(float_equal(obs_space_high[1], 1.))
self.assertTrue(float_equal(obs_space_low[1], -1.))

action_space = env.action_space
action_space_high = action_space.high
Expand Down