Skip to content

Commit

Permalink
Add fit and distance flags to TextDistance class
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Dec 23, 2023
1 parent 3937f18 commit 4cb033a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
23 changes: 17 additions & 6 deletions mltb2/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class TextDistance:
# set of all counted characters - see _normalize_char_counter
_counted_char_set: Optional[Set[str]] = field(default=None, init=False)

# flag if fit was called
_fit_called: bool = field(default=False, init=False)

# flag if distance was called
_distance_called: bool = field(default=False, init=False)

def __post_init__(self) -> None:
"""Do post init."""
if not self.max_dimensions > 0:
Expand All @@ -209,24 +215,27 @@ def fit(self, text: Union[str, Iterable[str]]) -> None:
ValueError: If :func:`~TextDistance.fit` is called after
:func:`~TextDistance.distance`.
"""
if self._char_counter is None:
raise ValueError("Fit mut not be called after distance calculation!")
if self._distance_called:
raise ValueError("fit mut not be called after distance calculation!")

if isinstance(text, str):
self._char_counter.update(text)
self._char_counter.update(text) # type: ignore
else:
for t in tqdm(text, disable=not self.show_progress_bar):
self._char_counter.update(t)
self._char_counter.update(t) # type: ignore

self._fit_called = True

def _normalize_char_counter(self) -> None:
"""Normalize the char counter to a defaultdict.
This supports lazy postprocessing of the char counter.
"""
if self._char_counter is not None:
self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions)
if not self._distance_called:
self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions) # type: ignore
self._char_counter = None
self._counted_char_set = set(self._normalized_char_counts)
self._distance_called = True

def distance(self, text) -> float:
"""Calculate the distance between the fitted text and the given text.
Expand All @@ -237,6 +246,8 @@ def distance(self, text) -> float:
Args:
text: The text to calculate the Manhattan distance to.
"""
if not self._fit_called:
raise ValueError("fit must not be called before distance!")
self._normalize_char_counter()
all_vector = []
text_vector = []
Expand Down
7 changes: 7 additions & 0 deletions tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def test_text_distance_fit_not_allowed_after_distance():
td.fit("Hello World")


def test_text_distance_distance_not_allowed_before_fit():
text = "Hello World!"
td = TextDistance()
with pytest.raises(ValueError):
_ = td.distance(text)


def test_text_distance_max_dimensions_must_be_greater_zero():
with pytest.raises(ValueError):
_ = TextDistance(max_dimensions=0)
Expand Down

0 comments on commit 4cb033a

Please sign in to comment.