diff --git a/.vscode/settings.json b/.vscode/settings.json index 0c3da44..e281264 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -28,9 +28,6 @@ "jupyter.jupyterServerType": "local", "jupyter.notebookFileRoot": "${workspaceFolder}", "files.trimFinalNewlines": true, - "editor.defaultFormatter": "ms-python.flake8", - "flake8.args": [ - ], + "editor.defaultFormatter": "ms-python.black-formatter", "files.trimTrailingWhitespace": true, - } diff --git a/cyberbattle/_env/cyberbattle_env.py b/cyberbattle/_env/cyberbattle_env.py index ee7478c..6896bff 100644 --- a/cyberbattle/_env/cyberbattle_env.py +++ b/cyberbattle/_env/cyberbattle_env.py @@ -1317,7 +1317,7 @@ def get_explored_network_node_properties_bitmap_as_numpy( ] ) - def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: + def step(self, action: Action) -> Tuple[Observation, float, bool, bool, StepInfo]: # type: ignore if self.__done: raise RuntimeError("new episode must be started with env.reset()") diff --git a/cyberbattle/_env/discriminatedunion.py b/cyberbattle/_env/discriminatedunion.py index bb72450..149233a 100644 --- a/cyberbattle/_env/discriminatedunion.py +++ b/cyberbattle/_env/discriminatedunion.py @@ -14,7 +14,7 @@ T_cov = TypeVar("T_cov", covariant=True) -class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore +class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): """ A discriminated union of simpler spaces. @@ -23,6 +23,9 @@ class DiscriminatedUnion(spaces.Dict, Generic[T_cov]): # type: ignore self.observation_space = discriminatedunion.DiscriminatedUnion( {"foo": spaces.Discrete(2), "Bar": spaces.Discrete(3)}) + Generic type T_cov is the type of the contained discriminated values. + It should be defined as a typed dictionary, e.g.: TypedDict('Choices', {'foo': int, 'Bar': int}) + """ def __init__( @@ -47,7 +50,7 @@ def __init__( def seed(self, seed: Union[dict, None, int] = None): return super().seed(seed) - def sample(self, mask=None) -> T_cov: # dict[str, object]: + def sample(self, mask=None) -> T_cov: # type: ignore space_count = len(self.spaces.items()) index_k = self.union_np_random.integers(0, space_count) kth_key, kth_space = list(self.spaces.items())[index_k] diff --git a/cyberbattle/agents/baseline/agent_wrapper.py b/cyberbattle/agents/baseline/agent_wrapper.py index 63471af..bf6dafc 100644 --- a/cyberbattle/agents/baseline/agent_wrapper.py +++ b/cyberbattle/agents/baseline/agent_wrapper.py @@ -4,8 +4,9 @@ """Agent wrapper for CyberBattle envrionments exposing additional features extracted from the environment observations""" +from abc import abstractmethod from cyberbattle._env.cyberbattle_env import EnvironmentBounds -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, overload import enum import numpy as np from gym import spaces, Wrapper @@ -25,7 +26,7 @@ def on_step( self, action: cyberbattle_env.Action, reward: float, - truncated, + truncated: bool, done: bool, observation: cyberbattle_env.Observation, ): @@ -35,6 +36,7 @@ def on_reset(self, observation: cyberbattle_env.Observation): self.observation = observation +# Abstract class for a feature (either global or node-specific) class Feature(spaces.MultiDiscrete): """ Feature consisting of multiple discrete dimensions. @@ -55,22 +57,67 @@ def name(self): p = len(type(Feature(self.env_properties, [])).__name__) + 1 return type(self).__name__[p:] - def get(self, a: StateAugmentation, node: Optional[int]) -> np.ndarray: + def pretty_print(self, v): + return v + + @abstractmethod + def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray: """Compute the current value of a feature value at the current observation and specific node""" raise NotImplementedError - def pretty_print(self, v): - return v + +class NodeFeature(Feature): + """ + Feature consisting of multiple discrete dimensions at a specific node. + """ + + @abstractmethod + def get_at(self, a: StateAugmentation, node: int) -> np.ndarray: + """Compute the current value of a feature value at + the current observation and specific node""" + raise NotImplementedError + + def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray: + assert node is not None, "feature only valid in the context of a node" + return self.get_at(a, node) + + +class GlobalFeature(Feature): + """ + Feature consisting of multiple discrete dimensions at the global level. + """ + + @abstractmethod + def get_global(self, a: StateAugmentation) -> np.ndarray: + """Compute the current value of a feature value at + the current observation""" + raise NotImplementedError + + def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray: + assert node is None, "feature only valid in the context of a node" + return self.get_global(a) + + # @staticmethod + # def get_feature_value( + # f: Union[NodeFeature, GlobalFeature], a: SA_T, node: Optional[int] + # ): + # """Return the feature value at the current observation and specific node""" + # if isinstance(f, NodeFeature): + # assert node is not None, "feature only valid in the context of a node" + # return f.get(a, node) + # elif isinstance(f, GlobalFeature): + # assert node is None, "feature only valid in the context of a node" + # return f.get(a) -class Feature_active_node_properties(Feature): +class Feature_active_node_properties(NodeFeature): """Bitmask of all properties set for the active node""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [2] * p.property_count) - def get(self, a: StateAugmentation, node) -> ndarray: + def get_at(self, a: StateAugmentation, node) -> ndarray: assert node is not None, "feature only valid in the context of a node" node_prop = a.observation["discovered_nodes_properties"] @@ -84,14 +131,14 @@ def get(self, a: StateAugmentation, node) -> ndarray: return remapped -class Feature_active_node_age(Feature): +class Feature_active_node_age(NodeFeature): """How recently was this node discovered? (measured by reverse position in the list of discovered nodes)""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_node_count]) - def get(self, a: StateAugmentation, node) -> ndarray: + def get_at(self, a: StateAugmentation, node) -> ndarray: assert node is not None, "feature only valid in the context of a node" discovered_node_count = a.observation["discovered_node_count"] @@ -103,17 +150,17 @@ def get(self, a: StateAugmentation, node) -> ndarray: return np.array([discovered_node_count - node - 1], dtype=np.int_) -class Feature_active_node_id(Feature): +class Feature_active_node_id(NodeFeature): """Return the node id itself""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_node_count] * 1) - def get(self, a: StateAugmentation, node) -> ndarray: + def get_at(self, a: StateAugmentation, node) -> ndarray: return np.array([node], dtype=np.int_) -class Feature_discovered_nodeproperties_sliding(Feature): +class Feature_discovered_nodeproperties_sliding(GlobalFeature): """Bitmask indicating node properties seen in last few cache entries""" window_size = 3 @@ -121,12 +168,12 @@ class Feature_discovered_nodeproperties_sliding(Feature): def __init__(self, p: EnvironmentBounds): super().__init__(p, [2] * p.property_count) - def get(self, a: StateAugmentation, node) -> ndarray: + def get_global(self, a: StateAugmentation) -> ndarray: n = a.observation["discovered_node_count"] node_prop = np.array(a.observation["discovered_nodes_properties"])[:n] # keep last window of entries - node_prop_window = node_prop[-self.window_size:, :] + node_prop_window = node_prop[-self.window_size :, :] # Remap to get rid of the unknown value (2) node_prop_window_remapped = np.int32(node_prop_window % 2) @@ -137,13 +184,13 @@ def get(self, a: StateAugmentation, node) -> ndarray: return bitmask -class Feature_discovered_ports(Feature): +class Feature_discovered_ports(GlobalFeature): """Bitmask vector indicating each port seen so far in discovered credentials""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [2] * p.port_count) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): n = a.observation["credential_cache_length"] known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32) if n > 0: @@ -152,7 +199,7 @@ def get(self, a: StateAugmentation, node): return known_credports -class Feature_discovered_ports_sliding(Feature): +class Feature_discovered_ports_sliding(GlobalFeature): """Bitmask indicating port seen in last few cache entries""" window_size = 3 @@ -160,22 +207,22 @@ class Feature_discovered_ports_sliding(Feature): def __init__(self, p: EnvironmentBounds): super().__init__(p, [2] * p.port_count) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): known_credports = np.zeros(self.env_properties.port_count, dtype=np.int32) n = a.observation["credential_cache_length"] if n > 0: ccm = np.array(a.observation["credential_cache_matrix"])[:n] - known_credports[np.int32(ccm[-self.window_size:, 1])] = 1 + known_credports[np.int32(ccm[-self.window_size :, 1])] = 1 return known_credports -class Feature_discovered_ports_counts(Feature): +class Feature_discovered_ports_counts(GlobalFeature): """Count of each port seen so far in discovered credentials""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_total_credentials + 1] * p.port_count) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): n = a.observation["credential_cache_length"] if n > 0: ccm = np.array(a.observation["credential_cache_matrix"])[:n] @@ -185,35 +232,35 @@ def get(self, a: StateAugmentation, node): return np.bincount(ports, minlength=self.env_properties.port_count) -class Feature_discovered_credential_count(Feature): +class Feature_discovered_credential_count(GlobalFeature): """number of credentials discovered so far""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_total_credentials + 1]) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): n = a.observation["credential_cache_length"] return np.array([n], dtype=np.int_) -class Feature_discovered_node_count(Feature): +class Feature_discovered_node_count(GlobalFeature): """number of nodes discovered so far""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_node_count + 1]) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): return np.array([a.observation["discovered_node_count"]], dtype=np.int_) -class Feature_discovered_notowned_node_count(Feature): +class Feature_discovered_notowned_node_count(GlobalFeature): """number of nodes discovered that are not owned yet (optionally clipped)""" def __init__(self, p: EnvironmentBounds, clip: Optional[int]): self.clip = p.maximum_node_count if clip is None else clip super().__init__(p, [self.clip + 1]) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): discovered = a.observation["discovered_node_count"] node_props = np.array(a.observation["discovered_nodes_properties"][:discovered]) # here we assume that a node is owned just if all its properties are known @@ -222,13 +269,13 @@ def get(self, a: StateAugmentation, node): return np.array([min(diff, self.clip)], dtype=np.int_) -class Feature_owned_node_count(Feature): +class Feature_owned_node_count(GlobalFeature): """number of owned nodes so far""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [p.maximum_node_count + 1]) - def get(self, a: StateAugmentation, node): + def get_global(self, a: StateAugmentation): levels = a.observation["nodes_privilegelevel"] owned_nodes_indices = np.where(levels > 0)[0] return np.array([len(owned_nodes_indices)], dtype=np.int_) @@ -240,7 +287,11 @@ class ConcatFeatures(Feature): feature_selection - a selection of features to combine """ - def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]): + def __init__( + self, + p: EnvironmentBounds, + feature_selection: List[Feature], + ): self.feature_selection = feature_selection self.dim_sizes = np.concatenate([f.nvec for f in feature_selection]) super().__init__(p, [self.dim_sizes]) @@ -248,14 +299,15 @@ def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]): def pretty_print(self, v): return v - def get(self, a: StateAugmentation, node=None) -> np.ndarray: + def get(self, a: StateAugmentation, node: Optional[int] = None) -> np.ndarray: """Return the feature vector""" feature_vector = [f.get(a, node) for f in self.feature_selection] + return np.concatenate(feature_vector) class FeatureEncoder(Feature): - """Encode a list of featues as a unique index""" + """Encode a list of features as a unique index""" feature_selection: List[Feature] @@ -278,15 +330,11 @@ def encode(self, a: StateAugmentation, node=None) -> int: feature_vector_concat = self.feature_vector_of_observation_at(a, node) return self.vector_to_index(feature_vector_concat) - def encode_at(self, a: StateAugmentation, node) -> int: + def encode_at(self, a: StateAugmentation, node: int) -> int: """Return the current feature vector encoding with a node context""" feature_vector_concat = self.feature_vector_of_observation_at(a, node) return self.vector_to_index(feature_vector_concat) - def get(self, a: StateAugmentation, node=None) -> np.ndarray: - """Return the feature vector""" - return np.array([self.encode(a, node)]) - def name(self): """Return a name for the feature encoding""" n = ", ".join([f.name() for f in self.feature_selection]) @@ -301,7 +349,10 @@ class HashEncoding(FeatureEncoder): """ def __init__( - self, p: EnvironmentBounds, feature_selection: List[Feature], hash_size: int + self, + p: EnvironmentBounds, + feature_selection: List[Feature], + hash_size: int, ): self.feature_selection = feature_selection self.hash_size = hash_size @@ -325,21 +376,24 @@ class RavelEncoding(FeatureEncoder): feature_selection - a selection of features to combine """ - def __init__(self, p: EnvironmentBounds, feature_selection: List[Feature]): + def __init__( + self, + p: EnvironmentBounds, + feature_selection: List[Feature], + ): self.feature_selection = feature_selection self.dim_sizes = np.concatenate([f.nvec for f in feature_selection]) self.ravelled_size: np.int64 = np.prod(self.dim_sizes) assert np.shape(self.ravelled_size) == (), f"! {np.shape(self.ravelled_size)}" super().__init__(p, [self.ravelled_size]) - def vector_to_index(self, feature_vector): + def vector_to_index(self, feature_vector) -> int: assert len(self.dim_sizes) == len(feature_vector), ( f"feature vector of size {len(feature_vector)}, " f"expecting {len(self.dim_sizes)}: {feature_vector} -- {self.dim_sizes}" ) - index: np.int32 = np.ravel_multi_index( - list(feature_vector), list(self.dim_sizes) - ) + index_intp = np.ravel_multi_index(list(feature_vector), list(self.dim_sizes)) + index = index_intp.item() assert index < self.ravelled_size, ( f"feature vector out of bound ({feature_vector}, dim={self.dim_sizes}) " f"-> index={index}, max_index={self.ravelled_size-1})" @@ -522,14 +576,22 @@ def on_reset(self, observation: cyberbattle_env.Observation): super().on_reset(observation) -class Feature_actions_tried_at_node(Feature): +class Feature_actions_tried_at_node(NodeFeature): """A bit mask indicating which actions were already tried a the current node: 0 no tried, 1 tried""" def __init__(self, p: EnvironmentBounds): super().__init__(p, [2] * AbstractAction(p).n_actions) - def get(self, a: ActionTrackingStateAugmentation, node: int): + @overload + def get_at(self, a: ActionTrackingStateAugmentation, node: int): ... + + @overload + def get_at(self, a: StateAugmentation, node: int): ... + + def get_at(self, a: StateAugmentation, node: int): + assert node is not None, "feature only valid in the context of a node" + assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type" return np.array( ((a.failed_action_count[node, :] + a.success_action_count[node, :]) != 0) * 1, @@ -537,7 +599,7 @@ def get(self, a: ActionTrackingStateAugmentation, node: int): ) -class Feature_success_actions_at_node(Feature): +class Feature_success_actions_at_node(NodeFeature): """number of time each action succeeded at a given node""" max_action_count = 100 @@ -545,11 +607,19 @@ class Feature_success_actions_at_node(Feature): def __init__(self, p: EnvironmentBounds): super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions) - def get(self, a: ActionTrackingStateAugmentation, node: int): + @overload + def get_at(self, a: ActionTrackingStateAugmentation, node: int): ... + + @overload + def get_at(self, a: StateAugmentation, node: int): ... + + def get_at(self, a: StateAugmentation, node: int): + assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type" + return np.minimum(a.success_action_count[node, :], self.max_action_count - 1) -class Feature_failed_actions_at_node(Feature): +class Feature_failed_actions_at_node(NodeFeature): """number of time each action failed at a given node""" max_action_count = 100 @@ -557,7 +627,8 @@ class Feature_failed_actions_at_node(Feature): def __init__(self, p: EnvironmentBounds): super().__init__(p, [self.max_action_count] * AbstractAction(p).n_actions) - def get(self, a: ActionTrackingStateAugmentation, node: int): + def get_at(self, a: StateAugmentation, node: int): + assert isinstance(a, ActionTrackingStateAugmentation), "invalid state type" return np.minimum(a.failed_action_count[node, :], self.max_action_count - 1) @@ -577,7 +648,7 @@ def __init__(self, env: cyberbattle_env.CyberBattleEnv, state: StateAugmentation self.env = env self.state = state - def step(self, action: cyberbattle_env.Action): + def step(self, action: cyberbattle_env.Action): # type: ignore observation, reward, done, truncated, info = self.env.step(action) self.state.on_step(action, reward, done, truncated, observation) return observation, reward, done, truncated, info diff --git a/cyberbattle/agents/baseline/notebooks/notebook_benchmark.py b/cyberbattle/agents/baseline/notebooks/notebook_benchmark.py index 0f53194..b00b939 100644 --- a/cyberbattle/agents/baseline/notebooks/notebook_benchmark.py +++ b/cyberbattle/agents/baseline/notebooks/notebook_benchmark.py @@ -88,7 +88,7 @@ ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)] ) a = w.StateAugmentation(o0) - w.Feature_discovered_ports(ep).get(a, None) + w.Feature_discovered_ports(ep).get(a) fe_example.encode_at(a, 0) # %% diff --git a/notebooks/notebook_benchmark-chain.ipynb b/notebooks/notebook_benchmark-chain.ipynb index c88bab8..eb8d0e8 100644 --- a/notebooks/notebook_benchmark-chain.ipynb +++ b/notebooks/notebook_benchmark-chain.ipynb @@ -83,7 +83,7 @@ "from cyberbattle.agents.baseline.agent_wrapper import Verbosity\n", "\n", "logging.basicConfig(stream=sys.stdout, level=logging.ERROR, format=\"%(levelname)s: %(message)s\")\n", - "%matplotlib inline " + "%matplotlib inline" ] }, { @@ -248,7 +248,7 @@ "\n", " fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])\n", " a = w.StateAugmentation(o0)\n", - " w.Feature_discovered_ports(ep).get(a, None)\n", + " w.Feature_discovered_ports(ep).get(a)\n", " fe_example.encode_at(a, 0)" ] }, @@ -144988,7 +144988,7 @@ } ], "source": [ - "%matplotlib inline \n", + "%matplotlib inline\n", "# Compare and plot results for all the agents\n", "all_runs = [\n", " random_run,\n", diff --git a/notebooks/notebook_benchmark-tiny.ipynb b/notebooks/notebook_benchmark-tiny.ipynb index 2f5bcb0..bc89c96 100644 --- a/notebooks/notebook_benchmark-tiny.ipynb +++ b/notebooks/notebook_benchmark-tiny.ipynb @@ -199,7 +199,7 @@ "\n", " fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])\n", " a = w.StateAugmentation(o0)\n", - " w.Feature_discovered_ports(ep).get(a, None)\n", + " w.Feature_discovered_ports(ep).get(a)\n", " fe_example.encode_at(a, 0)" ] }, diff --git a/notebooks/notebook_benchmark-toyctf.ipynb b/notebooks/notebook_benchmark-toyctf.ipynb index 29129a1..2e1533a 100644 --- a/notebooks/notebook_benchmark-toyctf.ipynb +++ b/notebooks/notebook_benchmark-toyctf.ipynb @@ -35,7 +35,7 @@ "\"\"\"\n", "\n", "# pylint: disable=invalid-name\n", - "%matplotlib inline " + "%matplotlib inline" ] }, { @@ -97,7 +97,7 @@ }, "outputs": [], "source": [ - "%matplotlib inline " + "%matplotlib inline" ] }, { @@ -259,7 +259,7 @@ "\n", " fe_example = w.RavelEncoding(ep, [w.Feature_active_node_properties(ep), w.Feature_discovered_node_count(ep)])\n", " a = w.StateAugmentation(o0)\n", - " w.Feature_discovered_ports(ep).get(a, None)\n", + " w.Feature_discovered_ports(ep).get(a)\n", " fe_example.encode_at(a, 0)" ] }, diff --git a/setup.cfg b/setup.cfg index 831335b..4db379e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -ignore = W504,W503,E501,N813,N812,E741 +ignore = W504,W503,E501,N813,N812,E741,E203,E704 max-line-length = 200 max-doc-length = 200 exclude = typings, venv