Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add v2 style rewrite #628

Merged
merged 13 commits into from
Nov 29, 2024
1 change: 1 addition & 0 deletions python/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = False,
**kwargs
):
r"""Component初始化方法.
Expand Down
2 changes: 2 additions & 0 deletions python/core/components/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

from .animal_recognize.component import AnimalRecognition
from .image_understand.component import ImageUnderstand
from .llms.style_rewrite.component import StyleRewrite

__V2_COMPONENTS__ = [
"AnimalRecognition",
"ImageUnderstand",
"StyleRewrite"
] # NOQA
15 changes: 15 additions & 0 deletions python/core/components/v2/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


2 changes: 2 additions & 0 deletions python/core/components/v2/llms/style_rewrite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""StyleRewrite"""
from .component import StyleRewrite
139 changes: 139 additions & 0 deletions python/core/components/v2/llms/style_rewrite/component.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from appbuilder.core.components.llms.base import CompletionBaseComponent
from appbuilder.core.message import Message

from typing import Optional
from appbuilder.utils.trace.tracer_wrapper import components_run_trace, components_run_stream_trace
from appbuilder.core.components.llms.style_rewrite.base import StyleRewriteArgs


class StyleRewrite(CompletionBaseComponent):
"""
文本风格转写大模型组件, 基于生成式大模型对文本的风格进行改写,支持有营销、客服、直播、激励及教学五种话术。

Examples:

.. code-block:: python

import os
import appbuilder
os.environ["APPBUILDER_TOKEN"] = '...'

style_rewrite = appbuilder.StyleRewrite(model="Qianfan-Agent-Speed-8k")
answer = style_rewrite(appbuilder.Message("文心大模型发布新版本"), style="激励话术")

"""
name = "style_rewrite"
version = "v1"
meta = StyleRewriteArgs

manifests = [
{
"name": "style_rewrite",
"description": "能够将一段文本转换成不同的风格(营销、客服、直播、激励及教学话术),同时保持原文的基本意义不变。",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "需要改写的文本。"
},
"style": {
"type": "string",
"description": "想要转换的文本风格,目前有营销、客服、直播、激励及教学五种话术可选. 默认是营销话术。",
"enum": ["营销话术", "客服话术", "直播话术", "激励话术", "教学话术"]
}
},
"required": [
"query"
]
}
}
]

def __init__(
self,
model=None,
secret_key: Optional[str] = None,
gateway: str = "",
lazy_certification: bool = False,
):
"""初始化StyleRewrite模型。

Args:
model (str|None): 模型名称,用于指定要使用的千帆模型。
secret_key (str, 可选): 用户鉴权token, 默认从环境变量中获取: os.getenv("APPBUILDER_TOKEN", "").
gateway (str, 可选): 后端网关服务地址,默认从环境变量中获取: os.getenv("GATEWAY_URL", "")
lazy_certification (bool, 可选): 延迟认证,为True时在第一次运行时认证. Defaults to False.

Returns:
None

"""
super().__init__(
StyleRewriteArgs, model=model, secret_key=secret_key, gateway=gateway,
lazy_certification=lazy_certification)

@components_run_trace
def run(self, message, style="营销话术", stream=False, temperature=1e-10, top_p=0.0, request_id=None):
"""
使用给定的输入运行模型并返回结果。

Args:
message (obj:`Message`): 输入消息,用于模型的主要输入内容。这是一个必需的参数。
style (str, optional): 想要转换的文本风格,目前有营销、客服、直播、激励及教学五种话术可选。默认为"营销话术"。
stream (bool, optional): 指定是否以流式形式返回响应。默认为 False。
temperature (float, optional): 模型配置的温度参数,用于调整模型的生成概率。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 1e-10。
top_p (float, optional): 影响输出文本的多样性,取值越大,生成文本的多样性越强。取值范围为 0.0 到 1.0,其中较低的值使生成更确定性,较高的值使生成更多样性。默认值为 0。

Returns:
obj:`Message`: 模型运行后的输出消息。

"""
return super().run(message=message, style=style, stream=stream, temperature=temperature, top_p=top_p, request_id=request_id)

