Skip to content

Commit

Permalink
fix(e2ei): get user current user clients identities
Browse files Browse the repository at this point in the history
  • Loading branch information
mchenani committed Jan 31, 2024
1 parent 70f23f1 commit e17bf1d
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class MLSClientImpl(
val clientIds = clients.map {
it.toString().encodeToByteArray()
}
return coreCrypto.getDeviceIdentities(groupId.decodeBase64Bytes(), clientIds).map {
return coreCrypto.getDeviceIdentities(groupId.decodeBase64Bytes(), clientIds).mapNotNull {
toIdentity(it)
}
}
Expand All @@ -290,7 +290,7 @@ class MLSClientImpl(
it.value
}
return coreCrypto.getUserIdentities(groupId.decodeBase64Bytes(), usersIds).mapValues {
it.value.map { identity -> toIdentity(identity) }
it.value.mapNotNull { identity -> toIdentity(identity) }
}
}

Expand Down Expand Up @@ -350,15 +350,20 @@ class MLSClientImpl(
value.crlNewDistributionPoints
)

fun toIdentity(value: com.wire.crypto.WireIdentity) = WireIdentity(
value.clientId,
value.handle,
value.displayName,
value.domain,
value.certificate,
toDeviceStatus(value.status),
value.thumbprint
)
fun toIdentity(value: com.wire.crypto.WireIdentity): WireIdentity? {
val clientId = CryptoQualifiedClientId.fromEncodedString(value.clientId)
return clientId?.let {
WireIdentity(
CryptoQualifiedClientId.fromEncodedString(value.clientId)!!,
value.handle,
value.displayName,
value.domain,
value.certificate,
toDeviceStatus(value.status),
value.thumbprint
)
}
}

fun toDeviceStatus(value: com.wire.crypto.DeviceStatus) = when (value) {
com.wire.crypto.DeviceStatus.VALID -> CryptoCertificateStatus.VALID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ data class CryptoQualifiedClientId(
}

data class WireIdentity(
val clientId: String,
val clientId: CryptoQualifiedClientId,
val handle: String,
val displayName: String,
val domain: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ fun com.wire.kalium.cryptography.DecryptedMessageBundle.toModel(groupID: GroupID
identity.clientId,
identity.handle,
identity.displayName,
identity.domain
identity.domain,
identity.certificate,
identity.status,
identity.thumbprint
)
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package com.wire.kalium.logic.data.conversation

import com.wire.kalium.cryptography.CommitBundle
import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.WireIdentity
Expand Down Expand Up @@ -96,7 +97,15 @@ data class DecryptedMessageBundle(
val identity: E2EIdentity?
)

data class E2EIdentity(val clientId: String, val handle: String, val displayName: String, val domain: String)
data class E2EIdentity(
val clientId: CryptoQualifiedClientId,
val handle: String,
val displayName: String,
val domain: String,
val certificate: String,
val status: CryptoCertificateStatus,
val thumbprint: String
)

@Suppress("TooManyFunctions", "LongParameterList")
interface MLSConversationRepository {
Expand Down Expand Up @@ -575,19 +584,26 @@ internal class MLSConversationDataSource(
mlsClient.getDeviceIdentities(
it.mlsGroupId,
listOf(CryptoQualifiedClientId(it.clientId, it.userId.toModel().toCrypto()))
).first() // todo: ask if it's possible that's a client has more than one identity?
).first()
}
}
}

