-
Notifications
You must be signed in to change notification settings - Fork 471
/
Copy pathseq_cls_grpc_client.py
executable file
·149 lines (128 loc) · 5.47 KB
/
seq_cls_grpc_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import time
from typing import Optional
from tritonclient import utils as client_utils
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
LOGGER = logging.getLogger("run_inference_on_triton")
class SyncGRPCTritonRunner:
DEFAULT_MAX_RESP_WAIT_S = 120
def __init__(
self,
server_url: str,
model_name: str,
model_version: str,
*,
verbose=False,
resp_wait_s: Optional[float]=None, ):
self._server_url = server_url
self._model_name = model_name
self._model_version = model_version
self._verbose = verbose
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
self._client = InferenceServerClient(
self._server_url, verbose=self._verbose)
error = self._verify_triton_state(self._client)
if error:
raise RuntimeError(
f"Could not communicate to Triton Server: {error}")
LOGGER.debug(
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
f"are up and ready!")
model_config = self._client.get_model_config(self._model_name,
self._model_version)
model_metadata = self._client.get_model_metadata(self._model_name,
self._model_version)
LOGGER.info(f"Model config {model_config}")
LOGGER.info(f"Model metadata {model_metadata}")
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
self._input_names = list(self._inputs)
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
self._output_names = list(self._outputs)
self._outputs_req = [
InferRequestedOutput(name) for name in self._outputs
]
def Run(self, inputs):
"""
Args:
inputs: list, Each value corresponds to an input name of self._input_names
Returns:
results: dict, {name : numpy.array}
"""
infer_inputs = []
for idx, data in enumerate(inputs):
data = np.array(
[[x.encode('utf-8')] for x in data], dtype=np.object_)
infer_input = InferInput(self._input_names[idx], [len(data), 1],
"BYTES")
infer_input.set_data_from_numpy(data)
infer_inputs.append(infer_input)
results = self._client.infer(
model_name=self._model_name,
model_version=self._model_version,
inputs=infer_inputs,
outputs=self._outputs_req,
client_timeout=self._response_wait_t, )
results = {name: results.as_numpy(name) for name in self._output_names}
return results
def _verify_triton_state(self, triton_client):
if not triton_client.is_server_live():
return f"Triton server {self._server_url} is not live"
elif not triton_client.is_server_ready():
return f"Triton server {self._server_url} is not ready"
elif not triton_client.is_model_ready(self._model_name,
self._model_version):
return f"Model {self._model_name}:{self._model_version} is not ready"
return None
def test_tnews_dataset(runner):
from paddlenlp.datasets import load_dataset
dev_ds = load_dataset('clue', "tnews", splits='dev')
batches = []
labels = []
idx = 0
batch_size = 32
while idx < len(dev_ds):
data = []
label = []
for i in range(batch_size):
if idx + i >= len(dev_ds):
break
data.append(dev_ds[idx + i]["sentence"])
label.append(dev_ds[idx + i]["label"])
batches.append(data)
labels.append(np.array(label))
idx += batch_size
accuracy = 0
for i, data in enumerate(batches):
ret = runner.Run([data])
# print("ret:", ret)
accuracy += np.sum(labels[i] == ret["label"])
print("acc:", 1.0 * accuracy / len(dev_ds))
if __name__ == "__main__":
from paddlenlp.datasets import load_dataset
dev_ds = load_dataset('clue', "tnews", splits='dev')
model_name = "ernie_seqcls"
model_version = "1"
url = "localhost:8001"
runner = SyncGRPCTritonRunner(url, model_name, model_version)
texts = [["你家拆迁,要钱还是要房?答案一目了然", "军嫂探亲拧包入住,部队家属临时来队房标准有了规定,全面落实!"], [
"区块链投资心得,能做到就不会亏钱",
]]
for text in texts:
# input format:[input1, input2 ... inputn], n = len(self._input_names)
result = runner.Run([text])
print(result)
test_tnews_dataset(runner)