Skip to content

Commit

Permalink
Merge branch 'main' into feature/qianfan
Browse files Browse the repository at this point in the history
  • Loading branch information
better629 authored Aug 7, 2024
2 parents 76a0271 + 48de3f9 commit 42907e9
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 28 deletions.
20 changes: 18 additions & 2 deletions metagpt/actions/action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,19 @@ def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]
"""基于pydantic v2的模型动态生成,用来检验结果类型正确性"""

def check_fields(cls, values):
required_fields = set(mapping.keys())
all_fields = set(mapping.keys())
required_fields = set()
for k, v in mapping.items():
type_v, field_info = v
if ActionNode.is_optional_type(type_v):
continue
required_fields.add(k)

missing_fields = required_fields - set(values.keys())
if missing_fields:
raise ValueError(f"Missing fields: {missing_fields}")

unrecognized_fields = set(values.keys()) - required_fields
unrecognized_fields = set(values.keys()) - all_fields
if unrecognized_fields:
logger.warning(f"Unrecognized fields: {unrecognized_fields}")
return values
Expand Down Expand Up @@ -717,3 +724,12 @@ def from_pydantic(cls, model: Type[BaseModel], key: str = None):
root_node.add_child(child_node)

return root_node

@staticmethod
def is_optional_type(tp) -> bool:
"""Return True if `tp` is `typing.Optional[...]`"""
if typing.get_origin(tp) is Union:
args = typing.get_args(tp)
non_none_types = [arg for arg in args if arg is not type(None)]
return len(non_none_types) == 1 and len(args) == 2
return False
2 changes: 1 addition & 1 deletion metagpt/actions/project_management_an.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
REQUIRED_PACKAGES = ActionNode(
key="Required packages",
expected_type=Optional[List[str]],
instruction="Provide required packages in requirements.txt format.",
instruction="Provide required third-party packages in requirements.txt format.",
example=["flask==1.1.2", "bcrypt==3.2.0"],
)

