diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index 4e3e09c6f..27aaeff71 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -21,3 +21,4 @@ from parl.utils.rl_utils import * from parl.utils.scheduler import * from parl.utils.path_utils import * +from parl.utils.env_utils import * diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py new file mode 100644 index 000000000..6bc1ccf82 --- /dev/null +++ b/parl/utils/env_utils.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.utils import logger +from parl.remote.remote_decorator import remote_class +try: + import gym + gym_installed = True +except ImportError: + gym_installed = False +if gym_installed: + from gym.spaces import Box, Discrete + +__all__ = ['RemoteGymEnv'] + + +@remote_class +class RemoteGymEnv(object): + """ + + Example: + .. code-block:: python + import parl + from parl.utils import RemoteEnv + + parl.connect('localhost') + env = RemoteGymEnv(env_name='HalfCheetah-v1') + + Attributes: + env_name: gym environment name + + Public Functions: the same as gym (mainly action_space, observation_space, reset, step, seed ...) + + Note: + ``RemoteGymEnv`` defines a remote environment wrapper for running the environment remotely and + enables large-scale parallel collection of environmental data. + + Support both Continuous action space and Discrete action space environments. + + """ + + def __init__(self, env_name=None): + assert isinstance(env_name, str) + + class ActionSpace(object): + def __init__(self, + action_space=None, + low=None, + high=None, + shape=None, + n=None): + self.action_space = action_space + self.low = low + self.high = high + self.shape = shape + self.n = n + + def sample(self): + return self.action_space.sample() + + class ObservationSpace(object): + def __init__(self, observation_space, low, high, shape=None): + self.observation_space = observation_space + self.low = low + self.high = high + self.shape = shape + + self.env = gym.make(env_name) + self._max_episode_steps = int(self.env._max_episode_steps) + try: + self._elapsed_steps = int(self.env._elapsed_steps) + except: + logger.error('object has no attribute _elspaed_steps') + + self.observation_space = ObservationSpace( + self.env.observation_space, self.env.observation_space.low, + self.env.observation_space.high, self.env.observation_space.shape) + if isinstance(self.env.action_space, Discrete): + self.action_space = ActionSpace(n=self.env.action_space.n) + elif isinstance(self.env.action_space, Box): + self.action_space = ActionSpace( + self.env.action_space, self.env.action_space.low, + self.env.action_space.high, self.env.action_space.shape) + + def reset(self): + return self.env.reset() + + def step(self, action): + return self.env.step(action) + + def seed(self, seed): + return self.env.seed(seed) + + def render(self): + return logger.warning( + 'Can not render in remote environment, render() have been skipped.' + ) diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py new file mode 100644 index 000000000..5dbd3ae37 --- /dev/null +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import threading +import time +import parl +import numpy as np +from parl.remote.master import Master +from parl.remote.worker import Worker +from parl.remote.client import disconnect +from parl.utils import logger, get_free_tcp_port +from parl.utils.env_utils import RemoteGymEnv + + +def float_equal(x1, x2): + if np.abs(x1 - x2) < 1e-6: + return True + else: + return False + + +# Test RemoteGymEnv +# for both discrete and continuous action space environment +class TestRemoteEnv(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_discrete_env_wrapper(self): + logger.info("Running: test discrete_env_wrapper") + port = get_free_tcp_port() + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + woker1 = Worker('localhost:{}'.format(port), 1) + + parl.connect('localhost:{}'.format(port)) + logger.info("Running: test discrete_env_wrapper: 1") + + env = RemoteGymEnv(env_name='MountainCar-v0') + env.seed(1) + env.render() + + obs, done = env.reset(), False + observation_space = env.observation_space + obs_space_high = observation_space.high + obs_space_low = observation_space.low + self.assertTrue(float_equal(obs_space_high[1], 0.07)) + self.assertTrue(float_equal(obs_space_low[0], -1.2)) + + action_space = env.action_space + act_dim = action_space.n + self.assertEqual(act_dim, 3) + + # Run an episode with a random policy + total_steps, episode_reward = 0, 0 + while not done: + action = np.random.choice(act_dim) + next_obs, reward, done, _ = env.step(action) + episode_reward += reward + logger.info('Episode done, total_steps {}, episode_reward {}'.format( + total_steps, episode_reward)) + + master.exit() + woker1.exit() + + def test_continuous_env_wrapper(self): + logger.info("Running: test continuous_env_wrapper") + port = get_free_tcp_port() + master = Master(port=port) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + woker1 = Worker('localhost:{}'.format(port), 1) + + parl.connect('localhost:{}'.format(port)) + logger.info("Running: test continuous_env_wrapper: 1") + + env = RemoteGymEnv(env_name='Pendulum-v0') + env.seed(0) + env.render() + + obs, done = env.reset(), False + observation_space = env.observation_space + obs_space_high = observation_space.high + obs_space_low = observation_space.low + self.assertTrue(float_equal(obs_space_high[1], 1.)) + self.assertTrue(float_equal(obs_space_low[1], -1.)) + + action_space = env.action_space + action_space_high = action_space.high + action_space_low = action_space.low + self.assertEqual(action_space_high, [2.]) + self.assertEqual(action_space_low, [-2.]) + + # Run an episode with a random policy + total_steps, episode_reward = 0, 0 + while not done: + total_steps += 1 + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + episode_reward += reward + logger.info('Episode done, total_steps {}, episode_reward {}'.format( + total_steps, episode_reward)) + + master.exit() + woker1.exit() + + +if __name__ == '__main__': + unittest.main()