Skip to content

Commit

Permalink
Fixing action and observation space in PettingZooParallelWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
MitchellKiely committed Nov 18, 2022
1 parent 9943f85 commit 94ab9d2
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 72 deletions.
88 changes: 59 additions & 29 deletions CybORG/Agents/Wrappers/PettingZooParallelWrapper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from typing import Optional

from CybORG import CybORG
from CybORG.Agents.Wrappers import BaseWrapper
import warnings
from gym import spaces
import numpy as np
from gym import spaces

from pettingzoo import ParallelEnv
from pettingzoo.utils import wrappers
from CybORG.Agents.Wrappers import BaseWrapper, OpenAIGymWrapper, BlueTableWrapper, RedTableWrapper, EnumActionWrapper
from CybORG.Shared.CommsRewardCalculator import CommsAvailabilityRewardCalculator
from CybORG import CybORG
from CybORG.Agents.Wrappers import BaseWrapper
from CybORG.Simulator.Actions import Sleep


class PettingZooParallelWrapper(BaseWrapper):
def __init__(self, env: CybORG):

def __init__(self, env: CybORG,):
super().__init__(env)
self._agent_ids = self.possible_agents
# assuming that the final value in the agent name indicates which drone that agent is on
Expand All @@ -28,8 +25,7 @@ def __init__(self, env: CybORG):
[3] + [2 for i in range(num_drones)] + [2] + [3 for i in range(num_drones)] + [101, 101] + (
num_drones - 1) * [num_drones, 101, 101, 2]) for agent_name in self.possible_agents}
self.metadata = {"render_modes": ["human", "rgb_array"], "name": "Cage_Challenge_3"}
self.seed = 117

self.agent_actions = self.int_to_cyborg_action()
self.dones = {agent: False for agent in self.possible_agents}
self.rewards = {agent: 0. for agent in self.possible_agents}
self.infos = {}
Expand All @@ -39,20 +35,25 @@ def reset(self,
return_info: bool = False,
options: Optional[dict] = None) -> dict:
res = self.env.reset()
self.agent_actions = self.int_to_cyborg_action()
self.dones = {agent: False for agent in self.possible_agents}
self.rewards = {agent: 0. for agent in self.possible_agents}
self.infos = {}
# assuming that the final value in the agent name indicates which drone that agent is on
self.int_to_action = self.int_to_cyborg_action()
self.agent_host_map = {agent_name: f'drone_{agent_name.split("_")[-1]}' for agent_name in self.possible_agents}
self.ip_addresses = list(self.env.get_ip_map().values())
return {agent: self.observation_change(agent, obs=self.env.get_observation(agent)) for agent in self.agents}

def step(self, actions: dict) -> (dict, dict, dict, dict):
actions, msgs = self.select_messages(actions)
actions_dict = {}

for agent, act in actions.items():
assert self.action_space(agent).contains(act)
actions_dict[agent] = self.agent_actions[agent][act]

