Skip to content

Commit

Permalink
Updated demo
Browse files Browse the repository at this point in the history
  • Loading branch information
iPieter committed Sep 17, 2022
1 parent 428990f commit 8a4f930
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 44 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ models/
src/__pycache__/

venv/
.env/
6 changes: 3 additions & 3 deletions examples/die_vs_data_rest_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def create_parser():
description="Create a REST endpoint for for 'die' vs 'dat' disambiguation."
)

parser.add_argument("--model-path", help="Path to the finetuned RobBERT folder.", required=True)

parser.add_argument("--model-path", help="Path to the finetuned RobBERT identifier.", required=False)
parser.add_argument("--fast-model-path", help="Path to the mlm RobBERT identifier.", required=False)

return parser

Expand All @@ -18,4 +18,4 @@ def create_parser():
args = arg_parser.parse_args()

create_parser()
create_app(args.model_path).run()
create_app(args.model_path, args.fast_model_path).run()
126 changes: 85 additions & 41 deletions examples/die_vs_data_rest_api/app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import Flask, request
import os
from transformers import RobertaForSequenceClassification, RobertaTokenizer
from transformers import RobertaForSequenceClassification, RobertaForMaskedLM, RobertaTokenizer
import torch
import nltk
from nltk.tokenize.treebank import TreebankWordDetokenizer
Expand Down Expand Up @@ -39,7 +39,7 @@ def replace_query_token(sentence):
raise ValueError("'die' or 'dat' should be surrounded by underscores.")


def create_app(model_path: str, device="cpu"):
def create_app(model_path: str, fast_model_path:str, device="cpu"):
"""
Create the flask app.
Expand All @@ -50,53 +50,97 @@ def create_app(model_path: str, device="cpu"):
app = Flask(__name__, instance_relative_config=True)

print("initializing tokenizer and RobBERT.")
tokenizer = RobertaTokenizer.from_pretrained(model_path)
if model_path:
tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_path, use_auth_token=True)
robbert = RobertaForSequenceClassification.from_pretrained(model_path, use_auth_token=True)
robbert.eval()
print("Loaded finetuned model")

robbert = RobertaForSequenceClassification.from_pretrained(model_path)
if fast_model_path:
fast_tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(fast_model_path, use_auth_token=True)
fast_robbert = RobertaForMaskedLM.from_pretrained(fast_model_path, use_auth_token=True)
fast_robbert.eval()

print(robbert)
print("Loaded MLM model")

possible_tokens = ['die', 'dat', 'Die', 'Dat']

ids = fast_tokenizer.convert_tokens_to_ids(possible_tokens)

mask_padding_with_zero = True
block_size = 512

# Disable dropout
robbert.eval()

nltk.download('punkt')

@app.route('/', methods=["POST"])
def hello_world():
sentence = request.form['sentence']
query = replace_query_token(sentence)

tokenized_text = tokenizer.encode(tokenizer.tokenize(query)[- block_size + 3: -1])

input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text)

pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
while len(tokenized_text) < block_size:
tokenized_text.append(pad_token)
input_mask.append(0 if mask_padding_with_zero else 1)
# segment_ids.append(pad_token_segment_id)
# p_mask.append(1)

# self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
batch = tuple(torch.tensor(t).to(torch.device(device)) for t in
[tokenized_text[0: block_size - 3], input_mask[0: block_size - 3], [0], [1][0]])
inputs = {"input_ids": batch[0].unsqueeze(0), "attention_mask": batch[1].unsqueeze(0),
"labels": batch[3].unsqueeze(0)}
with torch.no_grad():
outputs = robbert(**inputs)

rating = outputs[1].argmax().item()
confidence = outputs[1][0, rating].item()

response = {"rating": rating, "interpretation": "incorrect" if rating == 1 else "correct",
"confidence": confidence, "sentence": sentence}

# This would be a good place for logging/storing queries + results
print(response)

return json.dumps(response)

if fast_model_path:
@app.route('/fast', methods=["POST"])
def fast():
sentence = request.form['sentence']
for i, x in enumerate(possible_tokens):
if f"_{x}_" in sentence:
masked_id = i
query = sentence.replace(f"_{x}_" , fast_tokenizer.mask_token)

inputs = fast_tokenizer.encode_plus(query, return_tensors="pt")

masked_position = torch.where(inputs['input_ids'] == fast_tokenizer.mask_token_id)[1]
if len(masked_position) > 1:
return "No two queries allowed in one sentence.", 400

# self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
with torch.no_grad():
outputs = fast_robbert(**inputs)

print(outputs.logits[0,masked_position,ids] )
token = outputs.logits[0,masked_position,ids].argmax()

confidence = float(outputs.logits[0,masked_position,ids].max())

response = {"rating": possible_tokens[token], "interpretation": "correct" if token == masked_id else "incorrect",
"confidence": confidence, "sentence": sentence}

# This would be a good place for logging/storing queries + results
print(response)

return json.dumps(response)


if model_path:
@app.route('/', methods=["POST"])
def main():
sentence = request.form['sentence']
query = replace_query_token(sentence)

tokenized_text = tokenizer.encode(tokenizer.tokenize(query)[- block_size + 3: -1])

input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text)

pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
while len(tokenized_text) < block_size:
tokenized_text.append(pad_token)
input_mask.append(0 if mask_padding_with_zero else 1)
# segment_ids.append(pad_token_segment_id)
# p_mask.append(1)

# self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]])
batch = tuple(torch.tensor(t).to(torch.device(device)) for t in
[tokenized_text[0: block_size - 3], input_mask[0: block_size - 3], [0], [1][0]])
inputs = {"input_ids": batch[0].unsqueeze(0), "attention_mask": batch[1].unsqueeze(0),
"labels": batch[3].unsqueeze(0)}
with torch.no_grad():
outputs = robbert(**inputs)

rating = outputs[1].argmax().item()
confidence = outputs[1][0, rating].item()

response = {"rating": rating, "interpretation": "incorrect" if rating == 1 else "correct",
"confidence": confidence, "sentence": sentence}

# This would be a good place for logging/storing queries + results
print(response)

return json.dumps(response)

return app
Binary file not shown.

0 comments on commit 8a4f930

Please sign in to comment.