From 2585b445b5fe985b08b776f9e730c06af9cf9260 Mon Sep 17 00:00:00 2001 From: thabart Date: Fri, 8 Nov 2024 10:16:25 +0100 Subject: [PATCH] Always filter on the `master` realm --- .../SimpleIdServer.IdServer.Startup/Program.cs | 1 - .../ClientRepository.cs | 6 ++++-- .../ClientRepository.cs | 5 +++-- .../Api/Realms/RealmsController.cs | 10 +++++----- .../Api/Realms/RemoveRealmCommandConsumer.cs | 2 +- .../SimpleIdServer.IdServer/Jobs/UserSessionJob.cs | 2 +- .../Stores/IClientRepository.cs | 2 +- 7 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/IdServer/SimpleIdServer.IdServer.Startup/Program.cs b/src/IdServer/SimpleIdServer.IdServer.Startup/Program.cs index 1c1679d03..7308d482e 100644 --- a/src/IdServer/SimpleIdServer.IdServer.Startup/Program.cs +++ b/src/IdServer/SimpleIdServer.IdServer.Startup/Program.cs @@ -394,7 +394,6 @@ async void SeedData(WebApplication application, string scimBaseUrl) { var isInMemory = dbContext.Database.IsInMemory(); if (!isInMemory) dbContext.Database.Migrate(); - if (dbContext.Translations.Any()) return; var masterRealm = dbContext.Realms.FirstOrDefault(r => r.Name == SimpleIdServer.IdServer.Constants.StandardRealms.Master.Name) ?? SimpleIdServer.IdServer.Constants.StandardRealms.Master; if (!dbContext.Realms.Any()) dbContext.Realms.AddRange(SimpleIdServer.IdServer.Startup.IdServerConfiguration.Realms); diff --git a/src/IdServer/SimpleIdServer.IdServer.Store.EF/ClientRepository.cs b/src/IdServer/SimpleIdServer.IdServer.Store.EF/ClientRepository.cs index 5494327e9..ceecabb59 100644 --- a/src/IdServer/SimpleIdServer.IdServer.Store.EF/ClientRepository.cs +++ b/src/IdServer/SimpleIdServer.IdServer.Store.EF/ClientRepository.cs @@ -60,13 +60,15 @@ public Task> GetByClientIds(string realm, List clientIds, C .ToListAsync(cancellationToken); } - public Task> GetByClientIdsAndExistingBackchannelLogoutUri(List clientIds, CancellationToken cancellationToken) + public Task> GetByClientIdsAndExistingBackchannelLogoutUri(string realm, List clientIds, CancellationToken cancellationToken) { return _dbContext.Clients - .Where(c => clientIds.Contains(c.ClientId) && !string.IsNullOrWhiteSpace(c.BackChannelLogoutUri)) + .Include(c => c.Realms) + .Where(c => clientIds.Contains(c.ClientId) && c.Realms.Any(r => r.Name == realm) && !string.IsNullOrWhiteSpace(c.BackChannelLogoutUri)) .ToListAsync(); } + public Task> GetByClientIdsAndExistingFrontchannelLogoutUri(string realm, List clientIds, CancellationToken cancellationToken) { return _dbContext.Clients diff --git a/src/IdServer/SimpleIdServer.IdServer.Store.SqlSugar/ClientRepository.cs b/src/IdServer/SimpleIdServer.IdServer.Store.SqlSugar/ClientRepository.cs index 601fa1646..2396adde5 100644 --- a/src/IdServer/SimpleIdServer.IdServer.Store.SqlSugar/ClientRepository.cs +++ b/src/IdServer/SimpleIdServer.IdServer.Store.SqlSugar/ClientRepository.cs @@ -117,10 +117,11 @@ public async Task> GetByClientIds(string realm, List client return result.Select(r => r.ToDomain()).ToList(); } - public async Task> GetByClientIdsAndExistingBackchannelLogoutUri(List clientIds, CancellationToken cancellationToken) + public async Task> GetByClientIdsAndExistingBackchannelLogoutUri(string realm, List clientIds, CancellationToken cancellationToken) { var result = await _dbContext.Client.Queryable() - .Where(c => clientIds.Contains(c.ClientId) && !string.IsNullOrWhiteSpace(c.BackChannelLogoutUri)) + .Includes(c => c.Realms) + .Where(c => clientIds.Contains(c.ClientId) && c.Realms.Any(r => r.RealmsName == realm) && !string.IsNullOrWhiteSpace(c.BackChannelLogoutUri)) .ToListAsync(); return result.Select(r => r.ToDomain()).ToList(); } diff --git a/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RealmsController.cs b/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RealmsController.cs index 0eab9b288..ef431cf4f 100644 --- a/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RealmsController.cs +++ b/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RealmsController.cs @@ -99,11 +99,11 @@ public async Task Add([FromRoute] string prefix, [FromBody] AddRe var realm = new Realm { Name = request.Name, Description = request.Description, CreateDateTime = DateTime.UtcNow, UpdateDateTime = DateTime.UtcNow }; var administratorRole = RealmRoleBuilder.BuildAdministrativeRole(realm); var users = await _userRepository.GetUsersBySubjects(Constants.RealmStandardUsers, Constants.DefaultRealm, cancellationToken); - var groups = await _groupRepository.GetAllByStrictFullPath(Constants.RealmStandardGroupsFullPath, cancellationToken); - var clients = await _clientRepository.GetByClientIds(Constants.RealmStandardClients, cancellationToken); - var scopes = await _scopeRepository.GetByNames(Constants.RealmStandardScopes, cancellationToken); - var keys = await _fileSerializedKeyStore.GetByKeyIds(Constants.StandardKeyIds, cancellationToken); - var acrs = await _authenticationContextClassReferenceRepository.GetByNames(Constants.StandardAcrNames, cancellationToken); + var groups = await _groupRepository.GetAllByStrictFullPath(Constants.DefaultRealm, Constants.RealmStandardGroupsFullPath, cancellationToken); + var clients = await _clientRepository.GetAll(Constants.DefaultRealm, Constants.RealmStandardClients, cancellationToken); + var scopes = await _scopeRepository.GetAll(Constants.DefaultRealm, Constants.RealmStandardScopes, cancellationToken); + var keys = await _fileSerializedKeyStore.GetAll(Constants.DefaultRealm, cancellationToken); + var acrs = await _authenticationContextClassReferenceRepository.GetAll(cancellationToken); _realmRepository.Add(realm); foreach (var user in users) { diff --git a/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RemoveRealmCommandConsumer.cs b/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RemoveRealmCommandConsumer.cs index 736206998..6b25bbc30 100644 --- a/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RemoveRealmCommandConsumer.cs +++ b/src/IdServer/SimpleIdServer.IdServer/Api/Realms/RemoveRealmCommandConsumer.cs @@ -77,7 +77,7 @@ private async Task RevokeRealmUserSessions(string realm, CancellationToken cance .Where(s => !string.IsNullOrWhiteSpace(s)) .Distinct(); var sub = _authenticationHelper.GetLogin(activeSession.User); - var targetedClients = await _clientRepository.GetByClientIdsAndExistingBackchannelLogoutUri(clientIds.ToList(), CancellationToken.None); + var targetedClients = await _clientRepository.GetByClientIdsAndExistingBackchannelLogoutUri(realm, clientIds.ToList(), CancellationToken.None); var sessionClients = targetedClients.Where(c => activeSession.ClientIds.Contains(c.ClientId)); activeSession.State = UserSessionStates.Rejected; _userSessionRepository.Update(activeSession); diff --git a/src/IdServer/SimpleIdServer.IdServer/Jobs/UserSessionJob.cs b/src/IdServer/SimpleIdServer.IdServer/Jobs/UserSessionJob.cs index 2bfa77716..57acabad5 100644 --- a/src/IdServer/SimpleIdServer.IdServer/Jobs/UserSessionJob.cs +++ b/src/IdServer/SimpleIdServer.IdServer/Jobs/UserSessionJob.cs @@ -56,7 +56,7 @@ public async Task Execute() .Where(s => !string.IsNullOrWhiteSpace(s)) .Distinct(); - var targetedClients = await _clientRepository.GetByClientIdsAndExistingBackchannelLogoutUri(clientIds.ToList(), CancellationToken.None); + var targetedClients = await _clientRepository.GetByClientIdsAndExistingBackchannelLogoutUri(group.Key, clientIds.ToList(), CancellationToken.None); var sigCredentials = _keyStore.GetAllSigningKeys(group.Key); await Parallel.ForEachAsync(group.Select(_ => _), async (inactiveSession, c) => { diff --git a/src/IdServer/SimpleIdServer.IdServer/Stores/IClientRepository.cs b/src/IdServer/SimpleIdServer.IdServer/Stores/IClientRepository.cs index 8eb8a8cd4..a60e234b9 100644 --- a/src/IdServer/SimpleIdServer.IdServer/Stores/IClientRepository.cs +++ b/src/IdServer/SimpleIdServer.IdServer/Stores/IClientRepository.cs @@ -14,7 +14,7 @@ public interface IClientRepository Task GetByClientId(string realm, string clientId, CancellationToken cancellationToken); Task> GetByClientIds(List clientIds, CancellationToken cancellationToken); Task> GetByClientIds(string realm, List clientIds, CancellationToken cancellationToken); - Task> GetByClientIdsAndExistingBackchannelLogoutUri(List clientIds, CancellationToken cancellationToken); + Task> GetByClientIdsAndExistingBackchannelLogoutUri(string realm, List clientIds, CancellationToken cancellationToken); Task> GetByClientIdsAndExistingFrontchannelLogoutUri(string realm, List clientIds, CancellationToken cancellationToken); Task> GetAll(string realm, CancellationToken cancellationToken); Task> GetAll(string realm, List clientIds, CancellationToken cancellationToken);