Skip to content

Commit

Permalink
feat(e2ei): respect E2EI during login and client creation (WPB-5851) (#…
Browse files Browse the repository at this point in the history
…2403)


Co-authored-by: boris <[email protected]>
  • Loading branch information
mchenani and borichellow authored Jan 26, 2024
1 parent 65b08ae commit 06b6882
Show file tree
Hide file tree
Showing 45 changed files with 455 additions and 441 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class LoginCommand : CliktCommand(name = "login") {
when (client.getOrRegister(RegisterClientUseCase.RegisterClientParam(password, emptyList()))) {
is RegisterClientResult.Failure -> throw PrintMessage("Client registration failed")
is RegisterClientResult.Success -> echo("Login successful")
is RegisterClientResult.E2EICertificateRequired -> echo("Login successful and e2ei is required")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ class E2EIClientImpl(

fun toNewAcmeAuthz(value: com.wire.crypto.NewAcmeAuthz) = NewAcmeAuthz(
value.identifier,
value.wireOidcChallenge?.let { toAcmeChallenge(it) },
value.wireDpopChallenge?.let { toAcmeChallenge(it) },
keyAuth = value.keyauth,
wireDpopChallenge = toAcmeChallenge(value.wireDpopChallenge),
wireOidcChallenge = toAcmeChallenge(value.wireOidcChallenge)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,8 @@ class MLSClientImpl(
value.displayName,
value.domain,
value.certificate,
toDeviceStatus(value.status)
toDeviceStatus(value.status),
value.thumbprint
)

fun toDeviceStatus(value: com.wire.crypto.DeviceStatus) = when (value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ data class AcmeChallenge(

data class NewAcmeAuthz(
var identifier: String,
var wireOidcChallenge: AcmeChallenge?,
var wireDpopChallenge: AcmeChallenge?
var keyAuth: String,
var wireOidcChallenge: AcmeChallenge,
var wireDpopChallenge: AcmeChallenge
)

@Suppress("TooManyFunctions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ data class WireIdentity(
val displayName: String,
val domain: String,
val certificate: String,
val status: CryptoCertificateStatus
val status: CryptoCertificateStatus,
val thumbprint: String
)

enum class CryptoCertificateStatus {
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pbandk = "0.14.2"
turbine = "1.0.0"
avs = "9.6.9"
jna = "5.14.0"
core-crypto = "1.0.0-rc.29"
core-crypto = "1.0.0-rc.30"
core-crypto-multiplatform = "0.6.0-rc.3-multiplatform-pre1"
completeKotlin = "1.1.0"
desugar-jdk = "2.0.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ interface ClientRepository {
suspend fun clearRetainedClientId(): Either<CoreFailure, Unit>
suspend fun clearHasRegisteredMLSClient(): Either<CoreFailure, Unit>
suspend fun observeCurrentClientId(): Flow<ClientId?>
suspend fun setClientRegistrationBlockedByE2EI(): Either<CoreFailure, Unit>
suspend fun clearClientRegistrationBlockedByE2EI(): Either<CoreFailure, Unit>
suspend fun observeIsClientRegistrationBlockedByE2EI(): Flow<Boolean?>
suspend fun isClientRegistrationBlockedByE2EI(): Either<CoreFailure, Boolean>
suspend fun deleteClient(param: DeleteClientParam): Either<NetworkFailure, Unit>
suspend fun selfListOfClients(): Either<NetworkFailure, List<Client>>
suspend fun observeClientsByUserIdAndClientId(userId: UserId, clientId: ClientId): Flow<Either<StorageFailure, Client>>
Expand Down Expand Up @@ -152,6 +156,24 @@ class ClientDataSource(
rawClientId?.let { ClientId(it) }
}

override suspend fun setClientRegistrationBlockedByE2EI(): Either<CoreFailure, Unit> =
wrapStorageRequest {
clientRegistrationStorage.setClientRegistrationBlockedByE2EI()
}

override suspend fun clearClientRegistrationBlockedByE2EI(): Either<CoreFailure, Unit> =
wrapStorageRequest {
clientRegistrationStorage.clearClientRegistrationBlockedByE2EI()
}

override suspend fun observeIsClientRegistrationBlockedByE2EI(): Flow<Boolean> =
clientRegistrationStorage.observeIsClientRegistrationBlockedByE2EI()

override suspend fun isClientRegistrationBlockedByE2EI(): Either<CoreFailure, Boolean> =
wrapStorageRequest {
clientRegistrationStorage.isBlockedByE2EI()
}

override suspend fun deleteClient(param: DeleteClientParam): Either<NetworkFailure, Unit> {
return clientRemoteRepository.deleteClient(param).onSuccess {
wrapStorageRequest { clientDAO.deleteClient(selfUserID.toDao(), param.clientId.value) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
Expand All @@ -34,7 +34,7 @@ import com.wire.kalium.util.KaliumDispatcherImpl
import kotlinx.coroutines.withContext

interface E2EIClientProvider {
suspend fun getE2EIClient(clientId: ClientId? = null): Either<CoreFailure, E2EIClient>
suspend fun getE2EIClient(clientId: ClientId? = null, isNewClient: Boolean = false): Either<CoreFailure, E2EIClient>
suspend fun nuke()
}

Expand All @@ -47,7 +47,7 @@ internal class EI2EIClientProviderImpl(

private var e2EIClient: E2EIClient? = null

override suspend fun getE2EIClient(clientId: ClientId?): Either<CoreFailure, E2EIClient> =
override suspend fun getE2EIClient(clientId: ClientId?, isNewClient: Boolean): Either<CoreFailure, E2EIClient> =
withContext(dispatchers.io) {
val currentClientId =
clientId ?: currentClientIdProvider().fold({ return@withContext Either.Left(it) }, { it })
Expand All @@ -56,6 +56,7 @@ internal class EI2EIClientProviderImpl(
Either.Right(it)
} ?: run {
getSelfUserInfo().flatMap { selfUser ->
// TODO: use e2eiNewEnrollment for new clients, when CC fix the issues in it
mlsClientProvider.getMLSClient(currentClientId).flatMap {
val newE2EIClient = if (it.isE2EIEnabled()) {
kaliumLogger.e("initial E2EI client for mls client that already has e2ei enabled")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class MLSClientProviderImpl(
override suspend fun getMLSClient(clientId: ClientId?): Either<CoreFailure, MLSClient> = withContext(dispatchers.io) {
val currentClientId = clientId ?: currentClientIdProvider().fold({ return@withContext Either.Left(it) }, { it })
val cryptoUserId = CryptoUserID(value = userId.value, domain = userId.domain)

return@withContext mlsClient?.let {
Either.Right(it)
} ?: run {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import com.wire.kalium.logic.data.id.toApi
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.id.toModel
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysMapper
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
Expand Down Expand Up @@ -116,9 +117,9 @@ interface MLSConversationRepository {
suspend fun rotateKeysAndMigrateConversations(
clientId: ClientId,
e2eiClient: E2EIClient,
certificateChain: String
certificateChain: String,
isNewClient: Boolean = false
): Either<CoreFailure, Unit>

suspend fun getClientIdentity(clientId: ClientId): Either<CoreFailure, WireIdentity>
suspend fun getUserIdentity(userId: UserId): Either<CoreFailure, List<WireIdentity>>
suspend fun getMembersIdentities(
Expand Down Expand Up @@ -171,6 +172,7 @@ internal class MLSConversationDataSource(
private val commitBundleEventReceiver: CommitBundleEventReceiver,
private val epochsFlow: MutableSharedFlow<GroupID>,
private val proposalTimersFlow: MutableSharedFlow<ProposalTimer>,
private val keyPackageLimitsProvider: KeyPackageLimitsProvider,
private val idMapper: IdMapper = MapperProvider.idMapper(),
private val conversationMapper: ConversationMapper = MapperProvider.conversationMapper(selfUserId),
private val mlsPublicKeysMapper: MLSPublicKeysMapper = MapperProvider.mlsPublicKeyMapper(),
Expand Down Expand Up @@ -528,18 +530,21 @@ internal class MLSConversationDataSource(
override suspend fun rotateKeysAndMigrateConversations(
clientId: ClientId,
e2eiClient: E2EIClient,
certificateChain: String
) = mlsClientProvider.getMLSClient().flatMap { mlsClient ->
certificateChain: String,
isNewClient: Boolean
) = mlsClientProvider.getMLSClient(clientId).flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.e2eiRotateAll(e2eiClient, certificateChain, 10U)
mlsClient.e2eiRotateAll(e2eiClient, certificateChain, keyPackageLimitsProvider.refillAmount().toUInt())
}.map { rotateBundle ->
// todo: store keypackages to drop, later drop them again
kaliumLogger.w("upload new keypackages and drop old ones")
keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
return Either.Left(it)
if (!isNewClient) {
kaliumLogger.w("enrollment for existing client: upload new keypackages and drop old ones")
keyPackageRepository.replaceKeyPackages(clientId, rotateBundle.newKeyPackages).flatMapLeft {
return Either.Left(it)
}
}

kaliumLogger.w("send migration commits after key rotations")
kaliumLogger.w("rotate bundles: ${rotateBundle.commits.size}")
rotateBundle.commits.map {
sendCommitBundle(GroupID(it.key), it.value)
}.foldToEitherWhileRight(Unit) { value, _ -> value }.fold({ return Either.Left(it) }, { })
Expand Down Expand Up @@ -575,7 +580,7 @@ internal class MLSConversationDataSource(
userIds: List<UserId>
): Either<CoreFailure, Map<UserId, List<WireIdentity>>> =
wrapStorageRequest {
conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())!!
conversationDAO.getMLSGroupIdByConversationId(conversationId.toDao())
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.E2EIClientProvider
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.map
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.wrapApiRequest
import com.wire.kalium.logic.wrapE2EIRequest
import com.wire.kalium.logic.wrapMLSRequest
Expand All @@ -46,14 +49,15 @@ import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json

interface E2EIRepository {
suspend fun initE2EIClient(clientId: ClientId? = null, isNewClient: Boolean = false): Either<CoreFailure, Unit>
suspend fun fetchTrustAnchors(): Either<CoreFailure, Unit>
suspend fun loadACMEDirectories(): Either<CoreFailure, AcmeDirectory>
suspend fun getACMENonce(endpoint: String): Either<CoreFailure, String>
suspend fun createNewAccount(prevNonce: String, createAccountEndpoint: String): Either<CoreFailure, String>
suspend fun createNewOrder(prevNonce: String, createOrderEndpoint: String): Either<CoreFailure, Triple<NewAcmeOrder, String, String>>
suspend fun createAuthz(prevNonce: String, authzEndpoint: String): Either<CoreFailure, Triple<NewAcmeAuthz, String, String>>
suspend fun getWireNonce(): Either<CoreFailure, String>
suspend fun getWireAccessToken(wireNonce: String): Either<CoreFailure, AccessTokenResponse>
suspend fun getWireAccessToken(dpopToken: String): Either<CoreFailure, AccessTokenResponse>
suspend fun getDPoPToken(wireNonce: String): Either<CoreFailure, String>
suspend fun validateDPoPChallenge(
accessToken: String,
Expand All @@ -73,7 +77,7 @@ interface E2EIRepository {
suspend fun finalize(location: String, prevNonce: String): Either<CoreFailure, Pair<ACMEResponse, String>>
suspend fun checkOrderRequest(location: String, prevNonce: String): Either<CoreFailure, Pair<ACMEResponse, String>>
suspend fun certificateRequest(location: String, prevNonce: String): Either<CoreFailure, ACMEResponse>
suspend fun rotateKeysAndMigrateConversations(certificateChain: String): Either<CoreFailure, Unit>
suspend fun rotateKeysAndMigrateConversations(certificateChain: String, isNewClient: Boolean = false): Either<CoreFailure, Unit>
suspend fun getOAuthRefreshToken(): Either<CoreFailure, String?>
suspend fun nukeE2EIClient()
suspend fun fetchFederationCertificates(): Either<CoreFailure, Unit>
Expand All @@ -92,13 +96,22 @@ class E2EIRepositoryImpl(
private val userConfigRepository: UserConfigRepository
) : E2EIRepository {

override suspend fun initE2EIClient(clientId: ClientId?, isNewClient: Boolean): Either<CoreFailure, Unit> =
e2EIClientProvider.getE2EIClient(clientId, isNewClient).fold({
kaliumLogger.w("E2EI client initialization failed: $it")
Either.Left(it)
}, {
kaliumLogger.w("E2EI client initialized for enrollment")
Either.Right(Unit)
})

override suspend fun fetchTrustAnchors(): Either<CoreFailure, Unit> = userConfigRepository.getE2EISettings().flatMap {
wrapApiRequest {
acmeApi.getTrustAnchors(Url(it.discoverUrl).protocolWithAuthority)
}.flatMap { trustAnchors ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapE2EIRequest {
mlsClient.registerTrustAnchors(trustAnchors.value)
mlsClient.registerTrustAnchors(trustAnchors.decodeToString())
}
}
}
Expand Down Expand Up @@ -235,10 +248,10 @@ class E2EIRepositoryImpl(
}.map { it }
}

override suspend fun rotateKeysAndMigrateConversations(certificateChain: String) =
override suspend fun rotateKeysAndMigrateConversations(certificateChain: String, isNewClient: Boolean) =
e2EIClientProvider.getE2EIClient().flatMap { e2eiClient ->
currentClientIdProvider().flatMap { clientId ->
mlsConversationRepository.rotateKeysAndMigrateConversations(clientId, e2eiClient, certificateChain)
mlsConversationRepository.rotateKeysAndMigrateConversations(clientId, e2eiClient, certificateChain, isNewClient)
}
}

Expand All @@ -248,11 +261,11 @@ class E2EIRepositoryImpl(

override suspend fun fetchFederationCertificates(): Either<CoreFailure, Unit> = userConfigRepository.getE2EISettings().flatMap {
wrapApiRequest {
acmeApi.getACMEFederation(Url(it.discoverUrl).host)
acmeApi.getACMEFederation(Url(it.discoverUrl).protocolWithAuthority)
}.flatMap { data ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.registerIntermediateCa(data.value)
mlsClient.registerIntermediateCa(data)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ import com.wire.kalium.logic.feature.client.IsAllowedToRegisterMLSClientUseCase
import com.wire.kalium.logic.feature.client.IsAllowedToRegisterMLSClientUseCaseImpl
import com.wire.kalium.logic.feature.client.MLSClientManager
import com.wire.kalium.logic.feature.client.MLSClientManagerImpl
import com.wire.kalium.logic.feature.client.RegisterMLSClientUseCase
import com.wire.kalium.logic.feature.client.RegisterMLSClientUseCaseImpl
import com.wire.kalium.logic.feature.connection.ConnectionScope
import com.wire.kalium.logic.feature.connection.SyncConnectionsUseCase
Expand Down Expand Up @@ -435,6 +436,7 @@ import com.wire.kalium.util.DelicateKaliumApi
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.firstOrNull
Expand Down Expand Up @@ -498,7 +500,7 @@ class UserSessionScope internal constructor(
userId, qualifiedIdMapper, globalScope.sessionRepository
)

private val clientIdProvider = CurrentClientIdProvider { clientId() }
val clientIdProvider = CurrentClientIdProvider { clientId() }
private val mlsSelfConversationIdProvider: MLSSelfConversationIdProvider by lazy {
MLSSelfConversationIdProviderImpl(
conversationRepository
Expand Down Expand Up @@ -641,7 +643,8 @@ class UserSessionScope internal constructor(
mlsPublicKeysRepository,
commitBundleEventReceiver,
epochsFlow,
proposalTimersFlow
proposalTimersFlow,
keyPackageLimitsProvider
)

private val e2eiRepository: E2EIRepository
Expand Down Expand Up @@ -911,6 +914,15 @@ class UserSessionScope internal constructor(
mlsConversationRepository
)

private val registerMLSClientUseCase: RegisterMLSClientUseCase
get() = RegisterMLSClientUseCaseImpl(
mlsClientProvider,
clientRepository,
keyPackageRepository,
keyPackageLimitsProvider,
userConfigRepository
)

private val recoverMLSConversationsUseCase: RecoverMLSConversationsUseCase
get() = RecoverMLSConversationsUseCaseImpl(
featureSupport,
Expand Down Expand Up @@ -1102,14 +1114,14 @@ class UserSessionScope internal constructor(
lazy { conversations.updateMLSGroupsKeyingMaterials },
lazy { users.timestampKeyRepository })

internal val mlsClientManager: MLSClientManager = MLSClientManagerImpl(clientIdProvider,
val mlsClientManager: MLSClientManager = MLSClientManagerImpl(clientIdProvider,
isAllowedToRegisterMLSClient,
incrementalSyncRepository,
lazy { slowSyncRepository },
lazy { clientRepository },
lazy {
RegisterMLSClientUseCaseImpl(
mlsClientProvider, clientRepository, keyPackageRepository, keyPackageLimitsProvider
mlsClientProvider, clientRepository, keyPackageRepository, keyPackageLimitsProvider, userConfigRepository
)
})

Expand Down Expand Up @@ -1368,6 +1380,8 @@ class UserSessionScope internal constructor(
val observeLegalHoldStateForUser: ObserveLegalHoldStateForUserUseCase
get() = ObserveLegalHoldStateForUserUseCaseImpl(clientRepository)

suspend fun observeIfE2EIRequiredDuringLogin(): Flow<Boolean?> = clientRepository.observeIsClientRegistrationBlockedByE2EI()

val observeLegalHoldForSelfUser: ObserveLegalHoldForSelfUserUseCase
get() = ObserveLegalHoldForSelfUserUseCaseImpl(userId, observeLegalHoldStateForUser)

Expand Down Expand Up @@ -1610,7 +1624,9 @@ class UserSessionScope internal constructor(
authenticationScope.secondFactorVerificationRepository,
slowSyncRepository,
cachedClientIdClearer,
updateSupportedProtocolsAndResolveOneOnOnes
updateSupportedProtocolsAndResolveOneOnOnes,
registerMLSClientUseCase,
syncFeatureConfigsUseCase
)
val conversations: ConversationScope by lazy {
ConversationScope(
Expand Down Expand Up @@ -1716,7 +1732,9 @@ class UserSessionScope internal constructor(
e2eiRepository,
mlsConversationRepository,
team.isSelfATeamMember,
updateSupportedProtocols
updateSupportedProtocols,
clientRepository,
joinExistingMLSConversations
)

val search: SearchScope
Expand Down
Loading

0 comments on commit 06b6882

Please sign in to comment.