From 452df39a524747652994190cd9afbb16246f9aa0 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 18 Apr 2024 21:35:11 +0200 Subject: [PATCH] feat(agent/core): Allow zero-argument instantiation of `OpenAIProvider` --- .../autogpt/core/resource/model_providers/openai.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py index cd01b496a0b8..b974e6e036d6 100644 --- a/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py +++ b/autogpts/autogpt/autogpt/core/resource/model_providers/openai.py @@ -319,15 +319,20 @@ class OpenAIProvider( _budget: ModelProviderBudget _configuration: OpenAIConfiguration + _credentials: OpenAICredentials def __init__( self, - settings: OpenAISettings, - logger: logging.Logger, + settings: Optional[OpenAISettings] = None, + logger: Optional[logging.Logger] = None, ): + if not settings: + settings = self.default_settings.copy(deep=True) + if not settings.credentials: + settings.credentials = OpenAICredentials.from_env() + self._settings = settings - assert settings.credentials, "Cannot create OpenAIProvider without credentials" self._configuration = settings.configuration self._credentials = settings.credentials self._budget = settings.budget @@ -343,7 +348,7 @@ def __init__( self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs()) - self._logger = logger + self._logger = logger or logging.getLogger(__name__) def get_token_limit(self, model_name: str) -> int: """Get the token limit for a given model."""