From 36146921366d56397493c2b341bbfb2b129417e7 Mon Sep 17 00:00:00 2001 From: William Blum Date: Tue, 6 Aug 2024 23:48:05 +0000 Subject: [PATCH] . --- cyberbattle/_env/cyberbattle_env.py | 10 +++++----- notebooks/stable-baselines-agent.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/cyberbattle/_env/cyberbattle_env.py b/cyberbattle/_env/cyberbattle_env.py index 43687f7..751f9da 100644 --- a/cyberbattle/_env/cyberbattle_env.py +++ b/cyberbattle/_env/cyberbattle_env.py @@ -8,7 +8,7 @@ import logging import networkx from networkx import convert_matrix -from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict, cast +from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict from gym import spaces, Env from gym.utils import seeding @@ -321,8 +321,8 @@ def __init__(self, bounds: EnvironmentBounds): class CyberBattleSpaceKind(Env[Observation, Action]): - action_space: DiscriminatedUnion # type: ignore - observation_space: ObservationSpaceType # type: ignore + action_space: DiscriminatedUnion # type: ignore + observation_space: ObservationSpaceType # type: ignore class CyberBattleEnv(CyberBattleSpaceKind): @@ -526,7 +526,7 @@ def __init__( maximum_node_count = self.__bounds.maximum_node_count port_count = self.__bounds.port_count - action_spaces = { + action_spaces_dict: dict[str, spaces.Space] = { "local_vulnerability": spaces.MultiDiscrete( # source_node_id, vulnerability_id [maximum_node_count, local_vulnerabilities_count] @@ -547,7 +547,7 @@ def __init__( ), } - self.action_space = DiscriminatedUnion[Action](cast(dict, action_spaces)) # type: ignore + self.action_space = DiscriminatedUnion[Action](spaces=action_spaces_dict) self.observation_space = ObservationSpaceType(self.__bounds) diff --git a/notebooks/stable-baselines-agent.py b/notebooks/stable-baselines-agent.py index 120de05..aea2c1c 100644 --- a/notebooks/stable-baselines-agent.py +++ b/notebooks/stable-baselines-agent.py @@ -9,9 +9,13 @@ import sys from stable_baselines3.a2c.a2c import A2C from stable_baselines3.ppo.ppo import PPO -from cyberbattle._env.flatten_wrapper import FlattenObservationWrapper, FlattenActionWrapper +from cyberbattle._env.flatten_wrapper import ( + FlattenObservationWrapper, + FlattenActionWrapper, +) import os import numpy as np +from stable_baselines3.common.type_aliases import GymEnv os.environ["CUDA_LAUNCH_BLOCKING"] = "1" retrain = ["a2c"] @@ -51,20 +55,22 @@ ] env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields) +env_as_gym = cast(GymEnv, env2) + # %% if "a2c" in retrain: - model_a2c = A2C("MultiInputPolicy", env2).learn(10000) # type: ignore + model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(10000) model_a2c.save("a2c_trained_toyctf") # %% if "ppo" in retrain: - model_ppo = PPO("MultiInputPolicy", env2).learn(100) # type: ignore + model_ppo = PPO("MultiInputPolicy", env_as_gym).learn(100) model_ppo.save("ppo_trained_toyctf") # %% -model = A2C("MultiInputPolicy", env2).load("a2c_trained_toyctf") # type: ignore +model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf") # model = PPO("MultiInputPolicy", env2).load('ppo_trained_toyctf')