Skip to content

Commit

Permalink
add test for the test model ShrinkedResNet
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Aug 1, 2024
1 parent 7f9a3e1 commit 432a88d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
21 changes: 6 additions & 15 deletions fl_sim/utils/torch_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,25 @@ def torch_norm(
if input.layout == torch.strided and input.device.type in ("cpu", "cuda", "meta"):
if dim is not None:
if isinstance(dim, int):
_dim = [dim]
_dim = (dim,)
else:
_dim = dim
_dim = tuple(dim)
else:
_dim = None # type: ignore[assignment]

if isinstance(p, str):
if p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2):
if out is None:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
else:
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)

# Here we either call the nuclear norm, or we call matrix_norm with some arguments
# that will throw an error
if _dim is None:
_dim = list(range(input.ndim))
if out is None:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
else:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
_dim = tuple(range(input.ndim))
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
else:
# NB. p should be Union[str, number], not Optional!
_p = 2.0 if p is None else p
if out is None:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
else:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)

ndim = input.dim()

Expand Down
11 changes: 10 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
RNN_OriginalFedAvg,
RNN_Sent140,
RNN_StackOverFlow,
ShrinkedResNet,
reset_parameters,
)
from fl_sim.models.tokenizers import init_nltk, tokenize, words_from_text
Expand All @@ -38,7 +39,6 @@

@torch.no_grad()
def test_models():
""" """
model = CNNFEMnist_Tiny().eval()
inp = torch.rand(2, 1, 28, 28)
out = model(inp)
Expand Down Expand Up @@ -262,6 +262,15 @@ def test_models():
pred = model.pipeline("ew. getting ready for work")
assert isinstance(pred, int) and pred in [0, 1]

model = ShrinkedResNet(layers=[1, 1, 1], num_classes=10).eval()
inp = torch.rand(2, 3, 32, 32)
out = model(inp)
assert out.shape == (2, 10)
pred = model.predict(inp, batched=True)
assert len(pred) == 2
prob = model.predict_proba(inp, batched=True)
assert prob.shape == (2, 10)


def test_GloveEmbedding():
import tokenizers as hf_tokenizers
Expand Down
24 changes: 24 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
sys.path.append(str(Path(__file__).parents[1].resolve()))

import numpy as np
import pytest
import torch
from torch_ecg.utils import get_kwargs

Expand Down Expand Up @@ -234,11 +235,22 @@ def test_url_is_reachable():
def test_torch_norm():
a = torch.arange(9, dtype=torch.float) - 4
b = a.reshape((3, 3))
b_sp = b.clone().to_sparse()
assert torch.allclose(torch_norm(a, dtype=torch.float64), torch.tensor(7.7460, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, dtype=torch.float64), torch.tensor(7.7460, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(a, float("inf"), dtype=torch.float64), torch.tensor(4.0, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, float("inf"), dtype=torch.float64), torch.tensor(4.0, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, p="nuc", dtype=torch.float64), torch.tensor(9.7980, dtype=torch.float64), atol=1e-4)
b_sp = b.clone().to_sparse()
assert torch.allclose(torch_norm(b_sp, float("inf")), torch.tensor(4.0), atol=1e-4)
with pytest.raises(ValueError):
torch_norm(b_sp, dtype=torch.float64)
with pytest.raises(ValueError):
torch_norm(b_sp, p="nuc", dtype=torch.float64)
with pytest.raises(RuntimeError):
torch_norm(b_sp, p="xxx")
with pytest.raises(RuntimeError):
torch_norm(a, p="xxx")

c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
assert torch.allclose(
Expand All @@ -250,6 +262,14 @@ def test_torch_norm():
assert torch.allclose(
torch_norm(c, p=1, dim=1, dtype=torch.float64), torch.tensor([6.0, 6.0], dtype=torch.float64), atol=1e-4
)
c_sp = c.clone().to_sparse()
assert torch.allclose(torch_norm(c_sp, p=1), torch.tensor(12.0), atol=1e-4)
with pytest.raises(RuntimeError):
torch_norm(c_sp, p=1, dim=1)
with pytest.raises(ValueError):
torch_norm(c_sp, p="fro", dtype=torch.float64)
with pytest.raises(ValueError):
torch_norm(c_sp, p="nuc", dtype=torch.float64)

d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
assert torch.allclose(
Expand All @@ -260,3 +280,7 @@ def test_torch_norm():
assert torch.allclose(
torch_norm(d, p="nuc", dim=[1, 2], dtype=torch.float64), torch.tensor([4.2426, 11.4018], dtype=torch.float64), atol=1e-4
)
d_sp = d.clone().to_sparse()
assert torch.allclose(torch_norm(d_sp), torch.tensor(11.8322), atol=1e-4)
with pytest.raises(RuntimeError):
torch_norm(d_sp, p="xxx")

0 comments on commit 432a88d

Please sign in to comment.