Skip to content

Commit

Permalink
✅ Improve coverage (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
alafage committed Dec 28, 2024
1 parent 357cd46 commit df3fb52
Showing 1 changed file with 156 additions and 1 deletion.
157 changes: 156 additions & 1 deletion tests/layers/test_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import torch
from einops import repeat

from torch_uncertainty.layers.functional.packed import (
packed_in_projection_packed,
packed_multi_head_attention_forward,
)
from torch_uncertainty.layers.packed import (
PackedConv1d,
PackedConv2d,
Expand Down Expand Up @@ -820,4 +824,155 @@ def test_two_estimators(


class TestPackedFunctional:
def test_packed_multi_head_attention_forward(self): ...
def test_packed_in_projection_packed(
self,
batched_qkv: torch.Tensor,
):
proj_q, proj_k, proj_v = packed_in_projection_packed(
q=batched_qkv,
k=batched_qkv,
v=batched_qkv,
w=torch.rand((1, 18, 6)),
num_groups=1,
)
assert proj_q.shape == torch.Size([2, 3, 6])
assert proj_k.shape == torch.Size([2, 3, 6])
assert proj_v.shape == torch.Size([2, 3, 6])

q_kv = torch.rand((2, 3, 6)), torch.rand((2, 4, 6))

proj_q, proj_k, proj_v = packed_in_projection_packed(
q=q_kv[0],
k=q_kv[1],
v=q_kv[1],
w=torch.rand((1, 18, 6)),
num_groups=1,
b=None,
)
proj_q, proj_k, proj_v = packed_in_projection_packed(
q=q_kv[0],
k=q_kv[1],
v=q_kv[1],
w=torch.rand((1, 18, 6)),
num_groups=1,
b=torch.rand(18),
)

assert proj_q.shape == torch.Size([2, 3, 6])
assert proj_k.shape == torch.Size([2, 4, 6])
assert proj_v.shape == torch.Size([2, 4, 6])

q_k_v = torch.rand((2, 3, 6)), torch.rand((2, 4, 6)), torch.rand((2, 4, 6))

proj_q, proj_k, proj_v = packed_in_projection_packed(
q=q_k_v[0],
k=q_k_v[1],
v=q_k_v[2],
w=torch.rand((1, 18, 6)),
num_groups=1,
b=None,
)

proj_q, proj_k, proj_v = packed_in_projection_packed(
q=q_k_v[0],
k=q_k_v[1],
v=q_k_v[2],
w=torch.rand((1, 18, 6)),
num_groups=1,
b=torch.rand(18),
)

assert proj_q.shape == torch.Size([2, 3, 6])
assert proj_k.shape == torch.Size([2, 4, 6])
assert proj_v.shape == torch.Size([2, 4, 6])

def test_packed_multi_head_attention_forward_failures(self, batched_q_k_v: torch.Tensor):
q, k, v = batched_q_k_v
with pytest.raises(RuntimeError):
_ = packed_multi_head_attention_forward(
query=q,
key=k,
value=v,
embed_dim_to_check=6,
num_heads=2,
num_groups=1,
in_proj_weight=None,
in_proj_bias=torch.rand(18),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=torch.rand(6, 6),
out_proj_bias=None,
is_causal=True,
attn_mask=None,
)

with pytest.raises(RuntimeError):
_ = packed_multi_head_attention_forward(
query=q,
key=k,
value=v,
embed_dim_to_check=6,
num_heads=2,
num_groups=1,
in_proj_weight=None,
in_proj_bias=torch.rand(18),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=torch.rand(6, 6),
out_proj_bias=None,
attn_mask=torch.rand(2, 2),
use_separate_proj_weight=True,
q_proj_weight=torch.rand(1, 6, 6),
k_proj_weight=torch.rand(1, 6, 2),
v_proj_weight=torch.rand(1, 6, 4),
)

with pytest.raises(AssertionError):
_ = packed_multi_head_attention_forward(
query=q,
key=k,
value=v,
embed_dim_to_check=6,
num_heads=2,
num_groups=1,
in_proj_weight=None,
in_proj_bias=torch.rand(18),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=torch.rand(6, 6),
out_proj_bias=None,
attn_mask=torch.rand(1, 1, 3, 4),
use_separate_proj_weight=True,
q_proj_weight=torch.rand(1, 6, 6),
k_proj_weight=torch.rand(1, 6, 2),
v_proj_weight=torch.rand(1, 6, 4),
)

with pytest.raises(RuntimeError):
_ = packed_multi_head_attention_forward(
query=q,
key=k,
value=v,
embed_dim_to_check=6,
num_heads=2,
num_groups=1,
in_proj_weight=None,
in_proj_bias=torch.rand(18),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.0,
out_proj_weight=torch.rand(6, 6),
out_proj_bias=None,
attn_mask=torch.rand(1, 2, 2),
use_separate_proj_weight=True,
q_proj_weight=torch.rand(1, 6, 6),
k_proj_weight=torch.rand(1, 6, 2),
v_proj_weight=torch.rand(1, 6, 4),
)

0 comments on commit df3fb52

Please sign in to comment.