Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 20, 2024
1 parent 2f5da22 commit e58c5ed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def get_available_devices():
def get_default_devices():
num_cuda = torch.cuda.device_count()
if num_cuda == 0:
if torch.mps.is_available():
return [torch.device("mps:0")]
# if torch.mps.is_available():
# return [torch.device("mps:0")]
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
Expand Down
38 changes: 19 additions & 19 deletions torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def __init__(

super().__init__()

self.register_buffer("eps_init", torch.as_tensor([eps_init]))
self.register_buffer("eps_end", torch.as_tensor([eps_end]))
self.register_buffer("eps_init", torch.as_tensor(eps_init))
self.register_buffer("eps_end", torch.as_tensor(eps_end))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32))
self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32))

if spec is not None:
if not isinstance(spec, Composite) and len(self.out_keys) >= 1:
Expand Down Expand Up @@ -275,13 +275,13 @@ def __init__(
super().__init__(policy)
if sigma_end > sigma_init:
raise RuntimeError("sigma should decrease over time or be constant")
self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean], device=device))
self.register_buffer("std", torch.tensor([std], device=device))
self.register_buffer("mean", torch.tensor(mean, device=device))
self.register_buffer("std", torch.tensor(std, device=device))
self.register_buffer(
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
"sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
)
self.action_key = action_key
self.out_keys = list(self.td_module.out_keys)
Expand Down Expand Up @@ -423,13 +423,13 @@ def __init__(

super().__init__()

self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device))
self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device))
self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device))
self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device))
self.annealing_num_steps = annealing_num_steps
self.register_buffer("mean", torch.tensor([mean], device=device))
self.register_buffer("std", torch.tensor([std], device=device))
self.register_buffer("mean", torch.tensor(mean, device=device))
self.register_buffer("std", torch.tensor(std, device=device))
self.register_buffer(
"sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device)
"sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device)
)

if spec is not None:
Expand Down Expand Up @@ -628,16 +628,16 @@ def __init__(
key=action_key,
device=device,
)
self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer(
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
"eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
)
self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys
self.is_init_key = is_init_key
Expand Down Expand Up @@ -840,16 +840,16 @@ def __init__(
device=device,
)

self.register_buffer("eps_init", torch.tensor([eps_init], device=device))
self.register_buffer("eps_end", torch.tensor([eps_end], device=device))
self.register_buffer("eps_init", torch.tensor(eps_init, device=device))
self.register_buffer("eps_end", torch.tensor(eps_end, device=device))
if self.eps_end > self.eps_init:
raise ValueError(
"eps should decrease over time or be constant, "
f"got eps_init={eps_init} and eps_end={eps_end}"
)
self.annealing_num_steps = annealing_num_steps
self.register_buffer(
"eps", torch.tensor([eps_init], dtype=torch.float32, device=device)
"eps", torch.tensor(eps_init, dtype=torch.float32, device=device)
)

self.in_keys = [self.ou.key]
Expand Down

0 comments on commit e58c5ed

Please sign in to comment.