-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DLRM Model * make model a class rather a function --------- Co-authored-by: Marc Romeyn <[email protected]>
- Loading branch information
1 parent
593d78d
commit 1782b44
Showing
6 changed files
with
144 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Optional | ||
|
||
from torch import nn | ||
|
||
from merlin.models.torch.block import Block | ||
from merlin.models.torch.blocks.dlrm import DLRMBlock | ||
from merlin.models.torch.models.base import Model | ||
from merlin.models.torch.outputs.tabular import TabularOutputBlock | ||
from merlin.schema import Schema | ||
|
||
|
||
class DLRMModel(Model): | ||
""" | ||
The Deep Learning Recommendation Model (DLRM) as proposed in Naumov, et al. [1] | ||
Parameters | ||
---------- | ||
schema : Schema | ||
The schema to use for selection. | ||
dim : int | ||
The dimensionality of the output vectors. | ||
bottom_block : Block | ||
Block to pass the continuous features to. | ||
Note that, the output dimensionality of this block must be equal to ``dim``. | ||
top_block : Block, optional | ||
An optional upper-level block of the model. | ||
interaction : nn.Module, optional | ||
Interaction module for DLRM. | ||
If not provided, DLRMInteraction will be used by default. | ||
output_block : Block, optional | ||
The output block of the model, by default None. | ||
If None, a TabularOutputBlock with schema and default initializations is used. | ||
Returns | ||
------- | ||
Model | ||
An instance of Model class representing the fully formed DLRM. | ||
Example usage | ||
------------- | ||
>>> model = mm.DLRMModel( | ||
... schema, | ||
... dim=64, | ||
... bottom_block=mm.MLPBlock([256, 64]), | ||
... output_block=BinaryOutput(ColumnSchema("target"))) | ||
>>> trainer = pl.Trainer() | ||
>>> model.initialize(dataloader) | ||
>>> trainer.fit(model, dataloader) | ||
References | ||
---------- | ||
[1] Naumov, Maxim, et al. "Deep learning recommendation model for | ||
personalization and recommendation systems." arXiv preprint arXiv:1906.00091 (2019). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
schema: Schema, | ||
dim: int, | ||
bottom_block: Block, | ||
top_block: Optional[Block] = None, | ||
interaction: Optional[nn.Module] = None, | ||
output_block: Optional[Block] = None, | ||
) -> None: | ||
if output_block is None: | ||
output_block = TabularOutputBlock(schema, init="defaults") | ||
|
||
dlrm_body = DLRMBlock( | ||
schema, | ||
dim, | ||
bottom_block, | ||
top_block=top_block, | ||
interaction=interaction, | ||
) | ||
|
||
super().__init__(dlrm_body, output_block) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pytest | ||
import pytorch_lightning as pl | ||
|
||
import merlin.models.torch as mm | ||
from merlin.dataloader.torch import Loader | ||
from merlin.models.torch.batch import sample_batch | ||
from merlin.models.torch.utils import module_utils | ||
from merlin.schema import ColumnSchema | ||
|
||
|
||
@pytest.mark.parametrize("output_block", [None, mm.BinaryOutput(ColumnSchema("click"))]) | ||
class TestDLRMModel: | ||
def test_train_dlrm_with_lightning_loader( | ||
self, music_streaming_data, output_block, dim=2, batch_size=16 | ||
): | ||
schema = music_streaming_data.schema.select_by_name( | ||
["item_id", "user_id", "user_age", "item_genres", "click"] | ||
) | ||
music_streaming_data.schema = schema | ||
|
||
model = mm.DLRMModel( | ||
schema, | ||
dim=dim, | ||
bottom_block=mm.MLPBlock([4, 2]), | ||
top_block=mm.MLPBlock([4, 2]), | ||
output_block=output_block, | ||
) | ||
|
||
trainer = pl.Trainer(max_epochs=1, devices=1) | ||
|
||
with Loader(music_streaming_data, batch_size=batch_size) as train_loader: | ||
model.initialize(train_loader) | ||
trainer.fit(model, train_loader) | ||
|
||
assert trainer.logged_metrics["train_loss"] > 0.0 | ||
|
||
batch = sample_batch(music_streaming_data, batch_size) | ||
_ = module_utils.module_test(model, batch) |