Expand Down
2 changes: 1 addition & 1 deletion metagpt/actions/write_code_an_draft.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Main {\
end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Tasks
{"Required packages": ["无需Python包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
{"Required packages": ["无需第三方包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"}
## Code Files
----- index.html
Expand Down
2 changes: 1 addition & 1 deletion metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LLMType(Enum):
YI = "yi" # lingyiwanwu
OPENROUTER = "openrouter"
BEDROCK = "bedrock"
ARK = "ark"
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk

def __missing__(self, key):
return self.OPENAI
Expand Down
65 changes: 60 additions & 5 deletions metagpt/provider/ark_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
from openai import AsyncStream
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Provider for volcengine.
See Also: https://console.volcengine.com/ark/region:ark+cn-beijing/model
config2.yaml example:
```yaml
llm:
base_url: "https://ark.cn-beijing.volces.com/api/v3"
api_type: "ark"
endpoint: "ep-2024080514****-d****"
api_key: "d47****b-****-****-****-d6e****0fd77"
pricing_plan: "doubao-lite"
```
"""
from typing import Optional, Union

from pydantic import BaseModel
from volcenginesdkarkruntime import AsyncArk
from volcenginesdkarkruntime._base_client import AsyncHttpxClientWrapper
from volcenginesdkarkruntime._streaming import AsyncStream
from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk

from metagpt.configs.llm_config import LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.logs import log_llm_stream
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.token_counter import DOUBAO_TOKEN_COSTS


@register_provider(LLMType.ARK)
Expand All @@ -16,11 +37,45 @@ class ArkLLM(OpenAILLM):
见:https://www.volcengine.com/docs/82379/1263482
"""

aclient: Optional[AsyncArk] = None

def _init_client(self):
"""SDK: https://github.com/openai/openai-python#async-usage"""
self.model = (
self.config.endpoint or self.config.model
) # endpoint name, See more: https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint
self.pricing_plan = self.config.pricing_plan or self.model
kwargs = self._make_client_kwargs()
self.aclient = AsyncArk(**kwargs)

def _make_client_kwargs(self) -> dict:
kvs = {
"ak": self.config.access_key,
"sk": self.config.secret_key,
"api_key": self.config.api_key,
"base_url": self.config.base_url,
}
kwargs = {k: v for k, v in kvs.items() if v}

# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)

return kwargs

def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
if next(iter(DOUBAO_TOKEN_COSTS)) not in self.cost_manager.token_costs:
self.cost_manager.token_costs.update(DOUBAO_TOKEN_COSTS)
if model in self.cost_manager.token_costs:
self.pricing_plan = model
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
extra_body={"stream_options": {"include_usage": True}} # 只有增加这个参数才会在流式时最后返回usage
extra_body={"stream_options": {"include_usage": True}}, # 只有增加这个参数才会在流式时最后返回usage
)
usage = None
collected_messages = []
Expand All @@ -30,7 +85,7 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
usage = CompletionUsage(**chunk.usage)
usage = chunk.usage

log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
Expand Down
3 changes: 2 additions & 1 deletion metagpt/roles/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def _process_role_extra(self):
self._check_actions()
self.llm.system_prompt = self._get_prefix()
self.llm.cost_manager = self.context.cost_manager
self._watch(kwargs.pop("watch", [UserRequirement]))
if not self.rc.watch:
self._watch(kwargs.pop("watch", [UserRequirement]))

if self.latest_observed_msg:
self.recovered = True
Expand Down
4 changes: 2 additions & 2 deletions metagpt/utils/make_sk_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def make_sk_kernel():
if llm := config.get_azure_llm():
kernel.add_chat_service(
"chat_completion",
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
AzureChatCompletion(deployment_name=llm.model, base_url=llm.base_url, api_key=llm.api_key),
)
elif llm := config.get_openai_llm():
kernel.add_chat_service(
"chat_completion",
OpenAIChatCompletion(llm.model, llm.api_key),
OpenAIChatCompletion(ai_model_id=llm.model, api_key=llm.api_key),
)

return kernel
13 changes: 12 additions & 1 deletion metagpt/utils/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"gpt-4o": {"prompt": 0.005, "completion": 0.015},
"gpt-4o-mini": {"prompt": 0.00015, "completion": 0.0006},
"gpt-4o-2024-05-13": {"prompt": 0.005, "completion": 0.015},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
Expand Down Expand Up @@ -187,6 +188,14 @@
"mixtral-8x7b": {"prompt": 0.4, "completion": 1.6},
}

# https://console.volcengine.com/ark/region:ark+cn-beijing/model
DOUBAO_TOKEN_COSTS = {
"doubao-lite": {"prompt": 0.0003, "completion": 0.0006},
"doubao-lite-128k": {"prompt": 0.0008, "completion": 0.0010},
"doubao-pro": {"prompt": 0.0008, "completion": 0.0020},
"doubao-pro-128k": {"prompt": 0.0050, "completion": 0.0090},
}

# https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo
TOKEN_MAX = {
"gpt-4o-2024-05-13": 128000,
Expand All @@ -202,6 +211,7 @@
"gpt-4-0613": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
"gpt-4o-mini": 128000,
"gpt-3.5-turbo-0125": 16385,
"gpt-3.5-turbo": 16385,
"gpt-3.5-turbo-1106": 16385,
Expand Down Expand Up @@ -347,8 +357,9 @@ def count_input_tokens(messages, model="gpt-3.5-turbo-0125"):
"gpt-4-turbo",
"gpt-4-vision-preview",
"gpt-4-1106-vision-preview",
"gpt-4o-2024-05-13",
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
}:
tokens_per_message = 3 # # every reply is primed with <|start|>assistant<|message|>
tokens_per_name = 1
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ qianfan~=0.4.4
dashscope~=1.19.3
rank-bm25==0.2.2 # for tool recommendation
jieba==0.42.1 # for tool recommendation
volcengine-python-sdk[ark]~=1.0.94
# llama-index-vector-stores-elasticsearch~=0.2.5 # Used by `metagpt/memory/longterm_memory.py`
# llama-index-vector-stores-chroma~=0.1.10 # Used by `metagpt/memory/longterm_memory.py`
gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30
Expand Down
19 changes: 16 additions & 3 deletions tests/metagpt/actions/test_action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@File : test_action_node.py
"""
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple

import pytest
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -302,6 +302,19 @@ def test_action_node_from_pydantic_and_print_everything():
assert "tasks" in code, "tasks should be in code"


def test_optional():
mapping = {
"Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)),
"Task list": (Optional[List[str]], None),
"Plan": (Optional[str], ""),
"Anything UNCLEAR": (Optional[str], None),
}
m = {"Anything UNCLEAR": "a"}
t = ActionNode.create_model_class("test_class_1", mapping)

t1 = t(**m)
assert t1


if __name__ == "__main__":
test_create_model_class()
test_create_model_class_with_mapping()
pytest.main([__file__, "-s"])
9 changes: 8 additions & 1 deletion tests/metagpt/serialize_deserialize/test_environment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

import pytest

from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
Expand Down Expand Up @@ -55,6 +55,7 @@ def test_environment_serdeser(context):
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch


def test_environment_serdeser_v2(context):
Expand All @@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context):
assert isinstance(role, ProjectManager)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch


def test_environment_serdeser_save(context):
Expand All @@ -85,3 +87,8 @@ def test_environment_serdeser_save(context):
new_env: Environment = Environment(**env_dict, context=context)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch


if __name__ == "__main__":
pytest.main([__file__, "-s"])
4 changes: 2 additions & 2 deletions tests/metagpt/serialize_deserialize/test_role.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@

def test_roles(context):
role_a = RoleA()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
role_b = RoleB()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
assert len(role_b.rc.watch) == 1

role_d = RoleD(actions=[ActionOK()])
Expand Down
8 changes: 4 additions & 4 deletions tests/metagpt/serialize_deserialize/test_serdeser_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from pydantic import BaseModel, Field

from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.fix_bug import FixBug
from metagpt.roles.role import Role, RoleReactMode

serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
Expand Down Expand Up @@ -68,7 +68,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.set_actions([ActionPass])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])


class RoleB(Role):
Expand All @@ -93,7 +93,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

Expand Down
4 changes: 4 additions & 0 deletions tests/metagpt/serialize_deserialize/test_sk_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ async def test_sk_agent_serdeser():
new_role = SkAgent(**ser_role_dict)
assert new_role.name == "Sunshine"
assert len(new_role.actions) == 1


if __name__ == "__main__":
pytest.main([__file__, "-s"])
4 changes: 4 additions & 0 deletions tests/metagpt/serialize_deserialize/test_write_code_review.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ def div(a: int, b: int = 0):

assert new_action.name == "WriteCodeReview"
await new_action.run()


if __name__ == "__main__":
pytest.main([__file__, "-s"])
4 changes: 2 additions & 2 deletions tests/metagpt/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
def test_config_1():
cfg = Config.default()
llm = cfg.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if cfg.llm.api_type == LLMType.OPENAI:
assert llm is not None


def test_config_from_dict():
Expand Down
4 changes: 2 additions & 2 deletions tests/metagpt/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_context_1():
def test_context_2():
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if ctx.config.llm.api_type == LLMType.OPENAI:
assert llm is not None

kwargs = ctx.kwargs
assert kwargs is not None
Expand Down

0 comments on commit 42907e9

Please sign in to comment.