Skip to content

Commit 244ab98

Browse files
committed
[Lint] pyupgrade
ghstack-source-id: 015ce04cccdeac16b8592f4a17377b2649b12c52 Pull Request resolved: #2819
1 parent 36b6b9c commit 244ab98

File tree

131 files changed

+1026
-1418
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+1026
-1418
lines changed

Diff for: .github/unittest/helpers/coverage_run_parallel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None:
2828
argv: Arguments passed to this script, which need to be converted to config file entries
2929
"""
3030
assert not config_path.exists(), "Temporary coverage config exists already"
31-
cmdline = " ".join(shlex.quote(arg) for arg in argv[1:])
32-
with open(str(config_path), "wt", encoding="utf-8") as fh:
31+
cmdline = shlex.join(argv[1:])
32+
with open(str(config_path), "w", encoding="utf-8") as fh:
3333
fh.write(
3434
f"""# .coveragerc to control coverage.py
3535
[run]

Diff for: .pre-commit-config.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,9 @@ repos:
3535
hooks:
3636
- id: pydocstyle
3737
files: ^torchrl/
38+
39+
- repo: https://github.com/asottile/pyupgrade
40+
rev: v3.9.0
41+
hooks:
42+
- id: pyupgrade
43+
args: [--py38-plus]

Diff for: build_tools/setup_helpers/extension.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
import platform
99
import subprocess
1010
from pathlib import Path
11-
from subprocess import CalledProcessError, check_output, STDOUT
11+
from subprocess import CalledProcessError, STDOUT, check_output
1212

1313
import torch
1414
from setuptools import Extension
1515
from setuptools.command.build_ext import build_ext
1616

17-
1817
_THIS_DIR = Path(__file__).parent.resolve()
1918
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
2019
_TORCHRL_DIR = _ROOT_DIR / "torchrl"
@@ -130,7 +129,7 @@ def build_extension(self, ext):
130129
# using -j in the build_ext call, not supported by pip or PyPA-build.
131130
if hasattr(self, "parallel") and self.parallel:
132131
# CMake 3.12+ only.
133-
build_args += ["-j{}".format(self.parallel)]
132+
build_args += [f"-j{self.parallel}"]
134133

135134
if not os.path.exists(self.build_temp):
136135
os.makedirs(self.build_temp)

Diff for: setup.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def get_version():
3232
version_txt = os.path.join(cwd, "version.txt")
33-
with open(version_txt, "r") as f:
33+
with open(version_txt) as f:
3434
version = f.readline().strip()
3535
if os.getenv("TORCHRL_BUILD_VERSION"):
3636
version = os.getenv("TORCHRL_BUILD_VERSION")
@@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
6464
def write_version_file(version):
6565
version_path = os.path.join(cwd, "torchrl", "version.py")
6666
with open(version_path, "w") as f:
67-
f.write("__version__ = '{}'\n".format(version))
68-
f.write("git_version = {}\n".format(repr(sha)))
67+
f.write(f"__version__ = '{version}'\n")
68+
f.write(f"git_version = {repr(sha)}\n")
6969

7070

7171
def _get_pytorch_version(is_nightly, is_local):
@@ -185,7 +185,7 @@ def _main(argv):
185185
version = get_version()
186186
write_version_file(version)
187187
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
188-
logging.info("Building wheel {}-{}".format(package_name, version))
188+
logging.info(f"Building wheel {package_name}-{version}")
189189
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")
190190

191191
is_local = TORCHRL_BUILD_VERSION is None

Diff for: sota-implementations/a2c/a2c_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

Diff for: sota-implementations/a2c/a2c_mujoco.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

