diff --git a/transformers4rec/config/transformer.py b/transformers4rec/config/transformer.py index 7e2356ed50..72d32332a1 100644 --- a/transformers4rec/config/transformer.py +++ b/transformers4rec/config/transformer.py @@ -15,13 +15,55 @@ # import transformers +from merlin.models.utils.doc_utils import docstring_parameter from merlin.models.utils.registry import Registry transformer_registry: Registry = Registry("transformers") +TRANSFORMER_CONFIG_PARAMETER_DOCSTRING = """ + d_model: int + The hidden dimension of the transformer layer. + n_head: int + The number of attention heads in each transformer layer. + n_layer: int + The number of transformer layers to stack. + total_seq_length: int + The maximum sequence length. + hidden_act: str, optional + The activation function in the hidden layers. + By default 'gelu' + initializer_range: float, optional + The standard deviation of the `truncated_normal_initializer` + for initializing all transformer's weights parameters. + By default 0.01 + layer_norm_eps: float, optional + The epsilon used by the layer normalization layers. + By default 0.03 + dropout: float, optional + The dropout probability. By default 0.3 + pad_token: int, optional + The padding token ID. By default 0 + log_attention_weights: bool, optional + Whether to log attention weights. By default False +""" + + class T4RecConfig: + """A class responsible for setting the configuration of the transformers class + from Hugging Face and returning the corresponding T4Rec model. + """ + def to_huggingface_torch_model(self): + """ + Instantiate a Hugging Face transformer model based on + the configuration parameters of the class. + + Returns + ------- + transformers.PreTrainedModel + The Hugging Face transformer model. + """ model_cls = transformers.MODEL_MAPPING[self.transformers_config_cls] return model_cls(self) @@ -35,6 +77,37 @@ def to_torch_model( loss_reduction="mean", **kwargs ): + """Links the Hugging Face transformer model to the given input block and prediction tasks, + and returns a T4Rec model. + + Parameters + ---------- + input_features: torch4rec.TabularSequenceFeatures + The sequential block that represents the input features and + defines the masking strategy for training and evaluation. + prediction_task: torch4rec.PredictionTask + One or multiple prediction tasks. + task_blocks: list, optional + List of task-specific blocks that we apply on top of the HF transformer's output. + task_weights: list, optional + List of the weights to use for combining the tasks losses. + loss_reduction: str, optional + The reduction to apply to the prediction losses, possible values are: + 'none': no reduction will be applied, + 'mean': the weighted mean of the output is taken, + 'sum': the output will be summed. + By default: 'mean'. + + Returns + ------- + torch4rec.Model + The T4Rec torch model. + + Raises + ------ + ValueError + If input block or prediction task is of the wrong type. + """ from .. import torch as torch4rec if not isinstance(input_features, torch4rec.TabularSequenceFeatures): @@ -68,6 +141,11 @@ def build(cls, *args, **kwargs): @transformer_registry.register("reformer") class ReformerConfig(T4RecConfig, transformers.ReformerConfig): + """Subclass of T4RecConfig and transformers.ReformerConfig from Hugging Face. + It handles configuration for Reformer layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -84,6 +162,21 @@ def build( axial_pos_shape_first_dim=4, **kwargs ): + """ + Creates an instance of ReformerConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + axial_pos_shape_first_dim: int, optional + The first dimension of the axial position encodings. + During training, the product of the position dims has to be equal to the sequence length. + + Returns + ------- + ReformerConfig + An instance of ReformerConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -115,7 +208,12 @@ def build( @transformer_registry.register("gtp2") +@docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) class GPT2Config(T4RecConfig, transformers.GPT2Config): + """Subclass of T4RecConfig and transformers.GPT2Config from Hugging Face. + It handles configuration for GPT2 layers in the context of T4Rec models. + """ + @classmethod def build( cls, @@ -131,6 +229,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of GPT2Config with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + GPT2Config + An instance of GPT2Config. + """ return cls( n_embd=d_model, n_inner=d_model * 4, @@ -152,6 +262,11 @@ def build( @transformer_registry.register("longformer") class LongformerConfig(T4RecConfig, transformers.LongformerConfig): + """Subclass of T4RecConfig and transformers.LongformerConfig from Hugging Face. + It handles configuration for LongformerConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -167,6 +282,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of LongformerConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + LongformerConfig + An instance of LongformerConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -187,6 +314,11 @@ def build( @transformer_registry.register("electra") class ElectraConfig(T4RecConfig, transformers.ElectraConfig): + """Subclass of T4RecConfig and transformers.ElectraConfig from Hugging Face. + It handles configuration for ElectraConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -202,6 +334,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of ElectraConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + ElectraConfig + An instance of ElectraConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -224,6 +368,11 @@ def build( @transformer_registry.register("albert") class AlbertConfig(T4RecConfig, transformers.AlbertConfig): + """Subclass of T4RecConfig and transformers.AlbertConfig from Hugging Face. + It handles configuration for AlbertConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -239,6 +388,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of AlbertConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + AlbertConfig + An instance of AlbertConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -260,7 +421,13 @@ def build( @transformer_registry.register("xlnet") +@docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) class XLNetConfig(T4RecConfig, transformers.XLNetConfig): + """Subclass of T4RecConfig and transformers.XLNetConfig from Hugging Face. + It handles configuration for XLNetConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -278,6 +445,25 @@ def build( mem_len=1, **kwargs ): + """ + Creates an instance of XLNetConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + mem_len: int, + The number of tokens to be cached. Pre-computed key/value pairs + from a previous forward pass are stored and won't be re-computed. + This parameter is especially useful for long sequence modeling where + different batches may truncate the entire sequence. + Tasks like user-aware recommendation could benefit from this feature. + By default, this parameter is set to 1, which means no caching is used. + + Returns + ------- + XLNetConfig + An instance of XLNetConfig. + """ return cls( d_model=d_model, d_inner=d_model * 4, @@ -298,6 +484,11 @@ def build( @transformer_registry.register("bert") class BertConfig(T4RecConfig, transformers.BertConfig): + """Subclass of T4RecConfig and transformers.BertConfig from Hugging Face. + It handles configuration for BertConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -313,6 +504,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of BertConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + BertConfig + An instance of BertConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -333,6 +536,11 @@ def build( @transformer_registry.register("roberta") class RobertaConfig(T4RecConfig, transformers.RobertaConfig): + """Subclass of T4RecConfig and transformers.RobertaConfig from Hugging Face. + It handles configuration for RobertaConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -348,6 +556,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of RobertaConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + RobertaConfig + An instance of RobertaConfig. + """ # To account for target positions at inference mode, we extend the maximum sequence length. total_seq_length = total_seq_length + 2 return cls( @@ -368,6 +588,11 @@ def build( @transformer_registry.register("transfo-xl") class TransfoXLConfig(T4RecConfig, transformers.TransfoXLConfig): + """Subclass of T4RecConfig and transformers. TransfoXLConfig from Hugging Face. + It handles configuration for TransfoXLConfig layers in the context of T4Rec models. + """ + + @docstring_parameter(transformer_cfg_parameters=TRANSFORMER_CONFIG_PARAMETER_DOCSTRING) @classmethod def build( cls, @@ -383,6 +608,18 @@ def build( log_attention_weights=False, **kwargs ): + """ + Creates an instance of TransfoXLConfig with the given parameters. + + Parameters + ---------- + {transformer_cfg_parameters} + + Returns + ------- + TransfoXLConfig + An instance of TransfoXLConfig. + """ return cls( d_model=d_model, d_embed=d_model, diff --git a/transformers4rec/data/dataset.py b/transformers4rec/data/dataset.py index bb081f0f00..346f15a33d 100644 --- a/transformers4rec/data/dataset.py +++ b/transformers4rec/data/dataset.py @@ -25,6 +25,15 @@ class Dataset: + """Supports creating synthetic data for PyTorch and TensorFlow + based on a provided schema. + + Parameters + ---------- + schema_path : str + Path to the schema file. + """ + def __init__(self, schema_path: str): self.schema_path = schema_path if self.schema_path.endswith(".pb") or self.schema_path.endswith(".pbtxt"): @@ -38,11 +47,43 @@ def schema(self) -> Schema: @property def merlin_schema(self) -> CoreSchema: + """Convert the schema from merlin-standardlib to merlin-core schema""" return TensorflowMetadata.from_json(self.schema.to_json()).to_merlin_schema() def torch_synthetic_data( - self, num_rows=100, min_session_length=5, max_session_length=20, device=None, ragged=False + self, + num_rows: Optional[int] = 100, + min_session_length: Optional[int] = 5, + max_session_length: Optional[int] = 20, + device: Optional[str] = None, + ragged: Optional[bool] = False, ): + """ + Generates a dictionary of synthetic tensors based on the schema. + + Parameters + ---------- + num_rows : Optional[int] + Number of rows, + by default 100. + min_session_length : int, optional + Minimum session length, + by default 5. + max_session_length : int, optional + Maximum session length, + by default 20. + device : torch.device, optional + The device on which tensors should be created, + by default None. + ragged : bool, optional + Whether sequence tensors should be represented with `__values` and `__offsets`, + by default False. + + Returns + ------- + Dict[torch.Tensor] + Dictionary of tensors. + """ from transformers4rec.torch.utils import schema_utils return schema_utils.random_data_from_schema( @@ -55,6 +96,32 @@ def torch_synthetic_data( ) def tf_synthetic_data(self, num_rows=100, min_session_length=5, max_session_length=20): + """ + Generates a dictionary of synthetic tensors based on the schema. + + Parameters + ---------- + num_rows : Optional[int] + Number of rows, + by default 100. + min_session_length : int, optional + Minimum session length, + by default 5. + max_session_length : int, optional + Maximum session length, + by default 20. + device : torch.device, optional + The device on which tensors should be created, + by default None. + ragged : bool, optional + Whether sequence tensors should be represented with `__values` and `__offsets`, + by default False. + + Returns + ------- + Dict[tf.Tensor] + Dictionary of tensors. + """ from transformers4rec.tf.utils import schema_utils return schema_utils.random_data_from_schema( @@ -66,11 +133,30 @@ def tf_synthetic_data(self, num_rows=100, min_session_length=5, max_session_leng class ParquetDataset(Dataset): + """ + Class to read data from a Parquet file and load it as a Dataset. + + Parameters + ---------- + dir : str + Path to the directory containing the data and schema files. + parquet_file_name : Optional[str] + Name of the Parquet data file. + By default "data.parquet". + schema_file_name : Optional[str] + Name of the JSON schema file. + By default "schema.json". + schema_path : Optional[str] + Full path to the schema file. + If None, it will be constructed using `dir` and `schema_file_name`. + By default None. + """ + def __init__( self, dir, - parquet_file_name="data.parquet", - schema_file_name="schema.json", + parquet_file_name: Optional[str] = "data.parquet", + schema_file_name: Optional[str] = "schema.json", schema_path: Optional[str] = None, ): super(ParquetDataset, self).__init__(schema_path or os.path.join(dir, schema_file_name)) diff --git a/transformers4rec/torch/block/base.py b/transformers4rec/torch/block/base.py index 8d5a56fec2..b67df61fb2 100644 --- a/transformers4rec/torch/block/base.py +++ b/transformers4rec/torch/block/base.py @@ -30,7 +30,28 @@ class BlockBase(torch_utils.OutputSizeMixin, torch.nn.Module, metaclass=abc.ABCMeta): + """A subclass of PyTorch's torch.nn.Module, providing additional functionality + for dealing with automatic setting of input/output dimensions of neural networks layers. + Specifically, It implements the 'OutputSizeMixin' for managing output sizes. + """ + def to_model(self, prediction_task_or_head, inputs=None, **kwargs): + """Converts the BlockBase instance into a T4Rec model by attaching it to + attaching a 'Head' or 'PredictionTask'. + + Parameters + ---------- + prediction_task_or_head : Union[PredictionTask, Head] + A PredictionTask or Head instance to attach to this block. + inputs :InputBlock, optional + The input block representing input features. + By default None + + Raises + ------ + ValueError + If prediction_task_or_head is neither a Head nor a PredictionTask. + """ from ..model.base import Head, Model, PredictionTask if isinstance(prediction_task_or_head, PredictionTask): @@ -46,6 +67,15 @@ def to_model(self, prediction_task_or_head, inputs=None, **kwargs): return Model(head, **kwargs) def as_tabular(self, name=None): + """Converts the output of the block into a dictionary, keyed by the + provided name + + Parameters + ---------- + name : str, optional + The output name, if not provided, uses the name of the block class. + by default None + """ from ..tabular.base import AsTabular if not name: @@ -55,6 +85,18 @@ def as_tabular(self, name=None): class Block(BlockBase): + """Wraps a PyTorch module, allowing it to be used as a block in a T4Rec model. + It carries the module and its expected output size. + + Parameters + ---------- + module: torch.nn.Module + The PyTorch module to be wrapped in this block. + output_size: Union[List[int], torch.Size] + The expected output size of the module. + + """ + def __init__(self, module: torch.nn.Module, output_size: Union[List[int], torch.Size]): super().__init__() self.module = module @@ -64,6 +106,20 @@ def forward(self, inputs, **kwargs): return self.module(inputs, **kwargs) def forward_output_size(self, input_size): + """ + Calculates the output size of the tensor(s) returned by the forward pass, + given the input size. + + Parameters + ---------- + input_size: Union[List[int], torch.Size] + The size of the input tensor(s) to the module. + + Returns + ------- + Union[List[int], torch.Size] + The size of the output from the module. + """ if self._output_size[0] is None: batch_size = torch_utils.calculate_batch_size_from_input_size(input_size) @@ -73,7 +129,21 @@ def forward_output_size(self, input_size): class SequentialBlock(BlockBase, torch.nn.Sequential): - def __init__(self, *args, output_size=None): + """Extends the module torch.nn.Sequential. It's used for creating + a sequence of layers or blocks in a T4Rec model. The modules + will be applied to inputs in the order they are passed in the constructor. + + Parameters + ---------- + *args: + The list of PyTorch modules. + output_size : Union[List[int], torch.Size], optional + The expected output size from the last layer in the sequential block + By default None + + """ + + def __init__(self, *args, output_size: Union[List[int], torch.Size] = None): from transformers4rec.torch import TabularSequenceFeatures, TransformerBlock if isinstance(args[0], TabularSequenceFeatures) and any( @@ -94,12 +164,12 @@ def __init__(self, *args, output_size=None): if len(args) == 1 and isinstance(args[0], OrderedDict): last = None - for idx, key, module in enumerate(args[0].items()): - self.add_module_and_maybe_build(key, module, last, idx) - last = module + for idx, key, module in enumerate(args[0].items()): # type: ignore + self.add_module_and_maybe_build(key, module, last, idx) # type: ignore + last = module # type: ignore else: if len(args) == 1 and isinstance(args[0], list): - args = args[0] + args = args[0] # type: ignore last = None for idx, module in enumerate(args): last = self.add_module_and_maybe_build(str(idx), module, last, idx) @@ -113,6 +183,18 @@ def inputs(self): return first def add_module(self, name: str, module: Optional[Module]) -> None: + """ + Adds a PyTorch module to the sequential block. If a list of strings is provided, + a `FilterFeatures` block gets added to the sequential block. + + Parameters + ---------- + name : str + The name of the child module. The child module can be accessed + from this module using the given name. + module : Optional[Union[List[str], Module]] + The child module to be added to the module. + """ from ..tabular.base import FilterFeatures if isinstance(module, list): @@ -120,6 +202,19 @@ def add_module(self, name: str, module: Optional[Module]) -> None: super().add_module(name, module) def add_module_and_maybe_build(self, name: str, module, parent, idx) -> torch.nn.Module: + """Checks if a module needs to be built and adds it to the sequential block. + + Parameters + ---------- + name : str + The name of the child module. + module : torch.nn.Module + The child module to be added to the sequential block. + parent : torch.nn.Module + The parent module. + idx : int + The index of the current module in the sequential block. + """ # Check if module needs to be built if getattr(parent, "output_size", None) and getattr(module, "build", None): module = module.build(parent.output_size()) @@ -139,8 +234,18 @@ def __rshift__(self, other): return right_shift_block(other, self) def forward(self, input, training=False, testing=False, **kwargs): - # from transformers4rec.torch import TabularSequenceFeatures - + """Applies the module's layers sequentially to the input block. + + Parameters + ---------- + input : tensor + The input to the block. + training : bool, optional + Whether the block is in training mode. The default is False. + testing : bool, optional + Whether the block is in testing mode. The default is False. + + """ for i, module in enumerate(self): if i == len(self) - 1: filtered_kwargs = filter_kwargs(kwargs, module, cascade_kwargs_if_possible=True) @@ -157,6 +262,20 @@ def forward(self, input, training=False, testing=False, **kwargs): return input def build(self, input_size, schema=None, **kwargs): + """Builds the layers of the sequential block given the input size. + + Parameters + ---------- + input_size : Union[List[int], torch.Size] + The size of the input tensor(s). + schema : Schema, optional + The schema of the inputs features, by default None + + Returns + ------- + SequentialBlock + The built sequential block. + """ output_size = input_size for module in self: if not hasattr(module, "build"): @@ -167,6 +286,15 @@ def build(self, input_size, schema=None, **kwargs): return super(SequentialBlock, self).build(input_size, schema=None, **kwargs) def as_tabular(self, name=None): + """Converts the output of the block into a dictionary, keyed by the + provided name + + Parameters + ---------- + name : str, optional + The output name, if not provided, uses the name of the block class. + by default None + """ from transformers4rec.torch import AsTabular if not name: @@ -180,6 +308,20 @@ def __add__(self, other): return merge_tabular(self, other) def forward_output_size(self, input_size): + """ + Calculates the output size of the tensor(s) returned by the forward pass, + given the input size. + + Parameters + ---------- + input_size: Union[List[int], torch.Size] + The size of the input tensor(s) to the module. + + Returns + ------- + Union[List[int], torch.Size] + The size of the output from the module. + """ if self._static_output_size: return self._static_output_size @@ -212,10 +354,26 @@ def add_if_class_name_matches(to_check): def build_blocks(*modules): + """Builds a SequentialBlock from a list of PyTorch modules. + + Parameters + ---------- + *modules : List[torch.nn.Module] + List containing PyTorch modules. + + Returns + ------- + A SequentialBlock instance created from the provided modules. + """ return list(SequentialBlock(*modules)) class BuildableBlock(abc.ABC): + """ + Abstract base class for buildable blocks. + Subclasses of BuildableBlock must implement the `build` method + """ + @abc.abstractmethod def build(self, input_size) -> BlockBase: raise NotImplementedError diff --git a/transformers4rec/torch/block/mlp.py b/transformers4rec/torch/block/mlp.py index 4a870adfe6..78453ac9c4 100644 --- a/transformers4rec/torch/block/mlp.py +++ b/transformers4rec/torch/block/mlp.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import List, Optional +from typing import List, Optional, Union import torch @@ -21,13 +21,36 @@ class MLPBlock(BuildableBlock): + """Defines Multi-Layer Perceptron (MLP) Block by stacking + multiple DenseBlock instances. + + Parameters + ---------- + dimensions : int or list of int + The dimensions of the layers in the MLP. + If an integer is provided, a single layer MLP is created. + If a list is provided, it must contain the size of each layer in order. + activation : optional + The activation function to apply after each layer. + By default `torch.nn.ReLU`. + use_bias : bool, optional + Whether to add a bias term to the dense layers. + by default True + dropout : float, optional + The dropout rate to apply after each layer, by default None + normalization : str, optional + The normalization to apply after each layer, by default None + filter_features : List[str], optional + List of features to select from the input., by default None + """ + def __init__( self, dimensions, activation=torch.nn.ReLU, use_bias: bool = True, - dropout=None, - normalization=None, + dropout: float = None, + normalization: str = None, filter_features=None, ) -> None: super().__init__() @@ -65,9 +88,41 @@ def build(self, input_shape) -> SequentialBlock: class DenseBlock(SequentialBlock): + """ + A buildable dense Block to represent a fully connected layer. + + Parameters + ---------- + input_shape : Union[List[int], torch.Size] + The shape of the input tensor. + + in_features : int + Number of input features. + + out_features : int + Number of output features. + + activation : torch.nn.Module, optional + The activation function to apply after the linear layer. + By default `torch.nn.ReLU`. + + use_bias : bool, optional + Whether to use bias in the layer. + By default True. + + dropout : float, optional + The dropout rate to apply after the dense layer, if any. + By default is None. + + normalization : str, optional + The type of normalization to apply after the dense layer. + Only 'batch_norm' is supported. + By default is None. + """ + def __init__( self, - input_shape, + input_shape: Union[List[int], torch.Size], in_features: int, out_features: int, activation=torch.nn.ReLU, diff --git a/transformers4rec/torch/block/transformer.py b/transformers4rec/torch/block/transformer.py index f1a4e2056c..3c5d93346a 100644 --- a/transformers4rec/torch/block/transformer.py +++ b/transformers4rec/torch/block/transformer.py @@ -30,7 +30,20 @@ class TransformerPrepare(torch.nn.Module): - def __init__(self, transformer, masking): + """ + Base class to prepare additional inputs to the forward call of + the HF transformer layer. + + Parameters + ---------- + transformer : TransformerBody + The Transformer module. + masking : Optional[MaskSequence] + Masking block used to for masking input sequences. + By default None. + """ + + def __init__(self, transformer: TransformerBody, masking: Optional[MaskSequence] = None): super().__init__() self.transformer = transformer self.masking = masking @@ -40,6 +53,12 @@ def forward(self, inputs_embeds) -> Dict[str, Any]: class GPT2Prepare(TransformerPrepare): + """TransformerPrepare module for GPT-2. + + This class extends the inputs for GPT-2 with a + triangular causal mask to the inputs. + """ + def forward(self, inputs_embeds) -> Dict[str, Any]: seq_len = inputs_embeds.shape[1] # head_mask has shape n_layer x batch x n_heads x N x N diff --git a/transformers4rec/torch/features/embedding.py b/transformers4rec/torch/features/embedding.py index afa0ec5ca8..e77eef4450 100644 --- a/transformers4rec/torch/features/embedding.py +++ b/transformers4rec/torch/features/embedding.py @@ -256,6 +256,14 @@ def forward_output_size(self, input_sizes): class EmbeddingBagWrapper(torch.nn.EmbeddingBag): + """ + Wrapper class for the PyTorch EmbeddingBag module. + + This class extends the torch.nn.EmbeddingBag class and overrides + the forward method to handle 1D tensor inputs + by reshaping them to 2D as required by the EmbeddingBag. + """ + def forward(self, input, **kwargs): # EmbeddingBag requires 2D tensors (or offsets) if len(input.shape) == 1: @@ -404,6 +412,29 @@ def table_to_embedding_module(self, table: "TableConfig") -> "SoftEmbedding": class TableConfig: + """ + Class to configure the embeddings lookup table for a categorical feature. + + Attributes + ---------- + vocabulary_size : int + The size of the vocabulary, + i.e., the cardinality of the categorical feature. + dim : int + The dimensionality of the embedding vectors. + initializer : Optional[Callable[[torch.Tensor], None]] + The initializer function for the embedding weights. + If None, the weights are initialized using a normal + distribution with mean 0.0 and standard deviation 0.05. + combiner : Optional[str] + The combiner operation used to aggregate bag of embeddings. + Possible options are "mean", "sum", and "sqrtn". + By default "mean". + name : Optional[str] + The name of the lookup table. + By default None. + """ + def __init__( self, vocabulary_size: int, @@ -448,6 +479,23 @@ def __repr__(self): class FeatureConfig: + """ + Class to set the embeddings table of a categorical feature + with a maximum sequence length. + + Attributes + ---------- + table : TableConfig + Configuration for the lookup table, + which is used for embedding lookup and aggregation. + max_sequence_length : int, optional + Maximum sequence length for sequence features. + By default 0. + name : str, optional + The feature name. + By default None + """ + def __init__( self, table: TableConfig, max_sequence_length: int = 0, name: Optional[Text] = None ): diff --git a/transformers4rec/torch/model/prediction_task.py b/transformers4rec/torch/model/prediction_task.py index 52d6cd7076..9aef9a0b32 100644 --- a/transformers4rec/torch/model/prediction_task.py +++ b/transformers4rec/torch/model/prediction_task.py @@ -33,7 +33,26 @@ class BinaryClassificationPrepareBlock(BuildableBlock): + """Prepares the output layer of the binary classification prediction task. + The output layer is a SequentialBlock of a torch linear + layer followed by a sigmoid activation and a squeeze operation. + """ + def build(self, input_size) -> SequentialBlock: + """Builds the output layer of binary classification based on the input_size. + + Parameters + ---------- + input_size: Tuple[int] + The size of the input tensor, specifically the last dimension is + used for setting the input dimension of the linear layer. + + Returns + ------- + SequentialBlock + A SequentialBlock consisting of a linear layer (with input dimension equal to the last + dimension of input_size), a sigmoid activation, and a squeeze operation. + """ return SequentialBlock( torch.nn.Linear(input_size[-1], 1, bias=False), torch.nn.Sigmoid(), @@ -155,7 +174,26 @@ def __init__( class RegressionPrepareBlock(BuildableBlock): + """Prepares the output layer of the regression prediction task. + The output layer is a SequentialBlock of a torch linear + layer followed by a squeeze operation. + """ + def build(self, input_size) -> SequentialBlock: + """Builds the output layer of regression based on the input_size. + + Parameters + ---------- + input_size: Tuple[int] + The size of the input tensor, specifically the last dimension is + used for setting the input dimension of the linear layer. + + Returns + ------- + SequentialBlock + A SequentialBlock consisting of a linear layer (with input dimension equal to + the last dimension of input_size), and a squeeze operation. + """ return SequentialBlock( torch.nn.Linear(input_size[-1], 1), LambdaModule(lambda x: torch.squeeze(x, -1)), @@ -166,6 +204,81 @@ def build(self, input_size) -> SequentialBlock: class RegressionTask(PredictionTask): + """Returns a ``PredictionTask`` for regression. + + Example usage:: + + # Define the input module to process the tabular input features. + input_module = tr.TabularSequenceFeatures.from_schema( + schema, + max_sequence_length=max_sequence_length, + continuous_projection=d_model, + aggregation="concat", + masking=None, + ) + + # Define XLNetConfig class and set default parameters for HF XLNet config. + transformer_config = tr.XLNetConfig.build( + d_model=d_model, n_head=4, n_layer=2, total_seq_length=max_sequence_length + ) + + # Define the model block including: inputs, projection and transformer block. + body = tr.SequentialBlock( + input_module, + tr.MLPBlock([64]), + tr.TransformerBlock( + transformer_config, + ) + ) + + # Define a head with BinaryClassificationTask. + head = tr.Head( + body, + tr.RegressionTask( + "watch_time", + summary_type="mean", + metrics=[tm.regression.MeanSquaredError()] + ), + inputs=input_module, + ) + + # Get the end-to-end Model class. + model = tr.Model(head) + + Parameters + ---------- + + target_name: Optional[str] + Specifies the variable name that represents the continuous value to predict. + By default None + + task_name: Optional[str] + Specifies the name of the prediction task. If this parameter is not specified, + a name is automatically constructed based on ``target_name`` and the Python + class name of the model. + By default None + + task_block: Optional[BlockType] = None + Specifies a module to transform the input tensor before computing predictions. + + loss: torch.nn.Module + Specifies the loss function for the task. + The default class is ``torch.nn.MSELoss``. + + metrics: Tuple[torch.nn.Module, ...] + Specifies the metrics to calculate during training and evaluation. + The default metric is MeanSquaredError. + + summary_type: str + Summarizes a sequence into a single tensor. Accepted values are: + + - ``last`` -- Take the last token hidden state (like XLNet) + - ``first`` -- Take the first token hidden state (like Bert) + - ``mean`` -- Take the mean of all tokens hidden states + - ``cls_index`` -- Supply a Tensor of classification token position (GPT/GPT-2) + - ``attn`` -- Not implemented now, use multi-head attention + """ + DEFAULT_LOSS = torch.nn.MSELoss() DEFAULT_METRICS = (tm.regression.MeanSquaredError(),) @@ -395,6 +508,32 @@ def compute_metrics(self): class NextItemPredictionPrepareBlock(BuildableBlock): + """Prepares the output layer of the next item prediction task. + The output layer is a an instance of `_NextItemPredictionTask` class. + + Parameters + ---------- + target_dim: int + The output dimension for next-item predictions. + weight_tying: bool, optional + If true, ties the weights of the prediction layer and the item embedding layer. + By default False. + item_embedding_table: torch.nn.Module, optional + The module containing the item embedding table. + By default None. + softmax_temperature: float, optional + The temperature to be applied to the softmax function. Defaults to 0. + sampled_softmax: bool, optional + If true, sampled softmax is used for approximating the full softmax function. + By default False. + max_n_samples: int, optional + The maximum number of samples when using sampled softmax. + By default 100. + min_id: int, optional + The minimum value of the range for the log-uniform sampling. + By default 0. + """ + def __init__( self, target_dim: int, @@ -415,6 +554,18 @@ def __init__( self.min_id = min_id def build(self, input_size) -> Block: + """Builds the output layer of next-item prediction based on the input_size. + + Parameters + ---------- + input_size : Tuple[int] + The size of the input tensor, specifically the last dimension is + used for setting the input dimension of the output layer. + Returns + ------- + Block[_NextItemPredictionTask] + an instance of _NextItemPredictionTask + """ return Block( _NextItemPredictionTask( input_size, @@ -505,7 +656,7 @@ def forward( if self.sampled_softmax and training: logits, targets = self.sampled(inputs, targets, output_weights) else: - logits = inputs @ output_weights.t() + logits = inputs @ output_weights.t() # type: ignore if self.softmax_temperature: # Softmax temperature to reduce model overconfidence @@ -625,7 +776,9 @@ def get_log_uniform_distr(self, max_id: int, min_id: int = 0) -> torch.Tensor: log_indices = torch.arange(1.0, max_id - min_id + 2.0, 1.0).log_() probs = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] if min_id > 0: - probs = torch.cat([torch.zeros([min_id], dtype=probs.dtype), probs], axis=0) + probs = torch.cat( + [torch.zeros([min_id], dtype=probs.dtype), probs], axis=0 + ) # type: ignore return probs def get_unique_sampling_distr(self, dist, n_sample): @@ -682,9 +835,9 @@ def sample(self, labels: torch.Tensor): n_tries = self.n_sample with torch.no_grad(): - neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()[ - : self.max_n_samples - ] + neg_samples = torch.multinomial( + self.dist, n_tries, replacement=True # type: ignore + ).unique()[: self.max_n_samples] device = labels.device neg_samples = neg_samples.to(device) @@ -694,8 +847,8 @@ def sample(self, labels: torch.Tensor): else: dist = self.dist - true_probs = dist[labels] - samples_probs = dist[neg_samples] + true_probs = dist[labels] # type: ignore + samples_probs = dist[neg_samples] # type: ignore return neg_samples, true_probs, samples_probs diff --git a/transformers4rec/torch/tabular/aggregation.py b/transformers4rec/torch/tabular/aggregation.py index dd5729c2eb..1efd85dc44 100644 --- a/transformers4rec/torch/tabular/aggregation.py +++ b/transformers4rec/torch/tabular/aggregation.py @@ -98,7 +98,20 @@ def forward_output_size(self, input_size): class ElementwiseFeatureAggregation(TabularAggregation): - def _check_input_shapes_equal(self, inputs): + """Base class for aggregation methods that aggregates + features element-wise. + It implements two check methods to ensure inputs have the correct shape. + """ + + def _check_input_shapes_equal(self, inputs: TabularData): + """Checks if the shapes of all inputs are equal. + + Parameters + ---------- + inputs : TabularData + Dictionary of tensors. + + """ all_input_shapes_equal = len(set([x.shape for x in inputs.values()])) == 1 if not all_input_shapes_equal: raise ValueError( @@ -107,6 +120,14 @@ def _check_input_shapes_equal(self, inputs): ) def _check_inputs_last_dim_equal(self, inputs_sizes): + """ + Checks if the last dimensions of all inputs are equal. + + Parameters + ---------- + inputs_sizes : dict[str, Union[List[int], torch.Size]] + A dictionary containing the sizes of the inputs. + """ all_input_last_dim_equal = len(set([x[-1] for x in inputs_sizes.values()])) == 1 if not all_input_last_dim_equal: raise ValueError( diff --git a/transformers4rec/torch/utils/schema_utils.py b/transformers4rec/torch/utils/schema_utils.py index 2eed3de403..1666f9901e 100644 --- a/transformers4rec/torch/utils/schema_utils.py +++ b/transformers4rec/torch/utils/schema_utils.py @@ -35,6 +35,38 @@ def random_data_from_schema( ragged=False, seed=0, ) -> TabularData: + """Generates random tabular data based on a given schema. + The generated data can be used for testing + data preprocessing or model training pipelines. + + Parameters + ---------- + schema : Schema + The schema to be used for generating the random tabular data. + num_rows : int + The number of rows. + max_session_length : Optional[int] + The maximum session length. + If None, the session length will not be limited. + By default None + min_session_length : int + The minimum session length. + By default 5 + device : torch.device + The device on which the synthetic data should be created. + If None, the synthetic data will be created on the CPU. + By default None + ragged : bool + If True, the sequence features will be represented with __values and __offsets. + By default False + + Returns + ------- + TabularData + A dictionary where each key is a feature name and each value is the generated + tensor. + + """ data: Dict[str, Any] = {} random.seed(seed)