Skip to content

Commit

Permalink
FEA 添加中文ReadMe
Browse files Browse the repository at this point in the history
  • Loading branch information
pku-wuwei committed Aug 3, 2019
1 parent b7df77c commit 00552d4
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 83 deletions.
97 changes: 67 additions & 30 deletions example/bert_model.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,58 @@
# coding=utf-8
# Created by Meteorix at 2019/7/30
import logging
import multiprocessing
import time
from typing import List

import torch
from pytorch_transformers import *
from service_streamer import ManagedModel

from service_streamer import ManagedModel, Streamer, ThreadedStreamer

logging.basicConfig(level=logging.ERROR)

multiprocessing.set_start_method("spawn", force=True)

SEED = 0
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

class Model(object):
def __init__(self):
self.model_path = "bert-base-uncased"

class TextInfillingModel(object):
def __init__(self, max_sent_len=64):
# self.model_path = "bert-base-uncased"
self.model_path = "/data/nfsdata/nlp/BERT_BASE_DIR/uncased_L-24_H-1024_A-16"
self.tokenizer = BertTokenizer.from_pretrained(self.model_path)
self.bert = BertForMaskedLM.from_pretrained(self.model_path)
self.bert.eval()
self.bert.to("cuda")
self.max_sent_len = max_sent_len

def predict(self, batch):
"""predict next word"""
def predict(self, batch: List[str]) -> List[str]:
"""predict masked word"""
batch_inputs = []
masked_indexes = []

# add token cls & mask
for text in batch:
tokenized_text = self.tokenizer.tokenize(text)
tokenized_text.insert(0, "[CLS]")
tokenized_text.append("[MASK]")
length = len(tokenized_text)
masked_indexes.append(length-1)
if len(tokenized_text) > self.max_sent_len - 2:
tokenized_text = tokenized_text[: self.max_sent_len - 2]
tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]']
tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text))
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
# print(tokenized_text, indexed_tokens)
batch_inputs.append(indexed_tokens)

# padding to same length
max_len = max([len(tmp) for tmp in batch_inputs])
pad_inputs = []
for tmp_sent in batch_inputs:
tmp_sent.extend([0] * (max_len - len(tmp_sent)))
pad_inputs.append(tmp_sent)

tokens_tensor = torch.tensor(pad_inputs).to("cuda")
masked_indexes.append(tokenized_text.index('[MASK]'))
tokens_tensor = torch.tensor(batch_inputs).to("cuda")

with torch.no_grad():
predictions = self.bert(tokens_tensor)[0]
# prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
prediction_scores = self.bert(tokens_tensor)[0]

batch_outputs = []
for i in range(len(batch_inputs)):
predicted_index = torch.argmax(predictions[i, masked_indexes[i]]).item()
predicted_token = self.tokenizer.convert_ids_to_tokens([predicted_index])[0]
predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item()
predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index)
batch_outputs.append(predicted_token)

return batch_outputs
Expand All @@ -54,15 +61,45 @@ def predict(self, batch):
class ManagedBertModel(ManagedModel):

def init_model(self):
self.model = Model()
self.model = TextInfillingModel(max_sent_len=64)

def predict(self, batch):
return self.model.predict(batch)


def main():
batch = ["twinkle twinkle [MASK] star",
"Happy birthday to [MASK]",
'the answer to life, the [MASK], and everything']
m = TextInfillingModel()
start_time = time.time()
outputs = m.predict(batch)
print('original model', time.time() - start_time, outputs)

threaded_streamer = ThreadedStreamer(m.predict, 64, 0.1)
start_time = time.time()
outputs = threaded_streamer.predict(batch)
print('threaded model', time.time() - start_time, outputs)

streamer = Streamer(m.predict, 64, 0.1, worker_num=4)
start_time = time.time()
outputs = streamer.predict(batch)
print('single-gpu multiprocessing', time.time() - start_time, outputs)

managed_streamer = Streamer(ManagedBertModel, 64, 0.1, worker_num=4, cuda_devices=[0, 3])
start_time = time.time()
outputs = managed_streamer.predict(batch)
print('multi-gpu multiprocessing', time.time() - start_time, outputs)