override suspend fun getUserIdentity(userId: UserId) =
wrapStorageRequest { conversationDAO.getMLSGroupIdByUserId(userId.toDao()) }.flatMap { mlsGroupId ->
wrapStorageRequest {
if (userId == selfUserId) {
val selfConversationId = conversationDAO.getSelfConversationId(ConversationEntity.Protocol.MLS)
conversationDAO.getMLSGroupIdByConversationId(selfConversationId!!)
} else {
conversationDAO.getMLSGroupIdByUserId(userId.toDao())
}
}.flatMap { mlsGroupId ->
mlsClientProvider.getMLSClient().flatMap { mlsClient ->
wrapMLSRequest {
mlsClient.getUserIdentities(
mlsGroupId,
listOf(userId.toCrypto())
)[userId.value]!!
)[userId.value] ?: emptyList()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ class GetUserE2eiCertificatesUseCaseImpl internal constructor(
mlsConversationRepository.getUserIdentity(userId).map { identities ->
val result = mutableMapOf<String, E2eiCertificate>()
identities.forEach {
result[it.clientId] = pemCertificateDecoder.decode(it.certificate, it.status)
result[it.clientId.value] = pemCertificateDecoder.decode(it.certificate, it.status)
}
result
}.getOrElse(mapOf())

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ package com.wire.kalium.logic.data.conversation
import com.benasher44.uuid.uuid4
import com.wire.kalium.cryptography.CommitBundle
import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.E2EIClient
import com.wire.kalium.cryptography.E2EIConversationState
import com.wire.kalium.cryptography.GroupInfoBundle
import com.wire.kalium.cryptography.GroupInfoEncryptionType
import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.cryptography.MLSGroupId
import com.wire.kalium.cryptography.RatchetTreeType
import com.wire.kalium.cryptography.RotateBundle
import com.wire.kalium.cryptography.WelcomeBundle
Expand All @@ -35,6 +35,7 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.COMMIT_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.CRYPTO_CLIENT_ID
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.E2EI_CONVERSATION_CLIENT_INFO_ENTITY
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.ROTATE_BUNDLE
import com.wire.kalium.logic.data.conversation.MLSConversationRepositoryTest.Arrangement.Companion.TEST_FAILURE
Expand All @@ -43,6 +44,8 @@ import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.id.GroupID
import com.wire.kalium.logic.data.id.QualifiedClientID
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.id.toDao
import com.wire.kalium.logic.data.keypackage.KeyPackageLimitsProvider
import com.wire.kalium.logic.data.keypackage.KeyPackageRepository
import com.wire.kalium.logic.data.mlspublickeys.Ed25519Key
Expand Down Expand Up @@ -72,6 +75,7 @@ import com.wire.kalium.network.api.base.authenticated.notification.MemberLeaveRe
import com.wire.kalium.network.api.base.model.ErrorResponse
import com.wire.kalium.network.exceptions.KaliumException
import com.wire.kalium.network.utils.NetworkResponse
import com.wire.kalium.persistence.dao.QualifiedIDEntity
import com.wire.kalium.persistence.dao.UserIDEntity
import com.wire.kalium.persistence.dao.conversation.ConversationDAO
import com.wire.kalium.persistence.dao.conversation.ConversationEntity
Expand Down Expand Up @@ -1336,21 +1340,62 @@ class MLSConversationRepositoryTest {
}

@Test
fun givenUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest {
val groupId = "some_group"
fun givenSelfUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetSelfConversationIdReturns(TestConversation.MLS_CONVERSATION.id.toDao())
.withGetMLSGroupIdByConversationIdReturns(groupId)
.withGetUserIdentitiesReturn(
mapOf(
TestUser.USER_ID.value to listOf(WIRE_IDENTITY),
"some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = "another_client_id")),
"some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("another_client_id"))),
)
)
.withGetMLSGroupIdByUserIdReturns(groupId)
.arrange()

assertEquals(Either.Right(listOf(WIRE_IDENTITY)), mlsConversationRepository.getUserIdentity(TestUser.USER_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByUserId)
.with(any())
.wasNotInvoked()

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getSelfConversationId)
.with(eq(ConversationEntity.Protocol.MLS))
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByConversationId)
.with(eq(TestConversation.MLS_CONVERSATION.id.toDao()))
.wasInvoked(once)
}

@Test
fun givenOtherUserId_whenGetMLSGroupIdByUserIdSucceed_thenReturnsIdentities() = runTest {
val groupId = TestConversation.MLS_PROTOCOL_INFO.groupId.value
val (arrangement, mlsConversationRepository) = Arrangement()
.withGetMLSClientSuccessful()
.withGetSelfConversationIdReturns(TestConversation.MLS_CONVERSATION.id.toDao())
.withGetMLSGroupIdByConversationIdReturns(groupId)
.withGetUserIdentitiesReturn(
mapOf(
TestUser.OTHER_USER_ID.value to listOf(WIRE_IDENTITY),
"some_other_user_id" to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("another_client_id"))),
)
)
.withGetMLSGroupIdByUserIdReturns(groupId)
.arrange()

assertEquals(Either.Right(listOf(WIRE_IDENTITY)), mlsConversationRepository.getUserIdentity(TestUser.OTHER_USER_ID))

verify(arrangement.mlsClient)
.suspendFunction(arrangement.mlsClient::getUserIdentities)
.with(eq(groupId), any())
Expand All @@ -1360,6 +1405,16 @@ class MLSConversationRepositoryTest {
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByUserId)
.with(any())
.wasInvoked(once)

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getSelfConversationId)
.with(eq(ConversationEntity.Protocol.MLS))
.wasNotInvoked()

verify(arrangement.conversationDAO)
.suspendFunction(arrangement.conversationDAO::getMLSGroupIdByConversationId)
.with(eq(TestConversation.MLS_CONVERSATION.id.toDao()))
.wasNotInvoked()
}

