Skip to content

Commit

Permalink
[Fix] Fix BailingAPI model (#1707)
Browse files Browse the repository at this point in the history
* [fix] sequence under the multiple samples

* resolve the lint problems

* change the parameter name

* add another error code for retry

* output the log for invalid response

* format correction

* update

* update

* update

* update

* add two model python files

* update the default parameter

* use random for delay

* update the api example of bailing

* remove the unnecessary parameter
  • Loading branch information
cuauty authored Nov 26, 2024
1 parent ef695e2 commit bcb707d
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 62 deletions.
14 changes: 10 additions & 4 deletions configs/api_examples/eval_api_bailing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@

models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Lite-1116',
token='xxxxxx', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
generation_kwargs={},
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
},
),
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@

models = [
dict(
path='Bailing-Pro-0920',
path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@

models = [
dict(
path='Bailing-Pro-0920',
path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@

models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Lite-1116',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@

models = [
dict(
path='Bailing-Lite-0830',
path='Bailing-Pro-1120',
token='', # set your key here or in environment variable BAILING_API_KEY
url='https://bailingchat.alipay.com/chat/completions',
type=BailingAPI,
meta_template=api_meta_template,
query_per_second=1,
max_seq_len=4096,
max_out_len=11264,
batch_size=1,
generation_kwargs={
'temperature': 0.4,
'temperature': 0.01,
'top_p': 1.0,
'top_k': -1,
'n': 1,
'logprobs': 1,
'use_beam_search': False,
},
),
]
76 changes: 38 additions & 38 deletions opencompass/models/bailing_api_oc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import concurrent
import concurrent.futures
import os
import random
import socket
import time
import traceback
from typing import Dict, List, Optional, Union

import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from urllib3.connection import HTTPConnection

try:
Expand All @@ -21,8 +22,6 @@

PromptType = Union[PromptList, str]

BAILING_RETRY_DELAY: int = 30


class HTTPAdapterWithSocketOptions(HTTPAdapter):

Expand Down Expand Up @@ -104,7 +103,7 @@ def __init__(
def generate(
self,
inputs: Union[List[str], PromptList],
max_out_len: int = 4096,
max_out_len: int = 11264,
) -> List[str]:
"""Generate results given a list of inputs.
Expand All @@ -128,24 +127,33 @@ def generate(
): i
for i, input in enumerate(inputs)
}
results = []
results = [''] * len(inputs)
for future in concurrent.futures.as_completed(future_to_m):
m = future_to_m[future] # noqa F841
resp = future.result()
if resp and resp.status_code == 200:
try:
result = resp.json()
except Exception as e: # noqa F841
results.append('')
self.logger.error(f'Fail to inference; '
f'model_name={self.path}; '
f'error={e}, '
f'request={inputs[m]}')
else:
if (result.get('choices')
and result['choices'][0].get('message') and
result['choices'][0]['message'].get('content')
is not None):
results.append(
result['choices'][0]['message']['content'])
results[m] = \
result['choices'][0]['message']['content']
else:
self.logger.error(f'Receive invalid result. '
f'result={result}; '
f'request={inputs[m]}')
else:
results.append('')
self.logger.error(f'Receive invalid response. '
f'response={resp}; '
f'request={inputs[m]}')
self.flush()
return results

Expand Down Expand Up @@ -184,39 +192,31 @@ def _generate(
message['role'] = item['role']
messages.append(message)
request = {
'model':
self._model,
'messages':
messages,
'max_seq_len':
max(
max_out_len if max_out_len else 4096,
self.max_seq_len if self.max_seq_len else 4096,
),
'model': self._model,
'messages': messages,
'max_tokens': max_out_len,
}
request.update(self.generation_kwargs)
try:
retry_num = 0
while retry_num < self.retry:
retry_num = 0
while retry_num < self.retry:
try:
response = self._infer_result(request, sess)
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
elif response.status_code in [429, 500, 504]:
time.sleep(BAILING_RETRY_DELAY)
retry_num += 1 # retry
else:
raise ValueError(f'Status code = {response.status_code}')
except ConnectionError:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
continue
if response.status_code == 200:
break # success
elif response.status_code == 426:
retry_num += 1 # retry
elif response.status_code in [302, 429, 500, 504]:
time.sleep(random.randint(10, 30))
retry_num += 1 # retry
else:
raise ValueError(
f'Exceed the maximal retry times. Last status code '
f'= {response.status_code}')
except Exception as e:
self.logger.error(f'Fail to inference request={request}; '
f'model_name={self.path}; error={e}, '
f'stack:{traceback.format_exc()}')
raise e
raise ValueError(f'Status code = {response.status_code}')
else:
# Exceed the maximal retry times.
return ''
return response

# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
Expand Down

0 comments on commit bcb707d

Please sign in to comment.