Skip to content

Commit

Permalink
Updating render function on OpenAIGymWrapper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
maxstanden committed Dec 15, 2022
1 parent d0d7fa1 commit f34ad9a
Showing 1 changed file with 6 additions and 41 deletions.
47 changes: 6 additions & 41 deletions CybORG/Agents/Wrappers/OpenAIGymWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import numpy as np
from gym import spaces, Env
from typing import Union, List, Optional

from typing import Union, List, Optional, Tuple

from prettytable import PrettyTable

from CybORG.Agents.SimpleAgents.BaseAgent import BaseAgent
from CybORG.Agents.Wrappers.BaseWrapper import BaseWrapper



class OpenAIGymWrapper(Env, BaseWrapper):
def __init__(self, env: BaseWrapper, agent_name: str):
super().__init__(env)
Expand All @@ -27,7 +27,7 @@ def __init__(self, env: BaseWrapper, agent_name: str):
self.metadata = {}
self.action = None

def step(self, action: Union[int, List[int]] = None) -> (object, float, bool, dict):
def step(self, action: Union[int, List[int]] = None) -> Tuple[object, float, bool, dict]:
if action is not None:
action = self.possible_actions[action]
self.action = action
Expand All @@ -50,41 +50,8 @@ def reset(self, *, seed: Optional[int] = None, return_info: bool = False, option
else:
return np.array(result.observation, dtype=np.float32)

def render(self, mode):
# TODO: If FixedFlatWrapper it will error out!
if mode == 'human':
self.env.render(mode)
else:
if self.agent_name == 'Red':
table = PrettyTable({
'Subnet',
'IP Address',
'Hostname',
'Scanned',
'Access',
})
for ip in self.get_attr('red_info'):
table.add_row(self.get_attr('red_info')[ip])
table.sortby = 'IP Address'
if self.action is not None:
_action = self.get_attr('possible_actions')[self.action]
return print(f'\nRed Action: {_action}\n{table}')
elif self.agent_name == 'Blue':
table = PrettyTable({
'Subnet',
'IP Address',
'Hostname',
'Activity',
'Compromised',
})
for hostid in self.get_attr('info'):
table.add_row(self.get_attr('info')[hostid])
table.sortby = 'Hostname'
if self.action is not None:
_action = self.get_attr('possible_actions')[self.action]
red_action = self.get_last_action(agent=self.agent_name)
return print(f'\nBlue Action: {_action}\nRed Action: {red_action}\n{table}')
return print(table)
def render(self, mode='human'):
return self.env.render(mode)

def get_attr(self,attribute:str):
return self.env.get_attr(attribute)
Expand All @@ -109,8 +76,6 @@ def get_ip_map(self):
def get_rewards(self):
return self.get_attr('get_rewards')()



def action_space_change(self, action_space: dict) -> int:
assert type(action_space) is dict, \
f"Wrapper required a dictionary action space. " \
Expand Down Expand Up @@ -145,4 +110,4 @@ def action_space_change(self, action_space: dict) -> int:
possible_actions.append(action(**p_dict))

self.possible_actions = possible_actions
return len(possible_actions)
return len(possible_actions)

0 comments on commit f34ad9a

Please sign in to comment.