Skip to content

Commit

Permalink
Adding SchemaTrackingMixin (NVIDIA-Merlin#1109)
Browse files Browse the repository at this point in the history
* Adding SchemaTrackingMixin

* Small fix in test_tensor

* Add schema-tracking to Block
  • Loading branch information
marcromeyn authored May 29, 2023
1 parent b5dff16 commit b6d6645
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 7 deletions.
18 changes: 11 additions & 7 deletions merlin/models/torch/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.link import Link, LinkType
from merlin.models.torch.registry import registry
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
from merlin.models.utils.registry import RegistryMixin


class Block(BlockContainer, RegistryMixin):
class Block(BlockContainer, SchemaTrackingMixin, RegistryMixin):
"""A base-class that calls it's modules sequentially.
Parameters
Expand All @@ -36,12 +37,16 @@ class Block(BlockContainer, RegistryMixin):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

registry = registry

def __init__(self, *module: nn.Module, name: Optional[str] = None):
def __init__(self, *module: nn.Module, name: Optional[str] = None, track_schema: bool = True):
super().__init__(*module, name=name)
if track_schema:
self._register_schema_tracking_hook()

def forward(
self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None
Expand Down Expand Up @@ -138,17 +143,16 @@ class ParallelBlock(Block):
Variable length argument list of PyTorch modules to be contained in the block.
name : Optional[str], default = None
The name of the block. If None, no name is assigned.
track_schema : bool, default = True
If True, the schema of the output tensors are tracked.
"""

def __init__(
self,
*inputs: Union[nn.Module, Dict[str, nn.Module]],
):
def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]], track_schema: bool = True):
pre = BlockContainer(name="pre")
branches = BlockContainerDict(*inputs)
post = BlockContainer(name="post")

super().__init__()
super().__init__(track_schema=track_schema)

self.pre = pre
self.branches = branches
Expand Down
98 changes: 98 additions & 0 deletions merlin/models/torch/utils/schema_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch

from merlin.schema import ColumnSchema, Schema, Tags


class SchemaTrackingMixin:
"""
A mixin class for PyTorch modules to track the output shapes and dtypes
of the forward pass. This is used in order to automatically generate
the output-schema.
It registers a hook to capture this information and
provides methods to access the output schema, as well as to set the module
in training or evaluation mode.
"""

def __init__(self):
super().__init__()
self._register_schema_tracking_hook()

def _post_forward_hook(self, module, input, output):
"""Hook function to be called after the forward pass of the module.
Parameters
----------
module : torch.nn.Module
The module for which the forward pass was called.
input : tuple
The input arguments passed to the forward method.
output : torch.Tensor or dict
The output of the forward method.
"""
if not module._forward_called:
if isinstance(output, dict):
for key, value in output.items():
module._output_shapes[key] = value.shape
module._output_dtypes[key] = value.dtype
else:
module._output_shapes["output"] = output.shape
module._output_dtypes["output"] = output.dtype
module._forward_called = True
module._handle.remove()

def _register_schema_tracking_hook(self):
"""
Register the post forward hook to the module.
"""
self._forward_called = False
self._handle = None
self._output_shapes = {}
self._output_dtypes = {}

if self._handle is None:
self._handle = self.register_forward_hook(self._post_forward_hook)

def output_schema(self) -> Schema:
"""Get the output schema of the module.
Returns
-------
Schema
The output schema of the module.
Raises
------
RuntimeError
If forward() has not been called before calling this method.
"""

if not hasattr(self, "_output_shapes"):
raise RuntimeError(
"Schema-tracking hook not registered, use `_register_schema_tracking_hook`."
)

if not self._forward_called:
raise RuntimeError("forward() must be called before output_schema() can be called.")

columns = []

for name, shape in self._output_shapes.items():
dtype = self._output_dtypes[name]
dims = (None,) + tuple(shape)
tags = None

if len(shape) > 1 and dtype != torch.int32:
tags = [Tags.EMBEDDING]

columns.append(ColumnSchema(name, dims=dims, tags=tags, dtype=dtype))

return Schema(columns)

def train(self, mode=True):
self._register_schema_tracking_hook()
return super().train(mode)

def eval(self):
self._register_schema_tracking_hook()
return super().eval()
23 changes: 23 additions & 0 deletions tests/unit/torch/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from merlin.models.torch.block import Block, ParallelBlock
from merlin.models.torch.container import BlockContainer, BlockContainerDict
from merlin.models.torch.utils import module_utils
from merlin.schema import Tags


class PlusOne(nn.Module):
Expand All @@ -50,6 +51,14 @@ def test_identity(self):

assert torch.equal(inputs, outputs)

schema = block.output_schema()
assert schema.first.dtype.name == str(outputs.dtype).split(".")[-1]

def test_no_schema_tracking(self):
block = Block(track_schema=False)
with pytest.raises(RuntimeError, match="Schema-tracking hook not registered"):
block.output_schema()

def test_insertion(self):
block = Block()
block.prepend(PlusOne())
Expand Down Expand Up @@ -148,6 +157,20 @@ def test_forward_dict_duplicate(self):
with pytest.raises(RuntimeError):
pb(inputs)

def test_schema_tracking(self):
pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()})

inputs = torch.randn(1, 3)
outputs = pb(inputs)

schema = pb.output_schema()

for name in outputs:
assert name in schema.column_names
assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1]

assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2

def test_forward_tuple(self):
inputs = torch.randn(1, 3)
pb = ParallelBlock({"test": PlusOneTuple()})
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/torch/utils/test_schema_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import torch
from torch import nn

from merlin.models.torch.utils.module_utils import module_test
from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin
from merlin.schema import Schema, Tags


class TrackedModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)

def forward(self, x: torch.Tensor):
return self.linear(x)


class TrackedDictModule(SchemaTrackingMixin, nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LazyLinear(10)

def forward(self, x: torch.Tensor):
return {"a": self.linear(x), "b": self.linear(x)}


class TestSchemaTrackingMixin:
def test_tensor(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedModule()
module_test(tracked_module, inputs)

schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 1
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1

def test_dict(self):
inputs = torch.randn(1, 5)
tracked_module = TrackedDictModule()

outputs = tracked_module(inputs)
traced_outputs = module_test(tracked_module, inputs)
assert torch.equal(outputs["a"], traced_outputs["a"])
assert torch.equal(outputs["b"], traced_outputs["b"])

schema = tracked_module.output_schema()
assert isinstance(schema, Schema)
assert len(schema) == 2
assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2

def test_exception(self):
tracked_module = TrackedModule()
with pytest.raises(RuntimeError):
tracked_module.output_schema()

def test_train(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))

tracked_module.train()
assert not tracked_module._forward_called

def test_eval(self):
tracked_module = TrackedModule()
tracked_module(torch.randn(1, 5))

tracked_module.eval()
assert not tracked_module._forward_called

0 comments on commit b6d6645

Please sign in to comment.