Skip to content

Commit

Permalink
refactor: replace attrs with __init__ method in FiniteStateMachine
Browse files Browse the repository at this point in the history
  • Loading branch information
rwnobrega committed Dec 17, 2024
1 parent f86aafa commit fcc28fa
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions src/komm/_finite_state_machine/FiniteStateMachine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import numpy.typing as npt
from attrs import field, frozen


class MetricMemory(TypedDict):
Expand All @@ -16,7 +15,6 @@ class MetricMemory(TypedDict):
MetricFunction = Callable[[int, Z], float]


@frozen
class FiniteStateMachine:
r"""
Finite-state machine (Mealy machine). It is defined by a *set of states* $\mathcal{S}$, an *input alphabet* $\mathcal{X}$, an *output alphabet* $\mathcal{Y}$, and a *transition function* $T : \mathcal{S} \times \mathcal{X} \to \mathcal{S} \times \mathcal{Y}$. Here, for simplicity, the set of states, the input alphabet, and the output alphabet are always taken as $\mathcal{S} = \\{ 0, 1, \ldots, |\mathcal{S}| - 1 \\}$, $\mathcal{X} = \\{ 0, 1, \ldots, |\mathcal{X}| - 1 \\}$, and $\mathcal{Y} = \\{ 0, 1, \ldots, |\mathcal{Y}| - 1 \\}$, respectively.
Expand Down Expand Up @@ -49,24 +47,22 @@ class FiniteStateMachine:
>>> fsm = komm.FiniteStateMachine(next_states=[[0,1], [2,3], [0,1], [2,3]], outputs=[[0,3], [1,2], [3,0], [2,1]])
"""

next_states: npt.NDArray[np.integer] = field(
converter=np.asarray, repr=lambda x: x.tolist()
)
outputs: npt.NDArray[np.integer] = field(
converter=np.asarray, repr=lambda x: x.tolist()
)
_input_edges: npt.NDArray[np.integer] = field(init=False, repr=False)
_output_edges: npt.NDArray[np.integer] = field(init=False, repr=False)

def __attrs_post_init__(self) -> None:
input_edges = np.full((self.num_states, self.num_states), fill_value=-1)
output_edges = np.full((self.num_states, self.num_states), fill_value=-1)
def __init__(self, next_states: npt.ArrayLike, outputs: npt.ArrayLike):
self.next_states = np.asarray(next_states)
self.outputs = np.asarray(outputs)
self._input_edges = np.full((self.num_states, self.num_states), fill_value=-1)
self._output_edges = np.full((self.num_states, self.num_states), fill_value=-1)
for state_from in range(self.num_states):
for x, state_to in enumerate(self.next_states[state_from, :]):
input_edges[state_from, state_to] = x
output_edges[state_from, state_to] = self.outputs[state_from, x]
object.__setattr__(self, "_input_edges", input_edges)
object.__setattr__(self, "_output_edges", output_edges)
self._input_edges[state_from, state_to] = x
self._output_edges[state_from, state_to] = self.outputs[state_from, x]

def __repr__(self) -> str:
args = ", ".join([
f"next_states={self.next_states.tolist()}",
f"outputs={self.outputs.tolist()}",
])
return f"{self.__class__.__name__}({args})"

@cached_property
def num_states(self) -> int:
Expand Down

0 comments on commit fcc28fa

Please sign in to comment.