Skip to content

Commit

Permalink
use bert to score tweets in elasticsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
shwehtom89 committed Dec 7, 2020
1 parent 7a6f038 commit 546b32b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ twitter_monitor/tmlog.txt
embedder/embedderlog.txt
twitter_monitor/jdllog.txt
sentiment/sentimentlog.txt
sentiment/semeval*
sentiment/sentiment.pt
analysis/snapshots/*.Rdata
60 changes: 60 additions & 0 deletions sentiment/bert_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import requests
from transformers import AutoTokenizer
from torch.nn.functional import softmax
from typing import List

# global constants
MODEL_NAME = 'digitalepidemiologylab/covid-twitter-bert-v2'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
mapping = { 0: 'negative', 1: 'neutral', 2: 'positive' }

class BertSentiment():
"""
Initializes a bert model used for evaluation
@param path: local relative path of the bert model
@param remote: defaults to empty, if specified will download model from url
"""
def __init__(self, path: str, remote: str=""):
if len(remote) != 0:
self.download(remote)
self.tokenizer = tokenizer
self.load(path)

"""
Downloads bert model from remote
@param remote: url location of bert model
@param dest: destination path where model will be downloaded to
"""
def download(self, remote: str, dest: str) -> str:
try:
res = requests.get(remote, allow_redirects=True)
with open(dest, "wb") as f:
f.write(res.content)
return dest
except:
print("Could not download model")
return None

"""
Loads pytorch model in for inference
@patam path: local path to the bert model
"""
def load(self, path:str):
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = torch.load(path)
self.model.to(self.device)
self.model.eval()

"""
Takes in a tweet and calculates a sentiment prediction confidences
"""
def score(self, text):
encoding = self.tokenizer(text, return_tensors="pt", padding=True)
inputs = encoding["input_ids"].to(self.device)
logits = self.model(inputs, labels=None)[0]
temp = torch.flatten(logits.cpu())
preds = softmax(temp, dim=0)
sentiment = mapping[torch.argmax(preds).item()]
return preds.tolist(), sentiment

21 changes: 18 additions & 3 deletions sentiment/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sentiment_helpers
import time
import logging
from bert_eval import BertSentiment
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
Expand Down Expand Up @@ -31,7 +32,10 @@
#Load vader sentiment intensity analyzer
vader = SentimentIntensityAnalyzer()

bert = BertSentiment(config.model_path)

#Initialize elasticsearch settings
print(config.elasticsearch_verify_certs)
es = Elasticsearch(hosts=[config.elasticsearch_host],
verify_certs=config.elasticsearch_verify_certs,
timeout=config.elasticsearch_timeout_secs)
Expand All @@ -56,27 +60,38 @@
continue

#Run sentiment analysis on the batch
logging.info("Found {0} unscored docs. Calculating sentiment scores with Vader...".format(len(hits)))
logging.info("Found {0} unscored docs. Calculating sentiment scores with Vader and Bert...".format(len(hits)))
updates = []
for hit in hits:
text, quoted_text = sentiment_helpers.get_tweet_text(hit)
text = sentiment_helpers.clean_text_for_vader(text)
scores, result = bert.score(text)
action = {
"_op_type": "update",
"_id": hit.meta["id"],
"doc": {
"sentiment": {
"vader": {
"primary": vader.polarity_scores(text)["compound"]
}
},
"bert" : {
"scores": scores,
"class": result
}
}
}
}
if quoted_text is not None:
quoted_text = sentiment_helpers.clean_text_for_vader(quoted_text)
quoted_concat_text = "{0} {1}".format(quoted_text, text)
quoted_scores, quoted_class = bert.score(quoted_text)
quoted_concat_scores, quoted_concat_class = bert.score(quoted_concat_text)
action["doc"]["sentiment"]["vader"]["quoted"] = vader.polarity_scores(quoted_text)["compound"]
action["doc"]["sentiment"]["vader"]["quoted_concat"] = vader.polarity_scores(quoted_concat_text)["compound"]
action["doc"]["sentiment"]["bert"]["quoted_scores"] = quoted_scores
action["doc"]["sentiment"]["bert"]["quoted_class"] = quoted_class
action["doc"]["sentiment"]["bert"]["quoted_concat_scores"] = quoted_concat_scores
action["doc"]["sentiment"]["bert"]["quoted_concat_class"] = quoted_concat_class

updates.append(action)

Expand All @@ -89,4 +104,4 @@
time.sleep(config.sleep_not_idle_secs)

except Exception as ex:
logging.exception("Exception occurred while polling or processing a batch.")
logging.exception("Exception occurred while polling or processing a batch.")
43 changes: 42 additions & 1 deletion sentiment/sentiment_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,47 @@
import re

def get_query():
# query = {
# "_source": [
# "text",
# "full_text",
# "extended_tweet.full_text",
# "quoted_status.text",
# "quoted_status.full_text",
# "quoted_status.extended_tweet.full_text"
# ],
# "query": {
# "bool": {
# "filter": [
# {
# "bool": {
# "must_not": [
# {
# "exists": {
# "field": "sentiment.vader.primary"
# }
# },
# {
# "exists": {
# "field": "sentiment.bert.scores"
# }
# }
# ]
# }
# },
# {
# "bool": {
# "must_not": {
# "exists": {
# "field": "retweeted_status.id"
# }
# }
# }
# }
# ]
# }
# }
# }
query = {
"_source": [
"text",
Expand Down Expand Up @@ -55,4 +96,4 @@ def clean_text_for_vader(text):
text = re.sub(r"http\S+", "", text)
text = re.sub(r" +", " ", text)
text = text.strip()
return text
return text

0 comments on commit 546b32b

Please sign in to comment.