From 8a4f9304a6ccb912f4f852a0d56f8ec28b4c1940 Mon Sep 17 00:00:00 2001 From: Pieter Delobelle Date: Sat, 17 Sep 2022 16:18:14 +0200 Subject: [PATCH] Updated demo --- .gitignore | 1 + examples/die_vs_data_rest_api/app.py | 6 +- examples/die_vs_data_rest_api/app/__init__.py | 126 ++++++++++++------ .../app/__pycache__/__init__.cpython-37.pyc | Bin 2786 -> 3969 bytes 4 files changed, 89 insertions(+), 44 deletions(-) diff --git a/.gitignore b/.gitignore index 3d9f61c..67de45e 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ models/ src/__pycache__/ venv/ +.env/ diff --git a/examples/die_vs_data_rest_api/app.py b/examples/die_vs_data_rest_api/app.py index ec36d7a..87c1801 100644 --- a/examples/die_vs_data_rest_api/app.py +++ b/examples/die_vs_data_rest_api/app.py @@ -8,8 +8,8 @@ def create_parser(): description="Create a REST endpoint for for 'die' vs 'dat' disambiguation." ) - parser.add_argument("--model-path", help="Path to the finetuned RobBERT folder.", required=True) - + parser.add_argument("--model-path", help="Path to the finetuned RobBERT identifier.", required=False) + parser.add_argument("--fast-model-path", help="Path to the mlm RobBERT identifier.", required=False) return parser @@ -18,4 +18,4 @@ def create_parser(): args = arg_parser.parse_args() create_parser() - create_app(args.model_path).run() \ No newline at end of file + create_app(args.model_path, args.fast_model_path).run() \ No newline at end of file diff --git a/examples/die_vs_data_rest_api/app/__init__.py b/examples/die_vs_data_rest_api/app/__init__.py index acb637d..a50a498 100644 --- a/examples/die_vs_data_rest_api/app/__init__.py +++ b/examples/die_vs_data_rest_api/app/__init__.py @@ -1,6 +1,6 @@ from flask import Flask, request import os -from transformers import RobertaForSequenceClassification, RobertaTokenizer +from transformers import RobertaForSequenceClassification, RobertaForMaskedLM, RobertaTokenizer import torch import nltk from nltk.tokenize.treebank import TreebankWordDetokenizer @@ -39,7 +39,7 @@ def replace_query_token(sentence): raise ValueError("'die' or 'dat' should be surrounded by underscores.") -def create_app(model_path: str, device="cpu"): +def create_app(model_path: str, fast_model_path:str, device="cpu"): """ Create the flask app. @@ -50,53 +50,97 @@ def create_app(model_path: str, device="cpu"): app = Flask(__name__, instance_relative_config=True) print("initializing tokenizer and RobBERT.") - tokenizer = RobertaTokenizer.from_pretrained(model_path) + if model_path: + tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(model_path, use_auth_token=True) + robbert = RobertaForSequenceClassification.from_pretrained(model_path, use_auth_token=True) + robbert.eval() + print("Loaded finetuned model") - robbert = RobertaForSequenceClassification.from_pretrained(model_path) + if fast_model_path: + fast_tokenizer: RobertaTokenizer = RobertaTokenizer.from_pretrained(fast_model_path, use_auth_token=True) + fast_robbert = RobertaForMaskedLM.from_pretrained(fast_model_path, use_auth_token=True) + fast_robbert.eval() - print(robbert) + print("Loaded MLM model") + + possible_tokens = ['die', 'dat', 'Die', 'Dat'] + + ids = fast_tokenizer.convert_tokens_to_ids(possible_tokens) mask_padding_with_zero = True block_size = 512 # Disable dropout - robbert.eval() nltk.download('punkt') - - @app.route('/', methods=["POST"]) - def hello_world(): - sentence = request.form['sentence'] - query = replace_query_token(sentence) - - tokenized_text = tokenizer.encode(tokenizer.tokenize(query)[- block_size + 3: -1]) - - input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text) - - pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) - while len(tokenized_text) < block_size: - tokenized_text.append(pad_token) - input_mask.append(0 if mask_padding_with_zero else 1) - # segment_ids.append(pad_token_segment_id) - # p_mask.append(1) - - # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) - batch = tuple(torch.tensor(t).to(torch.device(device)) for t in - [tokenized_text[0: block_size - 3], input_mask[0: block_size - 3], [0], [1][0]]) - inputs = {"input_ids": batch[0].unsqueeze(0), "attention_mask": batch[1].unsqueeze(0), - "labels": batch[3].unsqueeze(0)} - with torch.no_grad(): - outputs = robbert(**inputs) - - rating = outputs[1].argmax().item() - confidence = outputs[1][0, rating].item() - - response = {"rating": rating, "interpretation": "incorrect" if rating == 1 else "correct", - "confidence": confidence, "sentence": sentence} - - # This would be a good place for logging/storing queries + results - print(response) - - return json.dumps(response) + + if fast_model_path: + @app.route('/fast', methods=["POST"]) + def fast(): + sentence = request.form['sentence'] + for i, x in enumerate(possible_tokens): + if f"_{x}_" in sentence: + masked_id = i + query = sentence.replace(f"_{x}_" , fast_tokenizer.mask_token) + + inputs = fast_tokenizer.encode_plus(query, return_tensors="pt") + + masked_position = torch.where(inputs['input_ids'] == fast_tokenizer.mask_token_id)[1] + if len(masked_position) > 1: + return "No two queries allowed in one sentence.", 400 + + # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) + with torch.no_grad(): + outputs = fast_robbert(**inputs) + + print(outputs.logits[0,masked_position,ids] ) + token = outputs.logits[0,masked_position,ids].argmax() + + confidence = float(outputs.logits[0,masked_position,ids].max()) + + response = {"rating": possible_tokens[token], "interpretation": "correct" if token == masked_id else "incorrect", + "confidence": confidence, "sentence": sentence} + + # This would be a good place for logging/storing queries + results + print(response) + + return json.dumps(response) + + + if model_path: + @app.route('/', methods=["POST"]) + def main(): + sentence = request.form['sentence'] + query = replace_query_token(sentence) + + tokenized_text = tokenizer.encode(tokenizer.tokenize(query)[- block_size + 3: -1]) + + input_mask = [1 if mask_padding_with_zero else 0] * len(tokenized_text) + + pad_token = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + while len(tokenized_text) < block_size: + tokenized_text.append(pad_token) + input_mask.append(0 if mask_padding_with_zero else 1) + # segment_ids.append(pad_token_segment_id) + # p_mask.append(1) + + # self.examples.append([tokenizer.build_inputs_with_special_tokens(tokenized_text[0 : block_size]), [0], [0]]) + batch = tuple(torch.tensor(t).to(torch.device(device)) for t in + [tokenized_text[0: block_size - 3], input_mask[0: block_size - 3], [0], [1][0]]) + inputs = {"input_ids": batch[0].unsqueeze(0), "attention_mask": batch[1].unsqueeze(0), + "labels": batch[3].unsqueeze(0)} + with torch.no_grad(): + outputs = robbert(**inputs) + + rating = outputs[1].argmax().item() + confidence = outputs[1][0, rating].item() + + response = {"rating": rating, "interpretation": "incorrect" if rating == 1 else "correct", + "confidence": confidence, "sentence": sentence} + + # This would be a good place for logging/storing queries + results + print(response) + + return json.dumps(response) return app diff --git a/examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc b/examples/die_vs_data_rest_api/app/__pycache__/__init__.cpython-37.pyc index 1e5031a1e5ed0bdb5330e3168f63103f8fb8aacd..6305fc2a565717f436529ef3df8c701cab76b296 100644 GIT binary patch literal 3969 zcmai1TaO$^74GWW^jvm!?QERgfTAWBHbateZ~z5ktSEMZK(^y}O&qldo!+UM-QMl# z9#_|R*OML;VWS-*0(mD$GZHWG1NZ?*{DerTUXhT@13w@y@SX12T^oX!QJ=1=bL!OX zd|#b@b7`qz;P?6GFZA9%V;Fy<%F)k8;SD@_3c(G|5@WyuHZTK|QQb_ez{1-~?12+F z12=F7wV<}pd4X5j)Pp+ScG4I$gXTil3R>LZ?tLRz;!V*OZSH+$@;YzacYRG&VvxQgC_ zU%9GR7Hv1PJ(0#^Av;VzbyEtl6{UN(GRZFsb)@Ps8XdnH9tY1QJo#TClrdw4F=rEY zTrv;-q|AcdW0#E&{`eDfVomJA=4{WDZ*a61)&*=)VbVLj0*wt#fR;*QE;Q$$aiCG$ zU1(|x%^EZwG{54tg~qKkg*~(8IF>y#|NWv&zMoO{J!9q+46}3=WfNtW*Jp0w6z&`w zeEXSOhN|(#J@#|Po3!EM-sT`m5`U!PBvyC*{wVFyX8J1gr5GkrPxz6ai&W8u`}?ul z@jncCEW!`{Oj0pYD3 z?9{Y(v&4@x6t1ueIXW{N_7{+O=!mvBMtDjgLM2M)DcdNA!e)AV0470%+>@@T6ZmPabPj8WUf)$?;zrK7N;^K)Co6 z-F;#uALr9;l&bT?NJazyWb@AZ*RiJrp3*!{g&L&@&(fm%Xdf=_#61D~yDF2t9lxr- zjCVgm@c5Um|JZ*q2dRh$EAf?~2$d6gJRrsp4@W~sW_VZmCQ z<|+b$5oHP3@{S05S=x`c$Inrm#ZeNE<8&J%Eojn@QhtQuE*!K+xd@|?+Nr{CeCA3P z5$^hrFLAuGXuo>pYPD$TA#!A$M3V#sZfuyqoHHDydur^Qr2y3S^>;QgQ(YSfwUhC@ z2fOk|y{P*Wc=9$VS{b_tmvv)e99&f96k!MoJ_ibB6BE>H{uX!C@2tYubLDEmCMeCU zllCdwwddx&NO?Or+9PX@BJE(WgxQkI-RWA^V+}*R~$Go;_z^*Yj zep4`GYV?i5`8Au=RI{k&2m*T@%5{Fru3v5ZNq*Iw~QL>`jW$hSM+QKR5Wq8ZutwdC-&dWcHuy75f z{CNoC_S0;I+3abvT{5|Cf$hyzwrZ|HHZ0E^uk}iP!p$-4UP!WDl;jt?H0&4b?3aV+ z{l`FrEQUgmO_5jN0T<8N)Mx@kPa%W(g}F|74|rL=&TWLHv(1!KnD^MKF(G?*YVX!y zVIRC)wNH&bL;kEVXI^2K%uK9p$~)Xa@3Q6gO=KbUL)HhVc_PTLghKH(*noU%4Z=kL5f@X0tte~qv)3#)KI1&ZiRF09V6o<==77IjXQwN0vC7@Yw7whT4ZoQ zHd3t=FjXVuPeFZ@<^Zgqyq=uDOrlG!>{y9`e1+=YBykSn2=?DXd6B*0ivX_0feIQG zOVYaJk6@sktq6Co{5JLe4vBJQ<)U6kwd3f<7T~rQ=JZ)Xd#=w=Z)8=RgA_x`i9@I| zB!mk(s2{KQ@(LOVSx@ zFAURYAi}WP21@6G<-W`YArZBNrQnjj_OvbTM2Vy?9{EEOEi(0Zwx5DOdEg+)R-)s| zr%Cx066JC&9ESWpRX0gc#C7WOog=;cE|t!c_#TPtBrcG6m&6ScCFwOzLh&1D%;^*j zuZ3jR#ul_GwKW$dB(yEFftqcTeAc(w^Uuj!ut(gl(`$uTpAM1spt^O;EDu^(aGKNg zgKva#)zx}wSE|@47k!4l8M;TO)K!ZQ#quMe3ZyEH&Xp`jt&Xw8+A-j?cLC|%C=nOw U=FC?hRf7ytkO delta 1611 zcmah}TW=dh6rP#AdVNVs(_Gr55m2ePGG+=U$Kwy6oL2w;Q`L9la@Eeo^zWyXV3Z0*?lzs+l8yE z#iD~?ef0eMuU{1q`duk!O9SCk*ir4p)$%=z7-Gb?sFiV>+Q4-`*K(+HdX}eoW-#*! zq6OwKo8_KhT4ZDsmGj3cXqlwOE#HfpsU`S>4v*w|nY{S*{a1fH#h0nZ8XSlxso?^3 zfvqW~0j0^$WFMPJNvmR-tY{Z4s(%{_MiW_{d{9>tC~F0R&g$Qtu6J&!9CUL?M@x5EMck2%~Ku;5G0bA zVMd;jE|!F8k1QOlJv7PhK~p{k5FG5m_r12t%1#~A$Y}jFQ=f{ z!OVF^4$&c=IV6YLp}uXfJS#l0ALA~GF<@3)PyRAidk0f(#{ICyeVgZqJ<@IJeYy@)Q=Hg-l5rnAG?gUMFf~!=&b9-~Ej8m({OI5wC* zg9c|ZIg^?FQtO(sFenbuj})@ewoz>UfV-9~#a8AsHbjH*p~~5o=SDVW#P+s|F|z06 zJA?DF9otNAlIK`#upH#!#JL!;{FRrG&?dTzHqbqod+0`CUBLT+j@4)!{{!n7Ct?1J zUTVlt)OV>Vc@Twy5*el@(CSWJVPP;)BAfSlx(cX+_}+hS!M)Cl^8 zbaZf|F1RPT>$TgJWk0NYeza2Ag~o?247<i8n{!_6FV($pBh1@)qa>7+kWWDRNsrj zAT=1+qOz0WNe%c8h=;Uv7MHn_?@Q{8YEngr_f&kN=1B!}W{;>9b|e%XiCF}X7jr6k zBXw$kTGNdnKQWzb6&HJNAW!rE0dg6xM=C5LCG21uewWE4DPW~d8Y}o3aEDCbX?!sM zKTUx`RQpofbpx-(T~}NNRTx0%Sn_xA61I|eOBZ@tHUp^!%@qskhDxn$h_aHi@wld> z4;8tth#HjCdHeJaK#$Z#gbT!m^mJ@$Hc%a?L#|vBOK=9Wc22Pp1<)0GDmiFp@@HxO zmg