@Test
Expand All @@ -1373,7 +1428,7 @@ class MLSConversationRepositoryTest {
.withGetUserIdentitiesReturn(
mapOf(
member1.value to listOf(WIRE_IDENTITY),
member2.value to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
member2.value to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("member_2_client_id")))
)
)
.withGetMLSGroupIdByConversationIdReturns(groupId)
Expand All @@ -1383,7 +1438,7 @@ class MLSConversationRepositoryTest {
Either.Right(
mapOf(
member1 to listOf(WIRE_IDENTITY),
member2 to listOf(WIRE_IDENTITY.copy(clientId = "member_2_client_id"))
member2 to listOf(WIRE_IDENTITY.copy(clientId = CRYPTO_CLIENT_ID.copy("member_2_client_id")))
)
),
mlsConversationRepository.getMembersIdentities(TestConversation.ID, listOf(member1, member2, member3))
Expand Down Expand Up @@ -1554,6 +1609,13 @@ class MLSConversationRepositoryTest {
.thenReturn(e2eiInfo)
}

fun withGetSelfConversationIdReturns(id: QualifiedIDEntity?) = apply {
given(conversationDAO)
.suspendFunction(conversationDAO::getSelfConversationId)
.whenInvokedWith(anything())
.thenReturn(id)
}

fun withAddMLSMemberSuccessful(commitBundle: CommitBundle = COMMIT_BUNDLE) = apply {
given(mlsClient)
.suspendFunction(mlsClient::addMember)
Expand Down Expand Up @@ -1737,8 +1799,17 @@ class MLSConversationRepositoryTest {
)
val COMMIT_BUNDLE = CommitBundle(COMMIT, WELCOME, PUBLIC_GROUP_STATE_BUNDLE, null)
val ROTATE_BUNDLE = RotateBundle(mapOf(RAW_GROUP_ID to COMMIT_BUNDLE), emptyList(), emptyList(), null)
val CRYPTO_CLIENT_ID = CryptoQualifiedClientId("clientId", TestConversation.USER_1.toCrypto())
val WIRE_IDENTITY =
WireIdentity("id", "user_handle", "User Test", "domain.com", "certificate", CryptoCertificateStatus.VALID, thumbprint = "thumbprint")
WireIdentity(
CRYPTO_CLIENT_ID,
"user_handle",
"User Test",
"domain.com",
"certificate",
CryptoCertificateStatus.VALID,
thumbprint = "thumbprint"
)
val E2EI_CONVERSATION_CLIENT_INFO_ENTITY =
E2EIConversationClientInfoEntity(UserIDEntity(uuid4().toString(), "domain.com"), "clientId", "groupId")
val DECRYPTED_MESSAGE_BUNDLE = com.wire.kalium.cryptography.DecryptedMessageBundle(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package com.wire.kalium.logic.feature.e2ei

import com.wire.kalium.cryptography.CryptoCertificateStatus
import com.wire.kalium.cryptography.CryptoQualifiedClientId
import com.wire.kalium.cryptography.WireIdentity
import com.wire.kalium.logic.E2EIFailure
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2eiCertificateUseCaseImpl
Expand All @@ -33,6 +34,8 @@ import kotlin.test.Test
import kotlin.test.assertEquals
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.id.toCrypto
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.usecase.GetE2EICertificateUseCaseResult
import kotlinx.coroutines.test.runTest

Expand Down Expand Up @@ -113,9 +116,12 @@ class GetE2eiCertificateUseCaseTest {

companion object {
val CLIENT_ID = ClientId("client-id")
private val USER_ID = UserId("value", "domain")
private val CRYPTO_QUALIFIED_CLIENT_ID = CryptoQualifiedClientId("clientId", USER_ID.toCrypto())

val e2eiCertificate = E2eiCertificate("certificate")
val identity = WireIdentity(
CLIENT_ID.value,
CRYPTO_QUALIFIED_CLIENT_ID,
handle = "alic_test",
displayName = "Alice Test",
domain = "test.com",
Expand Down
Loading

0 comments on commit e17bf1d

Please sign in to comment.