From 9fc8e4933f3be86ba97b80c5091278caa34fbb1b Mon Sep 17 00:00:00 2001 From: KennyVaneetvelde Date: Wed, 19 Jun 2024 22:02:12 +0200 Subject: [PATCH] Refactor BaseChatAgent to register and unregister context providers --- atomic_agents/agents/base_chat_agent.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/atomic_agents/agents/base_chat_agent.py b/atomic_agents/agents/base_chat_agent.py index 822982d..b7a1f03 100644 --- a/atomic_agents/agents/base_chat_agent.py +++ b/atomic_agents/agents/base_chat_agent.py @@ -141,7 +141,6 @@ def _get_and_handle_response(self): """ return self.get_response(response_model=self.output_schema) - def _init_run(self, user_input: Type[BaseAgentIO]): """ Initializes the run with the given user input. @@ -166,7 +165,7 @@ def _post_run(self, response): response (Type[BaseModel]): The response from the chat agent. """ self.memory.add_message('assistant', str(response)) - + def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextProviderBase]: """ Retrieves a context provider by name. @@ -183,3 +182,25 @@ def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextPr if provider_name not in self.system_prompt_generator.system_prompt_info.context_providers: raise KeyError(f"Context provider '{provider_name}' not found.") return self.system_prompt_generator.system_prompt_info.context_providers[provider_name] + + def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase): + """ + Registers a new context provider. + + Args: + provider_name (str): The name of the context provider. + provider (SystemPromptContextProviderBase): The context provider instance. + """ + self.system_prompt_generator.system_prompt_info.context_providers[provider_name] = provider + + def unregister_context_provider(self, provider_name: str): + """ + Unregisters an existing context provider. + + Args: + provider_name (str): The name of the context provider to remove. + """ + if provider_name in self.system_prompt_generator.system_prompt_info.context_providers: + del self.system_prompt_generator.system_prompt_info.context_providers[provider_name] + else: + raise KeyError(f"Context provider '{provider_name}' not found.") \ No newline at end of file