Skip to content

Commit

Permalink
Merge branch 'main' into move_and_remove_examples
Browse files Browse the repository at this point in the history
  • Loading branch information
rnyak authored Jul 3, 2023
2 parents 139597f + 1782b44 commit 3ad4363
Show file tree
Hide file tree
Showing 15 changed files with 765 additions and 360 deletions.
6 changes: 5 additions & 1 deletion merlin/models/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

from merlin.models.torch import schema
from merlin.models.torch.batch import Batch, Sequence
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.block import Block, ParallelBlock, ResidualBlock, ShortcutBlock
from merlin.models.torch.blocks.dlrm import DLRMBlock
from merlin.models.torch.blocks.mlp import MLPBlock
from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables
from merlin.models.torch.inputs.select import SelectFeatures, SelectKeys
from merlin.models.torch.inputs.tabular import TabularInputBlock
from merlin.models.torch.models.base import Model
from merlin.models.torch.models.ranking import DLRMModel
from merlin.models.torch.outputs.base import ModelOutput
from merlin.models.torch.outputs.classification import BinaryOutput
from merlin.models.torch.outputs.regression import RegressionOutput
Expand All @@ -45,12 +46,15 @@
"ParallelBlock",
"Sequence",
"RegressionOutput",
"ResidualBlock",
"RouterBlock",
"SelectKeys",
"SelectFeatures",
"ShortcutBlock",
"TabularInputBlock",
"Concat",
"Stack",
"schema",
"DLRMBlock",
"DLRMModel",
]
196 changes: 174 additions & 22 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
from merlin.models.torch import schema
from merlin.models.torch.batch import Batch
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.traversal_utils import TraversableMixin, leaf
from merlin.models.torch.utils.traversal_utils import TraversableMixin
from merlin.models.utils.registry import RegistryMixin
from merlin.schema import Schema

Expand All @@ -41,8 +40,6 @@ class Block(BlockContainer, RegistryMixin, TraversableMixin):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

