-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat.py
117 lines (98 loc) · 4.22 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from http.server import HTTPServer, BaseHTTPRequestHandler
from zhipuai import ZhipuAI
import json
import os
import traceback
import asyncio
from concurrent.futures import ThreadPoolExecutor
ZHIPUAI_API_KEY = os.environ.get("ZHIPUAI_API_KEY")
if not ZHIPUAI_API_KEY:
raise ValueError("ZHIPUAI_API_KEY environment variable is not set")
client = ZhipuAI(api_key=ZHIPUAI_API_KEY)
system_prompt = '''You are a helpful assistant in the field of law. You are designed to provide advice and assistance to users on legal matters.'''
# 创建线程池
executor = ThreadPoolExecutor(max_workers=4)
class RequestHandler(BaseHTTPRequestHandler):
async def call_api(self, messages):
# 在线程池中运行API调用
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
executor,
lambda: client.chat.completions.create(
model="glm-4",
messages=messages,
stream=False
)
)
return response
def do_POST(self):
try:
# 1. 读取请求内容
content_length = int(self.headers.get('Content-Length', 0))
post_data = self.rfile.read(content_length)
print(f"Received raw data: {post_data}")
data = json.loads(post_data.decode('utf-8'))
print(f"Parsed data: {data}")
query = data.get('message')
history = data.get('history', [])
# 2. 构建消息历史
messages = [{"role": "system", "content": system_prompt}]
recent_history = history[-10:]
for msg in recent_history:
if msg.get('isComplete', True):
messages.append({
"role": "user" if msg["sender"] == "user" else "assistant",
"content": msg["text"]
})
messages.append({"role": "user", "content": query})
print(f"Final messages: {messages}")
# 3. 异步调用 API
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
response = loop.run_until_complete(self.call_api(messages))
loop.close()
answer = response.choices[0].message.content
print(f"Got response: {answer}")
# 4. 发送响应头
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
# 5. 发送响应内容
response_data = {
"response": answer
}
response_json = json.dumps(response_data)
print(f"Sending response: {response_json}")
self.wfile.write(response_json.encode('utf-8'))
print("Response sent successfully")
except Exception as e:
print(f"Error occurred: {str(e)}")
print("Traceback:")
traceback.print_exc()
try:
self.send_response(500)
self.send_header('Content-Type', 'application/json')
self.send_header('Access-Control-Allow-Origin', '*')
self.end_headers()
error_response = json.dumps({"error": str(e)})
self.wfile.write(error_response.encode('utf-8'))
except:
print("Failed to send error response")
traceback.print_exc()
def do_OPTIONS(self):
self.send_response(200)
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
self.end_headers()
def run_server(port=8233):
server_address = ('', port)
httpd = HTTPServer(server_address, RequestHandler)
print(f'Starting server on port {port}...')
httpd.serve_forever()
if __name__ == '__main__':
print("Starting server with API key:", "***" + ZHIPUAI_API_KEY[-4:])
run_server()