Skip to content

Commit 061a41c

Browse files
authored
feat: Support reasoning content #2135 (#2158)
1 parent 0cc1d00 commit 061a41c

File tree

32 files changed

+816
-166
lines changed

32 files changed

+816
-166
lines changed

Diff for: apps/application/chat_pipeline/step/chat_step/i_chat_step.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,21 @@ class InstanceSerializer(serializers.Serializer):
6565
post_response_handler = InstanceField(model_type=PostResponseHandler,
6666
error_messages=ErrMessage.base(_("Post-processor")))
6767
# 补全问题
68-
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base(_("Completion Question")))
68+
padding_problem_text = serializers.CharField(required=False,
69+
error_messages=ErrMessage.base(_("Completion Question")))
6970
# 是否使用流的形式输出
7071
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base(_("Streaming Output")))
7172
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client id")))
7273
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("Client Type")))
7374
# 未查询到引用分段
74-
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base(_("No reference segment settings")))
75+
no_references_setting = NoReferencesSetting(required=True,
76+
error_messages=ErrMessage.base(_("No reference segment settings")))
7577

7678
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid(_("User ID")))
7779

80+
model_setting = serializers.DictField(required=True, allow_null=True,
81+
error_messages=ErrMessage.dict(_("Model settings")))
82+
7883
model_params_setting = serializers.DictField(required=False, allow_null=True,
7984
error_messages=ErrMessage.dict(_("Model parameter settings")))
8085

@@ -101,5 +106,5 @@ def execute(self, message_list: List[BaseMessage],
101106
paragraph_list=None,
102107
manage: PipelineManage = None,
103108
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
104-
no_references_setting=None, model_params_setting=None, **kwargs):
109+
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
105110
pass

Diff for: apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

+55-14
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
2525
from application.chat_pipeline.pipeline_manage import PipelineManage
2626
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
27+
from application.flow.tools import Reasoning
2728
from application.models.api_key_model import ApplicationPublicAccessClient
2829
from common.constants.authentication_type import AuthenticationType
2930
from setting.models_provider.tools import get_model_instance_by_model_user_id
@@ -63,17 +64,37 @@ def event_content(response,
6364
problem_text: str,
6465
padding_problem_text: str = None,
6566
client_id=None, client_type=None,
66-
is_ai_chat: bool = None):
67+
is_ai_chat: bool = None,
68+
model_setting=None):
69+
if model_setting is None:
70+
model_setting = {}
71+
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
72+
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
73+
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
74+
reasoning = Reasoning(reasoning_content_start,
75+
reasoning_content_end)
6776
all_text = ''
77+
reasoning_content = ''
6878
try:
6979
for chunk in response:
70-
all_text += chunk.content
80+
reasoning_chunk = reasoning.get_reasoning_content(chunk)
81+
content_chunk = reasoning_chunk.get('content')
82+
if 'reasoning_content' in chunk.additional_kwargs:
83+
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
84+
else:
85+
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
86+
all_text += content_chunk
87+
if reasoning_content_chunk is None:
88+
reasoning_content_chunk = ''
89+
reasoning_content += reasoning_content_chunk
7190
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
72-
[], chunk.content,
91+
[], content_chunk,
7392
False,
7493
0, 0, {'node_is_end': False,
7594
'view_type': 'many_view',
76-
'node_type': 'ai-chat-node'})
95+
'node_type': 'ai-chat-node',
96+
'real_node_id': 'ai-chat-node',
97+
'reasoning_content': reasoning_content_chunk if reasoning_content_enable else ''})
7798
# 获取token
7899
if is_ai_chat:
79100
try:
@@ -87,7 +108,8 @@ def event_content(response,
87108
response_token = 0
88109
write_context(step, manage, request_token, response_token, all_text)
89110
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
90-
all_text, manage, step, padding_problem_text, client_id)
111+
all_text, manage, step, padding_problem_text, client_id,
112+
reasoning_content=reasoning_content if reasoning_content_enable else '')
91113
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
92114
[], '', True,
93115
request_token, response_token,
@@ -122,17 +144,20 @@ def execute(self, message_list: List[BaseMessage],
122144
client_id=None, client_type=None,
123145
no_references_setting=None,
124146
model_params_setting=None,
147+
model_setting=None,
125148
**kwargs):
126149
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
127150
**model_params_setting) if model_id is not None else None
128151
if stream:
129152
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
130153
paragraph_list,
131-
manage, padding_problem_text, client_id, client_type, no_references_setting)
154+
manage, padding_problem_text, client_id, client_type, no_references_setting,
155+
model_setting)
132156
else:
133157
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
134158
paragraph_list,
135-
manage, padding_problem_text, client_id, client_type, no_references_setting)
159+
manage, padding_problem_text, client_id, client_type, no_references_setting,
160+
model_setting)
136161

