forked from gmh14/data_efficient_grammar
-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
36 lines (30 loc) · 1.23 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
class Agent(nn.Module):
def __init__(self, feat_dim, hidden_size):
super(Agent, self).__init__()
self.affine1 = nn.Linear(feat_dim + 2, hidden_size)
self.dropout = nn.Dropout(p=0.5)
self.affine2 = nn.Linear(hidden_size, 2)
self.saved_log_probs = {}
def forward(self, x):
x = self.affine1(x)
x = F.relu(x)
scores = self.affine2(x)
return F.softmax(scores, dim=1)
def sample(agent, subgraph_feature, iter_num, sample_number):
# subgraph_feature: N * (2+feat_dim), N is the number of subgraphs inside all inputs
prob = agent(subgraph_feature)
m = Categorical(prob)
a = m.sample()
take_action = (np.sum(a.numpy()) != 0)
if take_action:
if sample_number not in agent.saved_log_probs.keys():
agent.saved_log_probs[sample_number] = {}
if iter_num not in agent.saved_log_probs[sample_number].keys():
agent.saved_log_probs[sample_number][iter_num] = [m.log_prob(a)]
else:
agent.saved_log_probs[sample_number][iter_num].append(m.log_prob(a))
return a.numpy(), take_action