Skip to content

Commit a2e24a0

Browse files
committed
fixed tests
1 parent dc6f407 commit a2e24a0

File tree

5 files changed

+138
-121
lines changed

5 files changed

+138
-121
lines changed

src/fuzzylogic/estimate.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,23 @@
1818
from itertools import permutations
1919
from random import choice, randint
2020
from statistics import median
21-
from typing import Callable
2221

2322
import numpy as np
2423

25-
from .functions import R, S, constant, gauss, rectangular, sigmoid, singleton, step, trapezoid, triangular
24+
from .classes import Array
25+
from .functions import (
26+
Membership,
27+
R,
28+
S,
29+
constant,
30+
gauss,
31+
rectangular,
32+
sigmoid,
33+
singleton,
34+
step,
35+
trapezoid,
36+
triangular,
37+
)
2638

2739
np.seterr(all="raise")
2840
functions = [step, rectangular]
@@ -33,13 +45,13 @@
3345
argument4_functions = [trapezoid]
3446

3547

36-
def normalize(target: np.ndarray, output_length: int = 16) -> np.ndarray:
48+
def normalize(target: Array, output_length: int = 16) -> Array:
3749
"""Normalize and interpolate a numpy array.
3850
3951
Return an array of output_length and normalized values.
4052
"""
41-
min_val = np.min(target)
42-
max_val = np.max(target)
53+
min_val = float(np.min(target))
54+
max_val = float(np.max(target))
4355
if min_val == max_val:
4456
return np.ones(output_length)
4557
normalized_array = (target - min_val) / (max_val - min_val)
@@ -49,13 +61,12 @@ def normalize(target: np.ndarray, output_length: int = 16) -> np.ndarray:
4961
return normalized_array
5062

5163

52-
def guess_function(target: np.ndarray) -> Callable:
64+
def guess_function(target: Array) -> Membership:
5365
normalized = normalize(target)
54-
# trivial case
5566
return constant if np.all(normalized == 1) else singleton
5667

5768

58-
def fitness(func: Callable, target: np.ndarray, certainty: int | None = None) -> float:
69+
def fitness(func: Membership, target: Array, certainty: int | None = None) -> float:
5970
"""Compute the difference between the array and the function evaluated at the parameters.
6071
6172
if the error is 0, we have a perfect match: fitness -> 1
@@ -66,7 +77,7 @@ def fitness(func: Callable, target: np.ndarray, certainty: int | None = None) ->
6677
return result if certainty is None else round(result, certainty)
6778

6879

69-
def seed_population(func: Callable, target: np.ndarray) -> dict[tuple, float]:
80+
def seed_population(func: Membership, target: Array) -> dict[tuple, float]:
7081
# create a random population of parameters
7182
params = [p for p in inspect.signature(func).parameters.values() if p.kind == p.POSITIONAL_OR_KEYWORD]
7283
seed_population = {}
@@ -106,7 +117,7 @@ def reproduce(parent1: tuple, parent2: tuple) -> tuple:
106117

107118

108119
def guess_parameters(
109-
func: Callable, target: np.ndarray, precision: int | None = None, certainty: int | None = None
120+
func: Membership, target: Array, precision: int | None = None, certainty: int | None = None
110121
) -> tuple:
111122
"""Find the best fitting parameters for a function, targetting an array.
112123
@@ -188,7 +199,7 @@ def best() -> tuple:
188199
return best()
189200

190201

191-
def shave(target: np.ndarray, components: dict[Callable, tuple]) -> np.ndarray:
202+
def shave(target: Array, components: dict[Membership, tuple]) -> Array:
192203
"""Remove the membership functions from the target array."""
193204
result = np.zeros_like(target)
194205
for func, params in components.items():

src/fuzzylogic/neural_network.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44

5+
from .classes import Array
56
from .functions import R, S, constant, gauss, rectangular, sigmoid, singleton, step, trapezoid, triangular
67

78
functions = [step, rectangular]
@@ -12,7 +13,7 @@
1213
argument4_functions = [trapezoid]
1314

1415

15-
def generate_examples() -> dict[str, list[np.ndarray]]:
16+
def generate_examples() -> dict[str, list[Array]]:
1617
examples = defaultdict(lambda: [])
1718
examples["constant"] = [np.ones(16)]
1819
for x in range(16):

tests/test_caro.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
from fuzzylogic.classes import Domain, Rule
2+
3+
from fuzzylogic.classes import Domain, Rule, rule_from_table
34
from fuzzylogic.functions import R, S, trapezoid
45

56
temp = Domain("Temperatur", -30, 100, res=0.0001) # ,res=0.1)
@@ -38,7 +39,6 @@
3839
temp.heiß gef.klein gef.groß gef.groß
3940
"""
4041

41-
from fuzzylogic.classes import rule_from_table
4242

4343
table_rules = rule_from_table(table, globals())
4444

@@ -47,4 +47,4 @@
4747
value = {temp: 20, tan: 0.55}
4848
result = rules(value)
4949
assert isinstance(result, float)
50-
assert np.isclose(result, 0.45, atol=0.0001)
50+
assert np.isclose(result, -0.5, atol=0.0001)

tests/test_functionality.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@fixture
16-
def temp():
16+
def temp() -> Domain:
1717
d = Domain("temperature", -100, 100, res=0.1) # in Celsius
1818
d.cold = S(0, 15) # sic
1919
d.hot = Set(R(10, 30)) # sic
@@ -22,24 +22,24 @@ def temp():
2222

2323

2424
@fixture
25-
def simple():
25+
def simple() -> Domain:
2626
d = Domain("simple", 0, 10)
2727
d.low = S(0, 1)
2828
d.high = R(8, 10)
2929
return d
3030

3131

32-
def test_array(simple):
32+
def test_array(simple: Domain) -> None:
3333
assert array_equal(simple.low.array(), [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
3434
assert array_equal(simple.high.array(), [0, 0, 0, 0, 0, 0, 0, 0, 0, 0.5, 1.0])
3535
assert len(simple.low.array()) == 11 # unlike arrays and lists, upper boundary is INCLUDED
3636

3737

38-
def test_value(temp):
38+
def test_value(temp: Domain) -> None:
3939
assert temp(6) == {temp.cold: 0.6, temp.hot: 0, temp.warm: 0.4}
4040

4141

42-
def test_rating():
42+
def test_rating() -> None:
4343
"""Tom is surveying restaurants.
4444
He doesn't need fancy logic but rather uses a simple approach
4545
with weights.
@@ -61,7 +61,7 @@ def test_rating():
6161
weights = {"beverage": 0.3, "atmosphere": 0.2, "looks": 0.2, "taste": 0.3}
6262
w_func = weighted_sum(weights=weights, target_d=R)
6363

64-
ratings = {
64+
ratings: dict[str, float] = {
6565
"beverage": R.min(9),
6666
"atmosphere": R.min(5),
6767
"looks": R.min(4),

0 commit comments

Comments
 (0)