Skip to content

Commit

Permalink
Adding simple aggregations: concat & stack (NVIDIA-Merlin#1092)
Browse files Browse the repository at this point in the history
* Adding improved doc-strings

* Adding torch github-action + add copyright

* Adding simple aggregations: Concat & Stack

* Adding MaybeAgg for use in places like MLPBlock

* Add sorting to doc-string as pointed out in PR review

* Fixing linting issues
  • Loading branch information
marcromeyn authored May 29, 2023
1 parent 7a6086d commit b5dff16
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 1 deletion.
6 changes: 5 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@
#

from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import BinaryOutput
from merlin.models.torch.outputs.regression import RegressionOutput
from merlin.models.torch.transforms.agg import Concat, Stack

__all__ = [
"Batch",
"Concat",
"BinaryOutput",
"Block",
"ModelOutput",
"ParallelBlock",
"RegressionOutput",
"Sequence",
"Stack",
]
Empty file.
197 changes: 197 additions & 0 deletions merlin/models/torch/transforms/agg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from typing import Dict, Union

import torch
from torch import nn

from merlin.models.torch.registry import registry


@registry.register("concat")
class Concat(nn.Module):
"""Concatenate tensors along a specified dimension.
Parameters
----------
dim : int
The dimension along which the tensors will be concatenated.
Default is -1.
Examples
--------
>>> concat = Concat()
>>> feature1 = torch.tensor([[1, 2], [3, 4]]) # Shape: [batch_size, feature_dim]
>>> feature2 = torch.tensor([[5, 6], [7, 8]]) # Shape: [batch_size, feature_dim]
>>> input_dict = {"feature1": feature1, "feature2": feature2}
>>> output = concat(input_dict)
>>> print(output)
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]]) # Shape: [batch_size, feature_dim*number_of_features]
"""

def __init__(self, dim: int = -1):
super().__init__()
self.dim = dim

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Concatenates input tensors along the specified dimension.
The input dictionary will be sorted by name before concatenation.
Parameters
----------
inputs : Dict[str, torch.Tensor]
A dictionary where keys are the names of the tensors
and values are the tensors to be concatenated.
Returns
-------
torch.Tensor
A tensor that is the result of concatenating
the input tensors along the specified dimension.
Raises
------
RuntimeError
If the input tensor shapes don't match for concatenation
along the specified dimension.
"""
sorted_tensors = [inputs[name] for name in sorted(inputs.keys())]
# TODO: Fix this for dim=-1
if self.dim > 0:
if not all(
(
t.shape[: self.dim] == sorted_tensors[0].shape[: self.dim]
and t.shape[self.dim + 1 :] == sorted_tensors[0].shape[self.dim + 1 :]
)
for t in sorted_tensors
):
raise RuntimeError(
"Input tensor shapes don't match for concatenation",
"along the specified dimension.",
)

return torch.cat(sorted_tensors, dim=self.dim)


@registry.register("stack")
class Stack(nn.Module):
"""Stack tensors along a specified dimension.
The input dictionary will be sorted by name before concatenation.
Parameters
----------
dim : int
The dimension along which the tensors will be stacked.
Default is 0.
Examples
--------
>>> stack = Stack()
>>> feature1 = torch.tensor([[1, 2], [3, 4]]) # Shape: [batch_size, feature_dim]
>>> feature2 = torch.tensor([[5, 6], [7, 8]]) # Shape: [batch_size, feature_dim]
>>> input_dict = {"feature1": feature1, "feature2": feature2}
>>> output = stack(input_dict)
>>> print(output)
tensor([[[1, 2],
[5, 6]],
[[3, 4],
[7, 8]]]) # Shape: [batch_size, number_of_features, feature_dim]
"""

def __init__(self, dim: int = 0):
super().__init__()
self.dim = dim

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Stacks input tensors along the specified dimension.
Parameters
----------
inputs : Dict[str, torch.Tensor]
A dictionary where keys are the names of the tensors
and values are the tensors to be stacked.
Returns
-------
torch.Tensor
A tensor that is the result of stacking
the input tensors along the specified dimension.
Raises
------
RuntimeError
If the input tensor shapes don't match for stacking.
"""
sorted_tensors = [inputs[name] for name in sorted(inputs.keys())]
if not all(t.shape == sorted_tensors[0].shape for t in sorted_tensors):
raise RuntimeError("Input tensor shapes don't match for stacking.")

return torch.stack(sorted_tensors, dim=self.dim)


class MaybeAgg(nn.Module):
"""
This class is designed to conditionally apply an aggregation operation
(e.g., Stack or Concat) on a tensor or a dictionary of tensors.
Parameters
----------
agg : nn.Module
The aggregation operation to be applied.
Examples
--------
>>> stack = Stack(dim=0)
>>> maybe_agg = MaybeAgg(agg=stack)
>>> tensor1 = torch.tensor([[1, 2], [3, 4]])
>>> tensor2 = torch.tensor([[5, 6], [7, 8]])
>>> input_dict = {"tensor1": tensor1, "tensor2": tensor2}
>>> output = maybe_agg(input_dict)
>>> print(output)
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>> tensor = torch.tensor([1, 2, 3])
>>> output = maybe_agg(tensor)
>>> print(output)
tensor([1, 2, 3])
"""

def __init__(self, agg: nn.Module):
super().__init__()
self.agg = agg

def forward(self, inputs: Union[Dict[str, torch.Tensor], torch.Tensor]) -> torch.Tensor:
"""
Conditionally applies the aggregation operation on the inputs.
Parameters
----------
inputs : Union[Dict[str, torch.Tensor], torch.Tensor]
Inputs to be aggregated. If inputs is a dictionary of tensors,
the aggregation operation will be applied. If inputs is a single tensor,
it will be returned as is.
Returns
-------
torch.Tensor
Aggregated tensor if inputs is a dictionary, otherwise the input tensor.
"""

if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
return self.agg(inputs)

if not torch.jit.isinstance(inputs, torch.Tensor):
raise RuntimeError("Inputs must be either a dictionary of tensors or a single tensor.")

return inputs
Empty file.
121 changes: 121 additions & 0 deletions tests/unit/torch/transforms/test_agg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import torch

from merlin.models.torch.block import Block
from merlin.models.torch.transforms.agg import Concat, MaybeAgg, Stack
from merlin.models.torch.utils import module_utils


class TestConcat:
def test_valid_input(self):
concat = Concat(dim=1)
input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(2, 4),
}
output = module_utils.module_test(concat, input_tensors)
assert output.shape == (2, 7)

