-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
26 lines (23 loc) · 819 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import flags
import YRC.core.algorithm as algo_factory
import YRC.core.configs.utils as config_utils
import YRC.core.environment as env_factory
import YRC.core.policy as policy_factory
from YRC.core import Evaluator
if __name__ == "__main__":
args = flags.make()
config = config_utils.load(args.config, flags=args)
envs = env_factory.make(config)
policy = policy_factory.make(config, envs["train"])
evaluator = Evaluator(config.evaluation)
if config.general.algorithm == "always":
evaluator.eval(policy, envs, ["val_sim", "val_true"])
else:
algorithm = algo_factory.make(config, envs["train"])
algorithm.train(
policy,
envs,
evaluator,
train_split="train",
eval_splits=["val_sim", "val_true"],
)