Skip to content

Commit cbef21c

Browse files
committed
array is now precomputed when defined for caching.
1 parent 6a4ba89 commit cbef21c

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

src/fuzzylogic/classes.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __setattr__(self, name: str, value: Set | Membership) -> None:
152152
self._sets[name] = value
153153
value.domain = self
154154
value.name = name
155+
value.array() # force the array to be calculated for caching
155156

156157
def __delattr__(self, name: str) -> None:
157158
"""Delete a fuzzy set from the domain."""
@@ -223,7 +224,9 @@ def __init__(
223224
self.func: Membership = func
224225
self.domain: Domain | None = domain
225226
self.name: str | None = name
226-
self.__center_of_gravity: float | None = None
227+
228+
self._center_of_gravity: float | None = None
229+
self._cached_array: np.ndarray | None = None
227230

228231
@overload
229232
def __call__(self, x: float, /) -> float: ...
@@ -385,11 +388,7 @@ def multiplied(self, n: float) -> Set:
385388
def plot(self) -> None:
386389
"""Graph the set in the given domain."""
387390
assert self.domain is not None, NO_DOMAIN
388-
R = self.domain.range
389-
V = [self.func(x) for x in R]
390-
if plt:
391-
plt.plot(R, V) # type: ignore
392-
else:
391+
if not plt:
393392
raise ImportError(
394393
"matplotlib not available. Please re-install with 'pip install fuzzylogic[plotting]'"
395394
)
@@ -418,7 +417,9 @@ def plot(self) -> None:
418417
def array(self) -> Array:
419418
"""Return an array of all values for this set within the given domain."""
420419
assert self.domain is not None, NO_DOMAIN
421-
return np.fromiter((self.func(x) for x in self.domain.range), float)
420+
if self._cached_array is None:
421+
self._cached_array = np.fromiter((self.func(x) for x in self.domain.range), float)
422+
return self._cached_array
422423

423424
def range(self) -> Array:
424425
"""Return the range of the domain."""
@@ -427,14 +428,14 @@ def range(self) -> Array:
427428

428429
def center_of_gravity(self) -> float:
429430
"""Return the center of gravity for this distribution, within the given domain."""
430-
if self.__center_of_gravity is not None:
431-
return self.__center_of_gravity
431+
if self._center_of_gravity is not None:
432+
return self._center_of_gravity
432433
assert self.domain is not None, NO_DOMAIN
433434
weights = self.array()
434435
if sum(weights) == 0:
435436
return 0
436437
cog = float(np.average(self.domain.range, weights=weights))
437-
self.__center_of_gravity = cog
438+
self._center_of_gravity = cog
438439
return cog
439440

440441
def __repr__(self) -> str:

0 commit comments

Comments
 (0)