-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* First pass over DLRM-related blocks * add a test for dlrm block and fixes to make test pass * Add unit tests * Initialize default metrics/loss inside ModelOutput instead * Update merlin/models/torch/blocks/dlrm.py Co-authored-by: Radek Osmulski <[email protected]> * fixes and changes for failing tests * pass batch to model test --------- Co-authored-by: Marc Romeyn <[email protected]> Co-authored-by: Radek Osmulski <[email protected]>
- Loading branch information
1 parent
25f98f1
commit 86d0a34
Showing
7 changed files
with
268 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from merlin.models.torch.block import Block | ||
from merlin.models.torch.inputs.embedding import EmbeddingTables | ||
from merlin.models.torch.inputs.tabular import TabularInputBlock | ||
from merlin.models.torch.link import Link | ||
from merlin.models.torch.transforms.agg import MaybeAgg, Stack | ||
from merlin.models.utils.doc_utils import docstring_parameter | ||
from merlin.schema import Schema, Tags | ||
|
||
_DLRM_REF = """ | ||
References | ||
---------- | ||
.. [1] Naumov, Maxim, et al. "Deep learning recommendation model for | ||
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). | ||
""" | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMInputBlock(TabularInputBlock): | ||
"""Input block for DLRM model. | ||
Parameters | ||
---------- | ||
schema : Schema, optional | ||
The schema to use for selection. Default is None. | ||
dim : int | ||
The dimensionality of the output vectors. | ||
bottom_block : Block | ||
Block to pass the continuous features to. | ||
Note that, the output dimensionality of this block must be equal to ``dim``. | ||
{dlrm_reference} | ||
Raises | ||
------ | ||
ValueError | ||
If no categorical input is provided in the schema. | ||
""" | ||
|
||
def __init__(self, schema: Schema, dim: int, bottom_block: Block): | ||
super().__init__(schema) | ||
self.add_route(Tags.CATEGORICAL, EmbeddingTables(dim, seq_combiner="mean")) | ||
self.add_route(Tags.CONTINUOUS, bottom_block) | ||
|
||
if "categorical" not in self: | ||
raise ValueError("DLRMInputBlock must have a categorical input") | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMInteraction(nn.Module): | ||
""" | ||
This class defines the forward interaction operation as proposed | ||
in the DLRM | ||
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. | ||
This forward operation performs elementwise multiplication | ||
followed by a reduction sum (equivalent to a dot product) of all embedding pairs. | ||
{dlrm_reference} | ||
""" | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
if not hasattr(self, "triu_indices"): | ||
self.register_buffer( | ||
"triu_indices", torch.triu_indices(inputs.shape[1], inputs.shape[1], offset=1) | ||
) | ||
|
||
interactions = torch.bmm(inputs, torch.transpose(inputs, 1, 2)) | ||
interactions_flat = interactions[:, self.triu_indices[0], self.triu_indices[1]] | ||
|
||
return interactions_flat | ||
|
||
|
||
class ShortcutConcatContinuous(Link): | ||
""" | ||
A shortcut connection that concatenates | ||
continuous input features and intermediate outputs. | ||
When there's no continuous input, the intermediate output is returned. | ||
""" | ||
|
||
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: | ||
intermediate_output = self.output(inputs) | ||
|
||
if "continuous" in inputs: | ||
return torch.cat((inputs["continuous"], intermediate_output), dim=1) | ||
|
||
return intermediate_output | ||
|
||
|
||
@docstring_parameter(dlrm_reference=_DLRM_REF) | ||
class DLRMBlock(Block): | ||
"""Builds the DLRM architecture, as proposed in the following | ||
`paper https://arxiv.org/pdf/1906.00091.pdf`_ [1]_. | ||
Parameters | ||
---------- | ||
schema : Schema, optional | ||
The schema to use for selection. Default is None. | ||
dim : int | ||
The dimensionality of the output vectors. | ||
bottom_block : Block | ||
Block to pass the continuous features to. | ||
Note that, the output dimensionality of this block must be equal to ``dim``. | ||
top_block : Block, optional | ||
An optional upper-level block of the model. | ||
interaction : nn.Module, optional | ||
Interaction module for DLRM. | ||
If not provided, DLRMInteraction will be used by default. | ||
{dlrm_reference} | ||
Raises | ||
------ | ||
ValueError | ||
If no categorical input is provided in the schema. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
schema: Schema, | ||
dim: int, | ||
bottom_block: Block, | ||
top_block: Optional[Block] = None, | ||
interaction: Optional[nn.Module] = None, | ||
): | ||
super().__init__(DLRMInputBlock(schema, dim, bottom_block)) | ||
|
||
self.append( | ||
Block(MaybeAgg(Stack(dim=1)), interaction or DLRMInteraction()), | ||
link=ShortcutConcatContinuous(), | ||
) | ||
|
||
if top_block: | ||
self.append(top_block) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import math | ||
|
||
import pytest | ||
import torch | ||
|
||
import merlin.models.torch as mm | ||
from merlin.models.torch.batch import sample_batch | ||
from merlin.models.torch.blocks.dlrm import DLRMInputBlock, DLRMInteraction | ||
from merlin.models.torch.utils import module_utils | ||
from merlin.schema import Tags | ||
|
||
|
||
class TestDLRMInputBlock: | ||
def test_routes_and_output_shapes(self, testing_data): | ||
schema = testing_data.schema | ||
embedding_dim = 64 | ||
block = DLRMInputBlock(schema, embedding_dim, mm.MLPBlock([embedding_dim])) | ||
|
||
assert isinstance(block["categorical"], mm.EmbeddingTables) | ||
assert len(block["categorical"]) == len(schema.select_by_tag(Tags.CATEGORICAL)) | ||
|
||
assert isinstance(block["continuous"][0], mm.SelectKeys) | ||
assert isinstance(block["continuous"][1], mm.MLPBlock) | ||
|
||
batch_size = 16 | ||
batch = sample_batch(testing_data, batch_size=batch_size) | ||
|
||
outputs = module_utils.module_test(block, batch) | ||
|
||
for col in schema.select_by_tag(Tags.CATEGORICAL): | ||
assert outputs[col.name].shape == (batch_size, embedding_dim) | ||
assert outputs["continuous"].shape == (batch_size, embedding_dim) | ||
|
||
|
||
class TestDLRMInteraction: | ||
@pytest.mark.parametrize( | ||
"batch_size,num_features,dim", | ||
[(16, 3, 3), (32, 5, 8), (64, 5, 4)], | ||
) | ||
def test_output_shape(self, batch_size, num_features, dim): | ||
module = DLRMInteraction() | ||
inputs = torch.rand((batch_size, num_features, dim)) | ||
outputs = module_utils.module_test(module, inputs) | ||
|
||
assert outputs.shape == (batch_size, num_features - 1 + math.comb(num_features - 1, 2)) | ||
|
||
|
||
class TestDLRMBlock: | ||
@pytest.fixture(autouse=True) | ||
def setup_method(self, testing_data): | ||
self.schema = testing_data.schema | ||
self.batch_size = 16 | ||
self.batch = sample_batch(testing_data, batch_size=self.batch_size) | ||
|
||
def test_dlrm_output_shape(self): | ||
embedding_dim = 64 | ||
block = mm.DLRMBlock( | ||
self.schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
outputs = module_utils.module_test(block, self.batch) | ||
|
||
num_features = len(self.schema.select_by_tag(Tags.CATEGORICAL)) + 1 | ||
dot_product_dim = (num_features - 1) * num_features // 2 | ||
assert list(outputs.shape) == [self.batch_size, dot_product_dim + embedding_dim] | ||
|
||
def test_dlrm_with_top_block(self): | ||
embedding_dim = 32 | ||
top_block_dim = 8 | ||
block = mm.DLRMBlock( | ||
self.schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
top_block=mm.MLPBlock([top_block_dim]), | ||
) | ||
|
||
outputs = module_utils.module_test(block, self.batch) | ||
|
||
assert list(outputs.shape) == [self.batch_size, top_block_dim] | ||
|
||
def test_dlrm_block_no_categorical_features(self): | ||
schema = self.schema.remove_by_tag(Tags.CATEGORICAL) | ||
embedding_dim = 32 | ||
|
||
with pytest.raises(ValueError, match="must have a categorical input"): | ||
_ = mm.DLRMBlock( | ||
schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
def test_dlrm_block_no_continuous_features(self, testing_data): | ||
schema = testing_data.schema.remove_by_tag(Tags.CONTINUOUS) | ||
testing_data.schema = schema | ||
|
||
embedding_dim = 32 | ||
block = mm.DLRMBlock( | ||
schema, | ||
dim=embedding_dim, | ||
bottom_block=mm.MLPBlock([embedding_dim]), | ||
) | ||
|
||
batch_size = 16 | ||
batch = sample_batch(testing_data, batch_size=batch_size) | ||
|
||
outputs = module_utils.module_test(block, batch) | ||
|
||
num_features = len(schema.select_by_tag(Tags.CATEGORICAL)) | ||
dot_product_dim = (num_features - 1) * num_features // 2 | ||
assert list(outputs.shape) == [batch_size, dot_product_dim] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters