diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index eef6f9ece6..d2326af5e9 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -17,6 +17,7 @@ from merlin.models.torch import schema from merlin.models.torch.batch import Batch, Sequence from merlin.models.torch.block import Block, ParallelBlock +from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys @@ -51,4 +52,5 @@ "Concat", "Stack", "schema", + "DLRMBlock", ] diff --git a/merlin/models/torch/blocks/dlrm.py b/merlin/models/torch/blocks/dlrm.py new file mode 100644 index 0000000000..a24e4d1f71 --- /dev/null +++ b/merlin/models/torch/blocks/dlrm.py @@ -0,0 +1,141 @@ +from typing import Dict, Optional + +import torch +from torch import nn + +from merlin.models.torch.block import Block +from merlin.models.torch.inputs.embedding import EmbeddingTables +from merlin.models.torch.inputs.tabular import TabularInputBlock +from merlin.models.torch.link import Link +from merlin.models.torch.transforms.agg import MaybeAgg, Stack +from merlin.models.utils.doc_utils import docstring_parameter +from merlin.schema import Schema, Tags + +_DLRM_REF = """ + References + ---------- + .. [1] Naumov, Maxim, et al. "Deep learning recommendation model for + personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). +""" + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMInputBlock(TabularInputBlock): + """Input block for DLRM model. + + Parameters + ---------- + schema : Schema, optional + The schema to use for selection. Default is None. + dim : int + The dimensionality of the output vectors. + bottom_block : Block + Block to pass the continuous features to. + Note that, the output dimensionality of this block must be equal to ``dim``. + + {dlrm_reference} + + Raises + ------ + ValueError + If no categorical input is provided in the schema. + + """ + + def __init__(self, schema: Schema, dim: int, bottom_block: Block): + super().__init__(schema) + self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean")) + self.add_route(Tags.CONTINUOUS, bottom_block) + + if "categorical" not in self: + raise ValueError("DLRMInputBlock must have a categorical input") + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMInteraction(nn.Module): + """ + This class defines the forward interaction operation as proposed + in the DLRM + `paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. + + This forward operation performs elementwise multiplication + followed by a reduction sum (equivalent to a dot product) of all embedding pairs. + + {dlrm_reference} + + """ + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not hasattr(self, "triu_indices"): + self.register_buffer( + "triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1) + ) + + interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2)) + interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] + + return interactions_flat + + +class ShortcutConcatContinuous(Link): + """ + A shortcut connection that concatenates + continuous input features and intermediate outputs. + + When there's no continuous input, the intermediate output is returned. + """ + + def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + intermediate_output = self.output(inputs) + + if "continuous" in inputs: + return torch.cat((inputs["continuous"], intermediate_output), dim=1) + + return intermediate_output + + +@docstring_parameter(dlrm_reference=_DLRM_REF) +class DLRMBlock(Block): + """Builds the DLRM architecture, as proposed in the following + `paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. + + Parameters + ---------- + schema : Schema, optional + The schema to use for selection. Default is None. + dim : int + The dimensionality of the output vectors. + bottom_block : Block + Block to pass the continuous features to. + Note that, the output dimensionality of this block must be equal to ``dim``. + top_block : Block, optional + An optional upper-level block of the model. + interaction : nn.Module, optional + Interaction module for DLRM. + If not provided, DLRMInteraction will be used by default. + + {dlrm_reference} + + Raises + ------ + ValueError + If no categorical input is provided in the schema. + """ + + def __init__( + self, + schema: Schema, + dim: int, + bottom_block: Block, + top_block: Optional[Block] = None, + interaction: Optional[nn.Module] = None, + ): + super().__init__(DLRMInputBlock(schema, dim, bottom_block)) + + self.append( + Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), + link=ShortcutConcatContinuous(), + ) + + if top_block: + self.append(top_block) diff --git a/merlin/models/torch/router.py b/merlin/models/torch/router.py index 29126e0f91..326064c68c 100644 --- a/merlin/models/torch/router.py +++ b/merlin/models/torch/router.py @@ -88,6 +88,9 @@ def add_route( """ routing_module = schema.select(self.selectable, selection) + if not routing_module: + return self + if module is not None: schema.setup_schema(module, routing_module.schema) diff --git a/merlin/models/torch/transforms/agg.py b/merlin/models/torch/transforms/agg.py index f6dcf457d6..552fcf1d32 100644 --- a/merlin/models/torch/transforms/agg.py +++ b/merlin/models/torch/transforms/agg.py @@ -104,6 +104,9 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: if self.align_dims: max_dims = max(tensor.dim() for tensor in sorted_tensors) + max_dims = max( + max_dims, 2 + ) # assume first dimension is batch size + at least one feature _sorted_tensors = [] for tensor in sorted_tensors: if tensor.dim() < max_dims: diff --git a/tests/unit/torch/blocks/test_dlrm.py b/tests/unit/torch/blocks/test_dlrm.py new file mode 100644 index 0000000000..21d65e8561 --- /dev/null +++ b/tests/unit/torch/blocks/test_dlrm.py @@ -0,0 +1,112 @@ +import math + +import pytest +import torch + +import merlin.models.torch as mm +from merlin.models.torch.batch import sample_batch +from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction +from merlin.models.torch.utils import module_utils +from merlin.schema import Tags + + +class TestDLRMInputBlock: + def test_routes_and_output_shapes(self, testing_data): + schema = testing_data.schema + embedding_dim = 64 + block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim])) + + assert isinstance(block["categorical"], mm.EmbeddingTables) + assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL)) + + assert isinstance(block["continuous"][0], mm.SelectKeys) + assert isinstance(block["continuous"][1], mm.MLPBlock) + + batch_size = 16 + batch = sample_batch(testing_data, batch_size=batch_size) + + outputs = module_utils.module_test(block, batch) + + for col in schema.select_by_tag(Tags.CATEGORICAL): + assert outputs[col.name].shape == (batch_size, embedding_dim) + assert outputs["continuous"].shape == (batch_size, embedding_dim) + + +class TestDLRMInteraction: + @pytest.mark.parametrize( + "batch_size,num_features,dim", + [(16, 3, 3), (32, 5, 8), (64, 5, 4)], + ) + def test_output_shape(self, batch_size, num_features, dim): + module = DLRMInteraction() + inputs = torch.rand((batch_size, num_features, dim)) + outputs = module_utils.module_test(module, inputs) + + assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2)) + + +class TestDLRMBlock: + @pytest.fixture(autouse=True) + def setup_method(self, testing_data): + self.schema = testing_data.schema + self.batch_size = 16 + self.batch = sample_batch(testing_data, batch_size=self.batch_size) + + def test_dlrm_output_shape(self): + embedding_dim = 64 + block = mm.DLRMBlock( + self.schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + + outputs = module_utils.module_test(block, self.batch) + + num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1 + dot_product_dim = (num_features - 1) * num_features // 2 + assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim] + + def test_dlrm_with_top_block(self): + embedding_dim = 32 + top_block_dim = 8 + block = mm.DLRMBlock( + self.schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + top_block=mm.MLPBlock([top_block_dim]), + ) + + outputs = module_utils.module_test(block, self.batch) + + assert list(outputs.shape) == [self.batch_size, top_block_dim] + + def test_dlrm_block_no_categorical_features(self): + schema = self.schema.remove_by_tag(Tags.CATEGORICAL) + embedding_dim = 32 + + with pytest.raises(ValueError, match="must have a categorical input"): + _ = mm.DLRMBlock( + schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + + def test_dlrm_block_no_continuous_features(self, testing_data): + schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS) + testing_data.schema = schema + + embedding_dim = 32 + block = mm.DLRMBlock( + schema, + dim=embedding_dim, + bottom_block=mm.MLPBlock([embedding_dim]), + ) + + batch_size = 16 + batch = sample_batch(testing_data, batch_size=batch_size) + + outputs = module_utils.module_test(block, batch) + + num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) + dot_product_dim = (num_features - 1) * num_features // 2 + assert list(outputs.shape) == [batch_size, dot_product_dim] diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index cedd3e6ff6..ab329b8ca1 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -133,11 +133,11 @@ def test_training_step_with_dataloader(self): mm.BinaryOutput(ColumnSchema("target")), ) - feature = [[1.0, 2.0], [3.0, 4.0]] - target = [[0.0], [1.0]] + feature = [2.0, 3.0] + target = [0.0, 1.0] dataset = Dataset(pd.DataFrame({"feature": feature, "target": target})) - with Loader(dataset, batch_size=1) as loader: + with Loader(dataset, batch_size=2) as loader: model.initialize(loader) batch = loader.peek() diff --git a/tests/unit/torch/outputs/test_tabular.py b/tests/unit/torch/outputs/test_tabular.py index b3dd2abe4b..22ae132735 100644 --- a/tests/unit/torch/outputs/test_tabular.py +++ b/tests/unit/torch/outputs/test_tabular.py @@ -29,6 +29,8 @@ def test_exceptions(self): with pytest.raises(ValueError, match="not found"): mm.TabularOutputBlock(self.schema, init="not_found") + def test_no_route_for_non_existent_tag(self): outputs = mm.TabularOutputBlock(self.schema) - with pytest.raises(ValueError): - outputs.add_route(Tags.CATEGORICAL) + outputs.add_route(Tags.CATEGORICAL) + + assert not outputs