-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmodel.py
47 lines (40 loc) · 1.46 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
Example Neural Network Model for Vector Observation DQN Agent
DQN Model for Unity ML-Agents Environments using PyTorch
Example Developed By:
Michael Richardson, 2018
Project for Udacity Danaodgree in Deep Reinforcement Learning (DRL)
Code expanded and adapted from code examples provided by Udacity DRL Team, 2018.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
"""
#################################################
Initialize neural network model
Initialize parameters and build model.
"""
def __init__(self, state_size, action_size, seed, fc1_units=128, fc2_units=128):
"""
Params
======
state_size (int): Dimension of each state
action_size (int): Dimension of each action
seed (int): Random seed
fc1_units (int): Number of nodes in first hidden layer
fc2_units (int): Number of nodes in second hidden layer
"""
super(QNetwork, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(state_size, fc1_units)
self.fc2 = nn.Linear(fc1_units, fc2_units)
self.fc3 = nn.Linear(fc2_units, action_size)
"""
###################################################
Build a network that maps state -> action values.
"""
def forward(self, state):
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
return self.fc3(x)