Skip to content

Commit

Permalink
ACtually actually fixed pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Oct 18, 2023
1 parent f531095 commit 4ea4caa
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cares_reinforcement_learning/util/NetworkFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def create_network(self, observation_size, action_num, config: AlgorithmConfig):
algorithm = config.algorithm
if algorithm == "DQN":
return create_DQN(observation_size, action_num, config)
elif algorithm == "DDQN":
elif algorithm == "DoubleDQN":
return create_DDQN(observation_size, action_num, config)
elif algorithm == "DuelingDQN":
return create_DuelingDQN(observation_size, action_num, config)
Expand Down
1 change: 1 addition & 0 deletions cares_reinforcement_learning/util/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class DoubleDQNConfig(AlgorithmConfig):
algorithm: str = Field("DoubleDQN", Literal=True)
lr: Optional[float] = 1e-3
gamma: Optional[float] = 0.99
tau: Optional[float] = 0.005
memory: Optional[str] = "MemoryBuffer"

exploration_min: Optional[float] = 1e-3
Expand Down
35 changes: 15 additions & 20 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,58 +7,53 @@


def test_create_agents():
args = {
}

agent = create_DQN(10, 5, DQNConfig(**args))
agent = create_DQN(10, 5, DQNConfig())
assert isinstance(agent, DQN), "Failed to create DQN agent"

agent = create_DuelingDQN(10, 5, DuelingDQNConfig(**args))
agent = create_DuelingDQN(10, 5, DuelingDQNConfig())
assert isinstance(agent, DQN), "Failed to create DuelingDQN agent"

agent = create_DDQN(10, 5, DoubleDQN(**args))
agent = create_DDQN(10, 5, DoubleDQNConfig())
assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent"

agent = create_PPO(10, 5,PPOConfig(**args))
agent = create_PPO(10, 5,PPOConfig())
assert isinstance(agent, PPO), "Failed to create PPO agent"

agent = create_SAC(10, 5, SACConfig(**args))
agent = create_SAC(10, 5, SACConfig())
assert isinstance(agent, SAC), "Failed to create SAC agent"

agent = create_DDPG(10, 5, DDPGConfig(**args))
agent = create_DDPG(10, 5, DDPGConfig())
assert isinstance(agent, DDPG), "Failed to create DDPG agent"

agent = create_TD3(10, 5, TD3Config(**args))
agent = create_TD3(10, 5, TD3Config())
assert isinstance(agent, TD3), "Failed to create TD3 agent"


def test_create_network():
factory = NetworkFactory()
args = {
}

agent = factory.create_network(10, 5, DQNConfig(**args))
agent = factory.create_network(10, 5, DQNConfig())
assert isinstance(agent, DQN), "Failed to create DQN agent"

agent = factory.create_network(10, 5, DoubleDQN(**args))
agent = factory.create_network(10, 5, DoubleDQNConfig())
assert isinstance(agent, DoubleDQN), "Failed to create DDQN agent"

agent = factory.create_network(10, 5, DuelingDQNConfig(**args))
agent = factory.create_network(10, 5, DuelingDQNConfig())
assert isinstance(agent, DQN), "Failed to create DuelingDQN agent"

agent = factory.create_network(10, 5,PPOConfig(**args))
agent = factory.create_network(10, 5,PPOConfig())
assert isinstance(agent, PPO), "Failed to create PPO agent"

agent = factory.create_network(10, 5, SACConfig(**args))
agent = factory.create_network(10, 5, SACConfig())
assert isinstance(agent, SAC), "Failed to create SAC agent"

agent = factory.create_network(10, 5, DDPGConfig(**args))
agent = factory.create_network(10, 5, DDPGConfig())
assert isinstance(agent, DDPG), "Failed to create DDPG agent"

agent = factory.create_network(10, 5, TD3Config(**args))
agent = factory.create_network(10, 5, TD3Config())
assert isinstance(agent, TD3), "Failed to create TD3 agent"

agent = factory.create_network("Unknown", AlgorithmConfig(**args))
agent = factory.create_network(10, 5, AlgorithmConfig(algorithm="unknown"))
assert agent is None, f"Unkown failed to return None: returned {agent}"


Expand Down

0 comments on commit 4ea4caa

Please sign in to comment.