Skip to content

Support Reasoning Content #2158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 8, 2025
10 changes: 8 additions & 2 deletions apps/application/flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,22 @@


class Answer:
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node):
def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
reasoning_content):
self.view_type = view_type
self.content = content
self.reasoning_content = reasoning_content
self.runtime_node_id = runtime_node_id
self.chat_record_id = chat_record_id
self.child_node = child_node
self.real_node_id = real_node_id

def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
'chat_record_id': self.chat_record_id,
'child_node': self.child_node,
'reasoning_content': self.reasoning_content,
'real_node_id': self.real_node_id}


class NodeChunk:
Expand Down
7 changes: 5 additions & 2 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def handler(self, chat_id,
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = workflow.get_answer_text_list()
answer_text = '\n\n'.join(answer['content'] for answer in answer_text_list)
answer_text = '\n\n'.join(
'\n\n'.join(["\n\n".join([a.get('reasoning_content'), a.get('content')]) for a in answer]) for answer in
answer_text_list)
if workflow.chat_record is not None:
chat_record = workflow.chat_record
chat_record.answer_text = answer_text
Expand Down Expand Up @@ -158,7 +160,8 @@ def get_answer_list(self) -> List[Answer] | None:
if self.answer_text is None:
return None
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {})]
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], {},
self.runtime_node_id, self.context.get('reasoning_content'))]

def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
get_node_params=lambda node: node.properties.get('node_data')):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, AIMessage

from application.flow.common import Answer
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from setting.models import Model
from setting.models_provider import get_model_credential
from setting.models_provider.tools import get_model_instance_by_model_user_id


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
chat_model = node_variable.get('chat_model')
message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list'))
answer_tokens = chat_model.get_num_tokens(answer)
Expand All @@ -31,6 +33,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
node.context['reasoning_content'] = reasoning_content
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
node.answer_text = answer

Expand All @@ -45,10 +48,15 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
"""
response = node_variable.get('result')
answer = ''
reasoning_content = ''
for chunk in response:
answer += chunk.content
yield chunk.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
reasoning_content += reasoning_content_chunk
yield {'content': chunk.content, 'reasoning_content': reasoning_content_chunk}
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
Expand All @@ -61,7 +69,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
"""
response = node_variable.get('result')
answer = response.content
_write_context(node_variable, workflow_variable, node, workflow, answer)
reasoning_content = response.response_metadata.get('reasoning_content', '')
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)