@pytest.mark.parametrize("dim", [2, -1])
def test_same_order(self, dim):
concat = Concat(dim=dim)
a = torch.randn(2, 3, 4)
b = torch.randn(2, 3, 5)
output_a = module_utils.module_test(concat, {"a": a, "b": b})
output_b = module_utils.module_test(concat, {"b": b, "a": a})

assert torch.all(torch.eq(output_a, output_b))
assert output_a.shape == (2, 3, 9)

def test_invalid_input(self):
concat = Concat(dim=1)
input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(3, 3),
}
with pytest.raises(RuntimeError, match="Input tensor shapes don't match"):
concat(input_tensors)

def test_from_registry(self):
block = Block.parse("concat")

input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(2, 4),
}
output = module_utils.module_test(block, input_tensors)
assert output.shape == (2, 7)


class TestStack:
def test_2d_input(self):
stack = Stack(dim=0)
input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(2, 3),
}
output = module_utils.module_test(stack, input_tensors)
assert output.shape == (2, 2, 3)

def test_same_order(self):
stack = Stack(dim=0)
a = torch.randn(2, 3)
b = torch.randn(2, 3)
output_a = module_utils.module_test(stack, {"a": a, "b": b})
output_b = module_utils.module_test(stack, {"b": b, "a": a})

assert torch.all(torch.eq(output_a, output_b))

def test_invalid_input(self):
stack = Stack(dim=0)
input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(3, 3),
}
with pytest.raises(RuntimeError, match="Input tensor shapes don't match"):
stack(input_tensors)

def test_from_registry(self):
block = Block.parse("stack")

input_tensors = {
"a": torch.randn(2, 3),
"b": torch.randn(2, 3),
}
output = block(input_tensors)
assert output.shape == (2, 2, 3)


class TestMaybeAgg:
def test_with_single_tensor(self):
tensor = torch.tensor([1, 2, 3])
stack = Stack(dim=0)
maybe_agg = MaybeAgg(agg=stack)

output = module_utils.module_test(maybe_agg, tensor)
assert torch.equal(output, tensor)

def test_with_dict(self):
stack = Stack(dim=0)
maybe_agg = MaybeAgg(agg=stack)

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
input_dict = {"tensor1": tensor1, "tensor2": tensor2}
expected_output = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
output = module_utils.module_test(maybe_agg, input_dict)

assert torch.equal(output, expected_output)

def test_with_incompatible_dict(self):
concat = Concat(dim=0)
maybe_agg = MaybeAgg(agg=concat)

tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5])
input_dict = {"tensor1": (tensor1, tensor2)}

with pytest.raises(
RuntimeError, match="Inputs must be either a dictionary of tensors or a single tensor"
):
maybe_agg(input_dict)

0 comments on commit b5dff16

Please sign in to comment.