diff --git a/tests/layers/test_packed.py b/tests/layers/test_packed.py index bf1fd7ac..cfcee746 100644 --- a/tests/layers/test_packed.py +++ b/tests/layers/test_packed.py @@ -2,6 +2,10 @@ import torch from einops import repeat +from torch_uncertainty.layers.functional.packed import ( + packed_in_projection_packed, + packed_multi_head_attention_forward, +) from torch_uncertainty.layers.packed import ( PackedConv1d, PackedConv2d, @@ -9,6 +13,8 @@ PackedLayerNorm, PackedLinear, PackedMultiheadAttention, + PackedTransformerDecoderLayer, + PackedTransformerEncoderLayer, ) @@ -67,6 +73,12 @@ def batched_qkv() -> torch.Tensor: return torch.rand((2, 3, 6)) +@pytest.fixture() +def extended_batched_qkv() -> torch.Tensor: + expansion = 2 + return torch.rand((2, 3, 6 * expansion)) + + @pytest.fixture() def batched_q_kv() -> tuple[torch.Tensor, torch.Tensor]: return ( @@ -84,6 +96,38 @@ def batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) +@pytest.fixture() +def extended_batched_q_k_v() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + expansion = 2 + return ( + torch.rand((2, 3, 6 * expansion)), + torch.rand((2, 4, 2 * expansion)), + torch.rand((2, 4, 4 * expansion)), + ) + + +@pytest.fixture() +def unbatched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + return torch.rand((3, 6)), torch.rand((4, 6)) + + +@pytest.fixture() +def batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.rand((2, 3, 6)), + torch.rand((2, 4, 6)), + ) + + +@pytest.fixture() +def extended_batched_tgt_memory() -> tuple[torch.Tensor, torch.Tensor]: + expansion = 2 + return ( + torch.rand((2, 3, 6 * expansion)), + torch.rand((2, 4, 6 * expansion)), + ) + + class TestPackedLinear: """Testing the PackedLinear layer class.""" @@ -113,7 +157,7 @@ def test_linear_two_estimator_rearrange_not_divisible(self): def test_linear_full_implementation( self, feat_input_16_features: torch.Tensor, feat_multi_dim: torch.Tensor ): - layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full") + layer = PackedLinear(16, 4, alpha=1, num_estimators=1, implementation="full", bias=False) out = layer(feat_input_16_features) assert out.shape == torch.Size([3, 4]) layer = PackedLinear(16, 4, alpha=1, num_estimators=2, implementation="full") @@ -330,9 +374,11 @@ class TestPackedMultiheadAttention: """Testing the PackedMultiheadAttention layer class.""" def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + attn_mask = torch.zeros(1, 3, 3, dtype=torch.bool) + layer = PackedMultiheadAttention( embed_dim=6, - num_heads=2, + num_heads=1, alpha=1, num_estimators=1, ) @@ -340,14 +386,18 @@ def test_one_estimator_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch query=unbatched_qkv, key=unbatched_qkv, value=unbatched_qkv, + attn_mask=attn_mask, ) assert out.shape == torch.Size([3, 6]) unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + attn_mask = torch.zeros(2, 3, 3, dtype=torch.bool) out, _ = layer( query=unbatched_qkv, key=unbatched_qkv, value=unbatched_qkv, + attn_mask=attn_mask, + is_causal=True, ) assert out.shape == torch.Size([3, 2, 6]) @@ -402,6 +452,7 @@ def test_one_estimator_q_kv( kdim=2, vdim=2, batch_first=True, + bias=False, ) out, _ = layer( query=batched_q_kv[0], @@ -417,17 +468,21 @@ def test_one_estimator_q_k_v( ): layer = PackedMultiheadAttention( embed_dim=6, - num_heads=2, + num_heads=1, alpha=1, num_estimators=1, kdim=2, vdim=4, add_bias_kv=True, ) + + key_padding_mask = torch.zeros(4, dtype=torch.bool) + out, _ = layer( query=unbatched_q_k_v[0], key=unbatched_q_k_v[1], value=unbatched_q_k_v[2], + key_padding_mask=key_padding_mask, ) assert out.shape == torch.Size([3, 6]) @@ -465,10 +520,459 @@ def test_one_estimator_q_k_v( assert out.shape == torch.Size([2, 3, 6]) assert out.isfinite().all() + def test_two_estimators_qkv(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + ) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out, _ = layer( + query=unbatched_qkv, + key=unbatched_qkv, + value=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + batch_first=True, + ) + out, _ = layer( + query=batched_qkv, + key=batched_qkv, + value=batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators_q_kv( + self, + unbatched_q_kv: tuple[torch.Tensor, torch.Tensor], + batched_q_kv: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=2, + add_zero_attn=True, + ) + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + ) + assert out.shape == torch.Size([3, 6]) + unbatched_q_kv = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_kv) + + attn_mask = torch.zeros(12, 3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + out, _ = layer( + query=unbatched_q_kv[0], + key=unbatched_q_kv[1], + value=unbatched_q_kv[1], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=2, + batch_first=True, + ) + out, _ = layer( + query=batched_q_kv[0], + key=batched_q_kv[1], + value=batched_q_kv[1], + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators_q_k_v( + self, + unbatched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + extended_batched_q_k_v: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ): + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=1, + num_estimators=2, + kdim=2, + vdim=4, + add_bias_kv=True, + ) + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_q_k_v = tuple(repeat(seq, "l h -> l b h", b=2) for seq in unbatched_q_k_v) + + attn_mask = torch.zeros(3, 4, dtype=torch.bool) + key_padding_mask = torch.zeros(2, 4, dtype=torch.bool) + + out, _ = layer( + query=unbatched_q_k_v[0], + key=unbatched_q_k_v[1], + value=unbatched_q_k_v[2], + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedMultiheadAttention( + embed_dim=6, + num_heads=3, + alpha=2, + num_estimators=2, + kdim=2, + vdim=4, + batch_first=True, + ) + out, _ = layer( + query=extended_batched_q_k_v[0], + key=extended_batched_q_k_v[1], + value=extended_batched_q_k_v[2], + ) + assert out.shape == torch.Size([2, 3, 12]) + class TestPackedTransformerEncoderLayer: """Testing the PackedTransformerEncoderLayer class.""" + def test_one_estimator(self, unbatched_qkv: torch.Tensor, batched_qkv: torch.Tensor): + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + norm_first=True, + first=True, + ) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + batch_first=True, + last=True, + activation=torch.nn.GELU(), + ) + out = layer( + src=batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators(self, unbatched_qkv: torch.Tensor, extended_batched_qkv: torch.Tensor): + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=1, + num_estimators=2, + activation=torch.nn.ELU(), + ) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_qkv = repeat(unbatched_qkv, "l h -> l b h", b=2) + out = layer( + src=unbatched_qkv, + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerEncoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=2, + num_estimators=2, + batch_first=True, + ) + out = layer( + src=extended_batched_qkv, + ) + assert out.shape == torch.Size([2, 3, 12]) + class TestPackedTransformerDecoderLayer: """Testing the PackedTransformerDecoderLayer class.""" + + def test_one_estimator( + self, + unbatched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + batched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + norm_first=True, + first=True, + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_tgt_memory = tuple( + repeat(seq, "l h -> l b h", b=2) for seq in unbatched_tgt_memory + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=2, + alpha=1, + num_estimators=1, + batch_first=True, + last=True, + activation=torch.nn.GELU(), + bias=False, + ) + out = layer( + tgt=batched_tgt_memory[0], + memory=batched_tgt_memory[1], + ) + assert out.shape == torch.Size([2, 3, 6]) + + def test_two_estimators( + self, + unbatched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + extended_batched_tgt_memory: tuple[torch.Tensor, torch.Tensor], + ): + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=1, + num_estimators=2, + activation=torch.nn.ELU(), + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 6]) + + unbatched_tgt_memory = tuple( + repeat(seq, "l h -> l b h", b=2) for seq in unbatched_tgt_memory + ) + out = layer( + tgt=unbatched_tgt_memory[0], + memory=unbatched_tgt_memory[1], + ) + assert out.shape == torch.Size([3, 2, 6]) + + layer = PackedTransformerDecoderLayer( + d_model=6, + dim_feedforward=12, + nhead=3, + alpha=2, + num_estimators=2, + batch_first=True, + ) + out = layer( + tgt=extended_batched_tgt_memory[0], + memory=extended_batched_tgt_memory[1], + ) + assert out.shape == torch.Size([2, 3, 12]) + + +class TestPackedFunctional: + def test_packed_in_projection_packed( + self, + batched_qkv: torch.Tensor, + ): + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=batched_qkv, + k=batched_qkv, + v=batched_qkv, + w=torch.rand((1, 18, 6)), + num_groups=1, + ) + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 3, 6]) + assert proj_v.shape == torch.Size([2, 3, 6]) + + q_kv = torch.rand((2, 3, 6)), torch.rand((2, 4, 6)) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_kv[0], + k=q_kv[1], + v=q_kv[1], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=None, + ) + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_kv[0], + k=q_kv[1], + v=q_kv[1], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=torch.rand(18), + ) + + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 4, 6]) + assert proj_v.shape == torch.Size([2, 4, 6]) + + q_k_v = torch.rand((2, 3, 6)), torch.rand((2, 4, 6)), torch.rand((2, 4, 6)) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_k_v[0], + k=q_k_v[1], + v=q_k_v[2], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=None, + ) + + proj_q, proj_k, proj_v = packed_in_projection_packed( + q=q_k_v[0], + k=q_k_v[1], + v=q_k_v[2], + w=torch.rand((1, 18, 6)), + num_groups=1, + b=torch.rand(18), + ) + + assert proj_q.shape == torch.Size([2, 3, 6]) + assert proj_k.shape == torch.Size([2, 4, 6]) + assert proj_v.shape == torch.Size([2, 4, 6]) + + def test_packed_multi_head_attention_forward_failures(self, unbatched_q_k_v: torch.Tensor): + q, k, v = unbatched_q_k_v + with pytest.raises(RuntimeError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + is_causal=True, + attn_mask=None, + ) + + with pytest.raises(RuntimeError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(2, 2), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) + + with pytest.raises(AssertionError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(1, 1, 3, 4), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) + + with pytest.raises(AssertionError): + _ = packed_multi_head_attention_forward( + query=q, + key=k, + value=v, + embed_dim_to_check=6, + num_heads=2, + num_groups=1, + in_proj_weight=None, + in_proj_bias=torch.rand(18), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=torch.rand(1, 6, 6), + out_proj_bias=None, + attn_mask=torch.rand(1, 2, 2), + use_separate_proj_weight=True, + q_proj_weight=torch.rand(1, 6, 6), + k_proj_weight=torch.rand(1, 6, 2), + v_proj_weight=torch.rand(1, 6, 4), + ) diff --git a/torch_uncertainty/layers/functional/packed.py b/torch_uncertainty/layers/functional/packed.py index 38fe0b3b..c962531e 100644 --- a/torch_uncertainty/layers/functional/packed.py +++ b/torch_uncertainty/layers/functional/packed.py @@ -397,10 +397,12 @@ def packed_multi_head_attention_forward( # noqa: D417 elif attn_mask.dim() == 3: correct_3d_size = (bsz * num_heads, tgt_len, src_len) if attn_mask.shape != correct_3d_size: + # unreachable code due to the check above (F._mha_shape_check, l.274) raise RuntimeError( f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." ) else: + # unreachable code due to the check above (F._mha_shape_check, l.274) raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") if bias_k is not None and bias_v is not None: diff --git a/torch_uncertainty/layers/packed.py b/torch_uncertainty/layers/packed.py index ae835f83..a4eacab4 100644 --- a/torch_uncertainty/layers/packed.py +++ b/torch_uncertainty/layers/packed.py @@ -590,12 +590,16 @@ def __init__( alpha: float, eps: float = 1e-5, affine: bool = True, + device=None, + dtype=None, ) -> None: super().__init__( num_groups=num_estimators, num_channels=int(embed_dim * alpha), eps=eps, affine=affine, + device=device, + dtype=dtype, ) def forward(self, inputs: Tensor) -> Tensor: @@ -865,14 +869,13 @@ def __init__( gamma: int = 1, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: str | Callable[[Tensor], Tensor] = F.relu, + activation: Callable[[Tensor], Tensor] = F.relu, layer_norm_eps: float = 1e-5, bias: bool = True, batch_first: bool = False, norm_first: bool = False, first: bool = False, last: bool = False, - use_gqa: bool = False, device=None, dtype=None, ) -> None: @@ -889,7 +892,6 @@ def __init__( dropout=dropout, batch_first=batch_first, first=first, - use_gqa=use_gqa, **factory_kwargs, ) @@ -921,23 +923,26 @@ def __init__( self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) else: self.norm1 = PackedLayerNorm( - num_estimators, - int(d_model * alpha), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, ) if not self.norm_first and last: self.norm2 = PackedLayerNorm( - num_estimators, - int(d_model * num_estimators), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, ) else: self.norm2 = PackedLayerNorm( - num_estimators, - int(d_model * alpha), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, ) @@ -1030,14 +1035,13 @@ def __init__( gamma: int = 1, dim_feedforward: int = 2048, dropout: float = 0.1, - activation: str | Callable[[Tensor], Tensor] = F.relu, + activation: Callable[[Tensor], Tensor] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, first: bool = False, last: bool = False, bias: bool = True, - use_gqa: bool = False, device=None, dtype=None, ) -> None: @@ -1054,7 +1058,6 @@ def __init__( bias=bias, batch_first=batch_first, first=first, - use_gqa=use_gqa, **factory_kwargs, ) @@ -1067,7 +1070,6 @@ def __init__( dropout=dropout, bias=bias, batch_first=batch_first, - use_gqa=use_gqa, **factory_kwargs, ) @@ -1099,30 +1101,34 @@ def __init__( self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) else: self.norm1 = PackedLayerNorm( - num_estimators, - int(d_model * alpha), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, ) self.norm2 = PackedLayerNorm( - num_estimators, - int(d_model * alpha), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, ) if not self.norm_first and last: self.norm3 = PackedLayerNorm( - num_estimators, - d_model * num_estimators, + embed_dim=d_model, + num_estimators=num_estimators, + alpha=num_estimators, eps=layer_norm_eps, **factory_kwargs, ) else: self.norm3 = PackedLayerNorm( - num_estimators, - int(d_model * alpha), + embed_dim=d_model, + num_estimators=num_estimators, + alpha=alpha, eps=layer_norm_eps, **factory_kwargs, )