Skip to content

Commit

Permalink
gym 0.26 upgrade with new API contract - preparing for migration to g…
Browse files Browse the repository at this point in the history
…ymnasium (#84)

* upgrade to gym 0.26 and fix new typing errors

* Update notebooks with gym 0.26

* end of line

---------

Co-authored-by: William Blum <william.blum@microsoft.com>
  • Loading branch information
blumu and William Blum authored Aug 5, 2024
1 parent ed86a70 commit baf4f6e
Show file tree
Hide file tree
Showing 64 changed files with 1,305,699 additions and 235,857 deletions.
15 changes: 0 additions & 15 deletions createstubs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,6 @@ createstub asciichartpy
createstub networkx
createstub boolean
createstub IPython


if [ ! -d "typings/gym" ]; then
pyright --createstub gym
# Patch gym stubs
echo ' spaces = ...' >> typings/gym/spaces/dict.pyi
echo ' nvec = ...' >> typings/gym/spaces/space.pyi
echo ' spaces = ...' >> typings/gym/spaces/space.pyi
echo ' spaces = ...' >> typings/gym/spaces/tuple.pyi
echo ' n = ...' >> typings/gym/spaces/multi_binary.pyi
else
echo stub gym already created
fi


echo 'Typing stub generation completed'

popd
28 changes: 14 additions & 14 deletions cyberbattle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
""" same as gym.envs.registry.register, but adds CyberBattle specs to env.spec """
if id in registry.env_specs:
if id in registry:
raise Error('Cannot re-register id: {}'.format(id))
spec = EnvSpec(id, **kwargs)
# Map from port number to port names : List[model.PortName]
Expand All @@ -33,11 +33,11 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
# Array defining an index for every possible remote vulnerability name : List[model.VulnerabilityID]
spec.remote_vulnerabilities = cyberbattle_env_identifiers.remote_vulnerabilities

registry.env_specs[id] = spec
registry[id] = spec


if 'CyberBattleToyCtf-v0' in registry.env_specs:
del registry.env_specs['CyberBattleToyCtf-v0']
if 'CyberBattleToyCtf-v0' in registry:
del registry['CyberBattleToyCtf-v0']

register(
id='CyberBattleToyCtf-v0',
Expand All @@ -50,8 +50,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
# max_episode_steps=2600,
)

if 'CyberBattleTiny-v0' in registry.env_specs:
del registry.env_specs['CyberBattleTiny-v0']
if 'CyberBattleTiny-v0' in registry:
del registry['CyberBattleTiny-v0']

register(
id='CyberBattleTiny-v0',
Expand All @@ -67,17 +67,17 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
)


if 'CyberBattleRandom-v0' in registry.env_specs:
del registry.env_specs['CyberBattleRandom-v0']
if 'CyberBattleRandom-v0' in registry:
del registry['CyberBattleRandom-v0']

register(
id='CyberBattleRandom-v0',
cyberbattle_env_identifiers=generate_network.ENV_IDENTIFIERS,
entry_point='cyberbattle._env.cyberbattle_random:CyberBattleRandom',
)

if 'CyberBattleChain-v0' in registry.env_specs:
del registry.env_specs['CyberBattleChain-v0']
if 'CyberBattleChain-v0' in registry:
del registry['CyberBattleChain-v0']

register(
id='CyberBattleChain-v0',
Expand All @@ -95,8 +95,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):

ad_envs = [f"ActiveDirectory-v{i}" for i in range(0, 10)]
for (index, env) in enumerate(ad_envs):
if env in registry.env_specs:
del registry.env_specs[env]
if env in registry:
del registry[env]

register(
id=env,
Expand All @@ -110,8 +110,8 @@ def register(id: str, cyberbattle_env_identifiers: model.Identifiers, **kwargs):
}
)

