Skip to content

Commit

Permalink
Add DLRM block (#1162)
Browse files Browse the repository at this point in the history
* First pass over DLRM-related blocks

* add a test for dlrm block and fixes to make test pass

* Add unit tests

* Initialize default metrics/loss inside ModelOutput instead

* Update merlin/models/torch/blocks/dlrm.py

Co-authored-by: Radek Osmulski <[email protected]>

* fixes and changes for failing tests

* pass batch to model test

---------

Co-authored-by: Marc Romeyn <[email protected]>
Co-authored-by: Radek Osmulski <[email protected]>
  • Loading branch information
3 people authored Jul 1, 2023
1 parent 25f98f1 commit 86d0a34
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 5 deletions.
2 changes: 2 additions & 0 deletions merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
Expand Down Expand Up @@ -51,4 +52,5 @@
"Concat",
"Stack",
"schema",
"DLRMBlock",
]
141 changes: 141 additions & 0 deletions merlin/models/torch/blocks/dlrm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Dict, Optional

import torch
from torch import nn

from merlin.models.torch.block import Block
from merlin.models.torch.inputs.embedding import EmbeddingTables
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.link import Link
from merlin.models.torch.transforms.agg import MaybeAgg, Stack
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.schema import Schema, Tags

_DLRM_REF = """
References
----------
.. [1] Naumov, Maxim, et al. "Deep learning recommendation model for
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019).
"""


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMInputBlock(TabularInputBlock):
"""Input block for DLRM model.
Parameters
----------
schema : Schema, optional
The schema to use for selection. Default is None.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
{dlrm_reference}
Raises
------
ValueError
If no categorical input is provided in the schema.
"""

def __init__(self, schema: Schema, dim: int, bottom_block: Block):
super().__init__(schema)
self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean"))
self.add_route(Tags.CONTINUOUS, bottom_block)

if "categorical" not in self:
raise ValueError("DLRMInputBlock must have a categorical input")


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMInteraction(nn.Module):
"""
This class defines the forward interaction operation as proposed
in the DLRM
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_.
This forward operation performs elementwise multiplication
followed by a reduction sum (equivalent to a dot product) of all embedding pairs.
{dlrm_reference}
"""

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not hasattr(self, "triu_indices"):
self.register_buffer(
"triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1)
)

interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2))
interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]]

return interactions_flat


class ShortcutConcatContinuous(Link):
"""
A shortcut connection that concatenates
continuous input features and intermediate outputs.
When there's no continuous input, the intermediate output is returned.
"""

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
intermediate_output = self.output(inputs)

if "continuous" in inputs:
return torch.cat((inputs["continuous"], intermediate_output), dim=1)

return intermediate_output


@docstring_parameter(dlrm_reference=_DLRM_REF)
class DLRMBlock(Block):
"""Builds the DLRM architecture, as proposed in the following
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_.
Parameters
----------
schema : Schema, optional
The schema to use for selection. Default is None.
dim : int
The dimensionality of the output vectors.
bottom_block : Block
Block to pass the continuous features to.
Note that, the output dimensionality of this block must be equal to ``dim``.
top_block : Block, optional
An optional upper-level block of the model.
interaction : nn.Module, optional
Interaction module for DLRM.
If not provided, DLRMInteraction will be used by default.
{dlrm_reference}
Raises
------
ValueError
If no categorical input is provided in the schema.
"""

def __init__(
self,
schema: Schema,
dim: int,
bottom_block: Block,
top_block: Optional[Block] = None,
interaction: Optional[nn.Module] = None,
):
super().__init__(DLRMInputBlock(schema, dim, bottom_block))

self.append(
Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()),
link=ShortcutConcatContinuous(),
)

