Skip to content

Commit

Permalink
Add support for bedrock remaining attributes.
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhlogin committed Jun 13, 2024
1 parent bbc9c00 commit d583b59
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,4 @@
AWS_BEDROCK_KNOWLEDGEBASE_ID: str = "aws.bedrock.knowledgebase_id"
AWS_BEDROCK_AGENT_ID: str = "aws.bedrock.agent_id"
AWS_BEDROCK_MODEL_ID: str = "aws.bedrock.model_id"
AWS_BEDROCK_GAURDRAIL_ID: str = "aws.bedrock.guardrail_id"

AWS_BEDROCK_GUARDRAIL_ID: str = "aws.bedrock.guardrail_id"
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from amazon.opentelemetry.distro._aws_attribute_keys import (
AWS_BEDROCK_AGENT_ID,
AWS_BEDROCK_DATASOURCE_ID,
AWS_BEDROCK_GAURDRAIL_ID,
AWS_BEDROCK_GUARDRAIL_ID,
AWS_BEDROCK_KNOWLEDGEBASE_ID,
AWS_BEDROCK_MODEL_ID,
AWS_LOCAL_OPERATION,
AWS_LOCAL_SERVICE,
AWS_QUEUE_NAME,
Expand Down Expand Up @@ -297,6 +296,8 @@ def _normalize_remote_service_name(span: ReadableSpan, service_name: str) -> str
"""
if is_aws_sdk_span(span):
aws_sdk_service_mapping = {
"Bedrock Agent": _NORMALIZED_BEDROCK_SERVICE_NAME,
"Bedrock Agent Runtime": _NORMALIZED_BEDROCK_SERVICE_NAME,
}
return aws_sdk_service_mapping.get(service_name, "AWS::" + service_name)
return service_name
Expand Down Expand Up @@ -380,18 +381,15 @@ def _set_remote_type_and_identifier(span: ReadableSpan, attributes: BoundedAttri
remote_resource_identifier = _escape_delimiters(
SqsUrlParser.get_queue_name(span.attributes.get(AWS_QUEUE_URL))
)
elif is_key_present(span, AWS_BEDROCK_MODEL_ID):
remote_resource_type = _NORMALIZED_BEDROCK_SERVICE_NAME + "::Model"
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_MODEL_ID))
elif is_key_present(span, AWS_BEDROCK_AGENT_ID):
remote_resource_type = _NORMALIZED_BEDROCK_SERVICE_NAME + "::Agent"
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_AGENT_ID))
elif is_key_present(span, AWS_BEDROCK_DATASOURCE_ID):
remote_resource_type = _NORMALIZED_BEDROCK_SERVICE_NAME + "::DataSource"
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_DATASOURCE_ID))
elif is_key_present(span, AWS_BEDROCK_GAURDRAIL_ID):
elif is_key_present(span, AWS_BEDROCK_GUARDRAIL_ID):
remote_resource_type = _NORMALIZED_BEDROCK_SERVICE_NAME + "::Guardrail"
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_GAURDRAIL_ID))
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_GUARDRAIL_ID))
elif is_key_present(span, AWS_BEDROCK_KNOWLEDGEBASE_ID):
remote_resource_type = _NORMALIZED_BEDROCK_SERVICE_NAME + "::KnowledgeBase"
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_BEDROCK_KNOWLEDGEBASE_ID))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License.
import abc
import inspect
from typing import Any, Dict, Optional
from typing import Dict, Optional

from opentelemetry.instrumentation.botocore.extensions.types import (
_AttributeMapT,
Expand Down Expand Up @@ -31,14 +31,13 @@ class _AgentOperation(_BedrockAgentOperation):
# UpdateAgentKnowledgeBase -> KnowledgeBaseId
# GetAgentKnowledgeBase -> KnowledgeBaseId
start_attributes = {
"aws.bedrock.agent_id": "AgentId",
"aws.bedrock.agent_id": "agentId",
}
response_attributes = {
"aws.bedrock.agent_id": "AgentId",
"aws.bedrock.agent_id": "agentId",
}

@classmethod
@abc.abstractmethod
def operation_names(cls):
return [
"CreateAgentActionGroup",
Expand Down Expand Up @@ -71,7 +70,7 @@ class _KnowledgeBaseOperation(_BedrockAgentOperation):
# ListIngestionJobs -> not support
# StartIngestionJob -> not support
start_attributes = {
"aws.bedrock.knowledgebase_id": "KnowledgeBaseId",
"aws.bedrock.knowledgebase_id": "knowledgeBaseId",
}
response_attributes = {}

Expand All @@ -95,10 +94,10 @@ class _DataSourceOperation(_BedrockAgentOperation):
# ListIngestionJobs -> not support
# StartIngestionJob -> not support
start_attributes = {
"aws.bedrock.datasource_id": "DataSourceId",
"aws.bedrock.datasource_id": "dataSourceId",
}
response_attributes = {
"aws.bedrock.datasource_id": "DataSourceId",
"aws.bedrock.datasource_id": "dataSourceId",
}

@classmethod
Expand All @@ -108,7 +107,7 @@ def operation_names(cls):

_OPERATION_MAPPING = {
op_name: op_class
for op_class in [_KnowledgeBaseOperation, _DataSourceOperation]
for op_class in [_KnowledgeBaseOperation, _DataSourceOperation, _AgentOperation]
for op_name in op_class.operation_names()
if inspect.isclass(op_class) and issubclass(op_class, _BedrockAgentOperation) and not inspect.isabstract(op_class)
}
Expand Down Expand Up @@ -136,34 +135,28 @@ def on_success(self, span: Span, result: _BotoResultT):
if response_value:
span.set_attribute(
key,
value,
response_value,
)


class _BedrockAgentRuntimeExtension(_AwsSdkExtension): # -> AgentId, KnowledgebaseId -> no overlap
def extract_attributes(self, attributes: _AttributeMapT):
# AgentId, KnowledgebaseId
agent_id = self._call_context.params.get("AgentId")
agent_id = self._call_context.params.get("agentId")
if agent_id:
attributes["aws.bedrock.agent_id"] = agent_id

knowledgebase_id = self._call_context.params.get("KnowledgeBaseId")
knowledgebase_id = self._call_context.params.get("knowledgeBaseId")
if knowledgebase_id:
attributes["aws.bedrock.knowledgebase_id"] = knowledgebase_id


class _BedrockExtension(_AwsSdkExtension): # -> ModelId, GaurdrailId -> no overlap
def extract_attributes(self, attributes: _AttributeMapT):
# ModelId
model_id = self._call_context.params.get("ModelId")
if model_id:
attributes["aws.bedrock.model_id"] = model_id

def on_success(self, span: Span, result: _BotoResultT):
# GuardrailId
gaurdrail_id = result.get("guardrailId")
if gaurdrail_id:
guardrail_id = result.get("guardrailId")
if guardrail_id:
span.set_attribute(
"aws.bedrock.guardrail_id",
gaurdrail_id,
guardrail_id,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from unittest.mock import MagicMock

from amazon.opentelemetry.distro._aws_attribute_keys import (
AWS_BEDROCK_AGENT_ID,
AWS_BEDROCK_DATASOURCE_ID,
AWS_BEDROCK_GUARDRAIL_ID,
AWS_BEDROCK_KNOWLEDGEBASE_ID,
AWS_CONSUMER_PARENT_SPAN_KIND,
AWS_LOCAL_OPERATION,
AWS_LOCAL_SERVICE,
Expand Down Expand Up @@ -821,6 +825,9 @@ def test_normalize_remote_service_name_aws_sdk(self):
self.validate_aws_sdk_service_normalization("Kinesis", "AWS::Kinesis")
self.validate_aws_sdk_service_normalization("S3", "AWS::S3")
self.validate_aws_sdk_service_normalization("SQS", "AWS::SQS")
self.validate_aws_sdk_service_normalization("Bedrock", "AWS::Bedrock")
self.validate_aws_sdk_service_normalization("Bedrock Agent", "AWS::Bedrock")
self.validate_aws_sdk_service_normalization("Bedrock Agent Runtime", "AWS::Bedrock")

def validate_aws_sdk_service_normalization(self, service_name: str, expected_remote_service: str):
self._mock_attribute([SpanAttributes.RPC_SYSTEM, SpanAttributes.RPC_SERVICE], ["aws-api", service_name])
Expand Down Expand Up @@ -977,6 +984,26 @@ def test_sdk_client_span_with_remote_resource_attributes(self):
self._validate_remote_resource_attributes("AWS::DynamoDB::Table", "aws_table^^name")
self._mock_attribute([SpanAttributes.AWS_DYNAMODB_TABLE_NAMES], [None])

# Validate behaviour of AWS_BEDROCK_AGENT_ID attribute with special chars(^), then remove it.
self._mock_attribute([AWS_BEDROCK_AGENT_ID], ["test_agent_id"], keys, values)
self._validate_remote_resource_attributes("AWS::Bedrock::Agent", "test_agent_id")
self._mock_attribute([AWS_BEDROCK_AGENT_ID], [None])

# Validate behaviour of AWS_BEDROCK_DATASOURCE_ID attribute with special chars(^), then remove it.
self._mock_attribute([AWS_BEDROCK_DATASOURCE_ID], ["test_datasource_id"], keys, values)
self._validate_remote_resource_attributes("AWS::Bedrock::DataSource", "test_datasource_id")
self._mock_attribute([AWS_BEDROCK_DATASOURCE_ID], [None])

# Validate behaviour of AWS_BEDROCK_GUARDRAIL_ID attribute with special chars(^), then remove it.
self._mock_attribute([AWS_BEDROCK_GUARDRAIL_ID], ["test_guardrail_id"], keys, values)
self._validate_remote_resource_attributes("AWS::Bedrock::Guardrail", "test_guardrail_id")
self._mock_attribute([AWS_BEDROCK_GUARDRAIL_ID], [None])

# Validate behaviour of AWS_BEDROCK_KNOWLEDGEBASE_ID attribute with special chars(^), then remove it.
self._mock_attribute([AWS_BEDROCK_KNOWLEDGEBASE_ID], ["test_knowledgeBase_id"], keys, values)
self._validate_remote_resource_attributes("AWS::Bedrock::KnowledgeBase", "test_knowledgeBase_id")
self._mock_attribute([AWS_BEDROCK_KNOWLEDGEBASE_ID], [None])

self._mock_attribute([SpanAttributes.RPC_SYSTEM], [None])

def test_client_db_span_with_remote_resource_attributes(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict
from typing import Any, Dict
from unittest import TestCase
from unittest.mock import MagicMock, patch

Expand All @@ -9,11 +9,17 @@
from amazon.opentelemetry.distro.patches._instrumentation_patch import apply_instrumentation_patches
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.span import Span

_STREAM_NAME: str = "streamName"
_BUCKET_NAME: str = "bucketName"
_QUEUE_NAME: str = "queueName"
_QUEUE_URL: str = "queueUrl"
_BEDROCK_MODEL_ID: str = "modelId"
_BEDROCK_AGENT_ID: str = "agentId"
_BEDROCK_DATASOURCE_ID: str = "DataSourceId"
_BEDROCK_GUARDRAIL_ID: str = "GuardrailId"
_BEDROCK_KNOWLEDGEBASE_ID: str = "KnowledgeBaseId"


class TestInstrumentationPatch(TestCase):
Expand Down Expand Up @@ -74,6 +80,17 @@ def _validate_unpatched_botocore_instrumentation(self):
self.assertFalse("aws.sqs.queue_url" in attributes)
self.assertFalse("aws.sqs.queue_name" in attributes)

# Bedrock
self.assertFalse("bedrock" in _KNOWN_EXTENSIONS, "Upstream has added a Bedrock extension")

# Bedrock Agent
self.assertFalse("bedrock-agent" in _KNOWN_EXTENSIONS, "Upstream has added a Bedrock Agent extension")

# Bedrock Agent Runtime
self.assertFalse(
"bedrock-agent-runtime" in _KNOWN_EXTENSIONS, "Upstream has added a Bedrock Agent Runtime extension"
)

def _validate_patched_botocore_instrumentation(self):
# Kinesis
self.assertTrue("kinesis" in _KNOWN_EXTENSIONS)
Expand All @@ -96,6 +113,43 @@ def _validate_patched_botocore_instrumentation(self):
self.assertTrue("aws.sqs.queue_name" in sqs_attributes)
self.assertEqual(sqs_attributes["aws.sqs.queue_name"], _QUEUE_NAME)

# Bedrock
bedrock_sucess_attributes: Dict[str, str] = _do_bedrock_on_success()
self.assertTrue("aws.bedrock.guardrail_id" in bedrock_sucess_attributes)
self.assertEqual(bedrock_sucess_attributes["aws.bedrock.guardrail_id"], _BEDROCK_GUARDRAIL_ID)

# Bedrock Agent Operation
self.assertTrue("bedrock-agent" in _KNOWN_EXTENSIONS)
bedrock_agent_op_attributes: Dict[str, str] = _do_extract_bedrock_agent_op_attributes()
self.assertTrue("aws.bedrock.agent_id" in bedrock_agent_op_attributes)
self.assertEqual(bedrock_agent_op_attributes["aws.bedrock.agent_id"], _BEDROCK_AGENT_ID)
bedrock_agent_op_sucess_attributes: Dict[str, str] = _do_bedrock_agent_op_on_success()
self.assertTrue("aws.bedrock.agent_id" in bedrock_agent_op_sucess_attributes)
self.assertEqual(bedrock_agent_op_sucess_attributes["aws.bedrock.agent_id"], _BEDROCK_AGENT_ID)

# Bedrock DataSource Operation
self.assertTrue("bedrock-agent" in _KNOWN_EXTENSIONS)
bedrock_datasource_op_attributes: Dict[str, str] = _do_extract_bedrock_datasource_op_attributes()
self.assertTrue("aws.bedrock.datasource_id" in bedrock_datasource_op_attributes)
self.assertEqual(bedrock_datasource_op_attributes["aws.bedrock.datasource_id"], _BEDROCK_DATASOURCE_ID)
bedrock_datasource_op_sucess_attributes: Dict[str, str] = _do_bedrock_datasource_op_on_success()
self.assertTrue("aws.bedrock.datasource_id" in bedrock_datasource_op_sucess_attributes)
self.assertEqual(bedrock_datasource_op_sucess_attributes["aws.bedrock.datasource_id"], _BEDROCK_DATASOURCE_ID)

# Bedrock KnowledgeBase Operation
self.assertTrue("bedrock-agent" in _KNOWN_EXTENSIONS)
bedrock_knowledgebase_op_attributes: Dict[str, str] = _do_extract_bedrock_knowledgebase_op_attributes()
self.assertTrue("aws.bedrock.knowledgebase_id" in bedrock_knowledgebase_op_attributes)
self.assertEqual(bedrock_knowledgebase_op_attributes["aws.bedrock.knowledgebase_id"], _BEDROCK_KNOWLEDGEBASE_ID)

# Bedrock Agent Runtime
self.assertTrue("bedrock-agent-runtime" in _KNOWN_EXTENSIONS)
bedrock_agent_runtime_attributes: Dict[str, str] = _do_extract_bedrock_agent_runtime_attributes()
self.assertTrue("aws.bedrock.agent_id" in bedrock_agent_runtime_attributes)
self.assertEqual(bedrock_agent_runtime_attributes["aws.bedrock.agent_id"], _BEDROCK_AGENT_ID)
self.assertTrue("aws.bedrock.knowledgebase_id" in bedrock_agent_runtime_attributes)
self.assertEqual(bedrock_agent_runtime_attributes["aws.bedrock.knowledgebase_id"], _BEDROCK_KNOWLEDGEBASE_ID)


def _do_extract_kinesis_attributes() -> Dict[str, str]:
service_name: str = "kinesis"
Expand All @@ -115,10 +169,83 @@ def _do_extract_sqs_attributes() -> Dict[str, str]:
return _do_extract_attributes(service_name, params)


def _do_extract_attributes(service_name: str, params: Dict[str, str]) -> Dict[str, str]:
def _do_bedrock_on_success() -> Dict[str, str]:
service_name: str = "bedrock"
result: Dict[str, Any] = {"guardrailId": _BEDROCK_GUARDRAIL_ID}
return _do_on_success(service_name, result)


def _do_extract_bedrock_agent_op_attributes() -> Dict[str, str]:
service_name: str = "bedrock-agent"
params: Dict[str, str] = {"agentId": _BEDROCK_AGENT_ID}
operation: str = "CreateAgentAlias"
return _do_extract_attributes(service_name, params, operation)


def _do_bedrock_agent_op_on_success() -> Dict[str, str]:
service_name: str = "bedrock-agent"
result: Dict[str, Any] = {"agentId": _BEDROCK_AGENT_ID}
operation: str = "CreateAgentAlias"
return _do_on_success(service_name, result, operation)


def _do_extract_bedrock_datasource_op_attributes() -> Dict[str, str]:
service_name: str = "bedrock-agent"
params: Dict[str, str] = {"dataSourceId": _BEDROCK_DATASOURCE_ID}
operation: str = "UpdateDataSource"
return _do_extract_attributes(service_name, params, operation)


def _do_bedrock_datasource_op_on_success() -> Dict[str, str]:
service_name: str = "bedrock-agent"
result: Dict[str, Any] = {"dataSourceId": _BEDROCK_DATASOURCE_ID}
operation: str = "UpdateDataSource"
return _do_on_success(service_name, result, operation)


def _do_extract_bedrock_knowledgebase_op_attributes() -> Dict[str, str]:
service_name: str = "bedrock-agent"
params: Dict[str, str] = {"knowledgeBaseId": _BEDROCK_KNOWLEDGEBASE_ID}
operation: str = "GetKnowledgeBase"
return _do_extract_attributes(service_name, params, operation)


def _do_extract_bedrock_agent_runtime_attributes() -> Dict[str, str]:
service_name: str = "bedrock-agent-runtime"
params: Dict[str, str] = {"agentId": _BEDROCK_AGENT_ID, "knowledgeBaseId": _BEDROCK_KNOWLEDGEBASE_ID}
return _do_extract_attributes(service_name, params)


def _do_extract_attributes(service_name: str, params: Dict[str, Any], operation: str = None) -> Dict[str, str]:
mock_call_context: MagicMock = MagicMock()
mock_call_context.params = params
if operation:
mock_call_context.operation = operation
attributes: Dict[str, str] = {}
sqs_extension = _KNOWN_EXTENSIONS[service_name]()(mock_call_context)
sqs_extension.extract_attributes(attributes)
return attributes


def _do_on_success(
service_name: str, result: Dict[str, Any], operation: str = None, params: Dict[str, Any] = None
) -> Dict[str, str]:
span_mock: Span = MagicMock()
mock_call_context = MagicMock()
span_attributes: Dict[str, str] = {}

def set_side_effect(set_key, set_value):
span_attributes[set_key] = set_value

span_mock.set_attribute.side_effect = set_side_effect

if operation:
mock_call_context.operation = operation

if params:
mock_call_context.params = params

extension = _KNOWN_EXTENSIONS[service_name]()(mock_call_context)
extension.on_success(span_mock, result)

return span_attributes

0 comments on commit d583b59

Please sign in to comment.