Skip to content

Commit

Permalink
cleaned up test_units
Browse files Browse the repository at this point in the history
  • Loading branch information
nfearnley committed Jun 8, 2024
1 parent b48cca8 commit 79ae942
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pyright: reportUnnecessaryIsInstance=false
# ruff: noqa: UP018
from typing import Any
import pytest
from contextlib import AbstractContextManager, nullcontext as does_not_raise
Expand All @@ -15,8 +16,7 @@
def idfn(arg: Any) -> Any:
if isinstance(arg, BaseDecimal):
return repr(arg)
else:
return arg
return arg


def expect_raise(expected: Any | type[BaseException]) -> AbstractContextManager[Any]:
Expand All @@ -26,7 +26,7 @@ def expect_raise(expected: Any | type[BaseException]) -> AbstractContextManager[

# a + b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), int(6)),
(int(4), Decimal(2), Decimal(6)),
Expand Down Expand Up @@ -95,7 +95,7 @@ def expect_raise(expected: Any | type[BaseException]) -> AbstractContextManager[
],
ids=idfn
)
def test_add(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_add(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a + b
assert type(result) is type(expected)
Expand All @@ -107,7 +107,7 @@ def test_add(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):

# a - b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), int(2)),
(int(4), Decimal(2), Decimal(2)),
Expand Down Expand Up @@ -176,15 +176,15 @@ def test_add(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
],
ids=idfn
)
def test_sub(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_sub(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a - b
assert type(result) is type(expected)
assert result == expected

# a * b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a","b","expected"),
[
(int(4), int(2), int(8)),
(int(4), Decimal(2), Decimal(8)),
Expand Down Expand Up @@ -253,15 +253,15 @@ def test_sub(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
],
ids=idfn
)
def test_mul(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_mul(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a * b
assert type(result) is type(expected)
assert result == expected

# a / b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), float(2)),
(int(4), Decimal(2), Decimal(2)),
Expand Down Expand Up @@ -330,15 +330,15 @@ def test_mul(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
],
ids=idfn
)
def test_truediv(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_truediv(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a / b
assert type(result) is type(expected)
assert result == expected

# a // b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), int(2)),
(int(4), Decimal(2), Decimal(2)),
Expand Down Expand Up @@ -407,15 +407,15 @@ def test_truediv(a: Any, b: Any, expected: UnitType | type[NotImplementedError])
],
ids=idfn
)
def test_floordiv(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_floordiv(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a // b
assert type(result) is type(expected)
assert result == expected

# a % b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), int(0)),
(int(4), Decimal(2), Decimal(0)),
Expand Down Expand Up @@ -484,15 +484,15 @@ def test_floordiv(a: Any, b: Any, expected: UnitType | type[NotImplementedError]
],
ids=idfn
)
def test_mod(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_mod(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a % b
assert type(result) is type(expected)
assert result == expected

# divmod(a, b) = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), (int(2), int(0))),
(int(4), Decimal(2), (Decimal(2), Decimal(0))),
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_mod(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
],
ids=idfn
)
def test_divmod(a: Any, b: Any, expected: tuple[UnitType, UnitType] | type[NotImplementedError]):
def test_divmod(a: Any, b: Any, expected: tuple[UnitType, UnitType] | type[NotImplementedError]) -> None:
with expect_raise(expected):
result_a, result_b = divmod(a, b)
if isinstance(expected, tuple):
Expand All @@ -573,7 +573,7 @@ def test_divmod(a: Any, b: Any, expected: tuple[UnitType, UnitType] | type[NotIm

# a ** b = ?
@pytest.mark.parametrize(
"a,b,expected",
("a", "b", "expected"),
[
(int(4), int(2), int(16)),
(int(4), Decimal(2), Decimal(16)),
Expand Down Expand Up @@ -672,8 +672,8 @@ def test_divmod(a: Any, b: Any, expected: tuple[UnitType, UnitType] | type[NotIm
],
ids=idfn
)
def test_pow(a: Any, b: Any, expected: UnitType | type[NotImplementedError]):
def test_pow(a: Any, b: Any, expected: UnitType | type[NotImplementedError]) -> None:
with expect_raise(expected):
result = a ** b
assert type(result) is type(expected)
assert result == expected
assert result == expected

0 comments on commit 79ae942

Please sign in to comment.