def get_default_model_params_setting(model_id):
Expand Down Expand Up @@ -92,6 +101,7 @@ class BaseChatNode(IChatNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.context['reasoning_content'] = details.get('reasoning_content')
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
Expand Down Expand Up @@ -164,6 +174,7 @@ def get_details(self, index: int, **kwargs):
'history_message') is not None else [])],
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'reasoning_content': self.context.get('reasoning_content'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def _is_interrupt_exec(node, node_variable: Dict, workflow_variable: Dict):
return node_variable.get('is_interrupt_exec', False)


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
result = node_variable.get('result')
node.context['application_node_dict'] = node_variable.get('application_node_dict')
node.context['node_dict'] = node_variable.get('node_dict', {})
Expand All @@ -28,6 +29,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
node.context['answer_tokens'] = result.get('usage', {}).get('completion_tokens', 0)
node.context['answer'] = answer
node.context['result'] = answer
node.context['reasoning_content'] = reasoning_content
node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
Expand All @@ -44,6 +46,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
"""
response = node_variable.get('result')
answer = ''
reasoning_content = ''
usage = {}
node_child_node = {}
application_node_dict = node.context.get('application_node_dict', {})
Expand All @@ -60,9 +63,11 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
node_type = response_content.get('node_type')
real_node_id = response_content.get('real_node_id')
node_is_end = response_content.get('node_is_end', False)
_reasoning_content = response_content.get('reasoning_content', '')
if node_type == 'form-node':
is_interrupt_exec = True
answer += content
reasoning_content += _reasoning_content
node_child_node = {'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
'child_node': child_node}

Expand All @@ -75,13 +80,16 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
'chat_record_id': chat_record_id,
'child_node': child_node,
'index': len(application_node_dict),
'view_type': view_type}
'view_type': view_type,
'reasoning_content': _reasoning_content}
else:
application_node['content'] += content
application_node['reasoning_content'] += _reasoning_content

yield {'content': content,
'node_type': node_type,
'runtime_node_id': runtime_node_id, 'chat_record_id': chat_record_id,
'reasoning_content': _reasoning_content,
'child_node': child_node,
'real_node_id': real_node_id,
'node_is_end': node_is_end,
Expand All @@ -91,7 +99,7 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
node_variable['is_interrupt_exec'] = is_interrupt_exec
node_variable['child_node'] = node_child_node
node_variable['application_node_dict'] = application_node_dict
_write_context(node_variable, workflow_variable, node, workflow, answer)
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
Expand All @@ -106,7 +114,8 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
node_variable['result'] = {'usage': {'completion_tokens': response.get('completion_tokens'),
'prompt_tokens': response.get('prompt_tokens')}}
answer = response.get('content', '') or "抱歉,没有查找到相关内容,请重新描述您的问题或提供更多信息。"
_write_context(node_variable, workflow_variable, node, workflow, answer)
reasoning_content = response.get('reasoning_content', '')
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)


def reset_application_node_dict(application_node_dict, runtime_node_id, node_data):
Expand Down Expand Up @@ -139,18 +148,21 @@ def get_answer_list(self) -> List[Answer] | None:
if application_node_dict is None or len(application_node_dict) == 0:
return [
Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'],
self.context.get('child_node'))]
self.context.get('child_node'), self.runtime_node_id, '')]
else:
return [Answer(n.get('content'), n.get('view_type'), self.runtime_node_id,
self.workflow_params['chat_record_id'], {'runtime_node_id': n.get('runtime_node_id'),
'chat_record_id': n.get('chat_record_id')
, 'child_node': n.get('child_node')}) for n in
, 'child_node': n.get('child_node')}, n.get('real_node_id'), n.get('reasoning_content'))
for n in
sorted(application_node_dict.values(), key=lambda item: item.get('index'))]

def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['result'] = details.get('answer')
self.context['question'] = details.get('question')
self.context['type'] = details.get('type')
self.context['reasoning_content'] = details.get('reasoning_content')
self.answer_text = details.get('answer')

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
Expand Down Expand Up @@ -229,6 +241,7 @@ def get_details(self, index: int, **kwargs):
'run_time': self.context.get('run_time'),
'question': self.context.get('question'),
'answer': self.context.get('answer'),
'reasoning_content': self.context.get('reasoning_content'),
'type': self.node.type,
'message_tokens': self.context.get('message_tokens'),
'answer_tokens': self.context.get('answer_tokens'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def get_answer_list(self) -> List[Answer] | None:
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
value = prompt_template.format(form=form, context=context)
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None)]
return [Answer(value, self.view_type, self.runtime_node_id, self.workflow_params['chat_record_id'], None,
self.runtime_node_id, '')]

def get_details(self, index: int, **kwargs):
form_content_format = self.context.get('form_content_format')
Expand Down
27 changes: 16 additions & 11 deletions apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def hand_event_node_result(self, current_node, node_result_future):
if result is not None:
if self.is_result(current_node, current_result):
for r in result:
reasoning_content = ''
content = r
child_node = {}
node_is_end = False
Expand All @@ -479,9 +480,12 @@ def hand_event_node_result(self, current_node, node_result_future):
child_node = {'runtime_node_id': r.get('runtime_node_id'),
'chat_record_id': r.get('chat_record_id')
, 'child_node': r.get('child_node')}
real_node_id = r.get('real_node_id')
node_is_end = r.get('node_is_end')
if r.__contains__('real_node_id'):
real_node_id = r.get('real_node_id')
if r.__contains__('node_is_end'):
node_is_end = r.get('node_is_end')
view_type = r.get('view_type')
reasoning_content = r.get('reasoning_content')
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
Expand All @@ -492,7 +496,8 @@ def hand_event_node_result(self, current_node, node_result_future):
'view_type': view_type,
'child_node': child_node,
'node_is_end': node_is_end,
'real_node_id': real_node_id})
'real_node_id': real_node_id,
'reasoning_content': reasoning_content})
current_node.node_chunk.add_chunk(chunk)
chunk = (self.base_to_response
.to_stream_chunk_response(self.params['chat_id'],
Expand All @@ -504,7 +509,8 @@ def hand_event_node_result(self, current_node, node_result_future):
'node_type': current_node.type,
'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id}))
'real_node_id': real_node_id,
'reasoning_content': ''}))
current_node.node_chunk.add_chunk(chunk)
else:
list(result)
Expand Down Expand Up @@ -603,20 +609,19 @@ def get_answer_text_list(self):
if len(current_answer.content) > 0:
if up_node is None or current_answer.view_type == 'single_view' or (
current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'):
result.append(current_answer)
result.append([current_answer])
else:
if len(result) > 0:
exec_index = len(result) - 1
content = result[exec_index].content
result[exec_index].content += current_answer.content if len(
content) == 0 else ('\n\n' + current_answer.content)
if isinstance(result[exec_index], list):
result[exec_index].append(current_answer)
else:
result.insert(0, current_answer)
result.insert(0, [current_answer])
up_node = current_answer
if len(result) == 0:
# 如果没有响应 就响应一个空数据
return [Answer('', '', '', '', {}).to_dict()]
return [r.to_dict() for r in result]
return [[]]
return [[item.to_dict() for item in r] for r in result]

def get_next_node(self):
"""
Expand Down
8 changes: 7 additions & 1 deletion apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def get_dataset_setting_dict():


def get_model_setting_dict():
return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'}
return {
'prompt': Application.get_default_model_prompt(),
'no_references_prompt': '{question}',
'reasoning_content_start': '<think>',
'reasoning_content_end': '</think>',
'reasoning_content_enable': False,
}


class Application(AppModelMixin):
Expand Down
Loading