This repository has been archived by the owner on Apr 15, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiscriminator.py
58 lines (49 loc) · 1.77 KB
/
discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import argparse
from numpy import genfromtxt
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc = torch.nn.Linear(12,1)
self.relu = torch.nn.ReLU()
def forward(self, x):
output = self.fc(x)
output = self.relu(x)
return output
class Dataset(torch.utils.data.Dataset):
#TODO
def __init__(self, dir_path):
self.data = genfromtxt(dir_path,delimiter=';')
def __len__(self):
return
def __getitem__(self, idx):
dataset = torch.as_tensor(self.data, dtype=torch.float32)
x_train, y_train, x_test, y_test = dataset[idx]
return x_train, y_train, x_test, y_test
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default="train",
help="One of 'train' or 'test'.")
flags, unparsed = parser.parse_known_args()
model = Discriminator()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
x_train, y_train, x_test, y_test = Dataset()
if flags.mode == "train":
model.train()
epoch = 20
for epoch in range(epoch):
optimizer.zero_grad()
# Forward pass
y_pred = model(x_train)
# Compute Loss
loss = criterion(y_pred.squeeze(), y_train)
print('Epoch {}: train loss: {}'.format(epoch, loss.item()))
# Backward pass
loss.backward()
optimizer.step()
elif flags.mode == "test":
model.eval()
y_pred = model(x_test)
before_train = criterion(y_pred.squeeze(), y_test)
print('Test loss before training' , before_train.item())