start_time = time.time()
xs = []
for i in range(1):
future = threaded_streamer.submit(batch)
xs.append(future)
for future in xs:
outputs = future.result()
print('Future API', time.time() - start_time, outputs)


if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.INFO)
m = Model()
outputs = m.predict(["Today is your lucky", "Happy birthday to"])
print(outputs)
main()
5 changes: 3 additions & 2 deletions example/flask_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Created by Meteorix at 2019/7/30

import multiprocessing as mp
from flask import Flask, request, jsonify
from service_streamer import ThreadedStreamer, Streamer, RedisStreamer

from bert_model import Model
from flask import Flask, request, jsonify

from service_streamer import ThreadedStreamer

app = Flask(__name__)
model = None
Expand Down
10 changes: 6 additions & 4 deletions example/flask_multigpu_example.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# coding=utf-8
# Created by Meteorix at 2019/7/30
from gevent import monkey; monkey.patch_all()
from flask import Flask, request, jsonify
from service_streamer import Streamer, ThreadedStreamer
from bert_model import ManagedBertModel
from gevent import monkey

from service_streamer import Streamer
from .bert_model import ManagedBertModel

monkey.patch_all()
app = Flask(__name__)
model = None
streamer = None



@app.route("/naive", methods=["POST"])
def naive_predict():
inputs = request.form.getlist("s")
Expand All @@ -28,6 +28,7 @@ def stream_predict():

if __name__ == "__main__":
from multiprocessing import freeze_support

freeze_support()
streamer = Streamer(ManagedBertModel, batch_size=64, max_latency=0.1, worker_num=4, cuda_devices=(0, 1, 2, 3))

Expand All @@ -37,4 +38,5 @@ def stream_predict():
# streamer = ThreadedStreamer(model.predict, batch_size=64, max_latency=0.1)

from gevent.pywsgi import WSGIServer

WSGIServer(("0.0.0.0", 5005), app).serve_forever()
28 changes: 13 additions & 15 deletions example/future_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import multiprocessing as mp
from tqdm import tqdm
from service_streamer import ThreadedStreamer, Streamer, RedisStreamer
from bert_model import Model, ManagedBertModel
from example.bert_model import TextInfillingModel, ManagedBertModel


def main():
max_batch = 64
model = Model()
streamer = ThreadedStreamer(model.predict, batch_size=max_batch, max_latency=0.1)
# streamer = Streamer(ManagedBertModel, batch_size=max_batch, max_latency=0.1, worker_num=4, cuda_devices=(0, 1, 2, 3))
batch_size = 64
model = TextInfillingModel()
# streamer = ThreadedStreamer(model.predict, batch_size=max_batch, max_latency=0.1)
streamer = Streamer(ManagedBertModel, batch_size=batch_size, max_latency=0.1, worker_num=4, cuda_devices=(0, 1, 2, 3))
# streamer = RedisStreamer()

text = "Happy birthday to"
num_times = 8000

text = "Happy birthday to [MASK]"
num_epochs = 100
total_steps = batch_size * num_epochs

"""
t_start = time.time()
Expand All @@ -28,23 +28,21 @@ def main():
"""

t_start = time.time()
inputs = [text] * num_times
for i in tqdm(range(num_times // max_batch + 1)):
output = model.predict(inputs[i*max_batch:(i+1)*max_batch])
print(len(output))
for i in tqdm(range(num_epochs)):
output = model.predict([text] * batch_size)
t_end = time.time()
print('[batched]sentences per second', num_times / (t_end - t_start))
print('[batched]sentences per second', total_steps / (t_end - t_start))

t_start = time.time()
xs = []
for i in range(num_times):
for i in range(total_steps):
future = streamer.submit([text])
xs.append(future)

for future in tqdm(xs): # 先拿到所有future对象,再等待异步返回
output = future.result(timeout=20)
t_end = time.time()
print('[streamed]sentences per second', num_times / (t_end - t_start))
print('[streamed]sentences per second', total_steps / (t_end - t_start))


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 00552d4

Please sign in to comment.