137162
def get_details(self, manage, **kwargs):
138163
return {
@@ -187,14 +212,15 @@ def execute_stream(self, message_list: List[BaseMessage],
187212
manage: PipelineManage = None,
188213
padding_problem_text: str = None,
189214
client_id=None, client_type=None,
190-
no_references_setting=None):
215+
no_references_setting=None,
216+
model_setting=None):
191217
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
192218
no_references_setting, problem_text)
193219
chat_record_id = uuid.uuid1()
194220
r = StreamingHttpResponse(
195221
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
196222
post_response_handler, manage, self, chat_model, message_list, problem_text,
197-
padding_problem_text, client_id, client_type, is_ai_chat),
223+
padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
198224
content_type='text/event-stream;charset=utf-8')
199225

200226
r['Cache-Control'] = 'no-cache'
@@ -230,7 +256,13 @@ def execute_block(self, message_list: List[BaseMessage],
230256
paragraph_list=None,
231257
manage: PipelineManage = None,
232258
padding_problem_text: str = None,
233-
client_id=None, client_type=None, no_references_setting=None):
259+
client_id=None, client_type=None, no_references_setting=None,
260+
model_setting=None):
261+
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
262+
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
263+
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
264+
reasoning = Reasoning(reasoning_content_start,
265+
reasoning_content_end)
234266
chat_record_id = uuid.uuid1()
235267
# 调用模型
236268
try:
@@ -243,14 +275,23 @@ def execute_block(self, message_list: List[BaseMessage],
243275
request_token = 0
244276
response_token = 0
245277
write_context(self, manage, request_token, response_token, chat_result.content)
278+
reasoning.get_reasoning_content(chat_result)
279+
reasoning_result = reasoning.get_reasoning_content(chat_result)
280+
content = reasoning_result.get('content')
281+
if 'reasoning_content' in chat_result.response_metadata:
282+
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
283+
else:
284+
reasoning_content = reasoning_result.get('reasoning_content')
246285
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
247-
chat_result.content, manage, self, padding_problem_text, client_id)
286+
chat_result.content, manage, self, padding_problem_text, client_id,
287+
reasoning_content=reasoning_content if reasoning_content_enable else '')
248288
add_access_num(client_id, client_type, manage.context.get('application_id'))
249289
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
250-
chat_result.content, True,
251-
request_token, response_token)
290+
content, True,
291+
request_token, response_token,
292+
{'reasoning_content': reasoning_content})
252293
except Exception as e:
253-
all_text = '异常' + str(e)
294+
all_text = 'Exception:' + str(e)
254295
write_context(self, manage, 0, 0, all_text)
255296
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
256297
all_text, manage, self, padding_problem_text, client_id)

Diff for: apps/application/flow/common.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,22 @@
99

1010

1111
class Answer:
12-
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
12+
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
13+
reasoning_content):
1314
self.view_type = view_type
1415
self.content = content
16+
self.reasoning_content = reasoning_content
1517
self.runtime_node_id = runtime_node_id
1618
self.chat_record_id = chat_record_id
1719
self.child_node = child_node
20+
self.real_node_id = real_node_id
1821

1922
def to_dict(self):
2023
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
21-
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
24+
'chat_record_id': self.chat_record_id,
25+
'child_node': self.child_node,
26+
'reasoning_content': self.reasoning_content,
27+
'real_node_id': self.real_node_id}
2228

2329

2430
class NodeChunk:

Diff for: apps/application/flow/i_step_node.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def handler(self, chat_id,
6262
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
6363
'answer_tokens' in row and row.get('answer_tokens') is not None])
6464
answer_text_list = workflow.get_answer_text_list()
65-
answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
65+
answer_text = '\n\n'.join(
66+
'\n\n'.join([a.get('content') for a in answer]) for answer in
67+
answer_text_list)
6668
if workflow.chat_record is not None:
6769
chat_record = workflow.chat_record
6870
chat_record.answer_text = answer_text
@@ -157,8 +159,10 @@ def save_context(self, details, workflow_manage):
157159
def get_answer_list(self) -> List[Answer] | None:
158160
if self.answer_text is None:
159161
return None
162+
reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
160163
return [
161-
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
164+
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
165+
self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
162166

163167
def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
164168
get_node_params=lambda node: node.properties.get('node_data')):

Diff for: apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ class ChatNodeSerializer(serializers.Serializer):
2828
error_messages=ErrMessage.boolean(_('Whether to return content')))
2929

3030
model_params_setting = serializers.DictField(required=False,
31-
error_messages=ErrMessage.integer(_("Model parameter settings")))
32-
31+
error_messages=ErrMessage.dict(_("Model parameter settings")))
32+
model_setting = serializers.DictField(required=False,
33+
error_messages=ErrMessage.dict('Model settings'))
3334
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
3435
error_messages=ErrMessage.char(_("Context Type")))
3536

@@ -47,5 +48,6 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
4748
chat_record_id,
4849
model_params_setting=None,
4950
dialogue_type=None,
51+
model_setting=None,
5052
**kwargs) -> NodeResult:
5153
pass

0 commit comments

Comments
 (0)