From be8825ac05935436a2f4b368df50d89c06da832a Mon Sep 17 00:00:00 2001 From: Chaoyi Yuan Date: Fri, 5 Jan 2024 01:54:43 +0800 Subject: [PATCH] [C#] feat: add unit tests for Teams SSO auth related classes (#1112) ## Linked issues closes: #1053 #1110 ## Details 1. Create an adapter for MSAL library to make the classes testable with mock 2. Update existing class implementation to use the MSAL library adapter 3. Do not check token cache in TeamsSsoPrompt because the Authentication handler already checks token cache before invoking the prompt 4. Add unit tests ## Attestation Checklist - [x] My code follows the style guidelines of this project - I have checked for/fixed spelling, linting, and other errors - I have commented my code for clarity - I have made corresponding changes to the documentation (we use [TypeDoc](https://typedoc.org/) to document our code) - My changes generate no new warnings - I have added tests that validates my changes, and provides sufficient test coverage. I have tested with: - Local testing - E2E testing in Teams - New and existing unit tests pass locally with my changes --------- Co-authored-by: Alex Acebo --- .../Application/Authentication/AppConfig.cs | 55 ++++ .../AuthenticationManagerTests.cs | 1 - .../Bot/TeamsSsoBotAuthenticationTests.cs | 204 +++++++++++++ .../Authentication/Bot/TeamsSsoPromptTests.cs | 272 ++++++++++++++++++ ...SsoMessageExtensionsAuthenticationTests.cs | 193 +++++++++++++ .../TeamsSsoAuthenticationTests.cs | 172 +++++++++++ .../Bot/BotAuthenticationBase.cs | 4 +- .../Bot/TeamsSsoBotAuthentication.cs | 5 +- .../Authentication/Bot/TeamsSsoPrompt.cs | 41 +-- .../ConfidentialClientApplicationAdapter.cs | 50 ++++ .../IConfidentialClientApplicationAdapter.cs | 15 + ...TeamsSsoMessageExtensionsAuthentication.cs | 11 +- .../Authentication/TeamsSsoAuthentication.cs | 17 +- 13 files changed, 985 insertions(+), 55 deletions(-) create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs create mode 100644 dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs new file mode 100644 index 000000000..c2a924598 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AppConfig.cs @@ -0,0 +1,55 @@ +using Microsoft.Identity.Client; +using System.Security.Cryptography.X509Certificates; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication +{ + internal class AppConfig : IAppConfig + { +#pragma warning disable CS8618 // This class is for test purpose only + public AppConfig(string clientId, string tenantId) +#pragma warning restore CS8618 + { + ClientId = clientId; + TenantId = tenantId; + } + + public string ClientId { get; } + + public bool EnablePiiLogging { get; } + + public IMsalHttpClientFactory HttpClientFactory { get; } + + public LogLevel LogLevel { get; } + + public bool IsDefaultPlatformLoggingEnabled { get; } + + public string RedirectUri { get; } + + public string TenantId { get; } + + public LogCallback LoggingCallback { get; } + + public IDictionary ExtraQueryParameters { get; } + + public bool IsBrokerEnabled { get; } + + public string ClientName { get; } + + public string ClientVersion { get; } + + [Obsolete] + public ITelemetryConfig TelemetryConfig { get; } + + public bool ExperimentalFeaturesEnabled { get; } + + public IEnumerable ClientCapabilities { get; } + + public bool LegacyCacheCompatibilityEnabled { get; } + + public string ClientSecret { get; } + + public X509Certificate2 ClientCredentialCertificate { get; } + + public Func ParentActivityOrWindowFunc { get; } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs index f8ba1ab2e..e983b3084 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/AuthenticationManagerTests.cs @@ -110,7 +110,6 @@ public async void Test_SignOut_DefaultHandler() public async void Test_SignOut_SpecificHandler() { // arrange - var graphToken = "graph token"; var app = new TestApplication(new TestApplicationOptions()); var options = new AuthenticationOptions(); options._authenticationSettings = new Dictionary() diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs new file mode 100644 index 000000000..abae878b2 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoBotAuthenticationTests.cs @@ -0,0 +1,204 @@ +using Microsoft.Bot.Builder; +using Microsoft.Bot.Builder.Dialogs; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Microsoft.Teams.AI.State; +using Microsoft.Teams.AI.Tests.TestUtils; +using Moq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.Bot +{ + public class TeamsSsoBotAuthenticationTests + { + internal class MockTeamsSsoBotAuthentication : TeamsSsoBotAuthentication + where TState : TurnState, new() + { + public MockTeamsSsoBotAuthentication(Application app, string name, TeamsSsoSettings settings, TeamsSsoPrompt? mockPrompt = null) : base(app, name, settings, null) + { + if (mockPrompt != null) + { + _prompt = mockPrompt; + } + } + + public async Task TokenExchangeRouteSelectorPublic(ITurnContext context, CancellationToken cancellationToken) + { + return await base.TokenExchangeRouteSelector(context, cancellationToken); + } + } + + + [Fact] + public async void Test_RunDialog_BeginNew() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + + // act + var result = await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Waiting, result.Status); + } + + [Fact] + public async void Test_RunDialog_ContinueExisting() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); // Begin new dialog first + + // act + var tokenExchangeContext = MockTokenExchangeContext(); + var result = await botAuthentication.RunDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Complete, result.Status); + } + + + [Fact] + public async void Test_ContinueDialog() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); // Begin new dialog first + + // act + var tokenExchangeContext = MockTokenExchangeContext(); + var result = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Complete, result.Status); + } + + [Fact] + public async void Test_TokenExchangeRouteSelector_NameMatched() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var turnContext = MockTokenExchangeContext("test"); + + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "test", settings); + + // act + var result = await botAuthentication.TokenExchangeRouteSelectorPublic(turnContext, CancellationToken.None); + + // assert + Assert.True(result); + } + + [Fact] + public async void Test_TokenExchangeRouteSelector_NameNotMatch() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var turnContext = MockTokenExchangeContext("AnotherTokenName"); + + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "test", settings); + + // act + var result = await botAuthentication.TokenExchangeRouteSelectorPublic(turnContext, CancellationToken.None); + + // assert + Assert.False(result); + } + + [Fact] + public async void Test_Dedupe() + { + // arrange + var app = new Application(new ApplicationOptions()); + var msal = ConfidentialClientApplicationBuilder.Create("clientId").WithClientSecret("clientSecret").Build(); + var settings = new TeamsSsoSettings(new string[] { "User.Read" }, "https://localhost/auth-start.html", msal); + var mockedPrompt = CreateTeamsSsoPromptMock(settings); + var botAuthentication = new MockTeamsSsoBotAuthentication(app, "TokenName", settings, mockedPrompt.Object); + + // act + var messageContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(messageContext); + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + var tokenExchangeContext = MockTokenExchangeContext(); + var tokenExchangeResult = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.NotNull(tokenExchangeResult.Result); + Assert.Equal("test token", ((TokenResponse)tokenExchangeResult.Result).Token); + + // act - simulate processing duplicate request + await botAuthentication.RunDialog(messageContext, turnState, "dialogStateProperty"); + tokenExchangeResult = await botAuthentication.ContinueDialog(tokenExchangeContext, turnState, "dialogStateProperty"); + + // assert + Assert.Equal(DialogTurnStatus.Waiting, tokenExchangeResult.Status); + } + + private static Mock CreateTeamsSsoPromptMock(TeamsSsoSettings settings) + { + var mockedPrompt = new Mock("TeamsSsoPrompt", "TokenName", settings); + mockedPrompt + .Setup(mock => mock.BeginDialogAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .ReturnsAsync(new DialogTurnResult(DialogTurnStatus.Waiting)); + mockedPrompt + .Setup(mock => mock.ContinueDialogAsync(It.IsAny(), It.IsAny())) + .Returns(async (DialogContext dc, CancellationToken cancellationToken) => + { + return await dc.EndDialogAsync(new TokenResponse(token: "test token")); + }); + return mockedPrompt; + } + + private static TurnContext MockTurnContext(string type = ActivityTypes.Message, string? name = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = type, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId" }, + From = new() { Id = "fromId" }, + ChannelId = "channelId", + Name = name + }); + } + + private static TurnContext MockTokenExchangeContext(string settingName = "test") + { + JObject activityValue = new(); + activityValue["id"] = $"{Guid.NewGuid()}-{settingName}"; + + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId" }, + From = new() { Id = "fromId" }, + ChannelId = "channelId", + Value = activityValue + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs new file mode 100644 index 000000000..065194b22 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/Bot/TeamsSsoPromptTests.cs @@ -0,0 +1,272 @@ +using Microsoft.Bot.Builder.Dialogs; +using Microsoft.Bot.Builder; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Moq; +using Microsoft.Bot.Builder.Adapters; +using Microsoft.Bot.Connector; +using System.Text.Json; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.Bot +{ + public class TeamsSsoPromptTests + { + private const string TokenExchangeSuccess = "TokenExchangeSuccess"; + private const string TokenExchangeFail = "TokenExchangeFail"; + private const string DialogId = "DialogId"; + private const string PromptName = "PromptName"; + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoPromptMock : TeamsSsoPrompt + { + public TeamsSsoPromptMock(string dialogId, string name, TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(dialogId, name, settings) + { + _msalAdapter = msalAdapterMock; + } + } + + [Fact] + public async Task BeginDialogAsync_SendOAuthCard() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_TokenExchangeSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Value = JObject.FromObject(new TokenExchangeInvokeRequest() + { + Id = "fake_id", + Token = "fake_token" + }) + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(200, response!.Status); + }) + .AssertReply(TokenExchangeSuccess) + .AssertReply(activity => + { + var response = JsonSerializer.Deserialize(((Activity)activity).Text); + Assert.Equal(authenticationResult.AccessToken, response!.Token); + Assert.Equal(authenticationResult.ExpiresOn.ToString("O"), response!.Expiration); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_TokenExchangeFail() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalUiRequiredException("error code", "error message")); + + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.TokenExchangeOperationName, + Value = JObject.FromObject(new TokenExchangeInvokeRequest() + { + Id = "fake_id", + Token = "fake_token" + }) + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(412, response!.Status); + }) + .StartTestAsync(); + } + + [Fact] + public async Task ContinueDialogAsync_SignInVerify() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var testFlow = InitTestFlow(msalAdapterMock.Object); + + // Act and Assert + await testFlow + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Text = "hello", + Conversation = new ConversationAccount() { Id = "testUserId" } + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .Send(new Activity() + { + ChannelId = Channels.Msteams, + Type = ActivityTypes.Invoke, + Name = SignInConstants.VerifyStateOperationName + }) + .AssertReply(activity => + { + Assert.Equal(1, ((Activity)activity).Attachments.Count); + Assert.Equal(OAuthCard.ContentType, ((Activity)activity).Attachments[0].ContentType); + OAuthCard? card = ((Activity)activity).Attachments[0].Content as OAuthCard; + Assert.NotNull(card); + Assert.Equal(1, card.Buttons.Count); + Assert.Equal(ActionTypes.Signin, card!.Buttons[0].Type); + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", card!.Buttons[0].Value); + }) + .AssertReply(a => + { + Assert.Equal(ActivityTypesEx.InvokeResponse, a.Type); + var response = ((Activity)a).Value as InvokeResponse; + Assert.NotNull(response); + Assert.Equal(200, response!.Status); + }) + .StartTestAsync(); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoPrompt CreateTeamsSsoPrompt(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + var teamsSsoPrompt = new TeamsSsoPromptMock(DialogId, PromptName, settings, msalAdapterMock); + return teamsSsoPrompt; + } + + private static TestFlow InitTestFlow(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var teamsSsoPrompt = CreateTeamsSsoPrompt(msalAdapterMock); + var conversationState = new ConversationState(new MemoryStorage()); + var dialogState = conversationState.CreateProperty("dialogState"); + var dialogs = new DialogSet(dialogState); + dialogs.Add(teamsSsoPrompt); + + var adapter = new TestAdapter() + .Use(new AutoSaveStateMiddleware(conversationState)); + + BotCallbackHandler botCallbackHandler = async (turnContext, cancellationToken) => + { + var dc = await dialogs.CreateContextAsync(turnContext, cancellationToken); + + var results = await dc.ContinueDialogAsync(cancellationToken); + if (results.Status == DialogTurnStatus.Empty) + { + await dc.PromptAsync(DialogId, new PromptOptions(), cancellationToken); + } + else if (results.Status == DialogTurnStatus.Complete) + { + if (results.Result is TokenResponse) + { + await turnContext.SendActivityAsync(MessageFactory.Text(TokenExchangeSuccess), cancellationToken); + await turnContext.SendActivityAsync(MessageFactory.Text(JsonSerializer.Serialize(results.Result)), cancellationToken); + } + else + { + await turnContext.SendActivityAsync(MessageFactory.Text(TokenExchangeFail), cancellationToken); + } + } + }; + + return new TestFlow(adapter, botCallbackHandler); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs new file mode 100644 index 000000000..f8022eaa7 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthenticationTests.cs @@ -0,0 +1,193 @@ +using Microsoft.Bot.Builder; +using Microsoft.Identity.Client; +using Moq; +using Microsoft.Bot.Schema; +using Microsoft.Teams.AI.Tests.TestUtils; +using Newtonsoft.Json.Linq; +using Microsoft.Teams.AI.Exceptions; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication.MessageExtensions +{ + public class TeamsSsoMessageExtensionsAuthenticationTests + { + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoMessageExtensionsAuthenticationMock : TeamsSsoMessageExtensionsAuthentication + { + public TeamsSsoMessageExtensionsAuthenticationMock(TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(settings) + { + _msalAdapter = msalAdapterMock; + } + } + + [Fact] + public async Task GetSignInLink() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var signInLink = await messageExtensionAuth.GetSignInLink(turnContext); + + // Assert + Assert.Equal($"{AuthStartPage}?scope={UserReadScope}&clientId={ClientId}&tenantId={TenantId}", signInLink); + } + + [Fact] + public async Task HandleUserSignIn() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var tokenResponse = await messageExtensionAuth.HandleUserSignIn(turnContext, "123456"); + + // Assert + Assert.Null(tokenResponse.Token); + Assert.Null(tokenResponse.Expiration); + } + + [Fact] + public void IsValidActivity_Valid() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var isValidActivity = messageExtensionAuth.IsValidActivity(turnContext); + + // Assert + Assert.True(isValidActivity); + } + + [Fact] + public void IsValidActivity_InValid() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + + // Act and Assert + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.QUERY_LINK_INVOKE_NAME))); + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.FETCH_TASK_INVOKE_NAME))); + Assert.False(messageExtensionAuth.IsValidActivity(MockTurnContext(MessageExtensionsInvokeNames.ANONYMOUS_QUERY_LINK_INVOKE_NAME))); + } + + [Fact] + public async Task HandleSsoTokenExchange_NoTokenInRequest() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + var turnContext = MockTurnContext(); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Null(result.Token); + Assert.Null(result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_TokenExchangeSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var authenticationResult = MockAuthenticationResult(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).ReturnsAsync(authenticationResult); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result.Token); + Assert.Equal(authenticationResult.ExpiresOn.ToString("O"), result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_TokenExchangeFail() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalUiRequiredException("error code", "error message")); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act + var result = await messageExtensionAuth.HandleSsoTokenExchange(turnContext); + + // Assert + Assert.Null(result.Token); + Assert.Null(result.Expiration); + } + + [Fact] + public async Task HandleSsoTokenExchange_UnexpectedException() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.InitiateLongRunningProcessInWebApi(It.IsAny>(), It.IsAny(), ref It.Ref.IsAny)).Throws(new MsalServiceException("error code", "error message")); + var messageExtensionAuth = CreateTestClass(msalAdapterMock.Object); + JObject activityValue = new(); + activityValue["authentication"] = new JObject(); + activityValue["authentication"]!["token"] = "sso token"; + var turnContext = MockTurnContext(activityValue: activityValue); + + // Act and Assert + await Assert.ThrowsAsync(async () => { await messageExtensionAuth.HandleSsoTokenExchange(turnContext); }); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoMessageExtensionsAuthentication CreateTestClass(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + return new TeamsSsoMessageExtensionsAuthenticationMock(settings, msalAdapterMock); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static TurnContext MockTurnContext(string? name = null, JObject? activityValue = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId", TenantId = "tenantId" }, + From = new() { Id = "fromId", AadObjectId = "aadObjectId" }, + ChannelId = "channelId", + Name = name ?? MessageExtensionsInvokeNames.QUERY_INVOKE_NAME, + Value = activityValue ?? new JObject() + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs new file mode 100644 index 000000000..91054fb07 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI.Tests/Application/Authentication/TeamsSsoAuthenticationTests.cs @@ -0,0 +1,172 @@ +using Microsoft.Bot.Builder; +using Microsoft.Bot.Schema; +using Microsoft.Identity.Client; +using Microsoft.Teams.AI.State; +using Microsoft.Teams.AI.Exceptions; +using Microsoft.Teams.AI.Tests.TestUtils; +using Moq; +using Newtonsoft.Json.Linq; + +namespace Microsoft.Teams.AI.Tests.Application.Authentication +{ + public class TeamsSsoAuthenticationTests + { + private const string ClientId = "ClientId"; + private const string TenantId = "TenantId"; + private const string UserReadScope = "User.Read"; + private const string AuthStartPage = "https://localhost/auth-start.html"; + private const string AccessToken = "test token"; + + private class TeamsSsoAuthenticationMock : TeamsSsoAuthentication + where TState : TurnState, new() + { + public TeamsSsoAuthenticationMock(Application app, string name, TeamsSsoSettings settings, IConfidentialClientApplicationAdapter msalAdapterMock) : base(app, name, settings, null) + { + _msalAdapter = msalAdapterMock; + } + + public Func? GetSignInSuccessHandler() + { + return _botAuth?._userSignInSuccessHandler; + } + + public Func? GetSignInFailureHandler() + { + return _botAuth?._userSignInFailureHandler; + } + } + + [Fact] + public async Task SignInUserAsync_GetTokenFromCache() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).ReturnsAsync(authenticationResult); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.SignInUserAsync(turnContext, turnState); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result); + } + + [Fact] + public async Task SignOutUserAsync() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.StopLongRunningProcessInWebApiAsync(It.IsAny(), It.IsAny())).ReturnsAsync(true); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + await teamsSsoAuthentication.SignOutUserAsync(turnContext, turnState); + + // Assert + msalAdapterMock.Verify(m => m.StopLongRunningProcessInWebApiAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public void OnUserSignInSuccess() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + teamsSsoAuthentication.OnUserSignInSuccess((turnContext, turnState) => { return Task.CompletedTask; }); + + // Assert + Assert.NotNull(teamsSsoAuthentication.GetSignInSuccessHandler()); + } + + [Fact] + public void OnUserSignInFailure() + { + // Arrange + var msalAdapterMock = MockMsalAdapter(); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + teamsSsoAuthentication.OnUserSignInFailure((turnContext, turnState, exception) => { return Task.CompletedTask; }); + + // Assert + Assert.NotNull(teamsSsoAuthentication.GetSignInFailureHandler()); + } + + [Fact] + public async Task IsUserSignedInAsync_UserSignedIn() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).ReturnsAsync(authenticationResult); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.IsUserSignedInAsync(turnContext); + + // Assert + Assert.Equal(authenticationResult.AccessToken, result); + } + + [Fact] + public async Task IsUserSignedInAsync_UserNotSignedIn() + { + // Arrange + var authenticationResult = MockAuthenticationResult(); + var msalAdapterMock = MockMsalAdapter(); + msalAdapterMock.Setup(m => m.AcquireTokenInLongRunningProcess(It.IsAny>(), It.IsAny())).Throws(new MsalClientException("error code", "error message")); + var turnContext = MockTurnContext(); + var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext); + var teamsSsoAuthentication = CreateTestClass(msalAdapterMock.Object); + + // Act + var result = await teamsSsoAuthentication.IsUserSignedInAsync(turnContext); + + // Assert + Assert.Null(result); + } + + private static AuthenticationResult MockAuthenticationResult(string token = AccessToken, string scope = UserReadScope) + { + return new AuthenticationResult(token, false, "", DateTimeOffset.Now, DateTimeOffset.Now, "", null, "", new string[] { scope }, Guid.NewGuid()); + } + + private static Mock MockMsalAdapter() + { + var msalAdapterMock = new Mock(); + msalAdapterMock.Setup(m => m.AppConfig).Returns(new AppConfig(ClientId, TenantId)); + return msalAdapterMock; + } + + private static TeamsSsoAuthenticationMock CreateTestClass(IConfidentialClientApplicationAdapter msalAdapterMock) + { + var app = new Application(new ApplicationOptions()); + var settings = new TeamsSsoSettings(new string[] { UserReadScope }, AuthStartPage, It.IsAny()); + return new TeamsSsoAuthenticationMock(app, "test", settings, msalAdapterMock); + } + + private static TurnContext MockTurnContext(string? name = null, JObject? activityValue = null) + { + return new TurnContext(new SimpleAdapter(), new Activity() + { + Type = ActivityTypes.Invoke, + Recipient = new() { Id = "recipientId" }, + Conversation = new() { Id = "conversationId", TenantId = "tenantId" }, + From = new() { Id = "fromId", AadObjectId = "aadObjectId" }, + ChannelId = "channelId", + Name = name ?? MessageExtensionsInvokeNames.QUERY_INVOKE_NAME, + Value = activityValue ?? new JObject() + }); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs index 719f37d29..17e752ff3 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/BotAuthenticationBase.cs @@ -25,12 +25,12 @@ internal abstract class BotAuthenticationBase /// /// Callback when user sign in success /// - protected Func? _userSignInSuccessHandler; + internal Func? _userSignInSuccessHandler; /// /// Callback when user sign in fail /// - protected Func? _userSignInFailureHandler; + internal Func? _userSignInFailureHandler; /// /// Initializes the class diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs index cb61a95e1..7a1c27645 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoBotAuthentication.cs @@ -16,7 +16,7 @@ internal class TeamsSsoBotAuthentication : BotAuthenticationBase { private const string SSO_DIALOG_ID = "_TeamsSsoDialog"; private Regex _tokenExchangeIdRegex; - private TeamsSsoPrompt _prompt; + protected TeamsSsoPrompt _prompt; /// /// Initializes the class @@ -90,6 +90,7 @@ private async Task CreateSsoDialogContext(ITurnContext context, T TurnStateProperty accessor = new(state, "conversation", dialogStateProperty); DialogSet dialogSet = new(accessor); WaterfallDialog ssoDialog = new(SSO_DIALOG_ID); + dialogSet.Add(this._prompt); dialogSet.Add(new WaterfallDialog(SSO_DIALOG_ID, new WaterfallStep[] { @@ -99,7 +100,7 @@ private async Task CreateSsoDialogContext(ITurnContext context, T }, async (step, cancellationToken) => { - TokenResponse? tokenResponse = step.Result as TokenResponse; + TokenResponse? tokenResponse = step.Result as TokenResponse; if (tokenResponse != null && await ShouldDedup(context)) { state.Temp.DuplicateTokenExchange = true; diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs index a30dc1c11..3b10f3f8c 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/Bot/TeamsSsoPrompt.cs @@ -10,6 +10,8 @@ namespace Microsoft.Teams.AI { internal class TeamsSsoPrompt : Dialog { + protected IConfidentialClientApplicationAdapter _msalAdapter; + private const string _expiresKey = "expires"; private string _name; private TeamsSsoSettings _settings; @@ -17,8 +19,9 @@ internal class TeamsSsoPrompt : Dialog public TeamsSsoPrompt(string dialogId, string name, TeamsSsoSettings settings) : base(dialogId) { - this._name = name; - this._settings = settings; + _name = name; + _settings = settings; + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } public override async Task BeginDialogAsync(DialogContext dc, object options, CancellationToken cancellationToken) @@ -28,19 +31,7 @@ public override async Task BeginDialogAsync(DialogContext dc, IDictionary state = dc.ActiveDialog.State; state[_expiresKey] = DateTime.Now.AddMilliseconds(timeout); - AuthenticationResult? token = await this.TryGetUserToken(dc.Context); - if (token != null) - { - TokenResponse tokenResponse = new() - { - ConnectionName = "", // No connection name is available in this implementation - Token = token.AccessToken, - Expiration = token.ExpiresOn.ToString("o") - }; - return await dc.EndDialogAsync(tokenResponse); - } - - // Cannot get token from cache, send OAuth card to get SSO token + // Send OAuth card to get SSO token await this.SendOAuthCardToObtainTokenAsync(dc.Context, cancellationToken); return EndOfTurn; } @@ -106,11 +97,7 @@ private async Task> RecognizeTokenAsync(Di try { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - AuthenticationResult exchangedToken = await ((ILongRunningWebApi)_settings.MSAL).InitiateLongRunningProcessInWebApi( - _settings.Scopes, - ssoToken, - ref homeAccountId - ).ExecuteAsync(); + AuthenticationResult exchangedToken = await _msalAdapter.InitiateLongRunningProcessInWebApi(_settings.Scopes, ssoToken!, ref homeAccountId); tokenResponse = new TokenResponse { @@ -185,7 +172,7 @@ private async Task SendOAuthCardToObtainTokenAsync(ITurnContext context, Cancell private SignInResource GetSignInResource() { - string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_settings.MSAL.AppConfig.ClientId}&tenantId={_settings.MSAL.AppConfig.TenantId}"; + string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_msalAdapter.AppConfig.ClientId}&tenantId={_msalAdapter.AppConfig.TenantId}"; SignInResource signInResource = new() { @@ -199,18 +186,6 @@ private SignInResource GetSignInResource() return signInResource; } - private async Task TryGetUserToken(ITurnContext context) - { - string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - IAccount account = await this._settings.MSAL.GetAccountAsync(homeAccountId); - if (account != null) - { - AuthenticationResult result = await this._settings.MSAL.AcquireTokenSilent(this._settings.Scopes, account).ExecuteAsync(); - return result; - } - return null; // Return empty indication no token found in cache - } - private bool IsTeamsVerificationInvoke(ITurnContext context) { return (context.Activity.Type == ActivityTypes.Invoke) && (context.Activity.Name == SignInConstants.VerifyStateOperationName); diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs new file mode 100644 index 000000000..58b1796cd --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/ConfidentialClientApplicationAdapter.cs @@ -0,0 +1,50 @@ +using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensibility; + +namespace Microsoft.Teams.AI +{ + internal class ConfidentialClientApplicationAdapter : IConfidentialClientApplicationAdapter + { + private readonly IConfidentialClientApplication _msal; + + public ConfidentialClientApplicationAdapter(IConfidentialClientApplication msal) + { + _msal = msal; + } + + public IAppConfig AppConfig + { + get + { + return _msal.AppConfig; + } + } + + public Task InitiateLongRunningProcessInWebApi(IEnumerable scopes, string userToken, ref string longRunningProcessSessionKey) + { + return ((ILongRunningWebApi)_msal).InitiateLongRunningProcessInWebApi( + scopes, + userToken, + ref longRunningProcessSessionKey + ).ExecuteAsync(); + } + + public async Task StopLongRunningProcessInWebApiAsync(string longRunningProcessSessionKey, CancellationToken cancellationToken = default) + { + ILongRunningWebApi? oboCca = _msal as ILongRunningWebApi; + if (oboCca != null) + { + return await oboCca.StopLongRunningProcessInWebApiAsync(longRunningProcessSessionKey, cancellationToken); + } + return false; + } + + public async Task AcquireTokenInLongRunningProcess(IEnumerable scopes, string longRunningProcessSessionKey) + { + return await ((ILongRunningWebApi)_msal).AcquireTokenInLongRunningProcess( + scopes, + longRunningProcessSessionKey + ).ExecuteAsync(); + } + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs new file mode 100644 index 000000000..01bd59166 --- /dev/null +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/IConfidentialClientApplicationAdapter.cs @@ -0,0 +1,15 @@ +using Microsoft.Identity.Client; + +namespace Microsoft.Teams.AI +{ + internal interface IConfidentialClientApplicationAdapter + { + IAppConfig AppConfig { get; } + + Task InitiateLongRunningProcessInWebApi(IEnumerable scopes, string userToken, ref string longRunningProcessSessionKey); + + Task StopLongRunningProcessInWebApiAsync(string longRunningProcessSessionKey, CancellationToken cancellationToken = default); + + Task AcquireTokenInLongRunningProcess(IEnumerable scopes, string longRunningProcessSessionKey); + } +} diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs index 7bbed9a5d..154914353 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/MessageExtensions/TeamsSsoMessageExtensionsAuthentication.cs @@ -11,11 +11,14 @@ namespace Microsoft.Teams.AI /// internal class TeamsSsoMessageExtensionsAuthentication : MessageExtensionsAuthenticationBase { + protected IConfidentialClientApplicationAdapter _msalAdapter; + private TeamsSsoSettings _settings; public TeamsSsoMessageExtensionsAuthentication(TeamsSsoSettings settings) { _settings = settings; + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } @@ -26,7 +29,7 @@ public TeamsSsoMessageExtensionsAuthentication(TeamsSsoSettings settings) /// The sign in link public override Task GetSignInLink(ITurnContext context) { - string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_settings.MSAL.AppConfig.ClientId}&tenantId={_settings.MSAL.AppConfig.TenantId}"; + string signInLink = $"{_settings.SignInLink}?scope={Uri.EscapeDataString(string.Join(" ", _settings.Scopes))}&clientId={_msalAdapter.AppConfig.ClientId}&tenantId={_msalAdapter.AppConfig.TenantId}"; return Task.FromResult(signInLink); } @@ -58,11 +61,7 @@ public override async Task HandleSsoTokenExchange(ITurnContext co try { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - AuthenticationResult result = await ((ILongRunningWebApi)_settings.MSAL).InitiateLongRunningProcessInWebApi( - _settings.Scopes, - token.ToString(), - ref homeAccountId - ).ExecuteAsync(); + AuthenticationResult result = await _msalAdapter.InitiateLongRunningProcessInWebApi(_settings.Scopes, token.ToString(), ref homeAccountId); return new TokenResponse() { diff --git a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs index 8ee11bbe8..8a4edfce0 100644 --- a/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs +++ b/dotnet/packages/Microsoft.TeamsAI/Microsoft.TeamsAI/Application/Authentication/TeamsSsoAuthentication.cs @@ -1,6 +1,5 @@ using Microsoft.Bot.Builder; using Microsoft.Identity.Client; -using Microsoft.Identity.Client.Extensibility; using Microsoft.Teams.AI.Exceptions; using Microsoft.Teams.AI.State; @@ -12,7 +11,9 @@ namespace Microsoft.Teams.AI public class TeamsSsoAuthentication : IAuthentication where TState : TurnState, new() { - private TeamsSsoBotAuthentication? _botAuth; + internal IConfidentialClientApplicationAdapter _msalAdapter; + + internal TeamsSsoBotAuthentication? _botAuth; private TeamsSsoMessageExtensionsAuthentication? _messageExtensionsAuth; private TeamsSsoSettings _settings; @@ -28,6 +29,7 @@ public TeamsSsoAuthentication(Application app, string name, TeamsSsoSett _settings = settings; _botAuth = new TeamsSsoBotAuthentication(app, name, _settings, storage); _messageExtensionsAuth = new TeamsSsoMessageExtensionsAuthentication(_settings); + _msalAdapter = new ConfidentialClientApplicationAdapter(settings.MSAL); } /// @@ -68,11 +70,7 @@ public async Task SignOutUserAsync(ITurnContext context, TState state, Cancellat { string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; - ILongRunningWebApi? oboCca = _settings.MSAL as ILongRunningWebApi; - if (oboCca != null) - { - await oboCca.StopLongRunningProcessInWebApiAsync(homeAccountId, cancellationToken); - } + await _msalAdapter.StopLongRunningProcessInWebApiAsync(homeAccountId, cancellationToken); } /// @@ -120,10 +118,7 @@ private async Task _TryGetUserToken(ITurnContext context) string homeAccountId = $"{context.Activity.From.AadObjectId}.{context.Activity.Conversation.TenantId}"; try { - AuthenticationResult result = await ((ILongRunningWebApi)_settings.MSAL).AcquireTokenInLongRunningProcess( - _settings.Scopes, - homeAccountId - ).ExecuteAsync(); + AuthenticationResult result = await _msalAdapter.AcquireTokenInLongRunningProcess(_settings.Scopes, homeAccountId); return result.AccessToken; } catch (MsalClientException)