Skip to content

Commit

Permalink
[BugFix] Fix failing tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 1094f462427ce7661d879851e42d24d193c7a20b
Pull Request resolved: #2582
  • Loading branch information
vmoens committed Nov 19, 2024
1 parent 408cf7d commit aec773d
Show file tree
Hide file tree
Showing 18 changed files with 185 additions and 100 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ jobs:
REF_TYPE=${{ github.ref_type }}
REF_NAME=${{ github.ref_name }}
apt-get update
apt-get install rsync -y
if [[ "${REF_TYPE}" == branch ]]; then
if [[ "${REF_NAME}" == main ]]; then
Expand Down
6 changes: 4 additions & 2 deletions sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
OrnsteinUhlenbeckProcessModule(
spec=action_spec,
annealing_num_steps=1_000_000,
).to(device),
device=device,
),
)
elif cfg.network.noise_type == "gaussian":
actor_model_explore = TensorDictSequential(
Expand All @@ -245,7 +246,8 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
sigma_init=1.0,
mean=0.0,
std=0.1,
).to(device),
device=device,
),
)
else:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions sota-implementations/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def make_dreamer(
annealing_num_steps=1,
mean=0.0,
std=cfg.networks.exploration_noise,
device=device,
),
)

Expand Down
1 change: 1 addition & 0 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train(cfg: "DictConfig"): # noqa: F821
spec=env.unbatched_action_spec,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
device=cfg.train.device,
),
)

Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def main(cfg: "DictConfig"): # noqa: F821
annealing_num_steps=cfg.exploration.annealing_frames,
sigma=cfg.exploration.ou_sigma,
theta=cfg.exploration.ou_theta,
).to(device),
device=device,
),
)
if device == torch.device("cpu"):
# mostly for debugging
Expand Down
50 changes: 27 additions & 23 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@
from torchrl._utils import logger as torchrl_logger, VERBOSE
from torchrl.collectors.collectors import DataCollectorBase

from torchrl.data import ReplayBuffer, TensorDictReplayBuffer
from torchrl.data.postprocs import MultiStep
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.data import (
LazyMemmapStorage,
MultiStep,
PrioritizedSampler,
RandomSampler,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import (
from torchrl.envs import (
CatFrames,
CatTensors,
CenterCrop,
Compose,
DMControlEnv,
DoubleToFloat,
env_creator,
EnvBase,
EnvCreator,
FlattenObservation,
GrayScale,
gSDENoise,
GymEnv,
InitTracker,
NoopResetEnv,
ObservationNorm,
ParallelEnv,
Resize,
RewardScaling,
StepCounter,
ToTensorImage,
TransformedEnv,
VecNorm,
)
from torchrl.envs.transforms.transforms import (
FlattenObservation,
gSDENoise,
InitTracker,
StepCounter,
)
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
ActorCriticOperator,
ActorValueOperator,
DdpgCnnActor,
DdpgCnnQNet,
MLP,
NoisyLinear,
NormalParamExtractor,
ProbabilisticActor,
SafeModule,
SafeSequential,
TanhNormal,
ValueOperator,
)
from torchrl.modules.distributions import TanhNormal
from torchrl.modules.distributions.continuous import SafeTanhTransform
from torchrl.modules.models.exploration import LazygSDEModule
from torchrl.modules.models.models import DdpgCnnActor, DdpgCnnQNet, MLP
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.objectives import HardUpdate, SoftUpdate
from torchrl.objectives.common import LossModule
from torchrl.objectives import HardUpdate, LossModule, SoftUpdate, TargetNetUpdater
from torchrl.objectives.deprecated import REDQLoss_deprecated
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers import Logger
from torchrl.record.recorder import VideoRecorder
from torchrl.trainers.helpers import sync_async_collector, sync_sync_collector
Expand Down Expand Up @@ -518,7 +522,7 @@ def make_redq_model(
actor_module = SafeSequential(
actor_module,
SafeModule(
LazygSDEModule(transform=transform),
LazygSDEModule(transform=transform, device=device),
in_keys=["action", gSDE_state_key, "_eps_gSDE"],
out_keys=["loc", "scale", "action", "_eps_gSDE"],
),
Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/td3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def make_td3_agent(cfg, train_env, eval_env, device):
mean=0,
std=0.1,
spec=action_spec,
).to(device),
device=device,
),
)
return model, actor_model_explore

Expand Down
3 changes: 2 additions & 1 deletion sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def make_td3_agent(cfg, train_env, device):
mean=0,
std=0.1,
spec=action_spec,
).to(device),
device=device,
),
)
return model, actor_model_explore

