From 0334b8b71ccf985f90e51ba1608b3eb42d67fe59 Mon Sep 17 00:00:00 2001 From: DavdGao Date: Tue, 5 Nov 2024 19:04:28 +0800 Subject: [PATCH] [HOTFIX] Fix the error when using multiple dashscope api keys (#478) --- src/agentscope/_version.py | 2 +- src/agentscope/models/dashscope_model.py | 7 ++++--- tests/dashscope_test.py | 7 +++++++ 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/agentscope/_version.py b/src/agentscope/_version.py index 7121d7fc7..b850f527e 100644 --- a/src/agentscope/_version.py +++ b/src/agentscope/_version.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- """ Version of AgentScope.""" -__version__ = "0.1.0" +__version__ = "0.1.1.dev" diff --git a/src/agentscope/models/dashscope_model.py b/src/agentscope/models/dashscope_model.py index 02412e4df..6192a5d31 100644 --- a/src/agentscope/models/dashscope_model.py +++ b/src/agentscope/models/dashscope_model.py @@ -68,8 +68,6 @@ def __init__( self.generate_args = generate_args or {} self.api_key = api_key - if self.api_key: - dashscope.api_key = self.api_key self.max_length = None def format( @@ -245,7 +243,7 @@ def __call__( if stream: kwargs["incremental_output"] = True - response = dashscope.Generation.call(**kwargs) + response = dashscope.Generation.call(api_key=self.api_key, **kwargs) # step3: invoke llm api, record the invocation and update the monitor if stream: @@ -490,6 +488,7 @@ def __call__( response = dashscope.ImageSynthesis.call( model=self.model_name, prompt=prompt, + api_key=self.api_key, **kwargs, ) if response.status_code != HTTPStatus.OK: @@ -603,6 +602,7 @@ def __call__( response = dashscope.TextEmbedding.call( input=texts, model=self.model_name, + api_key=self.api_key, **kwargs, ) @@ -735,6 +735,7 @@ def __call__( response = dashscope.MultiModalConversation.call( model=self.model_name, messages=messages, + api_key=self.api_key, **kwargs, ) # Unhandled code path here diff --git a/tests/dashscope_test.py b/tests/dashscope_test.py index d877c6283..74d04e48f 100644 --- a/tests/dashscope_test.py +++ b/tests/dashscope_test.py @@ -67,6 +67,7 @@ def test_call_success(self, mock_generation_call: MagicMock) -> None: messages=messages, result_format="message", stream=False, + api_key="test_api_key", ) @patch("agentscope.models.dashscope_model.dashscope.Generation.call") @@ -102,6 +103,7 @@ def test_call_failure(self, mock_generation_call: MagicMock) -> None: messages=messages, result_format="message", stream=False, + api_key="test_api_key", ) def tearDown(self) -> None: @@ -194,6 +196,7 @@ def test_image_synthesis_wrapper_call_failure( model=self.model_name, prompt=prompt, n=1, # Assuming this is a default value used to call the API + api_key="test_api_key", ) def tearDown(self) -> None: @@ -235,6 +238,7 @@ def test_call_success(self, mock_call: MagicMock) -> None: mock_call.assert_called_once_with( input=texts, model=self.wrapper.model_name, + api_key="test_key", **self.wrapper.generate_args, ) @@ -267,6 +271,7 @@ def test_call_failure(self, mock_call: MagicMock) -> None: mock_call.assert_called_once_with( input=texts, model=self.wrapper.model_name, + api_key="test_key", **self.wrapper.generate_args, ) @@ -327,6 +332,7 @@ def test_call_success(self, mock_call: MagicMock) -> None: mock_call.assert_called_once_with( model=self.wrapper.model_name, messages=messages, + api_key="test_key", ) @patch( @@ -366,6 +372,7 @@ def test_call_failure(self, mock_call: MagicMock) -> None: mock_call.assert_called_once_with( model=self.wrapper.model_name, messages=messages, + api_key="test_key", ) def tearDown(self) -> None: