Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Check revocation list on decrypt message (WPB-3243) #2413

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,17 @@ class MLSClientImpl(
value.commitDelay?.toLong(),
value.senderClientId?.let { CryptoQualifiedClientId.fromEncodedString(String(it)) },
value.hasEpochChanged,
value.identity?.let { toIdentity(it) }
value.identity?.let { toIdentity(it) },
value.crlNewDistributionPoints
)

fun toDecryptedMessageBundle(value: BufferedDecryptedMessage) = DecryptedMessageBundle(
value.message,
value.commitDelay?.toLong(),
value.senderClientId?.let { CryptoQualifiedClientId.fromEncodedString(String(it)) },
value.hasEpochChanged,
value.identity?.let { toIdentity(it) }
value.identity?.let { toIdentity(it) },
value.crlNewDistributionPoints
)

fun toCredentialType(value: CredentialType) = when (value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class DecryptedMessageBundle(
val commitDelay: Long?,
val senderClientId: CryptoQualifiedClientId?,
val hasEpochChanged: Boolean,
val identity: WireIdentity?
val identity: WireIdentity?,
val crlNewDistributionPoints: List<String>?
)

@JvmInline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import com.wire.kalium.logic.data.id.toModel

fun com.wire.kalium.cryptography.DecryptedMessageBundle.toModel(groupID: GroupID): DecryptedMessageBundle =
DecryptedMessageBundle(
groupID,
message?.let { message ->
groupID = groupID,
applicationMessage = message?.let { message ->
// We will always have senderClientId together with an application message
// but CoreCrypto API doesn't express this
ApplicationMessage(
Expand All @@ -32,13 +32,14 @@ fun com.wire.kalium.cryptography.DecryptedMessageBundle.toModel(groupID: GroupID
senderClientID = senderClientId!!.toModel().clientId
)
},
commitDelay,
identity?.let { identity ->
commitDelay = commitDelay,
identity = identity?.let { identity ->
E2EIdentity(
identity.clientId,
identity.handle,
identity.displayName,
identity.domain
)
}
},
crlNewDistributionPoints = crlNewDistributionPoints
)
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ data class DecryptedMessageBundle(
val groupID: GroupID,
val applicationMessage: ApplicationMessage?,
val commitDelay: Long?,
val identity: E2EIdentity?
val identity: E2EIdentity?,
val crlNewDistributionPoints: List<String>?
)

data class E2EIdentity(val clientId: String, val handle: String, val displayName: String, val domain: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,8 @@ class UserSessionScope internal constructor(
subconversationRepository = subconversationRepository,
mlsConversationRepository = mlsConversationRepository,
pendingProposalScheduler = pendingProposalScheduler,
checkRevocationList = checkRevocationList,
certificateRevocationListRepository = certificateRevocationListRepository,
selfUserId = userId
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.SubconversationRepository
import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository
import com.wire.kalium.logic.data.event.Event
import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.GroupID
Expand All @@ -34,6 +35,7 @@ import com.wire.kalium.logic.data.message.ProtoContent
import com.wire.kalium.logic.data.message.ProtoContentMapper
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.feature.e2ei.usecase.CheckRevocationListUseCase
import com.wire.kalium.logic.feature.message.PendingProposalScheduler
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
Expand All @@ -58,6 +60,8 @@ internal class MLSMessageUnpackerImpl(
private val mlsConversationRepository: MLSConversationRepository,
private val pendingProposalScheduler: PendingProposalScheduler,
private val selfUserId: UserId,
private val checkRevocationList: CheckRevocationListUseCase,
private val certificateRevocationListRepository: CertificateRevocationListRepository,
private val protoContentMapper: ProtoContentMapper = MapperProvider.protoContentMapper(selfUserId = selfUserId),
) : MLSMessageUnpacker {

Expand All @@ -68,10 +72,21 @@ internal class MLSMessageUnpackerImpl(
if (bundles.isEmpty()) return@map listOf(MessageUnpackResult.HandshakeMessage)

bundles.map { bundle ->
checkRevocationList(bundle)
unpackMlsBundle(bundle, event.conversationId, event.timestampIso.toInstant())
}
}

private suspend fun checkRevocationList(bundle: DecryptedMessageBundle) {
bundle.crlNewDistributionPoints?.forEach { url ->
checkRevocationList(url).map { newExpiration ->
newExpiration?.let {
certificateRevocationListRepository.addOrUpdateCRL(url, newExpiration)
}
}
}
}

override suspend fun unpackMlsBundle(
bundle: DecryptedMessageBundle,
conversationId: ConversationId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,8 @@ class MLSConversationRepositoryTest {
commitDelay = null,
senderClientId = null,
hasEpochChanged = true,
identity = null
identity = null,
crlNewDistributionPoints = null
)
val MEMBER_JOIN_EVENT = EventContentDTO.Conversation.MemberJoinDTO(
TestConversation.NETWORK_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import com.wire.kalium.logic.data.conversation.ConversationRepository
import com.wire.kalium.logic.data.conversation.DecryptedMessageBundle
import com.wire.kalium.logic.data.conversation.MLSConversationRepository
import com.wire.kalium.logic.data.conversation.SubconversationRepository
import com.wire.kalium.logic.data.e2ei.CertificateRevocationListRepository
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.feature.e2ei.usecase.CheckRevocationListUseCase
import com.wire.kalium.logic.feature.message.PendingProposalScheduler
import com.wire.kalium.logic.framework.TestConversation
import com.wire.kalium.logic.framework.TestEvent
Expand All @@ -45,6 +47,7 @@ import io.mockative.given
import io.mockative.matching
import io.mockative.mock
import io.mockative.once
import io.mockative.twice
import io.mockative.verify
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
Expand Down Expand Up @@ -138,12 +141,50 @@ class MLSMessageUnpackerTest {
val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp)
mlsUnpacker.unpackMlsMessage(messageEvent)

verify(arrangement.checkRevocationList)
.suspendFunction(arrangement.checkRevocationList::invoke)
.with(eq(DECRYPTED_MESSAGE_BUNDLE))
.wasNotInvoked()

verify(arrangement.mlsConversationRepository)
.suspendFunction(arrangement.mlsConversationRepository::decryptMessage)
.with(matching { it.contentEquals(messageEvent.content.decodeBase64Bytes()) }, eq(TestConversation.GROUP_ID))
.wasInvoked(once)
}

@Test
fun givenNewMLSMessageEventWithCrlNewDistributionPoints_whenUnpacking_thenCheckRevocationList() = runTest {
val eventTimestamp = DateTimeUtil.currentInstant()
val decryptedMessageBundleWithDistributionPoints = DECRYPTED_MESSAGE_BUNDLE.copy(
crlNewDistributionPoints = listOf("https://crl.wire.com/crl.pem", "https://crl2.wire.com/crl.pem")
)
val (arrangement, mlsUnpacker) = Arrangement()
.withMLSClientProviderReturningClient()
.withGetConversationProtocolInfoSuccessful(TestConversation.MLS_CONVERSATION.protocol)
.withDecryptMessageReturning(Either.Right(listOf(decryptedMessageBundleWithDistributionPoints)))
.withCheckRevocationListReturning()
.arrange()

val messageEvent = TestEvent.newMLSMessageEvent(eventTimestamp)

mlsUnpacker.unpackMlsMessage(messageEvent)

verify(arrangement.mlsConversationRepository)
.suspendFunction(arrangement.mlsConversationRepository::decryptMessage)
.with(matching { it.contentEquals(messageEvent.content.decodeBase64Bytes()) }, eq(TestConversation.GROUP_ID))
.wasInvoked(once)

verify(arrangement.checkRevocationList)
.suspendFunction(arrangement.checkRevocationList::invoke)
.with(any())
.wasInvoked(twice)

verify(arrangement.certificateRevocationListRepository)
.suspendFunction(arrangement.certificateRevocationListRepository::addOrUpdateCRL)
.with(any(), any())
.wasInvoked(twice)
}

private class Arrangement {

@Mock
Expand All @@ -164,12 +205,20 @@ class MLSMessageUnpackerTest {
@Mock
val subconversationRepository = mock(classOf<SubconversationRepository>())

@Mock
val checkRevocationList = mock(classOf<CheckRevocationListUseCase>())

@Mock
val certificateRevocationListRepository = mock(classOf<CertificateRevocationListRepository>())

private val mlsMessageUnpacker = MLSMessageUnpackerImpl(
conversationRepository,
subconversationRepository,
mlsConversationRepository,
pendingProposalScheduler,
SELF_USER_ID
SELF_USER_ID,
checkRevocationList,
certificateRevocationListRepository
)

fun withMLSClientProviderReturningClient() = apply {
Expand All @@ -185,6 +234,12 @@ class MLSMessageUnpackerTest {
.whenInvokedWith(anything(), anything())
.thenReturn(result)
}
fun withCheckRevocationListReturning() = apply {
given(checkRevocationList)
.suspendFunction(checkRevocationList::invoke)
.whenInvokedWith(anything())
.thenReturn(Either.Right(ULong.MIN_VALUE))
}

fun withScheduleCommitSucceeding() = apply {
given(pendingProposalScheduler)
Expand All @@ -208,7 +263,8 @@ class MLSMessageUnpackerTest {
groupID = TestConversation.GROUP_ID,
applicationMessage = null,
commitDelay = null,
identity = null
identity = null,
crlNewDistributionPoints = null
)
}
}
Loading