Skip to content

Commit

Permalink
Refactor BaseChatAgent to register and unregister context providers
Browse files Browse the repository at this point in the history
  • Loading branch information
KennyVaneetvelde committed Jun 19, 2024
1 parent 950ad55 commit 9fc8e49
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions atomic_agents/agents/base_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def _get_and_handle_response(self):
"""
return self.get_response(response_model=self.output_schema)


def _init_run(self, user_input: Type[BaseAgentIO]):
"""
Initializes the run with the given user input.
Expand All @@ -166,7 +165,7 @@ def _post_run(self, response):
response (Type[BaseModel]): The response from the chat agent.
"""
self.memory.add_message('assistant', str(response))

def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextProviderBase]:
"""
Retrieves a context provider by name.
Expand All @@ -183,3 +182,25 @@ def get_context_provider(self, provider_name: str) -> Type[SystemPromptContextPr
if provider_name not in self.system_prompt_generator.system_prompt_info.context_providers:
raise KeyError(f"Context provider '{provider_name}' not found.")
return self.system_prompt_generator.system_prompt_info.context_providers[provider_name]

def register_context_provider(self, provider_name: str, provider: SystemPromptContextProviderBase):
"""
Registers a new context provider.
Args:
provider_name (str): The name of the context provider.
provider (SystemPromptContextProviderBase): The context provider instance.
"""
self.system_prompt_generator.system_prompt_info.context_providers[provider_name] = provider

def unregister_context_provider(self, provider_name: str):
"""
Unregisters an existing context provider.
Args:
provider_name (str): The name of the context provider to remove.
"""
if provider_name in self.system_prompt_generator.system_prompt_info.context_providers:
del self.system_prompt_generator.system_prompt_info.context_providers[provider_name]
else:
raise KeyError(f"Context provider '{provider_name}' not found.")

0 comments on commit 9fc8e49

Please sign in to comment.