if 'ActiveDirectoryTiny-v0' in registry.env_specs:
del registry.env_specs['ActiveDirectoryTiny-v0']
if 'ActiveDirectoryTiny-v0' in registry:
del registry['ActiveDirectoryTiny-v0']
register(
id='ActiveDirectoryTiny-v0',
cyberbattle_env_identifiers=chainpattern.ENV_IDENTIFIERS,
Expand Down
223 changes: 111 additions & 112 deletions cyberbattle/_env/cyberbattle_env.py

Large diffs are not rendered by default.

48 changes: 21 additions & 27 deletions cyberbattle/_env/cyberbattle_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

"""Test the CyberBattle Gym environment"""

from cyberbattle._env.option_wrapper import ContextWrapper, random_options
from cyberbattle._env.cyberbattle_env import AttackerGoal, CyberBattleEnv
import pytest
import gym
import numpy as np

from .cyberbattle_env import AttackerGoal
from typing import cast


def test_few_gym_iterations() -> None:
"""Run a few iterations of the gym environment"""
env = gym.make('CyberBattleToyCtf-v0')
env = cast(CyberBattleEnv, gym.make('CyberBattleToyCtf-v0'))

for _ in range(2):
env.reset()
Expand All @@ -24,7 +25,11 @@ def test_few_gym_iterations() -> None:
# sample a valid action
action = env.sample_valid_action()

observation, reward, done, info = env.step(action)
observation, reward, done, truncated, info = env.step(action)
if truncated:
print("Episode truncated after {} timesteps".format(t + 1))
break

if done:
print("Episode finished after {} timesteps".format(t + 1))
break
Expand Down Expand Up @@ -99,32 +104,21 @@ def test_step_after_done() -> None:
env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(own_atleast_percent=1.0))
env.reset()
for a in actions[:-1]:
observation, reward, done, info = env.step(a)
print(f"{a}, # done={done} r={reward}")
observation, reward, done, truncated, info = env.step(a)
print(f"{a}, # done={done} truncated={truncated} r={reward}")

with pytest.raises(RuntimeError, match=r'new episode must be started with env\.reset\(\)'):
env.step(actions[-1])


@pytest.mark.parametrize('env_name', ['CyberBattleToyCtf-v0', 'CyberBattleRandom-v0', 'CyberBattleChain-v0'])
def test_wrap_spec(env_name) -> None:
env = gym.make(env_name)

class DummyWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
assert hasattr(self, 'spec')
self.spec.dummy = 7

assert hasattr(env.spec, 'properties')
assert hasattr(env.spec, 'ports')
assert hasattr(env.spec, 'local_vulnerabilities')
assert hasattr(env.spec, 'remote_vulnerabilities')

env = DummyWrapper(env)
def test_option_wrapper():
env = gym.make('CyberBattleChain-v0', size=10, attacker_goal=AttackerGoal(reward=4000))
env = ContextWrapper(cast(CyberBattleEnv, env), options=random_options)

assert hasattr(env.spec, 'properties')
assert hasattr(env.spec, 'ports')
assert hasattr(env.spec, 'local_vulnerabilities')
assert hasattr(env.spec, 'remote_vulnerabilities')
assert hasattr(env.spec, 'dummy')
s = env.reset()
for t in range(4):
s, r, done, truncated, info = env.step()
if r > 0:
print(r, done, info['action'])
if done:
break
21 changes: 12 additions & 9 deletions cyberbattle/_env/discriminatedunion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
"""A discriminated union space for Gym"""

from collections import OrderedDict
from typing import Mapping, Union, List
from typing import Mapping, TypeVar, Union
from typing import Dict as TypingDict, Generic, cast

from gym import spaces
from gym.utils import seeding

T_cov = TypeVar("T_cov", covariant=True)

class DiscriminatedUnion(spaces.Dict): # type: ignore

class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore
"""
A discriminated union of simpler spaces.
Expand All @@ -22,7 +25,7 @@ class DiscriminatedUnion(spaces.Dict): # type: ignore
"""

