Skip to content

Commit

Permalink
[FIX] Fix broken azure_auth test (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsong468 authored Nov 11, 2024
1 parent 7e9a658 commit 6988eac
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"art==6.1.0",
"azure-cognitiveservices-speech>=1.36.0",
"azure-core>=1.26.1",
"azure-identity>=1.12.0",
"azure-identity>=1.19.0",
"azure-ai-contentsafety>=1.0.0",
"azure-ai-ml==1.13.0",
"azure-storage-blob>=12.19.0",
Expand Down
37 changes: 28 additions & 9 deletions tests/test_azure_auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pytest
import time
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, patch

from pyrit.auth.auth_config import REFRESH_TOKEN_BEFORE_MSEC
from pyrit.auth.azure_auth import AzureAuth, get_token_provider_from_default_azure_credential
Expand All @@ -14,15 +13,17 @@

def test_get_token_on_init():
with patch("azure.identity.AzureCliCredential.get_token") as mock_get_token:
mock_get_token.return_value = Mock(token=mock_token)
mock_get_token.return_value = MagicMock(token=mock_token)
test_instance = AzureAuth(token_scope="https://mocked_endpoint.azure.com")
assert test_instance.token == mock_token


def test_refresh_no_expiration():
# Token not expired so not reset
with patch("azure.identity.AzureCliCredential.get_token") as mock_get_token:
mock_get_token.return_value = Mock(token=mock_token, expires_on=curr_epoch_time + REFRESH_TOKEN_BEFORE_MSEC)
mock_get_token.return_value = MagicMock(
token=mock_token, expires_on=curr_epoch_time + REFRESH_TOKEN_BEFORE_MSEC
)
test_instance = AzureAuth(token_scope="https://mocked_endpoint.azure.com")
token = test_instance.refresh_token()
assert token == mock_token
Expand All @@ -32,16 +33,34 @@ def test_refresh_no_expiration():
def test_refresh_expiration():
# Token expired and reset
with patch("azure.identity.AzureCliCredential.get_token") as mock_get_token:
mock_get_token.return_value = Mock(token=mock_token, expires_on=curr_epoch_time)
mock_get_token.return_value = MagicMock(token=mock_token, expires_on=curr_epoch_time)
test_instance = AzureAuth(token_scope="https://mocked_endpoint.azure.com")
token = test_instance.refresh_token()
assert token
assert mock_get_token.call_count == 2


@pytest.mark.skip(reason="Need to updating mocking logic with new azure.identity version")
def test_get_token_provider_from_default_azure_credential():
with patch("azure.identity.DefaultAzureCredential.get_token") as mock_default_cred:
mock_default_cred.return_value = Mock(token=mock_token, expires_on=curr_epoch_time)
def test_get_token_provider_from_default_azure_credential_get_token():
with (
patch("azure.identity.DefaultAzureCredential.get_token") as mock_default_cred,
patch(
"builtins.hasattr",
side_effect=lambda obj, attr: False if attr == "get_token_info" else getattr(obj, attr, None) is not None,
),
):
mock_default_cred.return_value = MagicMock(token=mock_token, expires_on=curr_epoch_time)
token_provider = get_token_provider_from_default_azure_credential(scope="https://mocked_endpoint.azure.com")
assert token_provider() == mock_token


def test_get_token_provider_from_default_azure_credential_get_token_info():
with (
patch("azure.identity.DefaultAzureCredential.get_token_info") as mock_default_cred,
patch(
"builtins.hasattr",
side_effect=lambda obj, attr: True if attr == "get_token_info" else getattr(obj, attr, None) is not None,
),
):
mock_default_cred.return_value = MagicMock(token=mock_token, expires_on=curr_epoch_time)
token_provider = get_token_provider_from_default_azure_credential(scope="https://mocked_endpoint.azure.com")
assert token_provider() == mock_token

0 comments on commit 6988eac

Please sign in to comment.