-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
26 lines (22 loc) · 1.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Import environment libraries
from utils import startGameRand, startGameModel
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from utils import SaveOnBestTrainingRewardCallback
from stable_baselines3 import PPO
# Import preprocessing wrappers
from gym.wrappers import GrayScaleObservation
from stable_baselines3.common.vec_env import VecFrameStack, DummyVecEnv, VecMonitor
from matplotlib import pyplot as plt
# Start the environment
env = gym_super_mario_bros.make('SuperMarioBros-v0') # Generates the environment
env = JoypadSpace(env, SIMPLE_MOVEMENT) # Limits the joypads moves with important moves
# Apply the preprocessing
env = GrayScaleObservation(env, keep_dim=True) # Convert to grayscale to reduce dimensionality
env = DummyVecEnv([lambda: env])
# Alternatively, you may use SubprocVecEnv for multiple CPU processors
env = VecFrameStack(env, 4, channels_order='last') # Stack frames
env = VecMonitor(env, "./train/TestMonitor") # Monitor your progress
model = PPO.load('/home/ercihan/Desktop/EE449/HW3/best_model')
startGameModel(env, model)