Skip to content

Commit

Permalink
feat: add num_decimals config
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed May 28, 2024
1 parent 7ba1114 commit e93df9e
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
"""Sparv plugin for annotating sentences with sentiment analysis."""

from sparv import api as sparv_api # type: ignore [import-untyped]

from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment

__all__ = ["annotate_sentence_sentiment"]

__description__ = "Annotate sentence with sentiment analysis."
__version__ = "0.1.0"

__config__ = [
sparv_api.Config(
"sbx_sentence_sentiment_kb_bert.num_decimals",
description="The number of decimals to round the score to",
default=3,
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ def annotate_sentence_sentiment(
),
word: sparv_api.Annotation = sparv_api.Annotation("<token:word>"),
sentence: sparv_api.Annotation = sparv_api.Annotation("<sentence>"),
num_decimals_str: str = sparv_api.Config("sbx_sentence_sentiment_kb_sent.num_decimals"),
) -> None:
"""Sentiment analysis of sentence with KBLab/robust-swedish-sentiment-multiclass."""
try:
num_decimals = int(num_decimals_str)
except ValueError as exc:
raise sparv_api.SparvErrorMessage(
f"'sbx_word_prediction_kb_bert.num_decimals' must contain an 'int' got: '{num_decimals_str}'" # noqa: E501
) from exc
sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
out_sentence_sentiment_annotation = sentence.create_empty_attribute()

analyzer = SentimentAnalyzer.default()
analyzer = SentimentAnalyzer(num_decimals=num_decimals)

logger.progress(total=len(sentences)) # type: ignore
for sent_i, sent in enumerate(sentences):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class SentimentAnalyzer:
def __init__(
self,
*,
tokenizer: PreTrainedTokenizerFast,
model: MegatronBertForSequenceClassification,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
model: Optional[MegatronBertForSequenceClassification] = None,
num_decimals: int = 3,
) -> None:
"""Create a SentimentAnalyzer using the given tokenizer and model.
Expand All @@ -40,12 +40,22 @@ def __init__(
model (MegatronBertForSequenceClassification): the model to use
num_decimals (int): number of decimals to use (defaults to 3)
"""
logger.debug("type(tokenizer)=%s", type(tokenizer))
logger.debug("type(model)=%s", type(model))
self.tokenizer = tokenizer
self.model = model
self.tokenizer = self._default_tokenizer() if tokenizer is None else tokenizer
self.model = self._default_model() if model is None else model
self.num_decimals = num_decimals
self.classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
self.classifier = pipeline(
"sentiment-analysis", model=self.model, tokenizer=self.tokenizer
)

@classmethod
def _default_tokenizer(cls) -> PreTrainedTokenizerFast:
return AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION)

@classmethod
def _default_model(cls) -> MegatronBertForSequenceClassification:
return AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME, revision=MODEL_REVISION
)

@classmethod
def default(cls) -> "SentimentAnalyzer":
Expand All @@ -54,10 +64,8 @@ def default(cls) -> "SentimentAnalyzer":
Returns:
SentimentAnalyzer: the create SentimentAnalyzer
"""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME, revision=MODEL_REVISION
)
tokenizer = cls._default_tokenizer()
model = cls._default_model()
return cls(model=model, tokenizer=tokenizer)

def analyze_sentence(self, text: List[str]) -> Optional[str]:
Expand All @@ -70,9 +78,7 @@ def analyze_sentence(self, text: List[str]) -> Optional[str]:
List[Optional[str]]: the sentence annotations.
"""
sentence = TOK_SEP.join(text)

classifications = self.classifier(sentence)
logger.debug("classifications=%s", classifications)
collect_label_and_score = ((clss["label"], clss["score"]) for clss in classifications)
score_format, score_pred = SCORE_FORMAT_AND_PREDICATE[self.num_decimals]

Expand All @@ -89,14 +95,14 @@ def analyze_sentence(self, text: List[str]) -> Optional[str]:


SCORE_FORMAT_AND_PREDICATE = {
1: ("{:.1f}", lambda s: s.endswith(".0")),
2: ("{:.2f}", lambda s: s.endswith(".00")),
3: ("{:.3f}", lambda s: s.endswith(".000")),
4: ("{:.4f}", lambda s: s.endswith(".0000")),
5: ("{:.5f}", lambda s: s.endswith(".00000")),
6: ("{:.6f}", lambda s: s.endswith(".000000")),
7: ("{:.7f}", lambda s: s.endswith(".0000000")),
8: ("{:.8f}", lambda s: s.endswith(".00000000")),
9: ("{:.9f}", lambda s: s.endswith(".000000000")),
10: ("{:.10f}", lambda s: s.endswith(".0000000000")),
1: ("{:.1f}", lambda s: s.startswith("0") and s.endswith(".0")),
2: ("{:.2f}", lambda s: s.startswith("0") and s.endswith(".00")),
3: ("{:.3f}", lambda s: s.startswith("0") and s.endswith(".000")),
4: ("{:.4f}", lambda s: s.startswith("0") and s.endswith(".0000")),
5: ("{:.5f}", lambda s: s.startswith("0") and s.endswith(".00000")),
6: ("{:.6f}", lambda s: s.startswith("0") and s.endswith(".000000")),
7: ("{:.7f}", lambda s: s.startswith("0") and s.endswith(".0000000")),
8: ("{:.8f}", lambda s: s.startswith("0") and s.endswith(".00000000")),
9: ("{:.9f}", lambda s: s.startswith("0") and s.endswith(".000000000")),
10: ("{:.10f}", lambda s: s.startswith("0") and s.endswith(".0000000000")),
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# serializer version: 1
# name: test_annotate_sentence_sentiment
list([
'|POSITIVE:0.866|',
'|NEUTRAL:0.963|',
'|POSITIVE:0.9|',
'|NEUTRAL:1.0|',
])
# ---
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytest
from sbx_sentence_sentiment_kb_sent.annotations import annotate_sentence_sentiment
from sparv import api as sparv_api # type: ignore [import-untyped]

from tests.testing import MemoryOutput, MockAnnotation

Expand All @@ -12,6 +14,11 @@ def test_annotate_sentence_sentiment(snapshot) -> None: # noqa: ANN001
name="<sentence>", children={"<token>": [[0, 1, 2, 3], [4, 5, 6, 7]]}
)

annotate_sentence_sentiment(output, word, sentence)
annotate_sentence_sentiment(output, word, sentence, num_decimals_str="1")

assert output.values == snapshot


def test_annotate_sentence_sentiment_raises_on_bad_config() -> None:
with pytest.raises(sparv_api.SparvErrorMessage):
annotate_sentence_sentiment(None, None, None, num_decimals_str="not an int")

0 comments on commit e93df9e

Please sign in to comment.