diff --git a/merlin/models/torch/block.py b/merlin/models/torch/block.py index e9e42579b9..f1bd188453 100644 --- a/merlin/models/torch/block.py +++ b/merlin/models/torch/block.py @@ -24,10 +24,11 @@ from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.link import Link, LinkType from merlin.models.torch.registry import registry +from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin from merlin.models.utils.registry import RegistryMixin -class Block(BlockContainer, RegistryMixin): +class Block(BlockContainer, SchemaTrackingMixin, RegistryMixin): """A base-class that calls it's modules sequentially. Parameters @@ -36,12 +37,16 @@ class Block(BlockContainer, RegistryMixin): Variable length argument list of PyTorch modules to be contained in the block. name : Optional[str], default = None The name of the block. If None, no name is assigned. + track_schema : bool, default = True + If True, the schema of the output tensors are tracked. """ registry = registry - def __init__(self, *module: nn.Module, name: Optional[str] = None): + def __init__(self, *module: nn.Module, name: Optional[str] = None, track_schema: bool = True): super().__init__(*module, name=name) + if track_schema: + self._register_schema_tracking_hook() def forward( self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None @@ -138,17 +143,16 @@ class ParallelBlock(Block): Variable length argument list of PyTorch modules to be contained in the block. name : Optional[str], default = None The name of the block. If None, no name is assigned. + track_schema : bool, default = True + If True, the schema of the output tensors are tracked. """ - def __init__( - self, - *inputs: Union[nn.Module, Dict[str, nn.Module]], - ): + def __init__(self, *inputs: Union[nn.Module, Dict[str, nn.Module]], track_schema: bool = True): pre = BlockContainer(name="pre") branches = BlockContainerDict(*inputs) post = BlockContainer(name="post") - super().__init__() + super().__init__(track_schema=track_schema) self.pre = pre self.branches = branches diff --git a/merlin/models/torch/utils/schema_utils.py b/merlin/models/torch/utils/schema_utils.py new file mode 100644 index 0000000000..9feb7cbd8f --- /dev/null +++ b/merlin/models/torch/utils/schema_utils.py @@ -0,0 +1,98 @@ +import torch + +from merlin.schema import ColumnSchema, Schema, Tags + + +class SchemaTrackingMixin: + """ + A mixin class for PyTorch modules to track the output shapes and dtypes + of the forward pass. This is used in order to automatically generate + the output-schema. + + It registers a hook to capture this information and + provides methods to access the output schema, as well as to set the module + in training or evaluation mode. + """ + + def __init__(self): + super().__init__() + self._register_schema_tracking_hook() + + def _post_forward_hook(self, module, input, output): + """Hook function to be called after the forward pass of the module. + + Parameters + ---------- + module : torch.nn.Module + The module for which the forward pass was called. + input : tuple + The input arguments passed to the forward method. + output : torch.Tensor or dict + The output of the forward method. + """ + if not module._forward_called: + if isinstance(output, dict): + for key, value in output.items(): + module._output_shapes[key] = value.shape + module._output_dtypes[key] = value.dtype + else: + module._output_shapes["output"] = output.shape + module._output_dtypes["output"] = output.dtype + module._forward_called = True + module._handle.remove() + + def _register_schema_tracking_hook(self): + """ + Register the post forward hook to the module. + """ + self._forward_called = False + self._handle = None + self._output_shapes = {} + self._output_dtypes = {} + + if self._handle is None: + self._handle = self.register_forward_hook(self._post_forward_hook) + + def output_schema(self) -> Schema: + """Get the output schema of the module. + + Returns + ------- + Schema + The output schema of the module. + + Raises + ------ + RuntimeError + If forward() has not been called before calling this method. + """ + + if not hasattr(self, "_output_shapes"): + raise RuntimeError( + "Schema-tracking hook not registered, use `_register_schema_tracking_hook`." + ) + + if not self._forward_called: + raise RuntimeError("forward() must be called before output_schema() can be called.") + + columns = [] + + for name, shape in self._output_shapes.items(): + dtype = self._output_dtypes[name] + dims = (None,) + tuple(shape) + tags = None + + if len(shape) > 1 and dtype != torch.int32: + tags = [Tags.EMBEDDING] + + columns.append(ColumnSchema(name, dims=dims, tags=tags, dtype=dtype)) + + return Schema(columns) + + def train(self, mode=True): + self._register_schema_tracking_hook() + return super().train(mode) + + def eval(self): + self._register_schema_tracking_hook() + return super().eval() diff --git a/tests/unit/torch/test_block.py b/tests/unit/torch/test_block.py index d863d54ce9..f564ec9fcd 100644 --- a/tests/unit/torch/test_block.py +++ b/tests/unit/torch/test_block.py @@ -24,6 +24,7 @@ from merlin.models.torch.block import Block, ParallelBlock from merlin.models.torch.container import BlockContainer, BlockContainerDict from merlin.models.torch.utils import module_utils +from merlin.schema import Tags class PlusOne(nn.Module): @@ -50,6 +51,14 @@ def test_identity(self): assert torch.equal(inputs, outputs) + schema = block.output_schema() + assert schema.first.dtype.name == str(outputs.dtype).split(".")[-1] + + def test_no_schema_tracking(self): + block = Block(track_schema=False) + with pytest.raises(RuntimeError, match="Schema-tracking hook not registered"): + block.output_schema() + def test_insertion(self): block = Block() block.prepend(PlusOne()) @@ -148,6 +157,20 @@ def test_forward_dict_duplicate(self): with pytest.raises(RuntimeError): pb(inputs) + def test_schema_tracking(self): + pb = ParallelBlock({"a": PlusOne(), "b": PlusOne()}) + + inputs = torch.randn(1, 3) + outputs = pb(inputs) + + schema = pb.output_schema() + + for name in outputs: + assert name in schema.column_names + assert schema[name].dtype.name == str(outputs[name].dtype).split(".")[-1] + + assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2 + def test_forward_tuple(self): inputs = torch.randn(1, 3) pb = ParallelBlock({"test": PlusOneTuple()}) diff --git a/tests/unit/torch/utils/test_schema_utils.py b/tests/unit/torch/utils/test_schema_utils.py new file mode 100644 index 0000000000..76b703ea2c --- /dev/null +++ b/tests/unit/torch/utils/test_schema_utils.py @@ -0,0 +1,70 @@ +import pytest +import torch +from torch import nn + +from merlin.models.torch.utils.module_utils import module_test +from merlin.models.torch.utils.schema_utils import SchemaTrackingMixin +from merlin.schema import Schema, Tags + + +class TrackedModule(SchemaTrackingMixin, nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.LazyLinear(10) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + +class TrackedDictModule(SchemaTrackingMixin, nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.LazyLinear(10) + + def forward(self, x: torch.Tensor): + return {"a": self.linear(x), "b": self.linear(x)} + + +class TestSchemaTrackingMixin: + def test_tensor(self): + inputs = torch.randn(1, 5) + tracked_module = TrackedModule() + module_test(tracked_module, inputs) + + schema = tracked_module.output_schema() + assert isinstance(schema, Schema) + assert len(schema) == 1 + assert len(schema.select_by_tag(Tags.EMBEDDING)) == 1 + + def test_dict(self): + inputs = torch.randn(1, 5) + tracked_module = TrackedDictModule() + + outputs = tracked_module(inputs) + traced_outputs = module_test(tracked_module, inputs) + assert torch.equal(outputs["a"], traced_outputs["a"]) + assert torch.equal(outputs["b"], traced_outputs["b"]) + + schema = tracked_module.output_schema() + assert isinstance(schema, Schema) + assert len(schema) == 2 + assert len(schema.select_by_tag(Tags.EMBEDDING)) == 2 + + def test_exception(self): + tracked_module = TrackedModule() + with pytest.raises(RuntimeError): + tracked_module.output_schema() + + def test_train(self): + tracked_module = TrackedModule() + tracked_module(torch.randn(1, 5)) + + tracked_module.train() + assert not tracked_module._forward_called + + def test_eval(self): + tracked_module = TrackedModule() + tracked_module(torch.randn(1, 5)) + + tracked_module.eval() + assert not tracked_module._forward_called