-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathagent.py
43 lines (33 loc) · 1.24 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import math
import os
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from your_act import YourActionParser
import your_obs
from discrete_policy import DiscreteFF
# You can get the OBS size from the rlgym-ppo console print-outs when you start your bot
OBS_SIZE = your_obs_size_here
# If you haven't set these, they are [256, 256, 256] by default
POLICY_LAYER_SIZES = [your, layer, sizes, here]
class Agent:
def __init__(self):
self.action_parser = YourActionParser()
self.num_actions = len(self.action_parser._lookup_table)
cur_dir = os.path.dirname(os.path.realpath(__file__))
device = torch.device("cpu")
self.policy = DiscreteFF(OBS_SIZE, self.num_actions, POLICY_LAYER_SIZES, device)
self.policy.load_state_dict(torch.load(os.path.join(cur_dir, "PPO_POLICY.pt"), map_location=device))
torch.set_num_threads(1)
def act(self, state):
with torch.no_grad():
action_idx, probs = self.policy.get_action(state, True)
action = np.array(self.action_parser.parse_actions([action_idx], None))
if len(action.shape) == 2:
if action.shape[0] == 1:
action = action[0]
if len(action.shape) != 1:
raise Exception("Invalid action:", action)
return action