From 6bb8e1dc414471b7f72d03dc2b6d312dc507f812 Mon Sep 17 00:00:00 2001 From: hghandri Date: Fri, 17 Jan 2025 16:57:55 +0100 Subject: [PATCH] add sources coming from Knowledge base + add metadata capability for bedrock LLM and OpenAI --- .../multi_agent_orchestrator/agents/agent.py | 3 + .../agents/anthropic_agent.py | 2 +- .../agents/bedrock_llm_agent.py | 75 +++++++++++++------ .../agents/openai_agent.py | 63 +++++++++++----- .../retrievers/amazon_kb_retriever.py | 17 ++++- .../multi_agent_orchestrator/types/types.py | 34 +++++---- 6 files changed, 134 insertions(+), 60 deletions(-) diff --git a/python/src/multi_agent_orchestrator/agents/agent.py b/python/src/multi_agent_orchestrator/agents/agent.py index 39428927..e4bb46b0 100644 --- a/python/src/multi_agent_orchestrator/agents/agent.py +++ b/python/src/multi_agent_orchestrator/agents/agent.py @@ -27,6 +27,9 @@ def on_llm_new_token(self, message: ConversationMessage) -> None: # Default implementation pass + def on_llm_end(self, token: ConversationMessage) -> None: + # Default implementation + pass @dataclass class AgentOptions: diff --git a/python/src/multi_agent_orchestrator/agents/anthropic_agent.py b/python/src/multi_agent_orchestrator/agents/anthropic_agent.py index e2ca2c14..20125a19 100644 --- a/python/src/multi_agent_orchestrator/agents/anthropic_agent.py +++ b/python/src/multi_agent_orchestrator/agents/anthropic_agent.py @@ -118,7 +118,7 @@ async def process_request( if self.retriever: response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = f"\nHere is the context to use to answer the user's question:\n{response}" + context_prompt = f"\nHere is the context to use to answer the user's question:\n{response['text']}" system_prompt += context_prompt input = { diff --git a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py index fc797977..4627ac92 100644 --- a/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py +++ b/python/src/multi_agent_orchestrator/agents/bedrock_llm_agent.py @@ -5,8 +5,7 @@ import os import boto3 from multi_agent_orchestrator.agents import Agent, AgentOptions -from multi_agent_orchestrator.types import (ConversationMessage, - ConversationMessageMetadata, +from multi_agent_orchestrator.types import (ConversationMessage, ConversationMessageMetadata, ParticipantRole, BEDROCK_MODEL_ID_CLAUDE_3_HAIKU, TemplateVariables, @@ -14,6 +13,7 @@ from multi_agent_orchestrator.utils import conversation_to_dict, Logger, AgentTools from multi_agent_orchestrator.retrievers import Retriever +import traceback @dataclass class BedrockLLMAgentOptions(AgentOptions): @@ -115,11 +115,13 @@ async def process_request( self.update_system_prompt() system_prompt = self.system_prompt + citations = [] if self.retriever: response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = "\nHere is the context to use to answer the user's question:\n" + response + context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text'] system_prompt += context_prompt + citations = response['sources'] converse_cmd = { 'modelId': self.model_id, @@ -152,6 +154,17 @@ async def process_request( else: bedrock_response = await self.handle_single_response(converse_cmd) + if citations: + if not converse_message.metadata: + bedrock_response['metadata'] = ConversationMessageMetadata() + + bedrock_response.metadata.citations.extend(citations) + + if self.streaming: + self.callbacks.on_llm_end( + bedrock_response + ) + conversation.append(bedrock_response) if any('toolUse' in content for content in bedrock_response.content): @@ -173,17 +186,31 @@ async def process_request( return final_message if self.streaming: - return await self.handle_streaming_response(converse_cmd) + converse_message = await self.handle_streaming_response(converse_cmd) + else: + converse_message = await self.handle_single_response(converse_cmd) + + if citations: + if not converse_message.metadata: + converse_message['metadata'] = ConversationMessageMetadata() + + converse_message.metadata.citations.extend(citations) + + if self.streaming: + self.callbacks.on_llm_end( + converse_message + ) - return await self.handle_single_response(converse_cmd) + return converse_message async def handle_single_response(self, converse_input: dict[str, Any]) -> ConversationMessage: try: response = self.client.converse(**converse_input) if 'output' not in response: raise ValueError("No output received from Bedrock model") + return ConversationMessage( - role=response['output']['message']['role'], + role=ParticipantRole.ASSISTANT.value, content=response['output']['message']['content'], metadata=ConversationMessageMetadata({ 'usage': response['usage'], @@ -201,31 +228,36 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con message = {} content = [] message['content'] = content + message['metadata'] = None text = '' tool_use = {} - metadata: Optional[ConversationMessageMetadata] = None #stream the response into a message. for chunk in response['stream']: + if 'messageStart' in chunk: message['role'] = chunk['messageStart']['role'] + elif 'contentBlockStart' in chunk: tool = chunk['contentBlockStart']['start']['toolUse'] tool_use['toolUseId'] = tool['toolUseId'] tool_use['name'] = tool['name'] + elif 'contentBlockDelta' in chunk: delta = chunk['contentBlockDelta']['delta'] + if 'toolUse' in delta: if 'input' not in tool_use: tool_use['input'] = '' tool_use['input'] += delta['toolUse']['input'] + elif 'text' in delta: text += delta['text'] self.callbacks.on_llm_new_token( - ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=[{'text': delta['text']}] - ) + ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=delta['text'] + ) ) elif 'contentBlockStop' in chunk: if 'input' in tool_use: @@ -237,24 +269,19 @@ async def handle_streaming_response(self, converse_input: dict[str, Any]) -> Con text = '' elif 'metadata' in chunk: - metadata = { - 'usage': chunk['metadata']['usage'], - 'metrics': chunk['metadata']['metrics'] - } - - self.callbacks.on_llm_new_token( - ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - metadata=ConversationMessageMetadata(**metadata) - ) + + message['metadata'] = ConversationMessageMetadata( + usage=chunk['metadata']['usage'], + metrics=chunk['metadata']['metrics'] ) + + print('generate message stream :', message) return ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=message['content'], - metadata=ConversationMessageMetadata(**metadata) + **message ) except Exception as error: + print(traceback.print_exc()) Logger.error(f"Error getting stream from Bedrock model: {str(error)}") raise error diff --git a/python/src/multi_agent_orchestrator/agents/openai_agent.py b/python/src/multi_agent_orchestrator/agents/openai_agent.py index be72d1c4..704a7c5f 100644 --- a/python/src/multi_agent_orchestrator/agents/openai_agent.py +++ b/python/src/multi_agent_orchestrator/agents/openai_agent.py @@ -4,6 +4,7 @@ from multi_agent_orchestrator.agents import Agent, AgentOptions from multi_agent_orchestrator.types import ( ConversationMessage, + ConversationMessageMetadata, ParticipantRole, OPENAI_MODEL_ID_GPT_O_MINI, TemplateVariables @@ -28,15 +29,15 @@ class OpenAIAgentOptions(AgentOptions): class OpenAIAgent(Agent): def __init__(self, options: OpenAIAgentOptions): super().__init__(options) - if not options.api_key: - raise ValueError("OpenAI API key is required") - + if options.client: self.client = options.client else: + if not options.api_key: + raise ValueError("OpenAI API key is required") + self.client = OpenAI(api_key=options.api_key) - self.model = options.model or OPENAI_MODEL_ID_GPT_O_MINI self.streaming = options.streaming or False self.retriever: Optional[Retriever] = options.retriever @@ -83,7 +84,7 @@ def __init__(self, options: OpenAIAgentOptions): options.custom_system_prompt.get('template'), options.custom_system_prompt.get('variables') ) - + def is_streaming_enabled(self) -> bool: @@ -102,11 +103,13 @@ async def process_request( self.update_system_prompt() system_prompt = self.system_prompt + citations = None if self.retriever: response = await self.retriever.retrieve_and_combine_results(input_text) - context_prompt = "\nHere is the context to use to answer the user's question:\n" + response + context_prompt = "\nHere is the context to use to answer the user's question:\n" + response['text'] system_prompt += context_prompt + citations = response['sources'] messages = [ @@ -118,7 +121,6 @@ async def process_request( {"role": "user", "content": input_text} ] - request_options = { "model": self.model, "messages": messages, @@ -128,10 +130,24 @@ async def process_request( "stop": self.inference_config.get('stopSequences'), "stream": self.streaming } + if self.streaming: - return await self.handle_streaming_response(request_options) + converse_message = await self.handle_streaming_response(request_options) else: - return await self.handle_single_response(request_options) + converse_message = await self.handle_single_response(request_options) + + if citations: + if not converse_message.metadata: + converse_message['metadata'] = ConversationMessageMetadata() + + converse_message.metadata.citations.extend(citations) + + if self.streaming: + self.callbacks.on_llm_end( + converse_message + ) + + return converse_message except Exception as error: Logger.error(f"Error in OpenAI API call: {str(error)}") @@ -152,7 +168,11 @@ async def handle_single_response(self, request_options: Dict[str, Any]) -> Conve return ConversationMessage( role=ParticipantRole.ASSISTANT.value, - content=[{"text": assistant_message}] + content=[{"text": assistant_message}], + metadata=ConversationMessageMetadata({ + 'citations': chat_completion.citations, + 'usage': chat_completion.usage + }) ) except Exception as error: @@ -163,24 +183,33 @@ async def handle_streaming_response(self, request_options: Dict[str, Any]) -> Co try: stream = self.client.chat.completions.create(**request_options) accumulated_message = [] - + for chunk in stream: if chunk.choices[0].delta.content: + + metadata = { + 'citations': chunk.citations, + 'usage': chunk.usage + } + chunk_content = chunk.choices[0].delta.content accumulated_message.append(chunk_content) + if self.callbacks: self.callbacks.on_llm_new_token( - ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=[{'text': chunk_content}] - ) + ConversationMessage( + role=ParticipantRole.ASSISTANT.value, + content=chunk_content, + metadata=ConversationMessageMetadata(**metadata) + ) ) #yield chunk_content # Store the complete message in the instance for later access if needed return ConversationMessage( - role=ParticipantRole.ASSISTANT.value, - content=[{"text": ''.join(accumulated_message)}] + role=ParticipantRole.ASSISTANT.value, + content=[{"text": ''.join(accumulated_message)}], + metadata=ConversationMessageMetadata(**metadata) ) except Exception as error: diff --git a/python/src/multi_agent_orchestrator/retrievers/amazon_kb_retriever.py b/python/src/multi_agent_orchestrator/retrievers/amazon_kb_retriever.py index f7530730..29dfa42d 100644 --- a/python/src/multi_agent_orchestrator/retrievers/amazon_kb_retriever.py +++ b/python/src/multi_agent_orchestrator/retrievers/amazon_kb_retriever.py @@ -48,8 +48,21 @@ async def retrieve_and_combine_results(self, text, knowledge_base_id=None, retri @staticmethod def combine_retrieval_results(retrieval_results): - return "\n".join( + sources = [] + + sources.extend( + set(result['metadata']['x-amz-bedrock-kb-source-uri'] + for result in retrieval_results + if result and result.get('metadata') and isinstance(result['metadata'].get('x-amz-bedrock-kb-source-uri'), str)) + ) + + text = "\n".join( result['content']['text'] for result in retrieval_results if result and result.get('content') and isinstance(result['content'].get('text'), str) - ) \ No newline at end of file + ) + + return { + 'text': text, + 'sources': sources + } \ No newline at end of file diff --git a/python/src/multi_agent_orchestrator/types/types.py b/python/src/multi_agent_orchestrator/types/types.py index 15d98fb3..3d38b139 100644 --- a/python/src/multi_agent_orchestrator/types/types.py +++ b/python/src/multi_agent_orchestrator/types/types.py @@ -1,6 +1,6 @@ from enum import Enum from typing import List, Dict, Union, TypedDict, Optional, Any -from dataclasses import dataclass +from dataclasses import dataclass, field import time # Constants @@ -39,6 +39,7 @@ class ParticipantRole(Enum): ASSISTANT = "assistant" USER = "user" + class UsageMetrics(TypedDict): inputTokens: int outputTokens: int @@ -47,27 +48,28 @@ class UsageMetrics(TypedDict): class PerformanceMetrics(TypedDict): latencyMs: int -class ConversationMessageMetadata(TypedDict): - citations: List[str] - usage: Optional[UsageMetrics] - metrics: Optional[PerformanceMetrics] +@dataclass +class ConversationMessageMetadata: + citations: List[str] = field(default_factory=list) + usage: Optional[UsageMetrics] = field(default_factory=dict) + metrics: Optional[PerformanceMetrics] = field(default_factory=dict) + class ConversationMessage: role: ParticipantRole content: List[Any] metadata: ConversationMessageMetadata - def __init__(self, - role: ParticipantRole, - content: Optional[List[Any]] = None, - metadata: Optional[ConversationMessageMetadata] = None): + def __init__( + self, + role: ParticipantRole, + content: Optional[List[Any]] = None, + metadata: Optional[ConversationMessageMetadata] = None + ): self.role = role - self.content = content or [] - self.metadata = metadata or { - 'citations': [], - 'usage': None, - 'metrics': None - } + self.content = content + self.metadata = metadata + class TimestampedMessage(ConversationMessage): def __init__(self, @@ -92,4 +94,4 @@ class OrchestratorConfig: NO_SELECTED_AGENT_MESSAGE: str = "I'm sorry, I couldn't determine how to handle your request.\ Could you please rephrase it?" # pylint: disable=invalid-name GENERAL_ROUTING_ERROR_MSG_MESSAGE: str = None - MAX_MESSAGE_PAIRS_PER_AGENT: int = 100 # pylint: disable=invalid-name + MAX_MESSAGE_PAIRS_PER_AGENT: int = 100 # pylint: disable=invalid-name \ No newline at end of file