Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Stack transform #2567

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kurtamohler
Copy link
Collaborator

Description

Adds a transform that stacks tensors and specs from different keys of a tensordict into a common key.

Motivation and Context

close #2566

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Nov 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2567

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Pending, 18 Unrelated Failures

As of commit cec32d2 with merge base b4b5944 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 14, 2024
@kurtamohler
Copy link
Collaborator Author

Looks like there is a minor bug if I try to use this on UnityMLAgentsEnv and then do a rollout. I'll fix that and add a test

@vmoens vmoens added the enhancement New feature or request label Nov 15, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, long awaited feature!
Just left a couple of comments on the default dim and test set

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
test/test_transforms.py Show resolved Hide resolved
@kurtamohler kurtamohler force-pushed the Stack-Transform-0 branch 3 times, most recently from 23f7e1b to f443812 Compare November 19, 2024 05:37
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this is superb
I'd like to discuss the inverse transform:
Would it make sense in the inverse to get an entry (from the input_spec) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action"), ("agent1", "action"). The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?

torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
torchrl/envs/transforms/transforms.py Outdated Show resolved Hide resolved
@kurtamohler
Copy link
Collaborator Author

kurtamohler commented Nov 20, 2024

Would it make sense in the inverse to get an entry (from the input_spec) and unbind it?
Like: you have a single action with leading dim of 2, and map it to ("agent0", "action"), ("agent1", "action"). The spec seen from the outside is the stack of the 2 specs (as it is for stuff processed in forward).
Would that make sense?

Yes, I think that does make sense, and that is exactly what happens. For instance, if I add a line to print the output of Stack._inv_call at the end of the function and then run the following script:

Click to expand
from torchrl.envs import Stack, TransformedEnv
from torchrl.envs import UnityMLAgentsEnv

base_env = UnityMLAgentsEnv(registered_name='3DBall')

try:
    t = Stack(
        in_keys=[("group_0", f"agent_{idx}") for idx in range(12)],
        out_key="agents",
    )   
    env = TransformedEnv(base_env, t)
    action = env.full_action_spec.rand()
    print("-------------------------")
    print(action)
    print("-------------------------")
    env.step(action)

finally:
    base_env.close()

I get the following output:

Click to expand
-------------------------
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                continuous_action: Tensor(shape=torch.Size([12, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([12]),
            device=None,
            is_shared=False),
        group_0: TensorDict(
            fields={
            },
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
-------------------------
TensorDict(
    fields={
        group_0: TensorDict(
            fields={
                agent_0: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_10: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_11: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_1: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_2: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_3: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_4: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_5: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_6: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_7: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_8: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False),
                agent_9: TensorDict(
                    fields={
                        continuous_action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

Which shows that the inverse of Stack is able to correctly unbind the stacked actions into the format that the unity env expects. BTW, this functionality is being tested in the environment unit test I added.

But one thing from the above example that I'd like to fix is that env.full_action_spec still contains the "group_0" key, with an empty Composite spec, which is left over from exclude-ing all the keys under that group. I'm not sure what's the most efficient way to prune empty specs/tensordicts like those.

EDIT: I found a solution that I think is good--during __init__, build up a list of all the parent keys of in_keys. In the above example, that list would just be [("group_0",)]. Then in _transform_spec and _call, after deleting the in_keys from a spec/tensordict, check if we need to delete any of the parent keys as well. I'll push an update when I have it working and tested.

@kurtamohler kurtamohler added the Environments Adds or modifies an environment wrapper label Nov 20, 2024
@vmoens
Copy link
Contributor

vmoens commented Nov 21, 2024

Ah got it, sorry I was expecting something like in_keys or such. Maybe let's make clear in the doc strings that in_keys can be part of the input (usually it's reserved to output keys)

Maybe this? If that's what you need you can mention it in the doc of the transform.

https://pytorch.org/rl/main/reference/generated/torchrl.envs.transforms.RemoveEmptySpecs.html#torchrl.envs.transforms.RemoveEmptySpecs

@kurtamohler
Copy link
Collaborator Author

Ah got it, sorry I was expecting something like in_keys or such. Maybe let's make clear in the doc strings that in_keys can be part of the input (usually it's reserved to output keys)

I'm sorry, I'm not sure what you mean by this. in_keys is always supposed to be part of the input, isn't it? Are you talking about the inverse operation here, or the leftover empty specs/tensordicts, or something else?

RemoveEmptySpecs would do what I want, but it's not very efficient since it has to check through the entire spec/tensordict. Wouldn't it be better if Stack has the responsibility of cleaning up any empty keys left over? I don't think the way I implemented it is overly complicated, and it should give significantly better performance than RemoveEmptySpecs, since it knows which keys to check ahead of time. But I suppose one could make the argument that in many cases, it isn't really necessary to remove the empty specs/tensordicts, and not removing them would be more efficient anyway

@vmoens
Copy link
Contributor

vmoens commented Nov 22, 2024

I'm sorry, I'm not sure what you mean by this. in_keys is always supposed to be part of the input, isn't it? Are you talking about the inverse operation here, or the leftover empty specs/tensordicts, or something else?

Yeah usually anything you iterate over during inverse pass is passed through the inv_keys and anything you process during forward is in the in_keys.
It's useful to keep them separated because it might as well be the case that there is a key in the input named as an output but you just want to process it one way and not the other.

@vmoens
Copy link
Contributor

vmoens commented Nov 22, 2024

I don't think the way I implemented it is overly complicated, and it should give significantly better performance than RemoveEmptySpecs, since it knows which keys to check ahead of time.

Ok that sounds good!

@kurtamohler
Copy link
Collaborator Author

Yeah usually anything you iterate over during inverse pass is passed through the inv_keys and anything you process during forward is in the in_keys.

What I think you're saying is that most of the existing transforms that have an inverse allow the user to explicitly set the inverse keys in __init__. But with Stack, the inverse keys are automatically assigned to in_key_inv = out_key and out_keys_inv = in_keys. So we should document that difference. Is that what you're saying?

I guess we also might as well allow the user to override the default in_key_inv and out_keys_inv.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request Environments Adds or modifies an environment wrapper
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Transform that stacks data for agents with identical specs
3 participants