-
Notifications
You must be signed in to change notification settings - Fork 0
/
resume_sac_heli.py
32 lines (28 loc) · 974 Bytes
/
resume_sac_heli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/env python3
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import optuna
from garage import wrap_experiment
from garage.envs import GymEnv, normalize
from garage.experiment import deterministic
from garage.replay_buffer import PathBuffer
from garage.sampler import FragmentWorker, RaySampler
from garage.torch import set_gpu_mode
from garage.torch.algos import SAC
from garage.torch.policies import TanhGaussianMLPPolicy
from garage.torch.q_functions import ContinuousMLPQFunction
from garage.trainer import Trainer
import csv
import logging
import sys
import garage
from garage.experiment.deterministic import set_seed
@wrap_experiment
def sac_helicopter_resume(ctxt=None, snapshot_dir="data/local/experiment/sac_helicopter_251", seed=1):
set_seed(seed)
trainer = Trainer(snapshot_config=ctxt)
trainer.restore(snapshot_dir)
trainer.resume(n_epochs=2, batch_size=128)
sac_helicopter_resume(seed=521)