-
Notifications
You must be signed in to change notification settings - Fork 3
/
main.py
127 lines (108 loc) · 3.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Runnable script with hydra capabilities
"""
# This is a hotfix for tblite (used for the conformer generation) not
# importing correctly unless it is being imported first.
try:
from tblite import interface
except:
pass
import os
import pickle
import random
import sys
import hydra
import pandas as pd
from gflownet.utils.common import chdir_random_subdir
from gflownet.utils.policy import parse_policy_config
@hydra.main(config_path="./config", config_name="main", version_base="1.1")
def main(config):
# TODO: fix race condition in a more elegant way
chdir_random_subdir()
# Get current directory and set it as root log dir for Logger
cwd = os.getcwd()
config.logger.logdir.root = cwd
print(f"\nLogging directory of this run: {cwd}\n")
# Reset seed for job-name generation in multirun jobs
random.seed(None)
# Set other random seeds
set_seeds(config.seed)
# Logger
logger = hydra.utils.instantiate(config.logger, config, _recursive_=False)
# The proxy is required in the env for scoring: might be an oracle or a model
proxy = hydra.utils.instantiate(
config.proxy,
device=config.device,
float_precision=config.float_precision,
)
# The proxy is passed to env and used for computing rewards
env = hydra.utils.instantiate(
config.env,
proxy=proxy,
device=config.device,
float_precision=config.float_precision,
)
# The policy is used to model the probability of a forward/backward action
forward_config = parse_policy_config(config, kind="forward")
backward_config = parse_policy_config(config, kind="backward")
forward_policy = hydra.utils.instantiate(
forward_config,
env=env,
device=config.device,
float_precision=config.float_precision,
)
backward_policy = hydra.utils.instantiate(
backward_config,
env=env,
device=config.device,
float_precision=config.float_precision,
base=forward_policy,
)
gflownet = hydra.utils.instantiate(
config.gflownet,
device=config.device,
float_precision=config.float_precision,
env=env,
forward_policy=forward_policy,
backward_policy=backward_policy,
buffer=config.env.buffer,
logger=logger,
)
gflownet.train()
# Sample from trained GFlowNet
if config.n_samples > 0 and config.n_samples <= 1e5:
batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False)
x_sampled = batch.get_terminating_states(proxy=True)
energies = env.oracle(x_sampled)
x_sampled = batch.get_terminating_states()
df = pd.DataFrame(
{
"readable": [env.state2readable(x) for x in x_sampled],
"energies": energies.tolist(),
}
)
df.to_csv("gfn_samples.csv")
dct = {"x": x_sampled, "energy": energies}
pickle.dump(dct, open("gfn_samples.pkl", "wb"))
# TODO: refactor before merging
dct["conformer"] = [env.set_conformer(state).rdk_mol for state in x_sampled]
pickle.dump(
dct, open(f"conformers_{env.smiles}_{type(env.proxy).__name__}.pkl", "wb")
)
# Print replay buffer
if len(gflownet.buffer.replay) > 0:
print("\nReplay buffer:")
print(gflownet.buffer.replay)
# Close logger
gflownet.logger.end()
def set_seeds(seed):
import numpy as np
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if __name__ == "__main__":
main()
sys.exit()