registry = registry
Expand Down Expand Up @@ -73,7 +70,7 @@ def forward(

return inputs

def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Block":
def repeat(self, n: int = 1, name=None) -> "Block":
"""
Creates a new block by repeating the current block `n` times.
Each repetition is a deep copy of the current block.
Expand All @@ -97,9 +94,6 @@ def repeat(self, n: int = 1, link: Optional[LinkType] = None, name=None) -> "Blo
raise ValueError("n must be greater than 0")

repeats = [self.copy() for _ in range(n - 1)]
if link:
parsed_link = Link.parse(link)
repeats = [parsed_link.copy().setup_link(repeat) for repeat in repeats]

return Block(self, *repeats, name=name)

Expand Down Expand Up @@ -221,7 +215,7 @@ def forward(

return outputs

def append(self, module: nn.Module, link: Optional[LinkType] = None):
def append(self, module: nn.Module):
"""Appends a module to the post-processing stage.
Parameters
Expand All @@ -235,7 +229,7 @@ def append(self, module: nn.Module, link: Optional[LinkType] = None):
The current object itself.
"""

self.post.append(module, link=link)
self.post.append(module)

return self

Expand All @@ -244,7 +238,7 @@ def prepend(self, module: nn.Module):

return self

def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
def append_to(self, name: str, module: nn.Module):
"""Appends a module to a specified branch.
Parameters
Expand All @@ -260,11 +254,11 @@ def append_to(self, name: str, module: nn.Module, link: Optional[LinkType] = Non
The current object itself.
"""

self.branches[name].append(module, link=link)
self.branches[name].append(module)

return self

def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = None):
def prepend_to(self, name: str, module: nn.Module):
"""Prepends a module to a specified branch.
Parameters
Expand All @@ -279,11 +273,11 @@ def prepend_to(self, name: str, module: nn.Module, link: Optional[LinkType] = No
ParallelBlock
The current object itself.
"""
self.branches[name].prepend(module, link=link)
self.branches[name].prepend(module)

return self

def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
def append_for_each(self, module: nn.Module, shared=False):
"""Appends a module to each branch.
Parameters
Expand All @@ -300,11 +294,11 @@ def append_for_each(self, module: nn.Module, shared=False, link: Optional[LinkTy
The current object itself.
"""

self.branches.append_for_each(module, shared=shared, link=link)
self.branches.append_for_each(module, shared=shared)

return self

def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkType] = None):
def prepend_for_each(self, module: nn.Module, shared=False):
"""Prepends a module to each branch.
Parameters
Expand All @@ -321,7 +315,7 @@ def prepend_for_each(self, module: nn.Module, shared=False, link: Optional[LinkT
The current object itself.
"""

self.branches.prepend_for_each(module, shared=shared, link=link)
self.branches.prepend_for_each(module, shared=shared)

return self

Expand Down Expand Up @@ -356,10 +350,7 @@ def leaf(self) -> nn.Module:
raise ValueError("Cannot call leaf() on a ParallelBlock with multiple branches")

first = list(self.branches.values())[0]
if hasattr(first, "leaf"):
return first.leaf()

return leaf(first)
return first.leaf()

def __getitem__(self, idx: Union[slice, int, str]):
if isinstance(idx, str) and idx in self.branches:
Expand Down Expand Up @@ -415,6 +406,167 @@ def __repr__(self) -> str:
return self._get_name() + branches


class ResidualBlock(Block):
"""
A block that applies each contained module sequentially on the input
and performs a residual connection after each module.
Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
"""

def forward(self, inputs: torch.Tensor, batch: Optional[Batch] = None):
"""
Forward pass through the block. Applies each contained module sequentially on the input.
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input data as a tensor or a dictionary of tensors.
batch : Optional[Batch], default = None
Optional batch of data. If provided, it is used by the `module`s.
Returns
-------
torch.Tensor or Dict[str, torch.Tensor]
The output of the block after processing the input.
"""
shortcut, outputs = inputs, inputs
for module in self.values:
outputs = shortcut + module(outputs, batch=batch)

return outputs


class ShortcutBlock(Block):
"""
A block with a 'shortcut' or a 'skip connection'.
The shortcut tensor can be propagated through the layers of the module or not,
depending on the value of `propagate_shortcut` argument:
If `propagate_shortcut` is True, the shortcut tensor is passed through
each layer of the module.
If `propagate_shortcut` is False, the shortcut tensor is only used as part of
the final output dictionary.
Example usage::
>>> shortcut = mm.ShortcutBlock(nn.Identity())
>>> shortcut(torch.ones(1, 1))
{'shortcut': tensor([[1.]]), 'output': tensor([[1.]])}
Parameters
----------
*module : nn.Module
Variable length argument list of PyTorch modules to be contained in the block.
name : str, optional
The name of the module, by default None.
propagate_shortcut : bool, optional
If True, propagates the shortcut tensor through the layers of this block, by default False.
shortcut_name : str, optional
The name to use for the shortcut tensor, by default "shortcut".
output_name : str, optional
The name to use for the output tensor, by default "output".
"""

def __init__(
self,
*module: nn.Module,
name: Optional[str] = None,
propagate_shortcut: bool = False,
shortcut_name: str = "shortcut",
output_name: str = "output",
):
super().__init__(*module, name=name)
self.shortcut_name = shortcut_name
self.output_name = output_name
self.propagate_shortcut = propagate_shortcut

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
) -> Dict[str, torch.Tensor]:
"""
Defines the forward propagation of the module.
Parameters
----------
inputs : Union[torch.Tensor, Dict[str, torch.Tensor]]
The input tensor or a dictionary of tensors.
batch : Batch, optional
A batch of inputs, by default None.
Returns
-------
Dict[str, torch.Tensor]
The output tensor as a dictionary.
Raises
------
RuntimeError
If the shortcut name is not found in the input dictionary, or
if the module does not return a tensor or a dictionary with a key 'output_name'.
"""

if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]):
if self.shortcut_name not in inputs:
raise RuntimeError(
f"Shortcut name {self.shortcut_name} not found in inputs {inputs}"
)
shortcut = inputs[self.shortcut_name]
else:
shortcut = inputs

output = inputs
for module in self.values:
if self.propagate_shortcut:
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
module_output = module(output, batch=batch)
else:
to_pass: Dict[str, torch.Tensor] = {
self.shortcut_name: shortcut,
self.output_name: torch.jit.annotate(torch.Tensor, output),
}

module_output = module(to_pass, batch=batch)

if torch.jit.isinstance(module_output, torch.Tensor):
output = module_output
elif torch.jit.isinstance(module_output, Dict[str, torch.Tensor]):
output = module_output[self.output_name]
else:
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)
else:
if torch.jit.isinstance(inputs, Dict[str, torch.Tensor]) and torch.jit.isinstance(
output, Dict[str, torch.Tensor]
):
output = output[self.output_name]
_output = module(output, batch=batch)
if torch.jit.isinstance(_output, torch.Tensor) or torch.jit.isinstance(
_output, Dict[str, torch.Tensor]
):
output = _output
else:
raise RuntimeError(
f"Module {module} must return a tensor or a dict ",
f"with key {self.output_name}",
)

to_return = {self.shortcut_name: shortcut}
if torch.jit.isinstance(output, Dict[str, torch.Tensor]):
to_return.update(output)
else:
to_return[self.output_name] = output

return to_return


def get_pre(module: nn.Module) -> BlockContainer:
if hasattr(module, "pre"):
return module.pre
Expand Down
Loading

0 comments on commit 3ad4363

Please sign in to comment.