Diff for: sota-implementations/cql/cql_offline.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616
import hydra
1717
import numpy as np
18-
1918
import torch
2019
import tqdm
2120
from tensordict.nn import CudaGraphModule
22-
2321
from torchrl._utils import timeit
2422
from torchrl.envs.utils import ExplorationType, set_exploration_type
2523
from torchrl.objectives import group_optimizers
2624
from torchrl.record.loggers import generate_exp_name, get_logger
27-
2825
from utils import (
2926
dump_video,
3027
log_metrics,
@@ -39,7 +36,7 @@
3936

4037

4138
@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
42-
def main(cfg: "DictConfig"): # noqa: F821
39+
def main(cfg: DictConfig): # noqa: F821
4340
# Create logger
4441
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
4542
logger = None

Diff for: sota-implementations/cql/cql_online.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
2524
from torchrl._utils import timeit
2625
from torchrl.envs.utils import ExplorationType, set_exploration_type
2726
from torchrl.objectives import group_optimizers
2827
from torchrl.record.loggers import generate_exp_name, get_logger
29-
3028
from utils import (
3129
dump_video,
3230
log_metrics,
@@ -42,7 +40,7 @@
4240

4341

4442
@hydra.main(version_base="1.1", config_path="", config_name="online_config")
45-
def main(cfg: "DictConfig"): # noqa: F821
43+
def main(cfg: DictConfig): # noqa: F821
4644
# Create logger
4745
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
4846
logger = None

Diff for: sota-implementations/cql/discrete_cql_online.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616

1717
import hydra
1818
import numpy as np
19-
2019
import torch
2120
import torch.cuda
2221
import tqdm
2322
from tensordict.nn import CudaGraphModule
24-
2523
from torchrl._utils import timeit
26-
2724
from torchrl.envs.utils import ExplorationType, set_exploration_type
28-
2925
from torchrl.record.loggers import generate_exp_name, get_logger
3026
from utils import (
3127
log_metrics,
@@ -41,7 +37,7 @@
4137

4238

4339
@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
44-
def main(cfg: "DictConfig"): # noqa: F821
40+
def main(cfg: DictConfig): # noqa: F821
4541
device = cfg.optim.device
4642
if device in ("", None):
4743
if torch.cuda.is_available():

Diff for: sota-implementations/crossq/crossq.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
import warnings
1616

1717
import hydra
18-
1918
import numpy as np
20-
2119
import torch
2220
import torch.cuda
2321
import tqdm
2422
from tensordict import TensorDict
2523
from tensordict.nn import CudaGraphModule
26-
2724
from torchrl._utils import timeit
2825
from torchrl.envs.utils import ExplorationType, set_exploration_type
2926
from torchrl.objectives import group_optimizers
30-
3127
from torchrl.record.loggers import generate_exp_name, get_logger
3228
from utils import (
3329
log_metrics,
@@ -43,7 +39,7 @@
4339

4440

4541
@hydra.main(version_base="1.1", config_path=".", config_name="config")
46-
def main(cfg: "DictConfig"): # noqa: F821
42+
def main(cfg: DictConfig): # noqa: F821
4743
device = cfg.network.device
4844
if device in ("", None):
4945
if torch.cuda.is_available():

Diff for: sota-implementations/ddpg/ddpg.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
import warnings
1616

1717
import hydra
18-
1918
import numpy as np
2019
import torch
2120
import torch.cuda
2221
import tqdm
2322
from tensordict import TensorDict
2423
from tensordict.nn import CudaGraphModule
25-
2624
from torchrl._utils import timeit
27-
2825
from torchrl.envs.utils import ExplorationType, set_exploration_type
2926
from torchrl.objectives import group_optimizers
3027
from torchrl.record.loggers import generate_exp_name, get_logger
@@ -41,7 +38,7 @@
4138

4239

4340
@hydra.main(version_base="1.1", config_path="", config_name="config")
44-
def main(cfg: "DictConfig"): # noqa: F821
41+
def main(cfg: DictConfig): # noqa: F821
4542
device = cfg.optim.device
4643
if device in ("", None):
4744
if torch.cuda.is_available():

Diff for: sota-implementations/decision_transformer/dt.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
from tensordict.nn import CudaGraphModule
2020
from torchrl._utils import logger as torchrl_logger, timeit
2121
from torchrl.envs.libs.gym import set_gym_backend
22-
2322
from torchrl.envs.utils import ExplorationType, set_exploration_type
2423
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
2524
from torchrl.record import VideoRecorder
26-
2725
from utils import (
2826
dump_video,
2927
log_metrics,
@@ -37,7 +35,7 @@
3735

3836

3937
@hydra.main(config_path="", config_name="dt_config", version_base="1.1")
40-
def main(cfg: "DictConfig"): # noqa: F821
38+
def main(cfg: DictConfig): # noqa: F821
4139
set_gym_backend(cfg.env.backend).set()
4240

4341
model_device = cfg.optim.device

Diff for: sota-implementations/decision_transformer/online_dt.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torchrl.envs.utils import ExplorationType, set_exploration_type
2121
from torchrl.modules.tensordict_module import DecisionTransformerInferenceWrapper
2222
from torchrl.record import VideoRecorder
23-
2423
from utils import (
2524
dump_video,
2625
log_metrics,
@@ -34,7 +33,7 @@
3433

3534

3635
@hydra.main(config_path="", config_name="odt_config", version_base="1.1")
37-
def main(cfg: "DictConfig"): # noqa: F821
36+
def main(cfg: DictConfig): # noqa: F821
3837
set_gym_backend(cfg.env.backend).set()
3938

4039
model_device = cfg.optim.device

Diff for: sota-implementations/discrete_sac/discrete_sac.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
@hydra.main(version_base="1.1", config_path="", config_name="config")
41-
def main(cfg: "DictConfig"): # noqa: F821
41+
def main(cfg: DictConfig): # noqa: F821
4242
device = cfg.network.device
4343
if device in ("", None):
4444
if torch.cuda.is_available():

Diff for: sota-implementations/dqn/dqn_atari.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import tqdm
1919
from tensordict.nn import CudaGraphModule, TensorDictSequential
2020
from torchrl._utils import timeit
21-
2221
from torchrl.collectors import SyncDataCollector
2322
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
2423
from torchrl.envs import ExplorationType, set_exploration_type
@@ -32,7 +31,7 @@
3231

3332

3433
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
35-
def main(cfg: "DictConfig"): # noqa: F821
34+
def main(cfg: DictConfig): # noqa: F821
3635

3736
device = cfg.device
3837
if device in ("", None):

Diff for: sota-implementations/dqn/dqn_cartpole.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch.nn
1212
import torch.optim
1313
import tqdm
14-
1514
from tensordict.nn import CudaGraphModule, TensorDictSequential
1615
from torchrl._utils import timeit
1716
from torchrl.collectors import SyncDataCollector
@@ -27,7 +26,7 @@
2726

2827

2928
@hydra.main(config_path="", config_name="config_cartpole", version_base="1.1")
30-
def main(cfg: "DictConfig"): # noqa: F821
29+
def main(cfg: DictConfig): # noqa: F821
3130

3231
device = cfg.device
3332
if device in ("", None):

Diff for: sota-implementations/dreamer/dreamer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@
2020
make_environments,
2121
make_replay_buffer,
2222
)
23-
2423
# mixed precision training
2524
from torch.amp import GradScaler
2625
from torch.nn.utils import clip_grad_norm_
2726
from torchrl._utils import logger as torchrl_logger, timeit
2827
from torchrl.envs.utils import ExplorationType, set_exploration_type
2928
from torchrl.modules import RSSMRollout
30-
3129
from torchrl.objectives.dreamer import (
3230
DreamerActorLoss,
3331
DreamerModelLoss,
@@ -37,7 +35,7 @@
3735

3836

3937
@hydra.main(version_base="1.1", config_path="", config_name="config")
40-
def main(cfg: "DictConfig"): # noqa: F821
38+
def main(cfg: DictConfig): # noqa: F821
4139
# cfg = correct_for_frame_skip(cfg)
4240

4341
device = _default_device(cfg.networks.device)

Diff for: sota-implementations/gail/gail.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,25 @@
1717
import numpy as np
1818
import torch
1919
import tqdm
20-
2120
from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
2221
from ppo_utils import eval_model, make_env, make_ppo_models
2322
from tensordict.nn import CudaGraphModule
24-
2523
from torchrl._utils import compile_with_warmup, timeit
2624
from torchrl.collectors import SyncDataCollector
2725
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
2826
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
29-
3027
from torchrl.envs import set_gym_backend
3128
from torchrl.envs.utils import ExplorationType, set_exploration_type
3229
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
3330
from torchrl.objectives.value.advantages import GAE
3431
from torchrl.record import VideoRecorder
3532
from torchrl.record.loggers import generate_exp_name, get_logger
3633

37-
3834
torch.set_float32_matmul_precision("high")
3935

4036

4137
@hydra.main(config_path="", config_name="config")
42-
def main(cfg: "DictConfig"): # noqa: F821
38+
def main(cfg: DictConfig): # noqa: F821
4339
set_gym_backend(cfg.env.backend).set()
4440

4541
device = cfg.gail.device

Diff for: sota-implementations/gail/gail_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66

77
import torch.nn as nn
88
import torch.optim
9-
109
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
1110
from torchrl.data.replay_buffers import SamplerWithoutReplacement
1211
from torchrl.envs import DoubleToFloat
13-
1412
from torchrl.modules import SafeModule
1513

1614

@@ -45,7 +43,7 @@ def make_gail_discriminator(cfg, train_env, device="cpu"):
4543
# Define Discriminator Network
4644
class Discriminator(nn.Module):
4745
def __init__(self, state_dim, action_dim):
48-
super(Discriminator, self).__init__()
46+
super().__init__()
4947
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
5048
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
5149
self.fc3 = nn.Linear(hidden_dim, 1)

Diff for: sota-implementations/impala/impala_multi_node_ray.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@hydra.main(config_path="", config_name="config_multi_node_ray", version_base="1.1")
17-
def main(cfg: "DictConfig"): # noqa: F821
17+
def main(cfg: DictConfig): # noqa: F821
1818

1919
import time
2020

Diff for: sota-implementations/impala/impala_multi_node_submitit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
@hydra.main(
1717
config_path="", config_name="config_multi_node_submitit", version_base="1.1"
1818
)
19-
def main(cfg: "DictConfig"): # noqa: F821
19+
def main(cfg: DictConfig): # noqa: F821
2020

2121
import time
2222

0 commit comments

Comments
 (0)