if top_block:
self.append(top_block)
3 changes: 3 additions & 0 deletions merlin/models/torch/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def add_route(
"""

routing_module = schema.select(self.selectable, selection)
if not routing_module:
return self

if module is not None:
schema.setup_schema(module, routing_module.schema)

Expand Down
3 changes: 3 additions & 0 deletions merlin/models/torch/transforms/agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:

if self.align_dims:
max_dims = max(tensor.dim() for tensor in sorted_tensors)
max_dims = max(
max_dims, 2
) # assume first dimension is batch size + at least one feature
_sorted_tensors = []
for tensor in sorted_tensors:
if tensor.dim() < max_dims:
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/torch/blocks/test_dlrm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import math

import pytest
import torch

import merlin.models.torch as mm
from merlin.models.torch.batch import sample_batch
from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction
from merlin.models.torch.utils import module_utils
from merlin.schema import Tags


class TestDLRMInputBlock:
def test_routes_and_output_shapes(self, testing_data):
schema = testing_data.schema
embedding_dim = 64
block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim]))

assert isinstance(block["categorical"], mm.EmbeddingTables)
assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL))

assert isinstance(block["continuous"][0], mm.SelectKeys)
assert isinstance(block["continuous"][1], mm.MLPBlock)

batch_size = 16
batch = sample_batch(testing_data, batch_size=batch_size)

outputs = module_utils.module_test(block, batch)

for col in schema.select_by_tag(Tags.CATEGORICAL):
assert outputs[col.name].shape == (batch_size, embedding_dim)
assert outputs["continuous"].shape == (batch_size, embedding_dim)


class TestDLRMInteraction:
@pytest.mark.parametrize(
"batch_size,num_features,dim",
[(16, 3, 3), (32, 5, 8), (64, 5, 4)],
)
def test_output_shape(self, batch_size, num_features, dim):
module = DLRMInteraction()
inputs = torch.rand((batch_size, num_features, dim))
outputs = module_utils.module_test(module, inputs)

assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2))


class TestDLRMBlock:
@pytest.fixture(autouse=True)
def setup_method(self, testing_data):
self.schema = testing_data.schema
self.batch_size = 16
self.batch = sample_batch(testing_data, batch_size=self.batch_size)

def test_dlrm_output_shape(self):
embedding_dim = 64
block = mm.DLRMBlock(
self.schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

outputs = module_utils.module_test(block, self.batch)

num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1
dot_product_dim = (num_features - 1) * num_features // 2
assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim]

def test_dlrm_with_top_block(self):
embedding_dim = 32
top_block_dim = 8
block = mm.DLRMBlock(
self.schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
top_block=mm.MLPBlock([top_block_dim]),
)

outputs = module_utils.module_test(block, self.batch)

assert list(outputs.shape) == [self.batch_size, top_block_dim]

def test_dlrm_block_no_categorical_features(self):
schema = self.schema.remove_by_tag(Tags.CATEGORICAL)
embedding_dim = 32

with pytest.raises(ValueError, match="must have a categorical input"):
_ = mm.DLRMBlock(
schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

def test_dlrm_block_no_continuous_features(self, testing_data):
schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS)
testing_data.schema = schema

embedding_dim = 32
block = mm.DLRMBlock(
schema,
dim=embedding_dim,
bottom_block=mm.MLPBlock([embedding_dim]),
)

batch_size = 16
batch = sample_batch(testing_data, batch_size=batch_size)

outputs = module_utils.module_test(block, batch)

num_features = len(schema.select_by_tag(Tags.CATEGORICAL))
dot_product_dim = (num_features - 1) * num_features // 2
assert list(outputs.shape) == [batch_size, dot_product_dim]
6 changes: 3 additions & 3 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ def test_training_step_with_dataloader(self):
mm.BinaryOutput(ColumnSchema("target")),
)

feature = [[1.0, 2.0], [3.0, 4.0]]
target = [[0.0], [1.0]]
feature = [2.0, 3.0]
target = [0.0, 1.0]
dataset = Dataset(pd.DataFrame({"feature": feature, "target": target}))

with Loader(dataset, batch_size=1) as loader:
with Loader(dataset, batch_size=2) as loader:
model.initialize(loader)
batch = loader.peek()

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/torch/outputs/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_exceptions(self):
with pytest.raises(ValueError, match="not found"):
mm.TabularOutputBlock(self.schema, init="not_found")

def test_no_route_for_non_existent_tag(self):
outputs = mm.TabularOutputBlock(self.schema)
with pytest.raises(ValueError):
outputs.add_route(Tags.CATEGORICAL)
outputs.add_route(Tags.CATEGORICAL)

assert not outputs

0 comments on commit 86d0a34

Please sign in to comment.