From 4ea4caaf761a2a41fc38d7757534f68ef138a3c7 Mon Sep 17 00:00:00 2001 From: beardyface Date: Wed, 18 Oct 2023 15:39:00 +1300 Subject: [PATCH] ACtually actually fixed pytests --- .../util/NetworkFactory.py | 2 +- .../util/configurations.py | 1 + tests/test_utils.py | 35 ++++++++----------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/cares_reinforcement_learning/util/NetworkFactory.py b/cares_reinforcement_learning/util/NetworkFactory.py index b2bd56fe..7b7f6f5c 100644 --- a/cares_reinforcement_learning/util/NetworkFactory.py +++ b/cares_reinforcement_learning/util/NetworkFactory.py @@ -143,7 +143,7 @@ def create_network(self, observation_size, action_num, config: AlgorithmConfig): algorithm = config.algorithm if algorithm == "DQN": return create_DQN(observation_size, action_num, config) - elif algorithm == "DDQN": + elif algorithm == "DoubleDQN": return create_DDQN(observation_size, action_num, config) elif algorithm == "DuelingDQN": return create_DuelingDQN(observation_size, action_num, config) diff --git a/cares_reinforcement_learning/util/configurations.py b/cares_reinforcement_learning/util/configurations.py index ada70520..82a92055 100644 --- a/cares_reinforcement_learning/util/configurations.py +++ b/cares_reinforcement_learning/util/configurations.py @@ -54,6 +54,7 @@ class DoubleDQNConfig(AlgorithmConfig): algorithm: str = Field("DoubleDQN", Literal=True) lr: Optional[float] = 1e-3 gamma: Optional[float] = 0.99 + tau: Optional[float] = 0.005 memory: Optional[str] = "MemoryBuffer" exploration_min: Optional[float] = 1e-3 diff --git a/tests/test_utils.py b/tests/test_utils.py index 2e2afabb..af6a9cbb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,58 +7,53 @@ def test_create_agents(): - args = { - } - - agent = create_DQN(10, 5, DQNConfig(**args)) + agent = create_DQN(10, 5, DQNConfig()) assert isinstance(agent, DQN), "Failed to create DQN agent" - agent = create_DuelingDQN(10, 5, DuelingDQNConfig(**args)) + agent = create_DuelingDQN(10, 5, DuelingDQNConfig()) assert isinstance(agent, DQN), "Failed to create DuelingDQN agent" - agent = create_DDQN(10, 5, DoubleDQN(**args)) + agent = create_DDQN(10, 5, DoubleDQNConfig()) assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent" - agent = create_PPO(10, 5,PPOConfig(**args)) + agent = create_PPO(10, 5,PPOConfig()) assert isinstance(agent, PPO), "Failed to create PPO agent" - agent = create_SAC(10, 5, SACConfig(**args)) + agent = create_SAC(10, 5, SACConfig()) assert isinstance(agent, SAC), "Failed to create SAC agent" - agent = create_DDPG(10, 5, DDPGConfig(**args)) + agent = create_DDPG(10, 5, DDPGConfig()) assert isinstance(agent, DDPG), "Failed to create DDPG agent" - agent = create_TD3(10, 5, TD3Config(**args)) + agent = create_TD3(10, 5, TD3Config()) assert isinstance(agent, TD3), "Failed to create TD3 agent" def test_create_network(): factory = NetworkFactory() - args = { - } - agent = factory.create_network(10, 5, DQNConfig(**args)) + agent = factory.create_network(10, 5, DQNConfig()) assert isinstance(agent, DQN), "Failed to create DQN agent" - agent = factory.create_network(10, 5, DoubleDQN(**args)) + agent = factory.create_network(10, 5, DoubleDQNConfig()) assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent" - agent = factory.create_network(10, 5, DuelingDQNConfig(**args)) + agent = factory.create_network(10, 5, DuelingDQNConfig()) assert isinstance(agent, DQN), "Failed to create DuelingDQN agent" - agent = factory.create_network(10, 5,PPOConfig(**args)) + agent = factory.create_network(10, 5,PPOConfig()) assert isinstance(agent, PPO), "Failed to create PPO agent" - agent = factory.create_network(10, 5, SACConfig(**args)) + agent = factory.create_network(10, 5, SACConfig()) assert isinstance(agent, SAC), "Failed to create SAC agent" - agent = factory.create_network(10, 5, DDPGConfig(**args)) + agent = factory.create_network(10, 5, DDPGConfig()) assert isinstance(agent, DDPG), "Failed to create DDPG agent" - agent = factory.create_network(10, 5, TD3Config(**args)) + agent = factory.create_network(10, 5, TD3Config()) assert isinstance(agent, TD3), "Failed to create TD3 agent" - agent = factory.create_network("Unknown", AlgorithmConfig(**args)) + agent = factory.create_network(10, 5, AlgorithmConfig(algorithm="unknown")) assert agent is None, f"Unkown failed to return None: returned {agent}"