From b5a674fb75633e8110a3883c23908e1416165d71 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Mon, 28 Dec 2020 10:35:53 +0800 Subject: [PATCH 01/16] remote gym env wrapper --- parl/utils/__init__.py | 1 + parl/utils/remote_gym_env_wrapper.py | 92 +++++++++++++++++++ .../tests/remote_gym_env_wrapper_test.py | 57 ++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 parl/utils/remote_gym_env_wrapper.py create mode 100644 parl/utils/tests/remote_gym_env_wrapper_test.py diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index 4e3e09c6f..f2e2a4750 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -21,3 +21,4 @@ from parl.utils.rl_utils import * from parl.utils.scheduler import * from parl.utils.path_utils import * +from parl.utils.remote_gym_env_wrapper import * diff --git a/parl/utils/remote_gym_env_wrapper.py b/parl/utils/remote_gym_env_wrapper.py new file mode 100644 index 000000000..f20428b63 --- /dev/null +++ b/parl/utils/remote_gym_env_wrapper.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import parl +from parl.utils import logger +from gym.spaces import Box, Discrete + +__all__ = ['RemoteGymEnv'] + + +@parl.remote_class +class RemoteGymEnv(object): + """ + + Example: + .. code-block:: python + import parl + from parl.utils import RemoteEnv + + parl.connect('localhost') + env = RemoteGymEnv(env_name='HalfCheetah-v1') + + Attributes: + env_name: Mujoco gym environment name + + Public Functions: the same as gym (mainly action_space, observation_space, reset, step, seed ...) + + Note: + ``RemoteGymEnv`` defines a remote environment wrapper for machines that do not support mujoco, + support both Continuous action space and Discrete action space environments + + """ + def __init__(self, env_name=None): + assert isinstance(env_name, str) + + class ActionSpace(object): + def __init__(self, action_space=None, low=None, high=None, shape=None, n=None): + self.action_space = action_space + self.low = low + self.high = high + self.shape = shape + self.n = n + + def sample(self): + return self.action_space.sample() + + class ObservationSpace(object): + def __init__(self, observation_space, low, high, shape=None): + self.observation_space = observation_space + self.low = low + self.high = high + self.shape = shape + + self.env = gym.make(env_name) + self._max_episode_steps = int(self.env._max_episode_steps) + self._elapsed_steps = int(self.env._elapsed_steps) + + self.observation_space = ObservationSpace(self.env.observation_space, + self.env.observation_space.low, + self.env.observation_space.high, + self.env.observation_space.shape) + if isinstance(self.env.action_space, Discrete): + self.action_space = ActionSpace(n=self.env.action_space.n) + elif isinstance(self.env.action_space, Box): + self.action_space = ActionSpace(self.env.action_space, + self.env.action_space.low, + self.env.action_space.high, + self.env.action_space.shape) + + def reset(self): + return self.env.reset() + + def step(self, action): + return self.env.step(action) + + def seed(self, seed): + return self.env.seed(seed) + + def render(self): + return logger.warning('Using remote env, no need to render') diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py new file mode 100644 index 000000000..1eecbe590 --- /dev/null +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import numpy as np +from parl.utils import logger, RemoteGymEnv + + +# Example 1, Continuous action space +def main(): + """ + Get your localhost: + run "xparl start --port ****" on env server + """ + parl.connect('localhost') + env = RemoteGymEnv(env_name='HalfCheetah-v1') + + # Run an episode with a random policy + obs, done = env.reset(), False + total_steps, episode_reward = 0, 0 + while not done: + total_steps += 1 + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + episode_reward += reward + logger.info('Episode done, total_steps {}, episode_reward {}'.format(total_steps, episode_reward)) + + +# # Example 2, Discrete action space +# def main(): +# parl.connect('localhost') +# env = RemoteGymEnv(env_name='MountainCar-v0') +# +# # Run an episode with a random policy +# obs, done = env.reset(), False +# total_steps, episode_reward = 0, 0 +# while not done: +# total_steps += 1 +# action = np.random.choice(env.action_space.n) +# next_obs, reward, done, info = env.step(action) +# episode_reward += reward +# logger.info('Episode done, total_steps {}, episode_reward {}'.format(total_steps, episode_reward)) + + +if __name__ == '__main__': + main() From 72bde1ee12fd5174cf25b7b63d783569b1a9d758 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Mon, 28 Dec 2020 10:37:54 +0800 Subject: [PATCH 02/16] yapf .py --- parl/utils/remote_gym_env_wrapper.py | 22 +++++++++++-------- .../tests/remote_gym_env_wrapper_test.py | 4 ++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/parl/utils/remote_gym_env_wrapper.py b/parl/utils/remote_gym_env_wrapper.py index f20428b63..f97f592b3 100644 --- a/parl/utils/remote_gym_env_wrapper.py +++ b/parl/utils/remote_gym_env_wrapper.py @@ -42,11 +42,17 @@ class RemoteGymEnv(object): support both Continuous action space and Discrete action space environments """ + def __init__(self, env_name=None): assert isinstance(env_name, str) class ActionSpace(object): - def __init__(self, action_space=None, low=None, high=None, shape=None, n=None): + def __init__(self, + action_space=None, + low=None, + high=None, + shape=None, + n=None): self.action_space = action_space self.low = low self.high = high @@ -67,17 +73,15 @@ def __init__(self, observation_space, low, high, shape=None): self._max_episode_steps = int(self.env._max_episode_steps) self._elapsed_steps = int(self.env._elapsed_steps) - self.observation_space = ObservationSpace(self.env.observation_space, - self.env.observation_space.low, - self.env.observation_space.high, - self.env.observation_space.shape) + self.observation_space = ObservationSpace( + self.env.observation_space, self.env.observation_space.low, + self.env.observation_space.high, self.env.observation_space.shape) if isinstance(self.env.action_space, Discrete): self.action_space = ActionSpace(n=self.env.action_space.n) elif isinstance(self.env.action_space, Box): - self.action_space = ActionSpace(self.env.action_space, - self.env.action_space.low, - self.env.action_space.high, - self.env.action_space.shape) + self.action_space = ActionSpace( + self.env.action_space, self.env.action_space.low, + self.env.action_space.high, self.env.action_space.shape) def reset(self): return self.env.reset() diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py index 1eecbe590..9a420de31 100644 --- a/parl/utils/tests/remote_gym_env_wrapper_test.py +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -34,7 +34,8 @@ def main(): action = env.action_space.sample() next_obs, reward, done, info = env.step(action) episode_reward += reward - logger.info('Episode done, total_steps {}, episode_reward {}'.format(total_steps, episode_reward)) + logger.info('Episode done, total_steps {}, episode_reward {}'.format( + total_steps, episode_reward)) # # Example 2, Discrete action space @@ -52,6 +53,5 @@ def main(): # episode_reward += reward # logger.info('Episode done, total_steps {}, episode_reward {}'.format(total_steps, episode_reward)) - if __name__ == '__main__': main() From 58982082d87f9c8d3dd84d49c924c15a7673376f Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Mon, 28 Dec 2020 15:55:03 +0800 Subject: [PATCH 03/16] modify test --- .../tests/remote_gym_env_wrapper_test.py | 189 ++++++++++++++---- 1 file changed, 150 insertions(+), 39 deletions(-) diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py index 9a420de31..1e8c854a9 100644 --- a/parl/utils/tests/remote_gym_env_wrapper_test.py +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -12,46 +12,157 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest +import threading +import time import parl import numpy as np -from parl.utils import logger, RemoteGymEnv - - -# Example 1, Continuous action space -def main(): - """ - Get your localhost: - run "xparl start --port ****" on env server - """ - parl.connect('localhost') - env = RemoteGymEnv(env_name='HalfCheetah-v1') - - # Run an episode with a random policy - obs, done = env.reset(), False - total_steps, episode_reward = 0, 0 - while not done: - total_steps += 1 - action = env.action_space.sample() - next_obs, reward, done, info = env.step(action) - episode_reward += reward - logger.info('Episode done, total_steps {}, episode_reward {}'.format( - total_steps, episode_reward)) - - -# # Example 2, Discrete action space -# def main(): -# parl.connect('localhost') -# env = RemoteGymEnv(env_name='MountainCar-v0') -# -# # Run an episode with a random policy -# obs, done = env.reset(), False -# total_steps, episode_reward = 0, 0 -# while not done: -# total_steps += 1 -# action = np.random.choice(env.action_space.n) -# next_obs, reward, done, info = env.step(action) -# episode_reward += reward -# logger.info('Episode done, total_steps {}, episode_reward {}'.format(total_steps, episode_reward)) +from parl.remote.master import Master +from parl.remote.worker import Worker +from parl.remote.client import disconnect +from parl.utils import logger +import gym +from gym.spaces import Box, Discrete + + +@parl.remote_class +class RemoteGymEnv(object): + def __init__(self, env_name=None): + assert isinstance(env_name, str) + + class ActionSpace(object): + def __init__(self, + action_space=None, + low=None, + high=None, + shape=None, + n=None): + self.action_space = action_space + self.low = low + self.high = high + self.shape = shape + self.n = n + + def sample(self): + return self.action_space.sample() + + class ObservationSpace(object): + def __init__(self, observation_space, low, high, shape=None): + self.observation_space = observation_space + self.low = low + self.high = high + self.shape = shape + + self.env = gym.make(env_name) + self._max_episode_steps = int(self.env._max_episode_steps) + self._elapsed_steps = int(self.env._elapsed_steps) + + self.observation_space = ObservationSpace( + self.env.observation_space, self.env.observation_space.low, + self.env.observation_space.high, self.env.observation_space.shape) + if isinstance(self.env.action_space, Discrete): + self.action_space = ActionSpace(n=self.env.action_space.n) + elif isinstance(self.env.action_space, Box): + self.action_space = ActionSpace( + self.env.action_space, self.env.action_space.low, + self.env.action_space.high, self.env.action_space.shape) + + def reset(self): + return self.env.reset() + + def step(self, action): + return self.env.step(action) + + def seed(self, seed): + return self.env.seed(seed) + + def render(self): + return logger.warning('Using remote env, no need to render') + + +class TestRemoteEnv(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_discrete_env_wrapper(self): + logger.info("Running: test discrete_env_wrapper") + master = Master(port=8267) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + woker1 = Worker('localhost:8267', 1) + + parl.connect('localhost:8267') + logger.info("Running: test discrete_env_wrapper: 1") + + env = RemoteGymEnv(env_name='MountainCar-v0') + env.seed(1) + env.render() + + obs, done = env.reset(), False + observation_space = env.observation_space + obs_space_high = observation_space.high + obs_space_low = observation_space.low + self.assertEqual(obs_space_high[0], 0.6) + self.assertEqual(obs_space_low[0], -1.2) + + action_space = env.action_space + act_dim = action_space.n + self.assertEqual(act_dim, 3) + + # Run an episode with a random policy + total_steps, episode_reward = 0, 0 + while not done: + action = np.random.choice(act_dim) + next_obs, reward, done, _ = env.step(action) + episode_reward += reward + logger.info('Episode done, total_steps {}, episode_reward {}'.format( + total_steps, episode_reward)) + + master.exit() + woker1.exit() + + def test_continuous_env_wrapper(self): + logger.info("Running: test continuous_env_wrapper") + master = Master(port=8268) + th = threading.Thread(target=master.run) + th.start() + time.sleep(3) + woker1 = Worker('localhost:8268', 1) + + parl.connect('localhost:8268') + logger.info("Running: test continuous_env_wrapper: 1") + + env = RemoteGymEnv(env_name='Pendulum-v0') + env.seed(0) + env.render() + + obs, done = env.reset(), False + observation_space = env.observation_space + obs_space_high = observation_space.high + obs_space_low = observation_space.low + self.assertEqual(obs_space_high[1], 1.) + self.assertEqual(obs_space_low[1], -1.) + + action_space = env.action_space + action_space_high = action_space.high + action_space_low = action_space.low + self.assertEqual(action_space_high, [2.]) + self.assertEqual(action_space_low, [-2.]) + + # Run an episode with a random policy + total_steps, episode_reward = 0, 0 + while not done: + total_steps += 1 + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + episode_reward += reward + logger.info('Episode done, total_steps {}, episode_reward {}'.format( + total_steps, episode_reward)) + + master.exit() + woker1.exit() + if __name__ == '__main__': - main() + unittest.main() From 0bf0f070877fc69ae52b0c6f0597add37dd7acdd Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Tue, 29 Dec 2020 10:58:55 +0800 Subject: [PATCH 04/16] modify env_utils --- parl/utils/__init__.py | 2 +- ...remote_gym_env_wrapper.py => env_utils.py} | 4 +- .../tests/remote_gym_env_wrapper_test.py | 59 +------------------ 3 files changed, 7 insertions(+), 58 deletions(-) rename parl/utils/{remote_gym_env_wrapper.py => env_utils.py} (96%) diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index f2e2a4750..27aaeff71 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -21,4 +21,4 @@ from parl.utils.rl_utils import * from parl.utils.scheduler import * from parl.utils.path_utils import * -from parl.utils.remote_gym_env_wrapper import * +from parl.utils.env_utils import * diff --git a/parl/utils/remote_gym_env_wrapper.py b/parl/utils/env_utils.py similarity index 96% rename from parl/utils/remote_gym_env_wrapper.py rename to parl/utils/env_utils.py index f97f592b3..0a6e795d3 100644 --- a/parl/utils/remote_gym_env_wrapper.py +++ b/parl/utils/env_utils.py @@ -93,4 +93,6 @@ def seed(self, seed): return self.env.seed(seed) def render(self): - return logger.warning('Using remote env, no need to render') + return logger.warning( + 'Can not render in remote environment, render() have been skipped.' + ) diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py index 1e8c854a9..e00f1060d 100644 --- a/parl/utils/tests/remote_gym_env_wrapper_test.py +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -20,66 +20,13 @@ from parl.remote.master import Master from parl.remote.worker import Worker from parl.remote.client import disconnect -from parl.utils import logger +from parl.utils import logger, RemoteGymEnv import gym from gym.spaces import Box, Discrete -@parl.remote_class -class RemoteGymEnv(object): - def __init__(self, env_name=None): - assert isinstance(env_name, str) - - class ActionSpace(object): - def __init__(self, - action_space=None, - low=None, - high=None, - shape=None, - n=None): - self.action_space = action_space - self.low = low - self.high = high - self.shape = shape - self.n = n - - def sample(self): - return self.action_space.sample() - - class ObservationSpace(object): - def __init__(self, observation_space, low, high, shape=None): - self.observation_space = observation_space - self.low = low - self.high = high - self.shape = shape - - self.env = gym.make(env_name) - self._max_episode_steps = int(self.env._max_episode_steps) - self._elapsed_steps = int(self.env._elapsed_steps) - - self.observation_space = ObservationSpace( - self.env.observation_space, self.env.observation_space.low, - self.env.observation_space.high, self.env.observation_space.shape) - if isinstance(self.env.action_space, Discrete): - self.action_space = ActionSpace(n=self.env.action_space.n) - elif isinstance(self.env.action_space, Box): - self.action_space = ActionSpace( - self.env.action_space, self.env.action_space.low, - self.env.action_space.high, self.env.action_space.shape) - - def reset(self): - return self.env.reset() - - def step(self, action): - return self.env.step(action) - - def seed(self, seed): - return self.env.seed(seed) - - def render(self): - return logger.warning('Using remote env, no need to render') - - +# Test RemoteGymEnv +# for both discrete and continuous action space class TestRemoteEnv(unittest.TestCase): def tearDown(self): disconnect() From c3cdb84b3398e15edbd566613da5382afcca49d6 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Tue, 29 Dec 2020 19:03:26 +0800 Subject: [PATCH 05/16] modify env_utils.py --- parl/utils/env_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py index 0a6e795d3..dc37d916d 100644 --- a/parl/utils/env_utils.py +++ b/parl/utils/env_utils.py @@ -13,14 +13,14 @@ # limitations under the License. import gym -import parl from parl.utils import logger +from parl.remote.remote_decorator import remote_class from gym.spaces import Box, Discrete __all__ = ['RemoteGymEnv'] -@parl.remote_class +@remote_class class RemoteGymEnv(object): """ From 4262f8ccdff6e330f0dcb8dd14521195586bbb11 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Wed, 30 Dec 2020 16:51:34 +0800 Subject: [PATCH 06/16] modify test --- parl/utils/{tests => }/remote_gym_env_wrapper_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) rename parl/utils/{tests => }/remote_gym_env_wrapper_test.py (96%) diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py similarity index 96% rename from parl/utils/tests/remote_gym_env_wrapper_test.py rename to parl/utils/remote_gym_env_wrapper_test.py index e00f1060d..ad9aeaa87 100644 --- a/parl/utils/tests/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -20,7 +20,8 @@ from parl.remote.master import Master from parl.remote.worker import Worker from parl.remote.client import disconnect -from parl.utils import logger, RemoteGymEnv +from parl.utils import logger +from env_utils import RemoteGymEnv import gym from gym.spaces import Box, Discrete @@ -39,7 +40,7 @@ def test_discrete_env_wrapper(self): time.sleep(3) woker1 = Worker('localhost:8267', 1) - parl.connect('localhost:8267') + parl.connect('localhost:8267', distributed_files=['']) logger.info("Running: test discrete_env_wrapper: 1") env = RemoteGymEnv(env_name='MountainCar-v0') From e2f30e9131875db89bc6ecd39295a71016891970 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Wed, 30 Dec 2020 17:52:20 +0800 Subject: [PATCH 07/16] parl.connect, del file --- parl/utils/remote_gym_env_wrapper_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py index ad9aeaa87..0f3f9b128 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -40,7 +40,7 @@ def test_discrete_env_wrapper(self): time.sleep(3) woker1 = Worker('localhost:8267', 1) - parl.connect('localhost:8267', distributed_files=['']) + parl.connect('localhost:8267') logger.info("Running: test discrete_env_wrapper: 1") env = RemoteGymEnv(env_name='MountainCar-v0') From 7f163abf311cfc81fe4058a224401e7ea147be09 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 10:03:54 +0800 Subject: [PATCH 08/16] try: env._elapsed_steps --- parl/utils/env_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py index dc37d916d..b0fcd1f13 100644 --- a/parl/utils/env_utils.py +++ b/parl/utils/env_utils.py @@ -33,13 +33,15 @@ class RemoteGymEnv(object): env = RemoteGymEnv(env_name='HalfCheetah-v1') Attributes: - env_name: Mujoco gym environment name + env_name: gym environment name Public Functions: the same as gym (mainly action_space, observation_space, reset, step, seed ...) Note: - ``RemoteGymEnv`` defines a remote environment wrapper for machines that do not support mujoco, - support both Continuous action space and Discrete action space environments + ``RemoteGymEnv`` defines a remote environment wrapper for running the environment remotely and + enables large-scale parallel collection of environmental data. + + Support both Continuous action space and Discrete action space environments. """ @@ -71,7 +73,8 @@ def __init__(self, observation_space, low, high, shape=None): self.env = gym.make(env_name) self._max_episode_steps = int(self.env._max_episode_steps) - self._elapsed_steps = int(self.env._elapsed_steps) + try: + self._elapsed_steps = int(self.env._elapsed_steps) self.observation_space = ObservationSpace( self.env.observation_space, self.env.observation_space.low, From 5018ec5145a03752edc67bfd1a123c209f4fe099 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 10:38:11 +0800 Subject: [PATCH 09/16] try except _elspaed_steps --- parl/utils/env_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py index b0fcd1f13..22a0759c0 100644 --- a/parl/utils/env_utils.py +++ b/parl/utils/env_utils.py @@ -75,6 +75,8 @@ def __init__(self, observation_space, low, high, shape=None): self._max_episode_steps = int(self.env._max_episode_steps) try: self._elapsed_steps = int(self.env._elapsed_steps) + except: + logger.info('object has no attribute _elspaed_steps') self.observation_space = ObservationSpace( self.env.observation_space, self.env.observation_space.low, From b859687ae30b56cb5633ab14c25c0a262078bed1 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 14:40:58 +0800 Subject: [PATCH 10/16] obs_space.high --- parl/utils/remote_gym_env_wrapper_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py index 0f3f9b128..6915b95c7 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -51,7 +51,7 @@ def test_discrete_env_wrapper(self): observation_space = env.observation_space obs_space_high = observation_space.high obs_space_low = observation_space.low - self.assertEqual(obs_space_high[0], 0.6) + self.assertEqual(obs_space_high[1], 0.07) self.assertEqual(obs_space_low[0], -1.2) action_space = env.action_space From ac93617c0f1eb6a2881bddbdaea309bc70abf79e Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 15:21:08 +0800 Subject: [PATCH 11/16] float equal --- .../{ => tests}/remote_gym_env_wrapper_test.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) rename parl/utils/{ => tests}/remote_gym_env_wrapper_test.py (90%) diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py similarity index 90% rename from parl/utils/remote_gym_env_wrapper_test.py rename to parl/utils/tests/remote_gym_env_wrapper_test.py index 6915b95c7..98f424edf 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -21,11 +21,18 @@ from parl.remote.worker import Worker from parl.remote.client import disconnect from parl.utils import logger -from env_utils import RemoteGymEnv +from parl.utils.env_utils import * import gym from gym.spaces import Box, Discrete +def float_equal(x1, x2): + if np.abs(x1 - x2) < 1e-6: + return True + else: + return False + + # Test RemoteGymEnv # for both discrete and continuous action space class TestRemoteEnv(unittest.TestCase): @@ -51,8 +58,8 @@ def test_discrete_env_wrapper(self): observation_space = env.observation_space obs_space_high = observation_space.high obs_space_low = observation_space.low - self.assertEqual(obs_space_high[1], 0.07) - self.assertEqual(obs_space_low[0], -1.2) + self.assertTrue(float_equal(obs_space_high[1], 0.07)) + self.assertTrue(float_equal(obs_space_low[0], -1.2)) action_space = env.action_space act_dim = action_space.n @@ -89,8 +96,8 @@ def test_continuous_env_wrapper(self): observation_space = env.observation_space obs_space_high = observation_space.high obs_space_low = observation_space.low - self.assertEqual(obs_space_high[1], 1.) - self.assertEqual(obs_space_low[1], -1.) + self.assertTrue(float_equal(obs_space_high[1], 1.)) + self.assertTrue(float_equal(obs_space_low[1], -1.)) action_space = env.action_space action_space_high = action_space.high From c88dc39d3b41d1322da1bc679ffd29d3f56a0117 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 16:20:43 +0800 Subject: [PATCH 12/16] change location of env_test --- parl/utils/{tests => }/remote_gym_env_wrapper_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename parl/utils/{tests => }/remote_gym_env_wrapper_test.py (99%) diff --git a/parl/utils/tests/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py similarity index 99% rename from parl/utils/tests/remote_gym_env_wrapper_test.py rename to parl/utils/remote_gym_env_wrapper_test.py index 98f424edf..edd129473 100644 --- a/parl/utils/tests/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -21,7 +21,7 @@ from parl.remote.worker import Worker from parl.remote.client import disconnect from parl.utils import logger -from parl.utils.env_utils import * +from env_utils import RemoteGymEnv import gym from gym.spaces import Box, Discrete From 21bb3c8bff054db25317db4a64b2ed6ce759a09e Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Thu, 31 Dec 2020 18:32:41 +0800 Subject: [PATCH 13/16] port = get_free_tcp_port() --- parl/utils/remote_gym_env_wrapper_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py index edd129473..b47b603d0 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -20,7 +20,7 @@ from parl.remote.master import Master from parl.remote.worker import Worker from parl.remote.client import disconnect -from parl.utils import logger +from parl.utils import logger, get_free_tcp_port from env_utils import RemoteGymEnv import gym from gym.spaces import Box, Discrete @@ -41,13 +41,14 @@ def tearDown(self): def test_discrete_env_wrapper(self): logger.info("Running: test discrete_env_wrapper") - master = Master(port=8267) + port = get_free_tcp_port() + master = Master(port=port) th = threading.Thread(target=master.run) th.start() time.sleep(3) - woker1 = Worker('localhost:8267', 1) + woker1 = Worker('localhost:{}'.format(port), 1) - parl.connect('localhost:8267') + parl.connect('localhost:{}'.format(port)) logger.info("Running: test discrete_env_wrapper: 1") env = RemoteGymEnv(env_name='MountainCar-v0') @@ -79,13 +80,14 @@ def test_discrete_env_wrapper(self): def test_continuous_env_wrapper(self): logger.info("Running: test continuous_env_wrapper") - master = Master(port=8268) + port = get_free_tcp_port() + master = Master(port=port) th = threading.Thread(target=master.run) th.start() time.sleep(3) - woker1 = Worker('localhost:8268', 1) + woker1 = Worker('localhost:{}'.format(port), 1) - parl.connect('localhost:8268') + parl.connect('localhost:{}'.format(port)) logger.info("Running: test continuous_env_wrapper: 1") env = RemoteGymEnv(env_name='Pendulum-v0') From 176e077c434beb7548cfaf8e79216e1ee8f1f4b1 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Mon, 4 Jan 2021 10:11:11 +0800 Subject: [PATCH 14/16] repush for unit test --- parl/utils/remote_gym_env_wrapper_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py index b47b603d0..7d337d83b 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -34,7 +34,7 @@ def float_equal(x1, x2): # Test RemoteGymEnv -# for both discrete and continuous action space +# for both discrete and continuous action space environment class TestRemoteEnv(unittest.TestCase): def tearDown(self): disconnect() From 830e57156a9beda55f04b997ab20b9926f7a88f1 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Mon, 4 Jan 2021 13:06:52 +0800 Subject: [PATCH 15/16] try: import gym --- parl/utils/env_utils.py | 12 +++++++++--- parl/utils/remote_gym_env_wrapper_test.py | 2 -- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py index 22a0759c0..457a8b5d0 100644 --- a/parl/utils/env_utils.py +++ b/parl/utils/env_utils.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gym from parl.utils import logger from parl.remote.remote_decorator import remote_class -from gym.spaces import Box, Discrete +try: + import gym + gym_installed = True +except ImportError: + gym_installed = False + logger.error('ImportError: No module named gym') +if gym_installed: + from gym.spaces import Box, Discrete __all__ = ['RemoteGymEnv'] @@ -76,7 +82,7 @@ def __init__(self, observation_space, low, high, shape=None): try: self._elapsed_steps = int(self.env._elapsed_steps) except: - logger.info('object has no attribute _elspaed_steps') + logger.error('object has no attribute _elspaed_steps') self.observation_space = ObservationSpace( self.env.observation_space, self.env.observation_space.low, diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/remote_gym_env_wrapper_test.py index 7d337d83b..8b491b0b9 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/remote_gym_env_wrapper_test.py @@ -22,8 +22,6 @@ from parl.remote.client import disconnect from parl.utils import logger, get_free_tcp_port from env_utils import RemoteGymEnv -import gym -from gym.spaces import Box, Discrete def float_equal(x1, x2): From 332395fb2911149164a56c369614b8b936952634 Mon Sep 17 00:00:00 2001 From: ShuaibinLi Date: Fri, 15 Jan 2021 19:10:30 +0800 Subject: [PATCH 16/16] del logger.error & move test file --- parl/utils/env_utils.py | 1 - parl/utils/{ => tests}/remote_gym_env_wrapper_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) rename parl/utils/{ => tests}/remote_gym_env_wrapper_test.py (98%) diff --git a/parl/utils/env_utils.py b/parl/utils/env_utils.py index 457a8b5d0..6bc1ccf82 100644 --- a/parl/utils/env_utils.py +++ b/parl/utils/env_utils.py @@ -19,7 +19,6 @@ gym_installed = True except ImportError: gym_installed = False - logger.error('ImportError: No module named gym') if gym_installed: from gym.spaces import Box, Discrete diff --git a/parl/utils/remote_gym_env_wrapper_test.py b/parl/utils/tests/remote_gym_env_wrapper_test.py similarity index 98% rename from parl/utils/remote_gym_env_wrapper_test.py rename to parl/utils/tests/remote_gym_env_wrapper_test.py index 8b491b0b9..5dbd3ae37 100644 --- a/parl/utils/remote_gym_env_wrapper_test.py +++ b/parl/utils/tests/remote_gym_env_wrapper_test.py @@ -21,7 +21,7 @@ from parl.remote.worker import Worker from parl.remote.client import disconnect from parl.utils import logger, get_free_tcp_port -from env_utils import RemoteGymEnv +from parl.utils.env_utils import RemoteGymEnv def float_equal(x1, x2):