raw_obs, rews, dones, infos = self.env.parallel_step(actions, messages=msgs)
raw_obs, rews, dones, infos = self.env.parallel_step(actions_dict, messages=msgs)
# green_agents = {agent: if }
# rews = GreenAvailabilityRewardCalculator(raw_obs, ['green_agent_0','green_agent_1', 'green_agent_2' ]).calculate_reward()
obs = {agent: self.observation_change(agent, raw_obs[agent]) for agent in self.env.active_agents}
Expand Down Expand Up @@ -159,20 +160,50 @@ def get_done(self, agent):
'''
return self.dones[agent]

def int_to_cyborg_action(self):
'''
Returns a dictionary containing dictionaries that maps the number selected by the agent to a specific CybORG action only for blue agent
'''
cyborg_agent_actions = {}
for agent in self.possible_agents:
if 'blue' not in agent:
continue
cyborg_action_to_int = {}
act_count = 0
for action in self.env.get_action_space(agent)['action'].keys():
params_dict = {}
if action.__name__ == 'Sleep':
cyborg_action_to_int[act_count] = Sleep()
act_count+=1
elif action.__name__ == 'RemoveOtherSessions':
params_dict['agent'] = agent
params_dict['session'] = 0
cyborg_action_to_int[act_count] = action(**params_dict)
act_count+=1
else:
for ip in self.env.get_action_space(agent)['ip_address'].keys():
for sess in self.env.get_action_space(agent)['session'].keys():
if sess == 0:
params_dict['session'] = 0
params_dict['ip_address'] = ip
params_dict['agent'] = agent
cyborg_action_to_int[act_count] = action(**params_dict)
act_count+=1
cyborg_agent_actions[agent] = cyborg_action_to_int
return cyborg_agent_actions

def get_action_space(self, agent):
'''
Obtains the action_space of the specified agent
Parameters:
agent -> str
'''
this_agent = agent
initial = self.env.get_action_space(agent)

this_agent = agent
unmasked_as = []
agent_actions = []
action_list = ['ExploitDroneVulnerability', 'SeizeControl', 'FloodBandwidth', 'BlockTraffic', 'AllowTraffic',
'RetakeControl', 'SendData']

for key in initial.copy():
if key != 'action':
Expand All @@ -182,14 +213,12 @@ def get_action_space(self, agent):
for i in range(len(init_list)):
agent_actions.append(init_list[i][0].__name__)

if ('Sleep' in agent_actions):
unmasked_as.append('Sleep')

if ('RemoveOtherSessions' in agent_actions):
unmasked_as.append(f'Remove {this_agent}')

for act in action_list:
if (act in agent_actions):
for act in agent_actions:
if act == 'Sleep':
unmasked_as.append('Sleep')
elif act == 'RemoveOtherSessions':
unmasked_as.append(f'RemoveOtherSessions {this_agent}')
else:
for agent in self.possible_agents:
unmasked_as.append(f"{act} {agent}")

Expand Down Expand Up @@ -219,15 +248,15 @@ def observation_change(self, agent: str, obs: dict):
# element location --> [0, 1,...,num_drones, 1+num_drones, 2+num_drones, ..., 2+2*num_drones, 3+2*num_drones, 4+2*num_drones,...,4+4*num_drones]
index = 0
# success
new_obs[index] = obs['success'].value
new_obs[index] = obs['success'].value - 1
index += 1

if agent in self.env.active_agents:
# Add blocked IPs
for i, ip in enumerate(self.ip_addresses):
new_obs[index + i] = 1 if ip in [interface['blocked_ips'] for interface in
new_obs[index + i] = 1 if ip in [blocked_ip for interface in
obs[own_host_name]['Interface'] if
'blocked_ips' in interface] else 0
'blocked_ips' in interface for blocked_ip in interface['blocked_ips']] else 0
index += len(self.ip_addresses)

# add flagged malicious processes
Expand All @@ -236,9 +265,9 @@ def observation_change(self, agent: str, obs: dict):
# add flagged messages
for i, ip in enumerate(self.ip_addresses):
# TODO add in check for network connections
new_obs[i] = 1 if ip in [interface['Network Connections'] for interface in
new_obs[index + i] = 1 if ip in [network_conn for interface in
obs[own_host_name]['Interface'] if
'Network Connections' in interface] else 0
'Network Connections' in interface for network_conn in interface['Network Connections']] else 0
index += len(self.ip_addresses)

pos = obs[own_host_name]['System info'].get('position', (0, 0))
Expand Down Expand Up @@ -291,3 +320,4 @@ def agents(self):
@property
def possible_agents(self):
return self.env.agents

1 change: 1 addition & 0 deletions CybORG/Agents/Wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .ChallengeWrapper import ChallengeWrapper
from .IntFixedFlatWrapper import IntFixedFlatWrapper
from .SimpleRedWrapper import SimpleRedWrapper
from .PettingZooParallelWrapper import PettingZooParallelWrapper
35 changes: 24 additions & 11 deletions CybORG/Shared/EnvironmentController.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def reset(self, agent: str = None, np_random=None) -> Results:
action_space=self.agent_interfaces[agent].action_space.get_action_space())

def step(self, actions: dict = None, skip_valid_action_check=False):

"""Updates the environment based on the joint actions of all agents
Save the
Expand All @@ -147,9 +148,9 @@ def step(self, actions: dict = None, skip_valid_action_check=False):
agent_object.messages = []
if agent_name not in actions:
actions[agent_name] = agent_object.get_action(self.get_last_observation(agent_name))
if not self.test_valid_action(actions[agent_name], agent_object) and not skip_valid_action_check:
actions[agent_name] = InvalidAction(action=actions[agent_name])
self._log_debug(f"{__class__.__name__} have Invalid Action agent_action={actions[agent_name]} for agent={agent_name}")
if not skip_valid_action_check:
actions[agent_name] = self.replace_action_if_invalid(actions[agent_name], agent_object)

self.action = actions
actions = self.sort_action_order(actions)

Expand Down Expand Up @@ -423,20 +424,32 @@ def _filter_obs(self, obs: Observation, agent_name=None):
)
return obs

def test_valid_action(self, action: Action, agent: AgentInterface):
# returns true if the parameters in the action are in and true in the action set else return false
def replace_action_if_invalid(self, action: Action, agent: AgentInterface):
# returns action if the parameters in the action are in and true in the action set else return InvalidAction imbued with bug report.
action_space = agent.action_space.get_action_space()
# first check that the action class is allowed
if type(action) not in action_space['action'] or not action_space['action'][type(action)]:
return False

if type(action) not in action_space['action']:
message = f'Action {action} not in action space for agent {agent.agent_name}.'
return InvalidAction(action=action, error=message)

if not action_space['action'][type(action)]:
message = f'Action {action} is not valid for agent {agent.agent_name} at the moment. This usually means it is trying to access a host it has not discovered yet.'
return InvalidAction(action=action, error=message)

# next for each parameter in the action
for parameter_name, parameter_value in action.get_params().items():
if parameter_name not in action_space:
continue

if (parameter_value not in action_space[parameter_name]) or (not action_space[parameter_name][parameter_value]):
return False
return True
if parameter_value not in action_space[parameter_name]:
message = f'Action {action} has parameter {parameter_name} valued at {parameter_value}. However, {parameter_value} is not in the action space for agent {agent.agent_name}.'
return InvalidAction(action=action, error=message)

if not action_space[parameter_name][parameter_value]:
message = f'Action {action} has parameter {parameter_name} valued at the invalid value of {parameter_value}. This usually means an agent is trying to utilise information it has not discovered yet such as an ip_address or port number.'
return InvalidAction(action=action, error=message)

return action

def get_reward_breakdown(self, agent:str):
return self.agent_interfaces[agent].reward_calculator.host_scores
Expand Down
3 changes: 2 additions & 1 deletion CybORG/Simulator/Actions/Action.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def execute(self, state):

class InvalidAction(Action):

def __init__(self, action: Action = None):
def __init__(self, action: Action = None, error: str =None):
super().__init__()
self.action = action
self.error = error

def execute(self, state):
return Observation(success=False)
Expand Down
Loading

0 comments on commit 94ab9d2

Please sign in to comment.