Skip to content

Commit

Permalink
[Feature] Composite.separates
Browse files Browse the repository at this point in the history
ghstack-source-id: fbfc4308a81cd96ecc61723df8c0eb870c442def
Pull Request resolved: #2599
  • Loading branch information
vmoens committed Nov 24, 2024
1 parent 8d16c12 commit 83e0b05
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4339,6 +4339,33 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any:
return default
raise KeyError(f"{key} not found in composite spec.")

def separates(self, *keys: NestedKey, default: Any = None) -> Composite:
"""Splits the composite spec by extracting specified keys and their associated values into a new composite spec.
This method iterates over the provided keys, removes them from the current composite spec, and adds them to a new
composite spec. If a key is not found, the specified default value is used. The new composite spec is returned.
Args:
*keys (NestedKey):
One or more keys to be extracted from the composite spec. Each key can be a single key or a nested key.
default (Any, optional):
The value to use if a specified key is not found in the composite spec. Defaults to `None`.
Returns:
Composite: A new composite spec containing the extracted keys and their associated values.
Note:
If none of the specified keys are found, the method returns `None`.
"""
out = None
for key in keys:
result = self.pop(key, default=default)
if result is not None:
if out is None:
out = Composite(batch_size=self.batch_size, device=self.device)
out[key] = result
return out

def set(self, name, spec):
if self.locked:
raise RuntimeError("Cannot modify a locked Composite.")
Expand Down

0 comments on commit 83e0b05

Please sign in to comment.