diff --git a/merlin/models/torch/__init__.py b/merlin/models/torch/__init__.py index 4042784156..32f7e5b9b5 100644 --- a/merlin/models/torch/__init__.py +++ b/merlin/models/torch/__init__.py @@ -29,6 +29,7 @@ from merlin.models.torch.blocks.attention import CrossAttentionBlock from merlin.models.torch.blocks.dlrm import DLRMBlock from merlin.models.torch.blocks.experts import CGCBlock, MMOEBlock, PLEBlock +from merlin.models.torch.blocks.llama import LlamaBlock, LlamaConfig from merlin.models.torch.blocks.mlp import MLPBlock from merlin.models.torch.functional import map, walk from merlin.models.torch.inputs.embedding import EmbeddingTable, EmbeddingTables @@ -107,4 +108,6 @@ "CrossAttentionBlock", "map", "walk", + "LlamaBlock", + "LlamaConfig", ] diff --git a/merlin/models/torch/blocks/attention.py b/merlin/models/torch/blocks/attention.py index 6a64fa1297..3e6b161fae 100644 --- a/merlin/models/torch/blocks/attention.py +++ b/merlin/models/torch/blocks/attention.py @@ -6,6 +6,16 @@ from merlin.models.torch.batch import Batch from merlin.models.torch.block import Block +from merlin.models.utils.doc_utils import docstring_parameter + +_ROPE_REF = """ + .. [1] Su, et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding". + arXiv preprint arXiv:2104.09864 (2021). +""" +_TRANSFORMER_REF = """ + .. [1] Vaswani, et al., "Attention Is All You Need". + arXiv preprint arXiv:1706.03762 (2017). +""" class CrossAttentionBlock(Block): @@ -166,3 +176,239 @@ def get_seq(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: raise RuntimeError(f"Could not find {self.seq_key} in input dictionary, got: {x}.") return x[self.seq_key] + + +@docstring_parameter(rope_reference=_ROPE_REF) +class RotaryEmbeddings(nn.Module): + """Rotary Position Embedding (RoPE) as proposed in [1]. + + References + ---------- + {rope_reference} + """ + + def __init__(self, dim: int, max_seq_length: int, base: int = 10000) -> None: + super().__init__() + self.max_seq_length = max_seq_length + self.dim = dim + self.base = base + + self.cache = None + + def initialize( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> None: + inverse_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=dtype, device=device) / self.dim) + ) + self.register_buffer("inverse_freq", inverse_freq, persistent=False) + + position = torch.arange(self.max_seq_length, dtype=dtype, device=device) + freq = torch.outer(position, self.inverse_freq).float() + cache = torch.stack([torch.cos(freq), torch.sin(freq)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.half() + + self.cache = cache + self._is_initialized = True + + def forward( + self, + inputs: torch.Tensor, + positions: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + if self.cache is None: + self.initialize(device=device, dtype=dtype) + + batch_size, seq_length, width, _ = inputs.size() + + if positions is not None: + _cache = self.cache.index_select(0, positions) + else: + _cache = self.cache[:seq_length] + + _inputs = inputs.float().reshape(batch_size, seq_length, width, -1, 2) + _cache = _cache.view(1, _inputs.size(1), 1, _inputs.size(3), 2) + outputs = torch.stack( + [ + _inputs[..., 0] * _cache[..., 0] - _inputs[..., 1] * _cache[..., 1], + _inputs[..., 1] * _cache[..., 0] + _inputs[..., 0] * _cache[..., 1], + ], + -1, + ) + + return outputs.flatten(3).type_as(inputs) + + +@torch.jit.script +class AttentionMask: + def __init__(self, bool_mask: torch.Tensor) -> None: + self.bool_mask = bool_mask + + def select(self, seq_length: int) -> torch.Tensor: + return self.bool_mask[:, :, :seq_length, :seq_length] + + def select_position(self, position: torch.Tensor) -> torch.Tensor: + return self.bool_mask.index_select(2, position) + + +def create_attention_mask( + max_seq_length: int, device: Optional[torch.device] = None +) -> torch.Tensor: + ones = torch.ones( + (max_seq_length, max_seq_length), + device=device, + dtype=torch.bool, + ) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) + + +@torch.jit.script +class KeyValueCache: + def __init__( + self, + key: torch.Tensor, + value: torch.Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + self.key = key + self.value = value + + def cache( + self, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + dim: int = -2, + max_seq_length: Optional[int] = None, + ): + if max_seq_length is None: + max_seq_length = key.size(dim) + + cached_key, cached_value = self.key, self.value + # check if reached token limit + if positions[-1] >= max_seq_length: + positions = torch.tensor(max_seq_length - 1, device=positions.device) + # shift 1 position to the left + cached_key = torch.roll(cached_key, -1, dims=dim) + cached_value = torch.roll(cached_value, -1, dims=dim) + key = cached_key.index_copy(dim, positions, key) + value = cached_value.index_copy(dim, positions, value) + + self.key, self.value = key, value + + return key, value + + +@docstring_parameter( + transformer_reference=_TRANSFORMER_REF, rope_reference=_ROPE_REF.replace("[1]", "[2]") +) +class CausalSelfAttention(nn.Module): + """Transformer self-attention [1]. + + The key difference between our implementation and PyTorch implemention, + i.e., ``torch.nn.MultiheadAttention`` is that Rotary Position Embedding [2] + is applied to the key and query matrices. ``torch.nn.MultiheadAttention`` + is currently too rigid to support such variation. + + References + ---------- + {transformer_reference} + {rope_reference} + """ + + def __init__( + self, + num_heads: int, + embedding_dim: int, + max_seq_length: int, + bias: bool = False, + dropout_p: float = 0.0, + store_cache: bool = True, + kv_cache: Optional[KeyValueCache] = None, + rotary_embeds: Optional[RotaryEmbeddings] = None, + ) -> None: + super().__init__() + + if embedding_dim % num_heads != 0: + raise ValueError( + "The embedding dimension must be divible by the number of self-attention heads" + ) + + self.num_heads = num_heads + self.embedding_dim = embedding_dim + self.max_seq_length = max_seq_length + self.bias = bias + self.dropout_p = dropout_p + self.store_cache = store_cache + self.kv_cache = kv_cache + self.rotary_embeds = rotary_embeds + + # query, key, and value projections for all heads, but in a batch. + self.qkv_projection = nn.Linear(embedding_dim, 3 * embedding_dim, bias=self.bias) + self.output_projection = nn.Linear(embedding_dim, embedding_dim, bias=self.bias) + + def forward( + self, + x: torch.Tensor, + positions: Optional[torch.Tensor] = None, + mask: Optional[AttentionMask] = None, + ) -> torch.Tensor: + batch_size, seq_length, embedding_dim = x.size() + + if self.store_cache and self.kv_cache is None: + head_size = self.embedding_dim // self.num_heads + cache_shape = (batch_size, self.num_heads, self.max_seq_length, head_size) + self.kv_cache = KeyValueCache( + key=torch.zeros(cache_shape, device=x.device, dtype=x.dtype), + value=torch.zeros(cache_shape, device=x.device, dtype=x.dtype), + ) + + # calculate query, key, values for all heads in batch + # and move head forward to be the batch dim + q, k, v = self.qkv_projection(x).split(self.embedding_dim, dim=2) + + head_size = embedding_dim // self.num_heads + k = k.view(batch_size, seq_length, self.num_heads, head_size) + q = q.view(batch_size, seq_length, self.num_heads, head_size) + v = v.view(batch_size, seq_length, self.num_heads, head_size) + + if self.rotary_embeds is not None: + q = self.rotary_embeds(q, positions) + k = self.rotary_embeds(k, positions) + + k = k.transpose(1, 2) + q = q.transpose(1, 2) + v = v.transpose(1, 2) + + if self.kv_cache is not None and positions is not None: + k, v = self.kv_cache.cache( + key=k, + value=v, + positions=positions, + max_seq_length=self.max_seq_length, + ) + + if mask is not None: + if positions is not None: + attn_mask = mask.select_position(positions) + else: + attn_mask = mask.select(seq_length) + else: + attn_mask = None + + # efficient attention using Flash Attention CUDA kernels + y = nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=self.dropout_p + ) + + y = y.transpose(1, 2).contiguous().view(batch_size, seq_length, embedding_dim) + + y = self.output_projection(y) + + return y diff --git a/merlin/models/torch/blocks/llama.py b/merlin/models/torch/blocks/llama.py new file mode 100644 index 0000000000..29caf75881 --- /dev/null +++ b/merlin/models/torch/blocks/llama.py @@ -0,0 +1,237 @@ +from dataclasses import dataclass +from typing import Dict, Optional, TypeVar, Union + +import torch +import torch.nn as nn + +from merlin.models.torch.batch import Batch +from merlin.models.torch.block import Block +from merlin.models.torch.blocks.attention import ( + AttentionMask, + CausalSelfAttention, + RotaryEmbeddings, + create_attention_mask, +) +from merlin.models.torch.blocks.mlp import PositionwiseFeedForward +from merlin.models.torch.transforms.regularization import RMSNorm +from merlin.models.torch.utils.llama_utils import ( + convert_checkpoint, + find_multiple, + llama_model_lookup, +) +from merlin.models.utils.doc_utils import docstring_parameter + +Self = TypeVar("Self", bound="LlamaBlock") + +_LLAMA_REF = """ + .. [1] Touvron, et al., "LLaMA: Open and Efficient Foundation Language Models". + arXiv preprint arXiv:2302.13971 (2023). +""" + + +@dataclass +class LlamaConfig: + max_seq_length: int = 2048 + vocab_size: int = 32_000 + padded_vocab_size: Optional[int] = None + num_layers: int = 32 + num_heads: int = 32 + embedding_dim: int = 4096 + + def __post_init__(self): + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple(self.vocab_size, 64) + + @classmethod + def from_name(cls, name: str) -> Self: + return cls(**LLAMA_CONFIGS[name]) + + +LLAMA_CONFIGS = { + "7B": dict(num_layers=32, num_heads=32, embedding_dim=4096), + "13B": dict(num_layers=40, num_heads=40, embedding_dim=5120), + "30B": dict(num_layers=60, num_heads=52, embedding_dim=6656), + "65B": dict(num_layers=80, num_heads=64, embedding_dim=8192), +} + + +class _LlamaBaseBlock(Block): + def __init__( + self, + config: LlamaConfig, + token_key: Optional[str] = None, + position_key: Optional[str] = None, + ) -> None: + super().__init__() + + assert config.padded_vocab_size is not None + + self.config = config + self.token_key = token_key or "token" + self.position_key = position_key or "position" + + @classmethod + def from_name(cls, model_size: str) -> Self: + return cls(LlamaConfig.from_name(model_size)) + + def reset_cache(self) -> None: + raise NotImplementedError + + def get_tokens(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: + return inputs[self.token_key] + + def get_positions(self, inputs: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]: + return inputs.get(self.position_key) + + +@docstring_parameter(llama_reference=_LLAMA_REF) +class LlamaBlock(_LlamaBaseBlock): + """Llama-2 [1] + + References + ---------- + {llama_reference} + """ + + def __init__( + self, + config: LlamaConfig, + token_key: Optional[str] = None, + position_key: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + token_key=token_key, + position_key=position_key, + ) + self.transformer = LlamaTransformer(config) + self.output_embeddings = nn.Linear( + config.embedding_dim, config.padded_vocab_size, bias=False + ) + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ): + outputs = self.transformer(inputs) + logits = self.output_embeddings(outputs) + return logits + + @classmethod + def from_checkpoint( + cls, + checkpoint_dir, + model_size: Optional[str] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + state_dict = convert_checkpoint(checkpoint_dir, model_size) + model_size = model_size or llama_model_lookup(state_dict) + model = cls.from_name(model_size) + model.load_state_dict(state_dict) + return model + + def reset_cache(self) -> None: + for head in self.transformer.heads: + head.attention.kv_cache = None + + +class LlamaTransformer(_LlamaBaseBlock): + def __init__( + self, + config: LlamaConfig, + token_key: Optional[str] = None, + position_key: Optional[str] = None, + ) -> None: + super().__init__( + config=config, + token_key=token_key, + position_key=position_key, + ) + + self.rotary_embeds = RotaryEmbeddings( + self.config.embedding_dim // self.config.num_heads, + self.config.max_seq_length, + ) + self.mask_cache = AttentionMask( + create_attention_mask(max_seq_length=self.config.max_seq_length) + ) + + self.token_embeddings = nn.Embedding(config.padded_vocab_size, config.embedding_dim) + self.heads = nn.ModuleList( + LlamaAttentionHead( + num_heads=self.config.num_heads, + embedding_dim=self.config.embedding_dim, + max_seq_length=self.config.max_seq_length, + rotary_embeds=self.rotary_embeds, + ) + for _ in range(config.num_layers) + ) + self.layernorm = RMSNorm(config.embedding_dim) + + def forward( + self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], batch: Optional[Batch] = None + ): + if isinstance(inputs, torch.Tensor): + tokens, positions = inputs, None + else: + tokens, positions = self.get_tokens(inputs), self.get_positions(inputs) + + batch_size, seq_length = tokens.size() + + x = self.token_embeddings(tokens) + + for head in self.heads: + x = head( + x, + positions=positions, + mask=self.mask_cache, + ) + + x = self.layernorm(x) + + return x + + def reset_cache(self) -> None: + for head in self.heads: + head.attention.kv_cache = None + + +class LlamaAttentionHead(nn.Module): + def __init__( + self, + num_heads: int, + embedding_dim: int, + max_seq_length: int, + rotary_embeds: Optional[RotaryEmbeddings] = None, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.embedding_dim = embedding_dim + self.max_seq_length = max_seq_length + self.rotary_embeds = rotary_embeds + + self.input_layernorm = RMSNorm(self.embedding_dim) + self.attention = CausalSelfAttention( + num_heads=self.num_heads, + embedding_dim=self.embedding_dim, + max_seq_length=self.max_seq_length, + rotary_embeds=self.rotary_embeds, + ) + self.post_attention_layernorm = RMSNorm(self.embedding_dim) + + self.mlp = PositionwiseFeedForward(self.embedding_dim, bias=False, activation=nn.SiLU) + + def forward( + self, + x: torch.Tensor, + positions: Optional[torch.Tensor] = None, + mask: Optional[AttentionMask] = None, + ) -> torch.Tensor: + x = x + self.attention( + self.input_layernorm(x), + positions=positions, + mask=mask, + ) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x diff --git a/merlin/models/torch/blocks/mlp.py b/merlin/models/torch/blocks/mlp.py index 8038dc89f7..b5784d1d98 100644 --- a/merlin/models/torch/blocks/mlp.py +++ b/merlin/models/torch/blocks/mlp.py @@ -4,8 +4,11 @@ from torch import nn from merlin.models.torch.block import Block +from merlin.models.torch.blocks.attention import _TRANSFORMER_REF from merlin.models.torch.schema import Schema, output_schema from merlin.models.torch.transforms.agg import Concat, MaybeAgg +from merlin.models.torch.utils.llama_utils import find_multiple +from merlin.models.utils.doc_utils import docstring_parameter class MLPBlock(Block): @@ -84,8 +87,42 @@ def __init__( super().__init__(*modules) +@docstring_parameter(transformer_ref=_TRANSFORMER_REF) +class PositionwiseFeedForward(nn.Module): + """Position-wise Feed-Forward network as proposed in Section 3.3 of [1]. + + References + ---------- + {transformer_ref} + """ + + def __init__( + self, + embedding_dim: int, + intermediate_dim: Optional[int] = None, + bias: bool = False, + activation=nn.ReLU, + ): + super().__init__() + + if intermediate_dim is None: + hidden_dim = 4 * embedding_dim + intermediate_dim = int(2 * hidden_dim / 3) + intermediate_dim = find_multiple(intermediate_dim, 256) + + self.weights_1 = nn.Linear(embedding_dim, intermediate_dim, bias=bias) + self.weights_2 = nn.Linear(embedding_dim, intermediate_dim, bias=bias) + self.projection = nn.Linear(intermediate_dim, embedding_dim, bias=bias) + self.activation = activation if isinstance(activation, nn.Module) else activation() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + outputs = self.projection(self.activation(self.weights_1(inputs)) * self.weights_2(inputs)) + return outputs + + @output_schema.register(nn.LazyLinear) @output_schema.register(nn.Linear) @output_schema.register(MLPBlock) +@output_schema.register(PositionwiseFeedForward) def _output_schema_block(module: nn.LazyLinear, inputs: Schema): return output_schema.tensors(torch.ones((1, module.out_features), dtype=float)) diff --git a/merlin/models/torch/transforms/regularization.py b/merlin/models/torch/transforms/regularization.py new file mode 100644 index 0000000000..2b54a9ecaa --- /dev/null +++ b/merlin/models/torch/transforms/regularization.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from merlin.models.utils.doc_utils import docstring_parameter + +_RMSNORM_REF = """ + .. [1] Zhang and Sennrich, "Root Mean Square Layer Normalization". + arXiv preprintarXiv:1910.07467 (2019). +""" + + +@docstring_parameter(rmsnorm_reference=_RMSNORM_REF) +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization as proposed in [1]. + + References + ---------- + {rmsnorm_reference} + """ + + def __init__(self, dim: int, eps: float = 1e-5) -> None: + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + rms = tensor.to(torch.float32).square().mean(-1, keepdim=True).add(self.eps).rsqrt() + return (tensor * rms).to(tensor.dtype) * self.scale diff --git a/merlin/models/torch/utils/llama_utils.py b/merlin/models/torch/utils/llama_utils.py new file mode 100644 index 0000000000..e020acff9a --- /dev/null +++ b/merlin/models/torch/utils/llama_utils.py @@ -0,0 +1,142 @@ +import gc +from pathlib import Path +from typing import Dict + +import torch +from tqdm import tqdm + + +def llama_model_lookup(checkpoint: dict) -> str: + """Returns the LLaMA model name from the checkpoint.""" + from merlin.models.torch.blocks.llama import LLAMA_CONFIGS + + embedding_dim = checkpoint["transformer.token_embeddings.weight"].shape[1] + for name, configs in LLAMA_CONFIGS.items(): + if configs["embedding_dim"] == embedding_dim: + return name + + raise RuntimeError("Could not find model name from checkpoint.") + + +def convert_state_dict( + state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32 +) -> Dict[str, torch.Tensor]: + converted = {} + converted["transformer.token_embeddings.weight"] = state_dict["tok_embeddings.weight"].to(dtype) + converted["output_embeddings.weight"] = state_dict["output.weight"].to(dtype) + converted["transformer.layernorm.scale"] = state_dict["norm.weight"].to(dtype) + + for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): + # attention + # the wq, wk, wv from the FB model are stacked in our model. + converted[f"transformer.heads.{layer_idx}.attention.qkv_projection.weight"] = torch.cat( + ( + state_dict[f"layers.{layer_idx}.attention.wq.weight"].to(dtype), + state_dict[f"layers.{layer_idx}.attention.wk.weight"].to(dtype), + state_dict[f"layers.{layer_idx}.attention.wv.weight"].to(dtype), + ) + ) + converted[f"transformer.heads.{layer_idx}.attention.output_projection.weight"] = state_dict[ + f"layers.{layer_idx}.attention.wo.weight" + ].to(dtype) + # mlp + converted[f"transformer.heads.{layer_idx}.mlp.weights_1.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w1.weight" + ].to(dtype) + converted[f"transformer.heads.{layer_idx}.mlp.projection.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w2.weight" + ].to(dtype) + converted[f"transformer.heads.{layer_idx}.mlp.weights_2.weight"] = state_dict[ + f"layers.{layer_idx}.feed_forward.w3.weight" + ].to(dtype) + # rms norm + converted[f"transformer.heads.{layer_idx}.input_layernorm.scale"] = state_dict[ + f"layers.{layer_idx}.attention_norm.weight" + ].to(dtype) + converted[f"transformer.heads.{layer_idx}.post_attention_layernorm.scale"] = state_dict[ + f"layers.{layer_idx}.ffn_norm.weight" + ].to(dtype) + return converted + + +shard_dims = { + "output_embeddings.weight": 0, + "token_embeddings.weight": 1, + "attention.qkv_projection.weight": 0, + "attention.output_projection.weight": 1, + "mlp.weights_1.weight": 0, + "mlp.weights_2.weight": 0, + "mlp.projection.weight": 1, +} + + +def convert_checkpoint( + checkpoint_dir, + model_size: str = "7B", + dtype: str = "float32", +) -> None: + if isinstance(checkpoint_dir, str): + checkpoint_dir = Path(checkpoint_dir) + + dt = getattr(torch, dtype, None) + if not isinstance(dt, torch.dtype): + raise ValueError(f"{dtype} is not a valid dtype.") + dtype = dt + + checkpoint_files = sorted(checkpoint_dir.glob("*.pth")) + checkpoint_files.sort() + n_checkpoints = len(checkpoint_files) + + if n_checkpoints == 0: + raise RuntimeError( + f"No checkpoints were found at checkpoint_dir {checkpoint_dir}." + " `consolidated.0*.pth` files expected at that location." + ) + + # for the bigger models, there are multiple model-parallel checkpoints + # and we combine them into one single file + combined = None + for file in tqdm(checkpoint_files, total=n_checkpoints): + checkpoint = torch.load(file, map_location="cpu") + converted = convert_state_dict(checkpoint, dtype=dtype) + if combined is None: + combined = converted + continue + for name, param in converted.items(): + dim = None + for k, d in shard_dims.items(): + if k in name: + dim = d + break + if dim is None: + continue + combined[name] = torch.cat((combined[name], param), dim=dim) + + del checkpoint + del converted + gc.collect() + + for name, param in combined.items(): + if "c_attn" not in name: + continue + + src_chunk_len = param.shape[0] // n_checkpoints + mat_len = src_chunk_len // 3 + dst_chunk_len = mat_len * n_checkpoints + attn = torch.clone(param) + for i in range(n_checkpoints): + for j in range(3): + param[ + j * dst_chunk_len + i * mat_len : j * dst_chunk_len + (i + 1) * mat_len + ] = attn[i * src_chunk_len + j * mat_len : i * src_chunk_len + (j + 1) * mat_len] + + del attn + gc.collect() + + return combined + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) diff --git a/tests/unit/torch/blocks/test_attention.py b/tests/unit/torch/blocks/test_attention.py index 0c8926e285..cc38455a4b 100644 --- a/tests/unit/torch/blocks/test_attention.py +++ b/tests/unit/torch/blocks/test_attention.py @@ -2,7 +2,11 @@ import torch from torch import nn -from merlin.models.torch.blocks.attention import CrossAttentionBlock +from merlin.models.torch.blocks.attention import ( + CausalSelfAttention, + CrossAttentionBlock, + RotaryEmbeddings, +) from merlin.models.torch.utils import module_utils @@ -48,3 +52,29 @@ def test_get_seq_error(self): cross.get_seq( {"context": torch.randn(1, 10), "0": torch.randn(1, 10), "1": torch.randn(1, 10)} ) + + +class TestRotaryEmbeddings: + def test_forward(self): + batch_size, seq_length, num_heads, embedding_dim = 1, 6, 2, 8 + rotary_embeds = RotaryEmbeddings(embedding_dim // num_heads, seq_length) + inputs = torch.randint( + 0, + 10_000, + size=(batch_size, seq_length, num_heads, embedding_dim // num_heads), + ).float() + outputs = rotary_embeds(inputs) + assert inputs.size() == outputs.size() + + +class TestCausalSelfAttention: + def test_forward(self): + batch_size, seq_length, num_heads, embedding_dim = 1, 6, 2, 8 + attention = CausalSelfAttention(num_heads, embedding_dim, seq_length) + inputs = torch.randint( + 0, + 10_000, + size=(batch_size, seq_length, embedding_dim), + ).float() + outputs = attention(inputs) + assert inputs.size() == outputs.size() diff --git a/tests/unit/torch/blocks/test_llama.py b/tests/unit/torch/blocks/test_llama.py new file mode 100644 index 0000000000..b7fb95279d --- /dev/null +++ b/tests/unit/torch/blocks/test_llama.py @@ -0,0 +1,122 @@ +import torch + +import merlin.models.torch as mm +from merlin.models.torch.blocks.llama import LlamaTransformer, LlamaAttentionHead +from merlin.models.torch.utils import module_utils + + +class TestLlamaBlock: + def setup_method(self): + self.llama_config = mm.LlamaConfig( + max_seq_length=64, + vocab_size=100, + num_layers=1, + num_heads=1, + embedding_dim=16, + ) + self.llama = mm.LlamaBlock(self.llama_config) + self.input_dict = { + "token": torch.tensor([[1, 3, 36, 2, 10]]), + "position": torch.tensor([0, 1, 2, 3, 4]), + } + + def test_forward(self): + assert "position" in self.input_dict + out = self.llama(self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.padded_vocab_size + + def test_forward_without_position(self): + self.input_dict.pop("position") + assert "position" not in self.input_dict + out = self.llama(self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.padded_vocab_size + + def test_forward_tensor(self): + assert "position" in self.input_dict + out = self.llama(self.input_dict["token"]) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.padded_vocab_size + + def test_forward_torchscript(self): + assert "position" in self.input_dict + out = module_utils.module_test(self.llama, self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.padded_vocab_size + + def test_reset_cache(self): + _ = self.llama(self.input_dict) + assert all(h.attention.kv_cache is not None for h in self.llama.transformer.heads) + self.llama.reset_cache() + assert all(h.attention.kv_cache is None for h in self.llama.transformer.heads) + + +class TestLlamaTransformer: + def setup_method(self): + self.llama_config = mm.LlamaConfig( + max_seq_length=64, + vocab_size=100, + num_layers=1, + num_heads=1, + embedding_dim=16, + ) + self.transformer = LlamaTransformer(self.llama_config) + self.input_dict = { + "token": torch.tensor([[1, 3, 36, 2, 10]]), + "position": torch.tensor([0, 1, 2, 3, 4]), + } + + def test_forward(self): + assert "position" in self.input_dict + out = self.transformer(self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.embedding_dim + + def test_forward_without_position(self): + self.input_dict.pop("position") + assert "position" not in self.input_dict + out = self.transformer(self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.embedding_dim + + def test_forward_tensor(self): + assert "position" in self.input_dict + out = self.transformer(self.input_dict["token"]) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.embedding_dim + + def test_forward_torchscript(self): + assert "position" in self.input_dict + out = module_utils.module_test(self.transformer, self.input_dict) + assert isinstance(out, torch.Tensor) + assert out.shape[:-1] == self.input_dict["token"].shape + assert out.shape[-1] == self.llama_config.embedding_dim + + def test_reset_cache(self): + _ = self.transformer(self.input_dict) + assert all(h.attention.kv_cache is not None for h in self.transformer.heads) + self.transformer.reset_cache() + assert all(h.attention.kv_cache is None for h in self.transformer.heads) + + +class TestLlamaAttentionHead: + def test_forward(self): + batch_size = 2 + embedding_dim = 16 + max_seq_length = 64 + attn_head = LlamaAttentionHead( + num_heads=1, + embedding_dim=embedding_dim, + max_seq_length=max_seq_length, + ) + inputs = torch.randn(batch_size, max_seq_length, embedding_dim) + outputs = attn_head(inputs) + assert outputs.size() == inputs.size() diff --git a/tests/unit/torch/blocks/test_mlp.py b/tests/unit/torch/blocks/test_mlp.py index 6c69c06fff..66dd8e2176 100644 --- a/tests/unit/torch/blocks/test_mlp.py +++ b/tests/unit/torch/blocks/test_mlp.py @@ -3,7 +3,7 @@ from torch import nn from merlin.models.torch.block import Block -from merlin.models.torch.blocks.mlp import MLPBlock +from merlin.models.torch.blocks.mlp import MLPBlock, PositionwiseFeedForward from merlin.models.torch.utils import module_utils @@ -68,3 +68,35 @@ def test_forward(self): inputs = {"a": torch.randn(32, 2), "b": torch.randn(32, 2)} outputs = module_utils.module_test(mlp, inputs) assert outputs.shape == torch.Size([32, 32]) + + +class TestPositionwiseFeedForward: + def test_forward(self): + mlp = PositionwiseFeedForward(32) + inputs = torch.randn(16, 32) + outputs = mlp(inputs) + assert inputs.size() == outputs.size() + + def test_hidden_layer(self): + hidden_dim = 256 + mlp = PositionwiseFeedForward(32, intermediate_dim=hidden_dim) + inputs = torch.randn(16, 32) + assert mlp.weights_1(inputs).size()[-1] == hidden_dim + assert mlp.weights_2(inputs).size()[-1] == hidden_dim + + def test_bias(self): + mlp = PositionwiseFeedForward(32) + assert mlp.weights_1.bias is None + assert mlp.weights_2.bias is None + + mlp = PositionwiseFeedForward(32, bias=True) + assert isinstance(mlp.weights_1.bias, torch.Tensor) + assert isinstance(mlp.weights_2.bias, torch.Tensor) + + def test_activation(self): + silu = nn.SiLU() + mlp = PositionwiseFeedForward(32, activation=silu) + inputs = torch.randn(16, 32) + outputs = mlp(inputs) + expected = mlp.projection(silu(mlp.weights_1(inputs)) * mlp.weights_2(inputs)) + assert torch.allclose(outputs, expected) diff --git a/tests/unit/torch/transforms/test_regularization.py b/tests/unit/torch/transforms/test_regularization.py new file mode 100644 index 0000000000..1cfdae64e1 --- /dev/null +++ b/tests/unit/torch/transforms/test_regularization.py @@ -0,0 +1,17 @@ +import torch + +from merlin.models.torch.transforms.regularization import RMSNorm + + +class TestRMSNorm: + def test_init(self): + eps = 2e-5 + rms_norm = RMSNorm(8, eps=eps) + assert isinstance(rms_norm.scale, torch.nn.Parameter) + assert rms_norm.eps == eps + + def test_forward(self): + rms_norm = RMSNorm(8) + inputs = torch.randn(2, 4, 8) + outputs = rms_norm(inputs) + assert inputs.size() == outputs.size()