diff --git a/mltb2/text.py b/mltb2/text.py index d721e87..9858d6e 100644 --- a/mltb2/text.py +++ b/mltb2/text.py @@ -12,7 +12,12 @@ """ import re -from typing import Dict, Final, Pattern, Tuple +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from typing import Dict, Final, Iterable, Optional, Pattern, Set, Tuple, Union + +from scipy.spatial.distance import cityblock +from tqdm import tqdm INVISIBLE_CHARACTERS: Final[Tuple[str, ...]] = ( "\u200b", # Zero Width Space (ZWSP) https://www.compart.com/en/unicode/U+200b @@ -138,3 +143,108 @@ def clean_all_invisible_chars_and_whitespaces(text: str) -> str: text = replace_multiple_whitespaces(text) text = text.strip() return text + + +def _normalize_counter_to_defaultdict(counter: Counter, max_dimensions: int) -> defaultdict: + """Normalize a counter to to ``max_dimensions``. + + The number of dimensions is limited to ``max_dimensions`` + of the most commen characters. + The counter values are normalized by deviding them by the total count. + + Args: + counter: The counter to normalize. + max_dimensions: The maximum number of dimensions to use for the normalization. + Must be greater than 0. + Returns: + The normalized counter with a maximum of ``max_dimensions`` dimensions. + """ + total_count = sum(counter.values()) + normalized_counter = defaultdict(float) + for char, count in counter.most_common(max_dimensions): + normalized_counter[char] = count / total_count + return normalized_counter + + +@dataclass +class TextDistance: + """Calculate the distance between two texts. + + One text (or multiple texts) must first be fitted with :func:`~TextDistance.fit`. + After that the distance to other given texts can be calculated with :func:`~TextDistance.distance`. + After the distance was calculated the first time, the class can + not be fitted again. + + Args: + show_progress_bar: Show a progressbar during processing. + max_dimensions: The maximum number of dimensions to use for the distance calculation. + Must be greater than 0. + Raises: + ValueError: If ``max_dimensions`` is not greater than 0. + """ + + show_progress_bar: bool = False + max_dimensions: int = 100 + + # counter for the text we fit + _char_counter: Optional[Counter] = field(default_factory=Counter, init=False) + + # normalized counter for the text we fit - see _normalize_char_counter + _normalized_char_counts: Optional[defaultdict] = field(default=None, init=False) + + # set of all counted characters - see _normalize_char_counter + _counted_char_set: Optional[Set[str]] = field(default=None, init=False) + + def __post_init__(self) -> None: + """Do post init.""" + if not self.max_dimensions > 0: + raise ValueError("'max_dimensions' must be > 0!") + + def fit(self, text: Union[str, Iterable[str]]) -> None: + """Fit the text. + + Args: + text: The text to fit. + Raises: + ValueError: If :func:`~TextDistance.fit` is called after + :func:`~TextDistance.distance`. + """ + if self._char_counter is None: + raise ValueError("Fit mut not be called after distance calculation!") + + if isinstance(text, str): + self._char_counter.update(text) + else: + for t in tqdm(text, disable=not self.show_progress_bar): + self._char_counter.update(t) + + def _normalize_char_counter(self) -> None: + """Normalize the char counter to a defaultdict. + + This supports lazy postprocessing of the char counter. + """ + if self._char_counter is not None: + self._normalized_char_counts = _normalize_counter_to_defaultdict(self._char_counter, self.max_dimensions) + self._char_counter = None + self._counted_char_set = set(self._normalized_char_counts) + + def distance(self, text) -> float: + """Calculate the distance between the fitted text and the given text. + + This implementation uses the Manhattan distance (:func:`scipy.spatial.distance.cityblock`). + The distance is only calculated for ``max_dimensions`` most commen characters. + + Args: + text: The text to calculate the Manhattan distance to. + """ + self._normalize_char_counter() + all_vector = [] + text_vector = [] + text_count = Counter(text) + text_count_defaultdict = _normalize_counter_to_defaultdict(text_count, self.max_dimensions) + for c in self._counted_char_set.union(text_count_defaultdict): # type: ignore + all_vector.append( + self._normalized_char_counts[c] # type: ignore + ) # if c is not in defaultdict, it will return 0 + text_vector.append(text_count_defaultdict[c]) # if c is not in defaultdict, it will return 0 + return cityblock(all_vector, text_vector) diff --git a/tests/test_data.py b/tests/test_data.py index 3ea9165..dd8db57 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ # which is available at https://opensource.org/licenses/MIT import pandas as pd +import pytest from numpy.testing import assert_almost_equal from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate @@ -44,6 +45,7 @@ def test_load_colon_compare_original(): assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_prostate(): result = load_prostate() assert result is not None @@ -55,6 +57,7 @@ def test_load_prostate(): assert result[1].shape == (102, 6033) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_prostate_compare_original(): result = load_prostate() ori_result = load_prostate_data() @@ -64,6 +67,7 @@ def test_load_prostate_compare_original(): assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_leukemia_big(): result = load_leukemia_big() assert result is not None @@ -75,6 +79,7 @@ def test_load_leukemia_big(): assert result[1].shape == (72, 7128) +@pytest.mark.skip(reason="see https://github.com/telekom/mltb2/issues/118") def test_load_leukemia_big_compare_original(): result = load_leukemia_big() ori_result = load_leukemia_data() diff --git a/tests/test_text.py b/tests/test_text.py index 708e987..d3e2b49 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -2,11 +2,16 @@ # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT +from collections import Counter, defaultdict +from math import isclose + import pytest from mltb2.text import ( INVISIBLE_CHARACTERS, SPECIAL_WHITESPACES, + TextDistance, + _normalize_counter_to_defaultdict, clean_all_invisible_chars_and_whitespaces, has_invisible_characters, has_special_whitespaces, @@ -112,3 +117,72 @@ def test_clean_all_invisible_chars_and_whitespaces_empty_result(): text = " \u200b\u00ad\u2007 " result = clean_all_invisible_chars_and_whitespaces(text) assert result == "" + + +def test_text_distance_distance_same(): + text = "Hello World!" + td = TextDistance() + td.fit(text) + assert len(td._char_counter) == 9 + assert td._normalized_char_counts is None + assert td._counted_char_set is None + distance = td.distance(text) + assert td._char_counter is None # none after fit + assert td._normalized_char_counts is not None + assert td._counted_char_set is not None + + assert isclose(distance, 0.0), distance + + +def test_text_distance_orthogonal(): + text = "ab" + td = TextDistance() + td.fit(text) + distance = td.distance("xy") + assert distance > 0.0, distance + assert isclose(distance, 2.0), distance + + +def test_text_distance_extended(): + text = "aabbbb" # a:1/3, b:2/3 + td = TextDistance() + td.fit(text) + distance = td.distance("bbcccc") # b:1/3, c:2/3 + assert distance > 0.0, distance + assert isclose(distance, 1 / 3 + 1 / 3 + 2 / 3), distance + + +def test_text_distance_fit_not_allowed_after_distance(): + text = "Hello World!" + td = TextDistance() + td.fit(text) + _ = td.distance(text) + with pytest.raises(ValueError): + td.fit("Hello World") + + +def test_text_distance_max_dimensions_must_be_greater_zero(): + with pytest.raises(ValueError): + _ = TextDistance(max_dimensions=0) + + +def test_normalize_counter_to_defaultdict(): + counter = Counter("aaaabbbcc") + max_dimensions = 2 + normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions) + + assert isinstance(normalized_counter, defaultdict) + assert len(normalized_counter) == max_dimensions + assert isclose(normalized_counter["a"], 4 / 9) + assert isclose(normalized_counter["b"], 3 / 9) + assert "c" not in normalized_counter + assert len(normalized_counter) == max_dimensions + + +def test_normalize_counter_to_defaultdict_empty_counter(): + counter = Counter() + max_dimensions = 2 + normalized_counter = _normalize_counter_to_defaultdict(counter, max_dimensions) + + assert isinstance(normalized_counter, defaultdict) + assert len(normalized_counter) == 0