Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
William Blum committed Aug 6, 2024
1 parent 4acaa04 commit 3614692
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
10 changes: 5 additions & 5 deletions cyberbattle/_env/cyberbattle_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import networkx
from networkx import convert_matrix
from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict, cast
from typing import NamedTuple, Optional, Tuple, List, Dict, TypeVar, TypedDict

from gym import spaces, Env
from gym.utils import seeding
Expand Down Expand Up @@ -321,8 +321,8 @@ def __init__(self, bounds: EnvironmentBounds):


class CyberBattleSpaceKind(Env[Observation, Action]):
action_space: DiscriminatedUnion # type: ignore
observation_space: ObservationSpaceType # type: ignore
action_space: DiscriminatedUnion # type: ignore
observation_space: ObservationSpaceType # type: ignore


class CyberBattleEnv(CyberBattleSpaceKind):
Expand Down Expand Up @@ -526,7 +526,7 @@ def __init__(
maximum_node_count = self.__bounds.maximum_node_count
port_count = self.__bounds.port_count

action_spaces = {
action_spaces_dict: dict[str, spaces.Space] = {
"local_vulnerability": spaces.MultiDiscrete(
# source_node_id, vulnerability_id
[maximum_node_count, local_vulnerabilities_count]
Expand All @@ -547,7 +547,7 @@ def __init__(
),
}

self.action_space = DiscriminatedUnion[Action](cast(dict, action_spaces)) # type: ignore
self.action_space = DiscriminatedUnion[Action](spaces=action_spaces_dict)

self.observation_space = ObservationSpaceType(self.__bounds)

Expand Down
14 changes: 10 additions & 4 deletions notebooks/stable-baselines-agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
import sys
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.ppo.ppo import PPO
from cyberbattle._env.flatten_wrapper import FlattenObservationWrapper, FlattenActionWrapper
from cyberbattle._env.flatten_wrapper import (
FlattenObservationWrapper,
FlattenActionWrapper,
)
import os
import numpy as np
from stable_baselines3.common.type_aliases import GymEnv

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
retrain = ["a2c"]
Expand Down Expand Up @@ -51,20 +55,22 @@
]
env2 = FlattenObservationWrapper(cast(CyberBattleEnv, env1), ignore_fields=ignore_fields)

env_as_gym = cast(GymEnv, env2)

# %%
if "a2c" in retrain:
model_a2c = A2C("MultiInputPolicy", env2).learn(10000) # type: ignore
model_a2c = A2C("MultiInputPolicy", env_as_gym).learn(10000)
model_a2c.save("a2c_trained_toyctf")


# %%
if "ppo" in retrain:
model_ppo = PPO("MultiInputPolicy", env2).learn(100) # type: ignore
model_ppo = PPO("MultiInputPolicy", env_as_gym).learn(100)
model_ppo.save("ppo_trained_toyctf")


# %%
model = A2C("MultiInputPolicy", env2).load("a2c_trained_toyctf") # type: ignore
model = A2C("MultiInputPolicy", env_as_gym).load("a2c_trained_toyctf")
# model = PPO("MultiInputPolicy", env2).load('ppo_trained_toyctf')


Expand Down

0 comments on commit 3614692

Please sign in to comment.