Expand Down
4 changes: 2 additions & 2 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def get_available_devices():
def get_default_devices():
num_cuda = torch.cuda.device_count()
if num_cuda == 0:
if torch.mps.is_available():
return [torch.device("mps:0")]
return [torch.device("cpu")]
elif num_cuda == 1:
return [torch.device("cuda:0")]
elif torch.mps.is_available():
return [torch.device("mps:0")]
else:
# then run on all devices
return get_available_devices()
Expand Down
2 changes: 1 addition & 1 deletion test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
("data", "sample_log_prob"),
],
)
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=3):
def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=1):
env = NestedCountingEnv(nested_dim=nested_dim)
action_spec = Bounded(shape=torch.Size((nested_dim, n_actions)), high=1, low=-1)
policy_module = TensorDictModule(
Expand Down
31 changes: 14 additions & 17 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def test_ou(
self, device, interface, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0
):
torch.manual_seed(seed)
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = Bounded(-torch.ones(d_act), torch.ones(d_act), (d_act,))
Expand All @@ -252,13 +252,13 @@ def test_ou(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)

if interface == "module":
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
ou = OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
exploratory_policy = TensorDictSequential(policy, ou)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
ou = exploratory_policy

tensordict = TensorDict(
Expand Down Expand Up @@ -338,10 +338,10 @@ def test_collector(self, device, parallel_spec, probabilistic, interface, seed=0

if interface == "module":
exploratory_policy = TensorDictSequential(
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec).to(device)
policy, OrnsteinUhlenbeckProcessModule(spec=action_spec, device=device)
)
else:
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy)
exploratory_policy = OrnsteinUhlenbeckProcessWrapper(policy, device=device)
exploratory_policy(env.reset())
collector = SyncDataCollector(
create_env_fn=env,
Expand Down Expand Up @@ -456,10 +456,10 @@ def test_additivegaussian_sd(
device=device,
)
if interface == "module":
exploratory_policy = AdditiveGaussianModule(action_spec).to(device)
exploratory_policy = AdditiveGaussianModule(action_spec, device=device)
else:
net = nn.Sequential(nn.Linear(d_obs, 2 * d_act), NormalParamExtractor()).to(
device
net = nn.Sequential(
nn.Linear(d_obs, 2 * d_act, device=device), NormalParamExtractor()
)
module = SafeModule(
net,
Expand All @@ -473,10 +473,10 @@ def test_additivegaussian_sd(
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
default_interaction_type=InteractionType.RANDOM,
).to(device)
)
given_spec = action_spec if spec_origin == "spec" else None
exploratory_policy = AdditiveGaussianWrapper(policy, spec=given_spec).to(
device
exploratory_policy = AdditiveGaussianWrapper(
policy, spec=given_spec, device=device
)
if spec_origin is not None:
sigma_init = (
Expand Down Expand Up @@ -727,10 +727,7 @@ def test_gsde(
@pytest.mark.parametrize("std", [1, 2])
@pytest.mark.parametrize("sigma_init", [None, 1.5, 3])
@pytest.mark.parametrize("learn_sigma", [False, True])
@pytest.mark.parametrize(
"device",
[torch.device("cuda:0") if torch.cuda.device_count() else torch.device("cpu")],
)
@pytest.mark.parametrize("device", get_default_devices())
def test_gsde_init(sigma_init, state_dim, action_dim, mean, std, device, learn_sigma):
torch.manual_seed(0)
state = torch.randn(10000, *state_dim, device=device) * std + mean
Expand Down
3 changes: 3 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,7 +2076,10 @@ def test_transform_rb(self, rbclass):
):
td = rb.sample(10)

@retry(AssertionError, tries=10, delay=0)
def test_collector_match(self):
torch.manual_seed(0)

# The counter in the collector should match the one from the transform
t = TrajCounter()

Expand Down
5 changes: 3 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):

_iterator = None
total_frames: int
requested_frames_per_batch: int
frames_per_batch: int
trust_policy: bool
compiled_policy: bool
Expand Down Expand Up @@ -305,7 +306,7 @@ def __class_getitem__(self, index):

def __len__(self) -> int:
if self.total_frames > 0:
return -(self.total_frames // -self.frames_per_batch)
return -(self.total_frames // -self.requested_frames_per_batch)
raise RuntimeError("Non-terminating collectors do not have a length")


Expand Down Expand Up @@ -700,7 +701,7 @@ def __init__(
remainder = total_frames % frames_per_batch
if remainder != 0 and RL_WARNINGS:
warnings.warn(
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch})."
f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). "
f"This means {frames_per_batch - remainder} additional frames will be collected."
"To silence this message, set the environment variable RL_WARNINGS to False."
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .batched_envs import ParallelEnv, SerialEnv
from .common import EnvBase, EnvMetaData, make_tensordict
from .custom import PendulumEnv, TicTacToeEnv
from .env_creator import EnvCreator, get_env_metadata
from .env_creator import env_creator, EnvCreator, get_env_metadata
from .gym_like import default_info_dict_reader, GymLikeEnv
from .libs import (
BraxEnv,
Expand Down
9 changes: 9 additions & 0 deletions torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,15 @@ def __init__(
event_shape = param.shape[-1:]
super().__init__(batch_shape=batch_shape, event_shape=event_shape)

def expand(self, batch_shape: torch.Size, _instance=None):
if self.batch_shape != tuple(batch_shape):
return type(self)(
self.param.expand((*batch_shape, *self.event_shape)),
atol=self.atol,
rtol=self.rtol,
)
return self

def update(self, param):
self.param = param

Expand Down
Loading

0 comments on commit aec773d

Please sign in to comment.