From e93df9ed22061f4bc6f6d7324a26638a1d13548d Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Tue, 28 May 2024 13:47:45 +0200 Subject: [PATCH] feat: add num_decimals config --- .../__init__.py | 10 ++++ .../annotations.py | 9 +++- .../sentiment_analyzer.py | 52 +++++++++++-------- .../tests/__snapshots__/test_annotations.ambr | 4 +- .../tests/test_annotations.py | 9 +++- 5 files changed, 57 insertions(+), 27 deletions(-) diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py index 3567771..928c8b2 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/__init__.py @@ -1,8 +1,18 @@ """Sparv plugin for annotating sentences with sentiment analysis.""" +from sparv import api as sparv_api # type: ignore [import-untyped] + from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment __all__ = ["annotate_sentence_sentiment"] __description__ = "Annotate sentence with sentiment analysis." __version__ = "0.1.0" + +__config__ = [ + sparv_api.Config( + "sbx_sentence_sentiment_kb_bert.num_decimals", + description="The number of decimals to round the score to", + default=3, + ), +] diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py index 1cbd28c..e9f09d1 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/annotations.py @@ -16,13 +16,20 @@ def annotate_sentence_sentiment( ), word: sparv_api.Annotation = sparv_api.Annotation(""), sentence: sparv_api.Annotation = sparv_api.Annotation(""), + num_decimals_str: str = sparv_api.Config("sbx_sentence_sentiment_kb_sent.num_decimals"), ) -> None: """Sentiment analysis of sentence with KBLab/robust-swedish-sentiment-multiclass.""" + try: + num_decimals = int(num_decimals_str) + except ValueError as exc: + raise sparv_api.SparvErrorMessage( + f"'sbx_word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501 + ) from exc sentences, _orphans = sentence.get_children(word) token_word = list(word.read()) out_sentence_sentiment_annotation = sentence.create_empty_attribute() - analyzer = SentimentAnalyzer.default() + analyzer = SentimentAnalyzer(num_decimals=num_decimals) logger.progress(total=len(sentences)) # type: ignore for sent_i, sent in enumerate(sentences): diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py index dbf9633..999e10f 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py @@ -27,8 +27,8 @@ class SentimentAnalyzer: def __init__( self, *, - tokenizer: PreTrainedTokenizerFast, - model: MegatronBertForSequenceClassification, + tokenizer: Optional[PreTrainedTokenizerFast] = None, + model: Optional[MegatronBertForSequenceClassification] = None, num_decimals: int = 3, ) -> None: """Create a SentimentAnalyzer using the given tokenizer and model. @@ -40,12 +40,22 @@ def __init__( model (MegatronBertForSequenceClassification): the model to use num_decimals (int): number of decimals to use (defaults to 3) """ - logger.debug("type(tokenizer)=%s", type(tokenizer)) - logger.debug("type(model)=%s", type(model)) - self.tokenizer = tokenizer - self.model = model + self.tokenizer = self._default_tokenizer() if tokenizer is None else tokenizer + self.model = self._default_model() if model is None else model self.num_decimals = num_decimals - self.classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) + self.classifier = pipeline( + "sentiment-analysis", model=self.model, tokenizer=self.tokenizer + ) + + @classmethod + def _default_tokenizer(cls) -> PreTrainedTokenizerFast: + return AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION) + + @classmethod + def _default_model(cls) -> MegatronBertForSequenceClassification: + return AutoModelForSequenceClassification.from_pretrained( + MODEL_NAME, revision=MODEL_REVISION + ) @classmethod def default(cls) -> "SentimentAnalyzer": @@ -54,10 +64,8 @@ def default(cls) -> "SentimentAnalyzer": Returns: SentimentAnalyzer: the create SentimentAnalyzer """ - tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION) - model = AutoModelForSequenceClassification.from_pretrained( - MODEL_NAME, revision=MODEL_REVISION - ) + tokenizer = cls._default_tokenizer() + model = cls._default_model() return cls(model=model, tokenizer=tokenizer) def analyze_sentence(self, text: List[str]) -> Optional[str]: @@ -70,9 +78,7 @@ def analyze_sentence(self, text: List[str]) -> Optional[str]: List[Optional[str]]: the sentence annotations. """ sentence = TOK_SEP.join(text) - classifications = self.classifier(sentence) - logger.debug("classifications=%s", classifications) collect_label_and_score = ((clss["label"], clss["score"]) for clss in classifications) score_format, score_pred = SCORE_FORMAT_AND_PREDICATE[self.num_decimals] @@ -89,14 +95,14 @@ def analyze_sentence(self, text: List[str]) -> Optional[str]: SCORE_FORMAT_AND_PREDICATE = { - 1: ("{:.1f}", lambda s: s.endswith(".0")), - 2: ("{:.2f}", lambda s: s.endswith(".00")), - 3: ("{:.3f}", lambda s: s.endswith(".000")), - 4: ("{:.4f}", lambda s: s.endswith(".0000")), - 5: ("{:.5f}", lambda s: s.endswith(".00000")), - 6: ("{:.6f}", lambda s: s.endswith(".000000")), - 7: ("{:.7f}", lambda s: s.endswith(".0000000")), - 8: ("{:.8f}", lambda s: s.endswith(".00000000")), - 9: ("{:.9f}", lambda s: s.endswith(".000000000")), - 10: ("{:.10f}", lambda s: s.endswith(".0000000000")), + 1: ("{:.1f}", lambda s: s.startswith("0") and s.endswith(".0")), + 2: ("{:.2f}", lambda s: s.startswith("0") and s.endswith(".00")), + 3: ("{:.3f}", lambda s: s.startswith("0") and s.endswith(".000")), + 4: ("{:.4f}", lambda s: s.startswith("0") and s.endswith(".0000")), + 5: ("{:.5f}", lambda s: s.startswith("0") and s.endswith(".00000")), + 6: ("{:.6f}", lambda s: s.startswith("0") and s.endswith(".000000")), + 7: ("{:.7f}", lambda s: s.startswith("0") and s.endswith(".0000000")), + 8: ("{:.8f}", lambda s: s.startswith("0") and s.endswith(".00000000")), + 9: ("{:.9f}", lambda s: s.startswith("0") and s.endswith(".000000000")), + 10: ("{:.10f}", lambda s: s.startswith("0") and s.endswith(".0000000000")), } diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr b/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr index b0e948b..3888b7f 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/__snapshots__/test_annotations.ambr @@ -1,7 +1,7 @@ # serializer version: 1 # name: test_annotate_sentence_sentiment list([ - '|POSITIVE:0.866|', - '|NEUTRAL:0.963|', + '|POSITIVE:0.9|', + '|NEUTRAL:1.0|', ]) # --- diff --git a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py index b39a102..773febb 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/tests/test_annotations.py @@ -1,4 +1,6 @@ +import pytest from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment +from sparv import api as sparv_api # type: ignore [import-untyped] from tests.testing import MemoryOutput, MockAnnotation @@ -12,6 +14,11 @@ def test_annotate_sentence_sentiment(snapshot) -> None: # noqa: ANN001 name="", children={"": [[0, 1, 2, 3], [4, 5, 6, 7]]} ) - annotate_sentence_sentiment(output, word, sentence) + annotate_sentence_sentiment(output, word, sentence, num_decimals_str="1") assert output.values == snapshot + + +def test_annotate_sentence_sentiment_raises_on_bad_config() -> None: + with pytest.raises(sparv_api.SparvErrorMessage): + annotate_sentence_sentiment(None, None, None, num_decimals_str="not an int")