def __init__(self,
spaces: Union[None, List[spaces.Space], Mapping[str, spaces.Space]] = None,
spaces: Union[None, TypingDict[str, spaces.Space]] = None,
**spaces_kwargs: spaces.Space) -> None:
"""Create a discriminated union space"""
if spaces is None:
Expand All @@ -34,13 +37,13 @@ def seed(self, seed: Union[None, int] = None) -> None:
self._np_random, seed = seeding.np_random(seed)
super().seed(seed)

def sample(self) -> object:
def sample(self) -> T_cov: # dict[str, object]:
space_count = len(self.spaces.items())
index_k = self.np_random.randint(space_count)
index_k = self.np_random.integers(0, space_count)
kth_key, kth_space = list(self.spaces.items())[index_k]
return OrderedDict([(kth_key, kth_space.sample())])
return cast(T_cov, OrderedDict([(kth_key, kth_space.sample())]))

def contains(self, candidate: object) -> bool:
def contains(self, candidate) -> bool:
if not isinstance(candidate, dict) or len(candidate) != 1:
return False
k, space = list(candidate)[0]
Expand All @@ -64,10 +67,10 @@ def __getitem__(self, key: str) -> spaces.Space:
def __repr__(self) -> str:
return self.__class__.__name__ + "(" + ", ". join([str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"

def to_jsonable(self, sample_n: object) -> object:
def to_jsonable(self, sample_n: list) -> object:
return super().to_jsonable(sample_n)

def from_jsonable(self, sample_n: object) -> object:
def from_jsonable(self, sample_n: TypingDict[str, list]) -> object:
ret = super().from_jsonable(sample_n)
assert len(ret) == 1
return ret
Expand Down
20 changes: 13 additions & 7 deletions cyberbattle/_env/flatten_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""Space flattening wrappers fro the CyberBattleEnv gym environment.
"""Wrappers used to flatten action and observation spaces
for CyberBattleEnv gym environment.
"""
from collections import OrderedDict
from sqlite3 import NotSupportedError
from gym import spaces
import numpy as np
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action
from cyberbattle._env.cyberbattle_env import DummySpace, CyberBattleEnv, Action, CyberBattleSpaceKind
from gym.core import ObservationWrapper, ActionWrapper


class FlattenObservationWrapper(ObservationWrapper):
"""
Flatten all nested dictionaries and tuples from the
observation space of a CyberBattleSim environment`CyberBattleEnv`.
observation space of a CyberBattleSim environment.
The resulting observation space is a dictionary containing only
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
"""
Expand All @@ -28,7 +29,7 @@ def flatten_multibinary_space(self, space: spaces.Space):
else:
return space

def __init__(self, env: CyberBattleEnv, ignore_fields=['action_mask']):
def __init__(self, env: CyberBattleSpaceKind, ignore_fields=['action_mask']):
ObservationWrapper.__init__(self, env)
self.env = env
self.ignore_fields = ignore_fields
Expand Down Expand Up @@ -56,14 +57,14 @@ def __init__(self, env: CyberBattleEnv, ignore_fields=['action_mask']):

def flatten_multibinary_observation(self, space, o):
if isinstance(space, spaces.MultiBinary) and \
type(space.n) in [tuple, list, np.ndarray] and \
isinstance(space.n, tuple) and \
len(space.n) > 1:
flatten_dim = np.multiply.reduce(space.n)
return tuple(o.reshape(flatten_dim))
else:
return o

def observation(self, observation: dict):
def observation(self, observation):
o = OrderedDict({})
for key, space in self.env.observation_space.spaces.items():
value = observation[key]
Expand All @@ -86,11 +87,16 @@ def observation(self, observation: dict):

return o

def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info


class FlattenActionWrapper(ActionWrapper):
"""
Flatten all nested dictionaries and tuples from the
action space of a CyberBattleSim environment`CyberBattleEnv`.
action space of a CyberBattleSim environment.
The resulting action space is a dictionary containing only
subspaces of types: `Discrete`, `MultiBinary`, and `MultiDiscrete`.
"""
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/_env/graph_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
super().__init__()

def sample(self):
num_nodes = self.np_random.randint(self.max_num_nodes + 1)
num_nodes = self.np_random.integers(0, self.max_num_nodes + 1)
graph = self._nx_class()

# add nodes with properties
Expand Down
8 changes: 4 additions & 4 deletions cyberbattle/_env/graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ class CyberBattleGraph(gym.Wrapper):

def __init__(self, env, maximum_total_credentials=22, maximum_node_count=22):
super().__init__(env)
self._bounds = self.env.bounds
self._bounds = env.bounds
self.__graph = nx.DiGraph()
self.observation_space = DiGraph(self.bounds.maximum_node_count)
self.observation_space = DiGraph(self._bounds.maximum_node_count)

def reset(self):
observation = self.env.reset()
Expand All @@ -168,13 +168,13 @@ def step(self, action: Action):
"""
kind_id, *indicators = action
observation, reward, done, info = self.env.step({self.__kinds[kind_id]: indicators})
observation, reward, done, truncated, info = self.env.step({self.__kinds[kind_id]: indicators})
for _ in range(observation['newly_discovered_nodes_count']):
self.__add_node(observation)
if True: # TODO: do we need to update edges and nodes every time?
self.__update_edges(observation)
self.__update_nodes(observation)
return self.__graph, reward, done, info
return self.__graph, reward, done, truncated, info

def __add_node(self, observation):
while self.__graph.number_of_nodes() < observation['discovered_node_count']:
Expand Down
13 changes: 7 additions & 6 deletions cyberbattle/_env/option_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gym
from gym.spaces import Space, Discrete, Tuple
import numpy as onp
from cyberbattle._env.cyberbattle_env import Action, CyberBattleEnv


class Env(NamedTuple):
Expand All @@ -31,7 +32,7 @@ def context_spaces(observation_space, action_space):
class ContextWrapper(gym.Wrapper):
__kinds = ('local_vulnerability', 'remote_vulnerability', 'connect')

def __init__(self, env, options):
def __init__(self, env: CyberBattleEnv, options):

super().__init__(env)
self.env = env
Expand All @@ -52,21 +53,21 @@ def step(self, dummy=None):
local_node_id = self._options['local_node_id']((obs, kind))
if kind == 0:
local_vuln_id = self._options['local_vuln_id']((obs, local_node_id))
a = {self.__kinds[kind]: onp.array([local_node_id, local_vuln_id])}
a: Action = {"local_vulnerability": onp.array([local_node_id, local_vuln_id])}
else:
remote_node_id = self._options['remote_node_id']((obs, kind, local_node_id))
if kind == 1:
remote_vuln_id = \
self._options['remote_vuln_id']((obs, local_node_id, remote_node_id))
a = {self.__kinds[kind]: onp.array([local_node_id, remote_node_id, remote_vuln_id])}
a = {"remote_vulnerability": onp.array([local_node_id, remote_node_id, remote_vuln_id])}
else:
cred_id = self._options['cred_id'](obs)
assert cred_id < obs['credential_cache_length']
node_id, port_id = obs['credential_cache_matrix'][cred_id].astype('int32')
a = {self.__kinds[kind]: onp.array([local_node_id, node_id, port_id, cred_id])}
a = {"connect": onp.array([local_node_id, node_id, port_id, cred_id])}

self._observation, reward, done, info = self.env.step(a)
return self._observation, reward, done, {**info, 'action': a}
self._observation, reward, done, truncated, info = self.env.step(a)
return self._observation, reward, done, truncated, {**info, 'action': a}


# --- random option policies --------------------------------------------------------------------- #
Expand Down
2 changes: 1 addition & 1 deletion cyberbattle/agents/baseline/agent_dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def update_q_function(self,
self.optimize_model()

def on_step(self, wrapped_env: w.AgentWrapper,
observation, reward: float, done: bool, info, action_metadata):
observation, reward: float, done: bool, truncated: bool, info, action_metadata):
agent_state = wrapped_env.state
if done:
self.update_q_function(reward,
Expand Down
Loading

0 comments on commit baf4f6e

Please sign in to comment.