@components_run_stream_trace
def tool_eval(self, query: str, style: str = "营销话术", **kwargs):
"""
执行工具评估函数

Args:
name (str): 函数名称,本函数不使用该参数,但保留以符合某些框架的要求。
streaming (bool, optional): 是否以流的形式返回结果。默认为 False,即一次性返回结果。如果设置为 True,则以生成器形式逐个返回结果。
**kwargs: 其他参数,包含但不限于:
_sys_traceid (str): 请求的跟踪ID,用于日志记录和跟踪。
model_configs (dict, optional): 模型配置参数,可选的键包括:
temperature (float, optional): 温度参数,用于控制生成文本的随机性。默认为 1e-10。
top_p (float, optional): top_p 采样参数,用于控制生成文本的多样性。默认为 0.0。

Returns:
如果 streaming 为 False,则直接返回评估结果字符串。
如果 streaming 为 True,则以生成器形式逐个返回评估结果字符串。

Raises:
ValueError: 如果缺少参数 'query'。
"""
traceid = kwargs.get("_sys_traceid")
if not query:
raise ValueError("param `query` is required")
msg = Message(query)
if style not in ["营销话术", "客服话术", "直播话术", "激励话术", "教学话术"]:
style = "营销话术"
model_configs = kwargs.get('model_configs', {})
temperature = model_configs.get("temperature", 1e-10)
top_p = model_configs.get("top_p", 0.0)
message = super().run(message=msg, style=style, stream=False, temperature=temperature, top_p=top_p, request_id=traceid)

yield self.create_output(type="text", text=str(message.content), name="text", usage=message.token_usage)
14 changes: 13 additions & 1 deletion python/tests/component_tool_eval_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,21 @@ def inputs(self):
def schemas(self):
return [url_schema]

class StypeRewriteCase(Case):
def init_args(self):
return {"model": "Qianfan-Agent-Speed-8k"}

def inputs(self):
return {"query": "文心大模型发布新版"}

def schemas(self):
return [text_schema]


component_tool_eval_cases = {
"AnimalRecognition": AnimalRecognitionCase,
"ImageUnderstand": ImageUnderstandCase,
"ASR": ASRCase,
"TreeMind": TreeMindCase
"TreeMind": TreeMindCase,
"StyleRewrite": StypeRewriteCase
}
57 changes: 57 additions & 0 deletions python/tests/test_v2_style_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2024 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import time
import appbuilder
from appbuilder.core.components.v2 import StyleRewrite
from appbuilder.core.component import ComponentOutput

@unittest.skipUnless(os.getenv("TEST_CASE", "UNKNOWN") == "CPU_PARALLEL", "")
class TestStyleRewrite(unittest.TestCase):
def setUp(self) -> None:
self.com = StyleRewrite(model="Qianfan-Agent-Speed-8k")

def test_normal_case(self):
time.sleep(2)
text = "文心大模型发布新版"
style = "激励话术"
msg = appbuilder.Message(content=text)
out = self.com(msg, style=style)
self.assertIn("文心大模型", out.content)

def test_tool_eval(self):
time.sleep(2)
text = "文心大模型发布新版"
style = "营销话术"
out = self.com.tool_eval(query=text, style=style)
for item in out:
self.assertIsInstance(item, ComponentOutput)

def test_non_stream_tool_eval(self):
text = "成都是个包容的城市"
style = "直播话术"
out = self.com.non_stream_tool_eval(query=text, style=style)
print(out)
self.assertIsInstance(out, ComponentOutput)

def test_tool_eval_invalid(self):
"""测试 tool 方法对无效请求的处理。"""
with self.assertRaises(TypeError):
result = self.com.tool_eval(name="image_understand", streaming=True,
origin_query="")


if __name__ == '__main__':
unittest.main()