Skip to content

Commit

Permalink
Python: Update model diagnostics (#8346)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
We have a module that performs tracing for model invocations. However,
that module only works with OpenAI. We need it to work for all AI
connectors. This PR does the refactoring necessary for the module to
become more generic.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Rename the tracing module to `model_diagnostics`.
2. Restructure and optimize code.
3. Add tracing to all AI connectors except the HuggingFace connector,
because the HuggingFace connector will need some refactoring in the near
future.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
TaoChenOSU authored Aug 27, 2024
1 parent b58180a commit 5476ac3
Show file tree
Hide file tree
Showing 31 changed files with 854 additions and 613 deletions.
4 changes: 3 additions & 1 deletion python/.cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
"huggingface",
"pytestmark",
"contoso",
"opentelemetry"
"opentelemetry",
"SEMANTICKERNEL",
"OTEL"
]
}
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
TELEMETRY_SAMPLE_CONNECTION_STRING="..."
TELEMETRY_SAMPLE_CONNECTION_STRING="..."
SEMANTICKERNEL_EXPERIMENTAL_GENAI_ENABLE_OTEL_DIAGNOSTICS=true
SEMANTICKERNEL_EXPERIMENTAL_GENAI_ENABLE_OTEL_DIAGNOSTICS_SENSITIVE=true
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
import sys
from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, ClassVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from anthropic import AsyncAnthropic
from anthropic.types import (
Expand All @@ -29,11 +35,9 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason
from semantic_kernel.exceptions.service_exceptions import (
ServiceInitializationError,
ServiceResponseException,
)
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_chat_completion

# map finish reasons from Anthropic to Semantic Kernel
ANTHROPIC_TO_SEMANTIC_KERNEL_FINISH_REASON_MAP = {
Expand All @@ -49,8 +53,10 @@
class AnthropicChatCompletion(ChatCompletionClientBase):
"""Antropic ChatCompletion class."""

MODEL_PROVIDER_NAME: ClassVar[str] = "anthropic"

async_client: AsyncAnthropic

def __init__(
self,
ai_model_id: str | None = None,
Expand All @@ -68,10 +74,10 @@ def __init__(
service_id: Service ID tied to the execution settings.
api_key: The optional API key to use. If provided will override,
the env vars or .env file value.
async_client: An existing client to use.
async_client: An existing client to use.
env_file_path: Use the environment settings file as a fallback
to environment variables.
env_file_encoding: The encoding of the environment settings file.
to environment variables.
env_file_encoding: The encoding of the environment settings file.
"""
try:
anthropic_settings = AnthropicSettings.create(
Expand All @@ -82,7 +88,7 @@ def __init__(
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Anthropic settings.", ex) from ex

if not anthropic_settings.chat_model_id:
raise ServiceInitializationError("The Anthropic chat model ID is required.")

Expand All @@ -97,12 +103,14 @@ def __init__(
ai_model_id=anthropic_settings.chat_model_id,
)

@override
@trace_chat_completion(MODEL_PROVIDER_NAME)
async def get_chat_message_contents(
self,
chat_history: "ChatHistory",
settings: "PromptExecutionSettings",
**kwargs: Any,
) -> list["ChatMessageContent"]:
) -> list["ChatMessageContent"]:
"""Executes a chat completion request and returns the result.
Args:
Expand All @@ -127,22 +135,23 @@ async def get_chat_message_contents(
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex
) from ex

metadata: dict[str, Any] = {"id": response.id}
# Check if usage exists and has a value, then add it to the metadata
if hasattr(response, "usage") and response.usage is not None:
metadata["usage"] = response.usage

return [self._create_chat_message_content(response, content_block, metadata)
for content_block in response.content]

return [
self._create_chat_message_content(response, content_block, metadata) for content_block in response.content
]

async def get_streaming_chat_message_contents(
self,
chat_history: ChatHistory,
settings: PromptExecutionSettings,
settings: PromptExecutionSettings,
**kwargs: Any,
) -> AsyncGenerator[list[StreamingChatMessageContent], Any]:
) -> AsyncGenerator[list[StreamingChatMessageContent], Any]:
"""Executes a streaming chat completion request and returns the result.
Args:
Expand All @@ -166,17 +175,18 @@ async def get_streaming_chat_message_contents(
author_role = None
metadata: dict[str, Any] = {"usage": {}, "id": None}
content_block_idx = 0

async for stream_event in stream:
if isinstance(stream_event, RawMessageStartEvent):
author_role = stream_event.message.role
metadata["usage"]["input_tokens"] = stream_event.message.usage.input_tokens
metadata["id"] = stream_event.message.id
elif isinstance(stream_event, (RawContentBlockDeltaEvent, RawMessageDeltaEvent)):
yield [self._create_streaming_chat_message_content(stream_event,
content_block_idx,
author_role,
metadata)]
yield [
self._create_streaming_chat_message_content(
stream_event, content_block_idx, author_role, metadata
)
]
elif isinstance(stream_event, ContentBlockStopEvent):
content_block_idx += 1

Expand All @@ -187,21 +197,18 @@ async def get_streaming_chat_message_contents(
) from ex

def _create_chat_message_content(
self,
response: Message,
content: TextBlock,
response_metadata: dict[str, Any]
self, response: Message, content: TextBlock, response_metadata: dict[str, Any]
) -> "ChatMessageContent":
"""Create a chat message content object."""
items: list[ITEM_TYPES] = []

if content.text:
items.append(TextContent(text=content.text))

finish_reason = None
if response.stop_reason:
finish_reason = ANTHROPIC_TO_SEMANTIC_KERNEL_FINISH_REASON_MAP[response.stop_reason]

return ChatMessageContent(
inner_content=response,
ai_model_id=self.ai_model_id,
Expand All @@ -212,20 +219,20 @@ def _create_chat_message_content(
)

def _create_streaming_chat_message_content(
self,
stream_event: RawContentBlockDeltaEvent | RawMessageDeltaEvent,
content_block_idx: int,
role: str | None = None,
metadata: dict[str, Any] = {}
self,
stream_event: RawContentBlockDeltaEvent | RawMessageDeltaEvent,
content_block_idx: int,
role: str | None = None,
metadata: dict[str, Any] = {},
) -> StreamingChatMessageContent:
"""Create a streaming chat message content object from a choice."""
text_content = ""

if stream_event.delta and hasattr(stream_event.delta, "text"):
text_content = stream_event.delta.text

items: list[STREAMING_ITEM_TYPES] = [StreamingTextContent(choice_index=content_block_idx, text=text_content)]

finish_reason = None
if isinstance(stream_event, RawMessageDeltaEvent):
if stream_event.delta.stop_reason:
Expand All @@ -246,4 +253,3 @@ def _create_streaming_chat_message_content(
def get_prompt_execution_settings_class(self) -> "type[AnthropicChatPromptExecutionSettings]":
"""Create a request settings object."""
return AnthropicChatPromptExecutionSettings

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import contextlib
from abc import ABC
from typing import ClassVar

from azure.ai.inference.aio import ChatCompletionsClient, EmbeddingsClient

Expand All @@ -14,6 +15,8 @@
class AzureAIInferenceBase(KernelBaseModel, ABC):
"""Azure AI Inference Chat Completion Service."""

MODEL_PROVIDER_NAME: ClassVar[str] = "azureai"

client: ChatCompletionsClient | EmbeddingsClient

def __del__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_chat_completion
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT

if sys.version_info >= (3, 12):
Expand Down Expand Up @@ -119,6 +120,8 @@ def __init__(
)

# region Non-streaming
@override
@trace_chat_completion(AzureAIInferenceBase.MODEL_PROVIDER_NAME)
async def get_chat_message_contents(
self,
chat_history: ChatHistory,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from typing import ClassVar

from semantic_kernel.connectors.ai.google.google_ai.google_ai_settings import GoogleAISettings
from semantic_kernel.kernel_pydantic import KernelBaseModel
Expand All @@ -9,4 +10,6 @@
class GoogleAIBase(KernelBaseModel, ABC):
"""Google AI Service."""

MODEL_PROVIDER_NAME: ClassVar[str] = "googleai"

service_settings: GoogleAISettings
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_chat_completion

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(

# region Non-streaming
@override
@trace_chat_completion(GoogleAIBase.MODEL_PROVIDER_NAME)
async def get_chat_message_contents(
self,
chat_history: ChatHistory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from semantic_kernel.connectors.ai.google.google_ai.services.google_ai_base import GoogleAIBase
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_text_completion

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(

# region Non-streaming
@override
@trace_text_completion(GoogleAIBase.MODEL_PROVIDER_NAME)
async def get_text_contents(
self,
prompt: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from typing import ClassVar

from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_settings import VertexAISettings
from semantic_kernel.kernel_pydantic import KernelBaseModel
Expand All @@ -9,4 +10,6 @@
class VertexAIBase(KernelBaseModel, ABC):
"""Vertex AI Service."""

MODEL_PROVIDER_NAME: ClassVar[str] = "vertexai"

service_settings: VertexAISettings
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_chat_completion

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(

# region Non-streaming
@override
@trace_chat_completion(VertexAIBase.MODEL_PROVIDER_NAME)
async def get_chat_message_contents(
self,
chat_history: ChatHistory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
from semantic_kernel.utils.telemetry.model_diagnostics.decorators import trace_text_completion

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(

# region Non-streaming
@override
@trace_text_completion(VertexAIBase.MODEL_PROVIDER_NAME)
async def get_text_contents(
self,
prompt: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from typing import ClassVar

from mistralai.async_client import MistralAsyncClient

from semantic_kernel.kernel_pydantic import KernelBaseModel


class MistralAIBase(KernelBaseModel, ABC):
"""Mistral AI service base."""

MODEL_PROVIDER_NAME: ClassVar[str] = "mistralai"

async_client: MistralAsyncClient
Loading

0 comments on commit 5476ac3

Please sign in to comment.