From 94ab9d219bc6dbcf0ebd344c23b0a1701766c5cd Mon Sep 17 00:00:00 2001 From: Mitchell Kiely Date: Fri, 18 Nov 2022 12:53:10 +1030 Subject: [PATCH] Fixing action and observation space in PettingZooParallelWrapper --- .../Wrappers/PettingZooParallelWrapper.py | 88 ++++++---- CybORG/Agents/Wrappers/__init__.py | 1 + CybORG/Shared/EnvironmentController.py | 35 ++-- CybORG/Simulator/Actions/Action.py | 3 +- .../test_DroneActions/test_drone_actions.py | 50 +++--- .../test_PettingZooParallelWrapper.py | 150 ++++++++++++++++-- CybORG/env.py | 14 +- 7 files changed, 269 insertions(+), 72 deletions(-) diff --git a/CybORG/Agents/Wrappers/PettingZooParallelWrapper.py b/CybORG/Agents/Wrappers/PettingZooParallelWrapper.py index 0483ab28..d7f39ce0 100644 --- a/CybORG/Agents/Wrappers/PettingZooParallelWrapper.py +++ b/CybORG/Agents/Wrappers/PettingZooParallelWrapper.py @@ -1,19 +1,16 @@ from typing import Optional -from CybORG import CybORG -from CybORG.Agents.Wrappers import BaseWrapper -import warnings -from gym import spaces import numpy as np +from gym import spaces -from pettingzoo import ParallelEnv -from pettingzoo.utils import wrappers -from CybORG.Agents.Wrappers import BaseWrapper, OpenAIGymWrapper, BlueTableWrapper, RedTableWrapper, EnumActionWrapper -from CybORG.Shared.CommsRewardCalculator import CommsAvailabilityRewardCalculator +from CybORG import CybORG +from CybORG.Agents.Wrappers import BaseWrapper +from CybORG.Simulator.Actions import Sleep class PettingZooParallelWrapper(BaseWrapper): - def __init__(self, env: CybORG): + + def __init__(self, env: CybORG,): super().__init__(env) self._agent_ids = self.possible_agents # assuming that the final value in the agent name indicates which drone that agent is on @@ -28,8 +25,7 @@ def __init__(self, env: CybORG): [3] + [2 for i in range(num_drones)] + [2] + [3 for i in range(num_drones)] + [101, 101] + ( num_drones - 1) * [num_drones, 101, 101, 2]) for agent_name in self.possible_agents} self.metadata = {"render_modes": ["human", "rgb_array"], "name": "Cage_Challenge_3"} - self.seed = 117 - + self.agent_actions = self.int_to_cyborg_action() self.dones = {agent: False for agent in self.possible_agents} self.rewards = {agent: 0. for agent in self.possible_agents} self.infos = {} @@ -39,20 +35,25 @@ def reset(self, return_info: bool = False, options: Optional[dict] = None) -> dict: res = self.env.reset() + self.agent_actions = self.int_to_cyborg_action() self.dones = {agent: False for agent in self.possible_agents} self.rewards = {agent: 0. for agent in self.possible_agents} self.infos = {} # assuming that the final value in the agent name indicates which drone that agent is on + self.int_to_action = self.int_to_cyborg_action() self.agent_host_map = {agent_name: f'drone_{agent_name.split("_")[-1]}' for agent_name in self.possible_agents} self.ip_addresses = list(self.env.get_ip_map().values()) return {agent: self.observation_change(agent, obs=self.env.get_observation(agent)) for agent in self.agents} def step(self, actions: dict) -> (dict, dict, dict, dict): actions, msgs = self.select_messages(actions) + actions_dict = {} + for agent, act in actions.items(): assert self.action_space(agent).contains(act) + actions_dict[agent] = self.agent_actions[agent][act] - raw_obs, rews, dones, infos = self.env.parallel_step(actions, messages=msgs) + raw_obs, rews, dones, infos = self.env.parallel_step(actions_dict, messages=msgs) # green_agents = {agent: if } # rews = GreenAvailabilityRewardCalculator(raw_obs, ['green_agent_0','green_agent_1', 'green_agent_2' ]).calculate_reward() obs = {agent: self.observation_change(agent, raw_obs[agent]) for agent in self.env.active_agents} @@ -159,6 +160,39 @@ def get_done(self, agent): ''' return self.dones[agent] + def int_to_cyborg_action(self): + ''' + Returns a dictionary containing dictionaries that maps the number selected by the agent to a specific CybORG action only for blue agent + + ''' + cyborg_agent_actions = {} + for agent in self.possible_agents: + if 'blue' not in agent: + continue + cyborg_action_to_int = {} + act_count = 0 + for action in self.env.get_action_space(agent)['action'].keys(): + params_dict = {} + if action.__name__ == 'Sleep': + cyborg_action_to_int[act_count] = Sleep() + act_count+=1 + elif action.__name__ == 'RemoveOtherSessions': + params_dict['agent'] = agent + params_dict['session'] = 0 + cyborg_action_to_int[act_count] = action(**params_dict) + act_count+=1 + else: + for ip in self.env.get_action_space(agent)['ip_address'].keys(): + for sess in self.env.get_action_space(agent)['session'].keys(): + if sess == 0: + params_dict['session'] = 0 + params_dict['ip_address'] = ip + params_dict['agent'] = agent + cyborg_action_to_int[act_count] = action(**params_dict) + act_count+=1 + cyborg_agent_actions[agent] = cyborg_action_to_int + return cyborg_agent_actions + def get_action_space(self, agent): ''' Obtains the action_space of the specified agent @@ -166,13 +200,10 @@ def get_action_space(self, agent): Parameters: agent -> str ''' - this_agent = agent initial = self.env.get_action_space(agent) - + this_agent = agent unmasked_as = [] agent_actions = [] - action_list = ['ExploitDroneVulnerability', 'SeizeControl', 'FloodBandwidth', 'BlockTraffic', 'AllowTraffic', - 'RetakeControl', 'SendData'] for key in initial.copy(): if key != 'action': @@ -182,14 +213,12 @@ def get_action_space(self, agent): for i in range(len(init_list)): agent_actions.append(init_list[i][0].__name__) - if ('Sleep' in agent_actions): - unmasked_as.append('Sleep') - - if ('RemoveOtherSessions' in agent_actions): - unmasked_as.append(f'Remove {this_agent}') - - for act in action_list: - if (act in agent_actions): + for act in agent_actions: + if act == 'Sleep': + unmasked_as.append('Sleep') + elif act == 'RemoveOtherSessions': + unmasked_as.append(f'RemoveOtherSessions {this_agent}') + else: for agent in self.possible_agents: unmasked_as.append(f"{act} {agent}") @@ -219,15 +248,15 @@ def observation_change(self, agent: str, obs: dict): # element location --> [0, 1,...,num_drones, 1+num_drones, 2+num_drones, ..., 2+2*num_drones, 3+2*num_drones, 4+2*num_drones,...,4+4*num_drones] index = 0 # success - new_obs[index] = obs['success'].value + new_obs[index] = obs['success'].value - 1 index += 1 if agent in self.env.active_agents: # Add blocked IPs for i, ip in enumerate(self.ip_addresses): - new_obs[index + i] = 1 if ip in [interface['blocked_ips'] for interface in + new_obs[index + i] = 1 if ip in [blocked_ip for interface in obs[own_host_name]['Interface'] if - 'blocked_ips' in interface] else 0 + 'blocked_ips' in interface for blocked_ip in interface['blocked_ips']] else 0 index += len(self.ip_addresses) # add flagged malicious processes @@ -236,9 +265,9 @@ def observation_change(self, agent: str, obs: dict): # add flagged messages for i, ip in enumerate(self.ip_addresses): # TODO add in check for network connections - new_obs[i] = 1 if ip in [interface['Network Connections'] for interface in + new_obs[index + i] = 1 if ip in [network_conn for interface in obs[own_host_name]['Interface'] if - 'Network Connections' in interface] else 0 + 'Network Connections' in interface for network_conn in interface['Network Connections']] else 0 index += len(self.ip_addresses) pos = obs[own_host_name]['System info'].get('position', (0, 0)) @@ -291,3 +320,4 @@ def agents(self): @property def possible_agents(self): return self.env.agents + diff --git a/CybORG/Agents/Wrappers/__init__.py b/CybORG/Agents/Wrappers/__init__.py index d3758abe..cc4155da 100644 --- a/CybORG/Agents/Wrappers/__init__.py +++ b/CybORG/Agents/Wrappers/__init__.py @@ -8,3 +8,4 @@ from .ChallengeWrapper import ChallengeWrapper from .IntFixedFlatWrapper import IntFixedFlatWrapper from .SimpleRedWrapper import SimpleRedWrapper +from .PettingZooParallelWrapper import PettingZooParallelWrapper diff --git a/CybORG/Shared/EnvironmentController.py b/CybORG/Shared/EnvironmentController.py index 5ff8515d..a9249c1f 100644 --- a/CybORG/Shared/EnvironmentController.py +++ b/CybORG/Shared/EnvironmentController.py @@ -125,6 +125,7 @@ def reset(self, agent: str = None, np_random=None) -> Results: action_space=self.agent_interfaces[agent].action_space.get_action_space()) def step(self, actions: dict = None, skip_valid_action_check=False): + """Updates the environment based on the joint actions of all agents Save the @@ -147,9 +148,9 @@ def step(self, actions: dict = None, skip_valid_action_check=False): agent_object.messages = [] if agent_name not in actions: actions[agent_name] = agent_object.get_action(self.get_last_observation(agent_name)) - if not self.test_valid_action(actions[agent_name], agent_object) and not skip_valid_action_check: - actions[agent_name] = InvalidAction(action=actions[agent_name]) - self._log_debug(f"{__class__.__name__} have Invalid Action agent_action={actions[agent_name]} for agent={agent_name}") + if not skip_valid_action_check: + actions[agent_name] = self.replace_action_if_invalid(actions[agent_name], agent_object) + self.action = actions actions = self.sort_action_order(actions) @@ -423,20 +424,32 @@ def _filter_obs(self, obs: Observation, agent_name=None): ) return obs - def test_valid_action(self, action: Action, agent: AgentInterface): - # returns true if the parameters in the action are in and true in the action set else return false + def replace_action_if_invalid(self, action: Action, agent: AgentInterface): + # returns action if the parameters in the action are in and true in the action set else return InvalidAction imbued with bug report. action_space = agent.action_space.get_action_space() - # first check that the action class is allowed - if type(action) not in action_space['action'] or not action_space['action'][type(action)]: - return False + + if type(action) not in action_space['action']: + message = f'Action {action} not in action space for agent {agent.agent_name}.' + return InvalidAction(action=action, error=message) + + if not action_space['action'][type(action)]: + message = f'Action {action} is not valid for agent {agent.agent_name} at the moment. This usually means it is trying to access a host it has not discovered yet.' + return InvalidAction(action=action, error=message) + # next for each parameter in the action for parameter_name, parameter_value in action.get_params().items(): if parameter_name not in action_space: continue - if (parameter_value not in action_space[parameter_name]) or (not action_space[parameter_name][parameter_value]): - return False - return True + if parameter_value not in action_space[parameter_name]: + message = f'Action {action} has parameter {parameter_name} valued at {parameter_value}. However, {parameter_value} is not in the action space for agent {agent.agent_name}.' + return InvalidAction(action=action, error=message) + + if not action_space[parameter_name][parameter_value]: + message = f'Action {action} has parameter {parameter_name} valued at the invalid value of {parameter_value}. This usually means an agent is trying to utilise information it has not discovered yet such as an ip_address or port number.' + return InvalidAction(action=action, error=message) + + return action def get_reward_breakdown(self, agent:str): return self.agent_interfaces[agent].reward_calculator.host_scores diff --git a/CybORG/Simulator/Actions/Action.py b/CybORG/Simulator/Actions/Action.py index 93b09702..f3403728 100644 --- a/CybORG/Simulator/Actions/Action.py +++ b/CybORG/Simulator/Actions/Action.py @@ -43,9 +43,10 @@ def execute(self, state): class InvalidAction(Action): - def __init__(self, action: Action = None): + def __init__(self, action: Action = None, error: str =None): super().__init__() self.action = action + self.error = error def execute(self, state): return Observation(success=False) diff --git a/CybORG/Tests/test_sim/test_Actions/test_DroneActions/test_drone_actions.py b/CybORG/Tests/test_sim/test_Actions/test_DroneActions/test_drone_actions.py index 2973a065..325db53b 100644 --- a/CybORG/Tests/test_sim/test_Actions/test_DroneActions/test_drone_actions.py +++ b/CybORG/Tests/test_sim/test_Actions/test_DroneActions/test_drone_actions.py @@ -65,6 +65,7 @@ def test_red_exploit_on_attacked_red(attacked_red): """Testing SeizeControl""" + def test_red_seize_on_unattacked_blue(unattacked_blue): cyborg, red_agent, blue_agent, target_agent, target_host, target_ip = unattacked_blue assert 'Sessions' not in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) @@ -72,8 +73,9 @@ def test_red_seize_on_unattacked_blue(unattacked_blue): assert 'red_agent_' + target_agent.split('_')[-1] not in cyborg.active_agents action = SeizeControl(agent=red_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=red_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[red_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + red_agent]) == action assert results.observation['success'] == False, cyborg.environment_controller.observation[red_agent].raw assert target_agent in cyborg.active_agents assert 'red_agent_' + target_agent.split('_')[-1] not in cyborg.active_agents @@ -87,8 +89,9 @@ def test_red_seize_on_attacked_blue(attacked_blue): assert 'red_agent_' + target_agent.split('_')[-1] not in cyborg.active_agents action = SeizeControl(agent=red_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=red_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[red_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + red_agent]) == action assert results.observation['success'] == True, cyborg.environment_controller.observation[red_agent].raw assert target_agent not in cyborg.active_agents assert 'red_agent_' + target_agent.split('_')[-1] in cyborg.active_agents @@ -104,8 +107,9 @@ def test_red_seize_on_unattacked_red(unattacked_red): assert 'red_agent_' + target_agent.split('_')[-1] in cyborg.active_agents action = SeizeControl(agent=red_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=red_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[red_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + red_agent]) == action assert results.observation['success'] == False, cyborg.environment_controller.observation[red_agent].raw assert target_agent not in cyborg.active_agents assert 'red_agent_' + target_agent.split('_')[-1] in cyborg.active_agents @@ -113,6 +117,7 @@ def test_red_seize_on_unattacked_red(unattacked_red): assert len(cyborg.get_action_space('red_agent_' + target_agent.split('_')[-1])[ 'session']) > 0, f"{cyborg.get_action_space('red_agent_' + target_agent.split('_')[-1])['session']}" + def test_red_seize_on_attacked_red(attacked_red): cyborg, red_agent, blue_agent, target_agent, target_host, target_ip, session_id = attacked_red assert 'Sessions' in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) @@ -120,8 +125,9 @@ def test_red_seize_on_attacked_red(attacked_red): assert 'red_agent_' + target_agent.split('_')[-1] in cyborg.active_agents action = SeizeControl(agent=red_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=red_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[red_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + red_agent]) == action assert results.observation['success'] == True, cyborg.environment_controller.observation[red_agent].raw assert target_agent not in cyborg.active_agents assert 'red_agent_' + target_agent.split('_')[-1] in cyborg.active_agents @@ -129,6 +135,7 @@ def test_red_seize_on_attacked_red(attacked_red): assert len(cyborg.get_action_space('red_agent_' + target_agent.split('_')[-1])[ 'session']) > 0, f"{cyborg.get_action_space('red_agent_' + target_agent.split('_')[-1])['session']}" + """Testing Remove""" @@ -157,8 +164,9 @@ def test_blue_retake_on_unattacked_blue(unattacked_blue): assert 'Sessions' not in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) action = RetakeControl(agent=blue_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=blue_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[blue_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + blue_agent]) == action assert results.observation['success'] == False assert target_agent in cyborg.active_agents assert 'Sessions' not in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) @@ -171,8 +179,9 @@ def test_blue_retake_on_attacked_blue(attacked_blue): assert target_agent in cyborg.active_agents action = RetakeControl(agent=blue_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=blue_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[blue_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + blue_agent]) == action assert results.observation['success'] == False assert target_agent in cyborg.active_agents assert 'Sessions' in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) @@ -184,8 +193,9 @@ def test_blue_retake_on_unattacked_red(unattacked_red): assert target_agent not in cyborg.active_agents action = RetakeControl(agent=blue_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=blue_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[blue_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + blue_agent]) == action assert results.observation['success'] == True assert target_agent in cyborg.active_agents assert 'Sessions' not in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) @@ -198,13 +208,15 @@ def test_blue_retake_on_attacked_red(attacked_red): assert target_agent not in cyborg.active_agents action = RetakeControl(agent=blue_agent, session=0, ip_address=target_ip) results = cyborg.step(agent=blue_agent, action=action) - assert cyborg.environment_controller.test_valid_action(action, - cyborg.environment_controller.agent_interfaces[blue_agent]) + assert cyborg.environment_controller.replace_action_if_invalid(action, + cyborg.environment_controller.agent_interfaces[ + blue_agent]) == action assert results.observation['success'] == True assert target_agent in cyborg.active_agents assert 'Sessions' not in cyborg.get_observation(red_agent).get(target_host, {}), cyborg.get_observation(red_agent) assert len(cyborg.get_action_space(target_agent)['session']) > 0 + def test_remove_always(): sg = DroneSwarmScenarioGenerator(max_length_data_links=28, num_drones=15, red_spawn_rate=0, starting_num_red=1) cyborg = CybORG(sg, 'sim') @@ -222,6 +234,7 @@ def test_remove_always(): # breakpoint() cyborg.parallel_step(actions) + def test_restore_always(): sg = DroneSwarmScenarioGenerator(max_length_data_links=28, num_drones=15, red_spawn_rate=0, starting_num_red=1) cyborg = CybORG(sg, 'sim') @@ -233,7 +246,8 @@ def test_restore_always(): for agent in cyborg.active_agents: if 'blue' in agent: - actions[agent] = RetakeControl(agent=agent, session=0, ip_address=cyborg.get_ip_map()[cyborg.np_random.choice(agent_list)]) + actions[agent] = RetakeControl(agent=agent, session=0, + ip_address=cyborg.get_ip_map()[cyborg.np_random.choice(agent_list)]) # breakpoint() - cyborg.parallel_step(actions) \ No newline at end of file + cyborg.parallel_step(actions) diff --git a/CybORG/Tests/test_sim/test_wrappers/test_PettingZooParallelWrapper.py b/CybORG/Tests/test_sim/test_wrappers/test_PettingZooParallelWrapper.py index cefb77f7..1d291670 100644 --- a/CybORG/Tests/test_sim/test_wrappers/test_PettingZooParallelWrapper.py +++ b/CybORG/Tests/test_sim/test_wrappers/test_PettingZooParallelWrapper.py @@ -23,6 +23,117 @@ def create_wrapped_cyborg(request): def test_petting_zoo_parallel_wrapper(create_wrapped_cyborg): parallel_api_test(create_wrapped_cyborg, num_cycles=1000) +#Test if actions inputted are valid +def test_valid_actions(): + sg = DroneSwarmScenarioGenerator(num_drones=2, max_length_data_links=10000, starting_num_red=0) + cyborg_raw = CybORG(scenario_generator=sg, seed=123) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + + for i in range(50): + actions = {} + for agent in cyborg.active_agents: + actions[agent] = random.randint(0, len(cyborg.get_action_space(agent))-1) + + obs, rews, dones, infos = cyborg.step(actions) + for agent in cyborg.active_agents: + assert cyborg.get_last_actions(agent) != 'InvalidAction' + +#test reward bug +def test_equal_reward(): + sg = DroneSwarmScenarioGenerator(num_drones=17, max_length_data_links=1000, starting_num_red=0) + cyborg_raw = CybORG(scenario_generator=sg, seed=123) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + + rews_tt = {} + for i in range(10): + actions = {} + for agent in cyborg.agents: + actions[agent] = random.randint(0,len(cyborg.get_action_space(agent))-1) + + obs, rews, dones, infos = cyborg.step(actions) + rews_tt[i] = rews + + for i in rews_tt.keys(): + assert len(set(rews_tt[1].values())) == 1 + +def test_blue_retake_on_red(): + sg = DroneSwarmScenarioGenerator(num_drones=2, max_length_data_links=100000, starting_num_red=1, red_spawn_rate=0, + starting_positions=[np.array([0, 0]), np.array([1,1])]) + cyborg_raw = CybORG(scenario_generator=sg, seed=110) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + actions = {} + + if cyborg.active_agents[0] == 'blue_agent_0': + actions[cyborg.active_agents[0]]=1 + else: + actions[cyborg.active_agents[0]]=0 + + assert len(cyborg.active_agents) == 1 + + obs, rews, dones, infos = cyborg.step(actions) + + assert obs[cyborg.active_agents[0]][0] == 0 or 1 + assert len(cyborg.active_agents) == 2 + +def test_blue_remove_on_red(): + sg = DroneSwarmScenarioGenerator(num_drones=2, max_length_data_links=100000, starting_num_red=1, red_spawn_rate=0, + starting_positions=[np.array([0, 0]), np.array([1,1])]) + cyborg_raw = CybORG(scenario_generator=sg, seed=110) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + actions = {} + actions[cyborg.active_agents[0]]=2 + assert len(cyborg.active_agents) == 1 + + obs, rews, dones, infos = cyborg.step(actions) + + assert obs[cyborg.active_agents[0]][0] == 2 + assert len(cyborg.active_agents) == 1 + + + +def test_blue_retake_on_blue(): + sg = DroneSwarmScenarioGenerator(num_drones=2, max_length_data_links=100000, starting_num_red=0, red_spawn_rate=0, + starting_positions=[np.array([0, 0]), np.array([1,1])]) + cyborg_raw = CybORG(scenario_generator=sg, seed=110) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + actions = {} + actions['blue_agent_0']=1 + actions['blue_agent_1']=0 + + assert len(cyborg.active_agents) == 2 + + obs, rews, dones, infos = cyborg.step(actions) + + assert obs['blue_agent_0'][0] == 2 + assert obs['blue_agent_1'][0] == 2 + + assert len(cyborg.active_agents) == 2 + + +#test blocked IP bug +def test_block_and_check_IP(): + sg = DroneSwarmScenarioGenerator(num_drones=2, max_length_data_links=100000, starting_num_red=0, red_spawn_rate=0, + starting_positions=[np.array([0, 0]), np.array([1,1])]) + cyborg_raw = CybORG(scenario_generator=sg, seed=110) + cyborg = PettingZooParallelWrapper(env=cyborg_raw) + cyborg.reset() + + actions = {} + for i in range(2): + count = 0 + for agent in cyborg.active_agents: + actions[agent] = 4 -count + count += 1 + obs, rews, dones, infos = cyborg.step(actions) + + assert obs['blue_agent_0'][2] == 1 + assert obs['blue_agent_1'][1] == 1 + def test_attributes(create_wrapped_cyborg): # Create cyborg and reset it @@ -58,6 +169,7 @@ def test_agent_data_change(create_wrapped_cyborg): for agent in create_wrapped_cyborg.agents: assert isinstance(obs[agent], np.ndarray) + assert isinstance(create_wrapped_cyborg.action_space(agent), spaces.Discrete) assert isinstance(rews[agent], float) assert isinstance(dones[agent], bool) assert isinstance(infos, dict) @@ -103,12 +215,12 @@ def test_observation_change(create_wrapped_cyborg): actions = {} for agent in create_wrapped_cyborg.agents: actions[agent] = create_wrapped_cyborg.action_spaces[agent].sample() + obs, rews, dones, infos = create_wrapped_cyborg.step(actions) - for agent in create_wrapped_cyborg.agents: - assert isinstance(obs[agent], np.ndarray) - assert isinstance(rews, dict) - assert isinstance(dones, dict) - assert isinstance(infos, dict) + assert isinstance(obs[agent], np.ndarray) + assert isinstance(rews, dict) + assert isinstance(dones, dict) + assert isinstance(infos, dict) final_obs = create_wrapped_cyborg.observation_spaces assert (initial_obs == final_obs) @@ -177,8 +289,26 @@ def test_active_agent_in_observation(): if any(dones.values()): break -def test_observation(): - sg = DroneSwarmScenarioGenerator(num_drones=20, max_length_data_links=10, starting_num_red=0) - cyborg_raw = CybORG(scenario_generator=sg, seed=123) - - cyborg = PettingZooParallelWrapper(env=cyborg_raw) +@pytest.mark.parametrize('num_drones', [2,10,18,25]) +@pytest.mark.parametrize('wrapper', [PettingZooParallelWrapper, AgentCommsPettingZooParallelWrapper, ActionsCommsPettingZooParallelWrapper, ObsCommsPettingZooParallelWrapper]) +def test_observation(num_drones, wrapper): + sg = DroneSwarmScenarioGenerator(num_drones=num_drones) + cyborg = wrapper(CybORG(scenario_generator=sg, seed=123)) + cyborg.reset() + for i in range(10): + for j in range(600): + obs, rew, dones, infos = cyborg.step({agent: cyborg.action_space(agent).sample() for agent in cyborg.agents}) + for agent in cyborg.agents: + if type(cyborg) == PettingZooParallelWrapper: + assert len(obs[agent]) == (num_drones*6) + elif type(cyborg) == ObsCommsPettingZooParallelWrapper: + assert len(obs[agent]) == (num_drones*22) + else: + assert len(obs[agent]) == (num_drones*7) + if any(dones.values()) or len(cyborg.agents) == 0: + assert all(dones) + break + if j > 499: + breakpoint() + assert j <= 500 + cyborg.reset() diff --git a/CybORG/env.py b/CybORG/env.py index 376f1fbe..e79c8044 100644 --- a/CybORG/env.py +++ b/CybORG/env.py @@ -9,7 +9,7 @@ from CybORG.Shared import Observation, Results, CybORGLogger from CybORG.Simulator.Actions import DiscoverNetworkServices, DiscoverRemoteSystems, ExploitRemoteService, \ InvalidAction, \ - Sleep, PrivilegeEscalate, Impact, Remove, Restore, SeizeControl, RetakeControl, RemoveOtherSessions + Sleep, PrivilegeEscalate, Impact, Remove, Restore, SeizeControl, RetakeControl, RemoveOtherSessions, FloodBandwidth from CybORG.Simulator.Actions.ConcreteActions.ActivateTrojan import ActivateTrojan from CybORG.Simulator.Actions.ConcreteActions.ControlTraffic import BlockTraffic, AllowTraffic from CybORG.Simulator.Actions.ConcreteActions.ExploitActions.ExploitAction import ExploitAction @@ -452,6 +452,8 @@ def render(self, mode='human'): red_hosts = [] red_low_hosts = [] for agent in self.environment_controller.team['Red']: + if not self.environment_controller.is_active(agent): + continue red_hosts += [i.hostname for i in self.environment_controller.state.sessions[agent].values() if i.username == 'SYSTEM' or i.username == 'root'] red_low_hosts += [i.hostname for i in self.environment_controller.state.sessions[agent].values()] @@ -484,14 +486,20 @@ def render(self, mode='human'): red_action_type = 'port scan' elif isinstance(red_action, Impact): red_action_type = 'impact' + elif isinstance(red_action, (AllowTraffic, BlockTraffic, FloodBandwidth)): + red_action_type = None else: red_action_type = type(red_action) - data['actions'].append( - {"agent": red_from, "destination": red_target, "source": red_source, "type": red_action_type}) + + if red_action_type is not None: + data['actions'].append( + {"agent": red_from, "destination": red_target, "source": red_source, "type": red_action_type}) blue_hosts = [] blue_protected_hosts = [] for agent in self.environment_controller.team['Blue']: + if not self.environment_controller.is_active(agent): + continue blue_hosts += [i.hostname for i in self.environment_controller.state.sessions[agent].values()] blue_protected_hosts += [blue_session.hostname for blue_session in self.environment_controller.state.sessions[agent].values() if