Skip to content

Commit

Permalink
TestFeature
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 19, 2024
1 parent d2093da commit 1f5a82d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
8 changes: 0 additions & 8 deletions tests/beignet/features/conftest.py

This file was deleted.

44 changes: 26 additions & 18 deletions tests/beignet/features/test__feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@


class TestFeature:
def test__to_tensor(self, feature: Feature):
result = Feature._to_tensor([1, 2, 3])
def test___deepcopy__(self):
feature = Feature(torch.tensor([1, 2, 3]))

assert torch.is_tensor(result)

assert not result.requires_grad

def test_wrap_like(self):
with pytest.raises(NotImplementedError):
Feature.wrap_like(None, None)
copy.deepcopy(feature)

def test___torch_function__(self):
feature = Feature(torch.tensor([1, 2, 3]))

def test___torch_function__(self, feature: Feature):
result = feature.__torch_function__(
torch.add,
(Feature, torch.Tensor),
Expand All @@ -26,18 +23,29 @@ def test___torch_function__(self, feature: Feature):

assert not isinstance(result, Feature)

def test_device(self, feature: Feature):
def test__to_tensor(self):
feature = Feature._to_tensor([1, 2, 3])

assert torch.is_tensor(feature)

assert not feature.requires_grad

def test_device(self):
feature = Feature(torch.tensor([1, 2, 3]))

assert feature.device == feature.device

def test_ndim(self, feature: Feature):
assert feature.ndim == 1
def test_dtype(self):
assert Feature(torch.tensor([1, 2, 3])).dtype == torch.int64

def test_dtype(self, feature: Feature):
assert feature.dtype == torch.int64
def test_ndim(self):
assert Feature(torch.tensor([1, 2, 3])).ndim == 1

def test_shape(self, feature: Feature):
assert feature.shape == (3,)
def test_shape(self):
assert Feature(torch.tensor([1, 2, 3])).shape == (3,)

def test___deepcopy__(self, feature: Feature):
def test_wrap_like(self):
with pytest.raises(NotImplementedError):
copy.deepcopy(feature)
feature = Feature(torch.tensor([1, 2, 3]))

Feature.wrap_like(feature, torch.tensor([1, 2, 3]))

0 comments on commit 1f5a82d

Please sign in to comment.