Skip to content

Commit a2c7366

Browse files
committed
refactor and cleanup
1 parent cbef21c commit a2c7366

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/fuzzylogic/functions.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939

4040

4141
try:
42-
# from numba import njit # still not ready for prime time :(
42+
from numba import njit as njit # ready for prime time?
43+
4344
raise ImportError
4445

4546
except ImportError:
@@ -48,6 +49,8 @@ def njit(func: Membership) -> Membership:
4849
return func
4950

5051

52+
LOW_HIGH = "low must be less than high"
53+
5154
#####################
5255
# SPECIAL FUNCTIONS #
5356
#####################
@@ -111,7 +114,6 @@ def alpha(
111114

112115
floor_clip = floor if floor_clip is None else floor_clip
113116
ceiling_clip = ceiling if ceiling_clip is None else ceiling_clip
114-
# assert 0 <= floor_clip <= ceiling_clip <= 1, "%s <= %s"%(floor_clip, ceiling_clip)
115117

116118
def f(x: float) -> float:
117119
m = func(x)
@@ -254,7 +256,7 @@ def bounded_linear(
254256
>>> f(4)
255257
1.0
256258
"""
257-
assert low < high, "low must be less than high"
259+
assert low < high, LOW_HIGH
258260
assert c_m > no_m, "core_m must be greater than unsupported_m"
259261

260262
if inverse:
@@ -414,7 +416,7 @@ def sigmoid(L: float, k: float, x0: float = 0) -> Membership:
414416

415417
def f(x: float) -> float:
416418
if isnan(k * x):
417-
# e^(0*inf) = 1
419+
# e^(0*inf) == 1
418420
o = 1.0
419421
else:
420422
try:
@@ -464,7 +466,7 @@ def bounded_sigmoid(low: float, high: float, inverse: bool = False) -> Membershi
464466
>>> round(f(-100000), 2)
465467
0.0
466468
"""
467-
assert low < high, "low must be less than high"
469+
assert low < high, LOW_HIGH
468470

469471
if inverse:
470472
low, high = high, low
@@ -489,7 +491,7 @@ def f(x: float) -> float:
489491
except OverflowError:
490492
q = float("inf")
491493

492-
# e^(inf)*e^(-inf) = 1
494+
# e^(inf)*e^(-inf) == 1
493495
r = p * q
494496
if isnan(r):
495497
r = 1
@@ -560,7 +562,7 @@ def triangular_sigmoid(low: float, high: float, c: float | None = None) -> Membe
560562
>>> round(g(3), 2)
561563
0.9
562564
"""
563-
assert low < high, "low must be less than high"
565+
assert low < high, LOW_HIGH
564566
c = c if c is not None else (low + high) / 2.0
565567
assert low < c < high, "c must be inbetween"
566568

tests/test_units.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
common_settings = settings(deadline=None, suppress_health_check=cast(list[HealthCheck], list(HealthCheck)))
1919

2020

21-
class Test_Functions(TestCase):
21+
class TestFunctions(TestCase):
2222
@common_settings
2323
@given(st.floats(allow_nan=False))
2424
def test_noop(self, x: float) -> None:
@@ -250,7 +250,7 @@ def test_bounded_exponential(self, k: float, limit: float, x: float) -> None:
250250
assert 0 <= f(x) <= limit
251251

252252

253-
class Test_Hedges(TestCase):
253+
class TestHedges(TestCase):
254254
@common_settings
255255
@given(st.floats(min_value=0, max_value=1))
256256
def test_very(self, x: float) -> None:
@@ -273,7 +273,7 @@ def test_plus(self, x: float) -> None:
273273
assert 0 <= f(x) <= 1
274274

275275

276-
class Test_Combinators(TestCase):
276+
class TestCombinators(TestCase):
277277
@common_settings
278278
@given(st.floats(min_value=0, max_value=1))
279279
def test_MIN(self, x: float) -> None:
@@ -381,7 +381,7 @@ def test_simple_disjoint_sum(self, x: float) -> None:
381381
assert 0 <= f(x) <= 1
382382

383383

384-
class Test_Domain(TestCase):
384+
class TestDomain(TestCase):
385385
def test_basics(self) -> None:
386386
D = Domain("d", 0, 10)
387387
assert D._name == "d" # type: ignore
@@ -399,7 +399,7 @@ def test_basics(self) -> None:
399399
# assert d == D
400400

401401

402-
class Test_Set(TestCase):
402+
class TestSet(TestCase):
403403
@common_settings
404404
@given(
405405
st.floats(allow_nan=False, allow_infinity=False),
@@ -441,7 +441,7 @@ def test_complement(self) -> None:
441441
assert all(np.isclose(D.s1.array(), D.s2.array()))
442442

443443

444-
class Test_Rules(TestCase):
444+
class TestRules(TestCase):
445445
@common_settings
446446
@given(
447447
st.floats(min_value=0, max_value=1),
@@ -462,7 +462,7 @@ def round_partial(self, x: float, res: float) -> None:
462462
assert isclose(x, ru.round_partial(x, res))
463463

464464

465-
class Test_Truth(TestCase):
465+
class TestTruth(TestCase):
466466
@common_settings
467467
@given(st.floats(min_value=0, max_value=1))
468468
def test_true_and_false(self, m: float) -> None:

0 commit comments

Comments
 (0)