-
Notifications
You must be signed in to change notification settings - Fork 9
/
tedvae_ihdp.py
86 lines (67 loc) · 3.6 KB
/
tedvae_ihdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import argparse
import logging
import pandas as pd
import torch
import pyro
from tedvae_gpu import TEDVAE
from datasets import IHDP
import numpy as np
logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)
def main(args,reptition = 1, path = "./IHDP/"):
pyro.enable_validation(__debug__)
# if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
# Generate synthetic data.
pyro.set_rng_seed(args.seed)
train, test, contfeats, binfeats = IHDP(path = path, reps = reptition, cuda = True)
(x_train, t_train, y_train), true_ite_train = train
(x_test, t_test, y_test), true_ite_test = test
ym, ys = y_train.mean(), y_train.std()
y_train = (y_train - ym) / ys
# Train.
pyro.set_rng_seed(args.seed)
pyro.clear_param_store()
tedvae = TEDVAE(feature_dim=args.feature_dim, continuous_dim= contfeats, binary_dim = binfeats,
latent_dim=args.latent_dim, latent_dim_t = args.latent_dim_t, latent_dim_y = args.latent_dim_y,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
num_samples=10)
tedvae.fit(x_train, t_train, y_train,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
learning_rate_decay=args.learning_rate_decay, weight_decay=args.weight_decay)
# Evaluate.
est_ite = tedvae.ite(x_test, ym, ys)
est_ite_train = tedvae.ite(x_train, ym, ys)
pehe = np.sqrt(np.mean((true_ite_test.squeeze()-est_ite.cpu().numpy())*(true_ite_test.squeeze()-est_ite.cpu().numpy())))
pehe_train = np.sqrt(np.mean((true_ite_train.squeeze()-est_ite_train.cpu().numpy())*(true_ite_train.squeeze()-est_ite_train.cpu().numpy())))
print("PEHE_train = {:0.3g}".format(pehe_train))
print("PEHE = {:0.3g}".format(pehe))
return pehe, pehe_train
if __name__ == "__main__":
# assert pyro.__version__.startswith('1.3.0')
parser = argparse.ArgumentParser(description="TEDVAE")
parser.add_argument("--feature-dim", default=25, type=int)
parser.add_argument("--latent-dim", default=20, type=int)
parser.add_argument("--latent-dim-t", default=10, type=int)
parser.add_argument("--latent-dim-y", default=10, type=int)
parser.add_argument("--hidden-dim", default=500, type=int)
parser.add_argument("--num-layers", default=4, type=int)
parser.add_argument("-n", "--num-epochs", default=200, type=int)
parser.add_argument("-b", "--batch-size", default=1000, type=int)
parser.add_argument("-lr", "--learning-rate", default=1e-3, type=float)
parser.add_argument("-lrd", "--learning-rate-decay", default=0.01, type=float)
parser.add_argument("--weight-decay", default=1e-4, type=float)
parser.add_argument("--seed", default=1234567890, type=int)
parser.add_argument("--jit", action="store_true")
parser.add_argument("--cuda", action="store_true")
args = parser.parse_args()
# tedvae_pehe = main(args)
tedvae_pehe = np.zeros((100,1))
tedvae_pehe_train = np.zeros((100,1))
path = "./IHDP_b/"
for i in range(100):
print("Dataset {:d}".format(i+1))
tedvae_pehe[i,0], tedvae_pehe_train[i,0] = main(args,i+1, path)