From 52a0a1c3fe9fa6df23f3735050799b0de522c7ef Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Fri, 15 Dec 2023 10:52:35 +0100
Subject: [PATCH 1/4] Try to mitigate timing errors when getting the relay list

---
 .../mullvadvpn/ui/serviceconnection/RelayListListener.kt    | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
index d04e7394f9f3..c405aed28011 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
@@ -32,7 +32,11 @@ class RelayListListener(
             // not be a relay list since the fetching of a relay list would be done before the
             // event stream is available.
             .onStart { messageHandler.trySendRequest(Request.FetchRelayList) }
-            .stateIn(CoroutineScope(dispatcher), SharingStarted.Eagerly, defaultRelayList())
+            .stateIn(
+                CoroutineScope(dispatcher),
+                SharingStarted.WhileSubscribed(),
+                defaultRelayList()
+            )
 
     fun updateSelectedRelayLocation(value: GeographicLocationConstraint) {
         messageHandler.trySendRequest(Request.SetRelayLocation(value))

From bf2d6a6e15a2d405537948d37bdc94b510805f67 Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Fri, 15 Dec 2023 10:55:04 +0100
Subject: [PATCH 2/4] Fetch the relay list when entering the select location
 screen

---
 .../mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt | 2 +-
 .../mullvadvpn/ui/serviceconnection/RelayListListener.kt      | 4 ++++
 .../kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt | 4 ++++
 .../mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt   | 4 ++++
 4 files changed, 13 insertions(+), 1 deletion(-)

diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
index 8f0624cf8a5c..b0157e4e1503 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
@@ -194,7 +194,7 @@ fun SelectLocationScreen(
                         }
                     }
                     is SelectLocationUiState.ShowData -> {
-                        if (uiState.countries.isEmpty()) {
+                        if (uiState.countries.isEmpty() && uiState.searchTerm.isNotEmpty()) {
                             item(contentType = ContentType.EMPTY_TEXT) {
                                 val firstRow =
                                     HtmlCompat.fromHtml(
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
index c405aed28011..2d6ce332f693 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/ui/serviceconnection/RelayListListener.kt
@@ -53,5 +53,9 @@ class RelayListListener(
         messageHandler.trySendRequest(Request.SetOwnershipAndProviders(ownership, providers))
     }
 
+    fun fetchRelayList() {
+        messageHandler.trySendRequest(Request.FetchRelayList)
+    }
+
     private fun defaultRelayList() = RelayList(ArrayList(), WireguardEndpointData(ArrayList()))
 }
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt
index 0bfe1d038c81..8db90b939050 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/usecase/RelayListUseCase.kt
@@ -48,6 +48,10 @@ class RelayListUseCase(
 
     fun selectedRelayItem(): Flow<RelayItem?> = relayListWithSelection().map { it.selectedItem }
 
+    fun fetchRelayList() {
+        relayListListener.fetchRelayList()
+    }
+
     private fun List<RelayCountry>.findSelectedRelayItem(
         relaySettings: RelaySettings?,
     ): RelayItem? {
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
index dc9d5e7d6fc2..70bd92fef77d 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
@@ -84,6 +84,10 @@ class SelectLocationViewModel(
     private val _uiSideEffect = Channel<SelectLocationSideEffect>(1, BufferOverflow.DROP_OLDEST)
     val uiSideEffect = _uiSideEffect.receiveAsFlow()
 
+    init {
+        viewModelScope.launch { relayListUseCase.fetchRelayList() }
+    }
+
     fun selectRelay(relayItem: RelayItem) {
         relayListUseCase.updateSelectedRelayLocation(relayItem.location)
         serviceConnectionManager.connectionProxy()?.connect()

From 93eb40ea81019f4389c894ca30076b8fd4aafb90 Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Mon, 18 Dec 2023 11:25:28 +0100
Subject: [PATCH 3/4] Improve handling of empty relay list

---
 .../compose/screen/SelectLocationScreen.kt    | 161 ++++++++++--------
 .../state/FilterConstrainExtensions.kt        |   4 +-
 .../compose/state/SelectLocationUiState.kt    |  14 +-
 .../mullvadvpn/viewmodel/FilterViewModel.kt   |   1 +
 .../viewmodel/SelectLocationViewModel.kt      |  70 +++-----
 5 files changed, 132 insertions(+), 118 deletions(-)

diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
index b0157e4e1503..0de13b5e6c3d 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreen.kt
@@ -10,8 +10,8 @@ import androidx.compose.foundation.layout.fillMaxSize
 import androidx.compose.foundation.layout.fillMaxWidth
 import androidx.compose.foundation.layout.height
 import androidx.compose.foundation.layout.padding
-import androidx.compose.foundation.layout.size
 import androidx.compose.foundation.lazy.LazyColumn
+import androidx.compose.foundation.lazy.LazyListScope
 import androidx.compose.foundation.lazy.LazyListState
 import androidx.compose.foundation.lazy.rememberLazyListState
 import androidx.compose.material3.Icon
@@ -45,6 +45,7 @@ import net.mullvad.mullvadvpn.compose.component.textResource
 import net.mullvad.mullvadvpn.compose.constant.ContentType
 import net.mullvad.mullvadvpn.compose.destinations.FilterScreenDestination
 import net.mullvad.mullvadvpn.compose.extensions.toAnnotatedString
+import net.mullvad.mullvadvpn.compose.state.RelayListState
 import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState
 import net.mullvad.mullvadvpn.compose.test.CIRCULAR_PROGRESS_INDICATOR
 import net.mullvad.mullvadvpn.compose.textfield.SearchTextField
@@ -62,12 +63,15 @@ import org.koin.androidx.compose.koinViewModel
 @Composable
 private fun PreviewSelectLocationScreen() {
     val state =
-        SelectLocationUiState.ShowData(
+        SelectLocationUiState.Data(
             searchTerm = "",
-            countries = listOf(RelayCountry("Country 1", "Code 1", false, emptyList())),
-            selectedRelay = null,
             selectedOwnership = null,
-            selectedProvidersCount = 0
+            selectedProvidersCount = 0,
+            relayListState =
+                RelayListState.RelayList(
+                    countries = listOf(RelayCountry("Country 1", "Code 1", false, emptyList())),
+                    selectedRelay = null,
+                )
         )
     AppTheme {
         SelectLocationScreen(
@@ -141,7 +145,7 @@ fun SelectLocationScreen(
 
             when (uiState) {
                 SelectLocationUiState.Loading -> {}
-                is SelectLocationUiState.ShowData -> {
+                is SelectLocationUiState.Data -> {
                     if (uiState.hasFilter) {
                         FilterCell(
                             ownershipFilter = uiState.selectedOwnership,
@@ -163,12 +167,16 @@ fun SelectLocationScreen(
             }
             Spacer(modifier = Modifier.height(height = Dimens.verticalSpace))
             val lazyListState = rememberLazyListState()
-            if (uiState is SelectLocationUiState.ShowData && uiState.selectedRelay != null) {
-                LaunchedEffect(uiState.selectedRelay) {
+            if (
+                uiState is SelectLocationUiState.Data &&
+                    uiState.relayListState is RelayListState.RelayList &&
+                    uiState.relayListState.selectedRelay != null
+            ) {
+                LaunchedEffect(uiState.relayListState.selectedRelay) {
                     val index =
-                        uiState.countries.indexOfFirst {
-                            it.location.location.country ==
-                                uiState.selectedRelay.location.location.country
+                        uiState.relayListState.countries.indexOfFirst { relayCountry ->
+                            relayCountry.location.location.country ==
+                                uiState.relayListState.selectedRelay.location.location.country
                         }
 
                     lazyListState.scrollToItem(index)
@@ -187,66 +195,14 @@ fun SelectLocationScreen(
             ) {
                 when (uiState) {
                     SelectLocationUiState.Loading -> {
-                        item(contentType = ContentType.PROGRESS) {
-                            MullvadCircularProgressIndicatorLarge(
-                                Modifier.testTag(CIRCULAR_PROGRESS_INDICATOR)
-                            )
-                        }
+                        loading()
                     }
-                    is SelectLocationUiState.ShowData -> {
-                        if (uiState.countries.isEmpty() && uiState.searchTerm.isNotEmpty()) {
-                            item(contentType = ContentType.EMPTY_TEXT) {
-                                val firstRow =
-                                    HtmlCompat.fromHtml(
-                                            textResource(
-                                                id = R.string.select_location_empty_text_first_row,
-                                                uiState.searchTerm
-                                            ),
-                                            HtmlCompat.FROM_HTML_MODE_COMPACT
-                                        )
-                                        .toAnnotatedString(boldFontWeight = FontWeight.ExtraBold)
-                                val secondRow =
-                                    textResource(
-                                        id = R.string.select_location_empty_text_second_row
-                                    )
-                                Column(
-                                    modifier =
-                                        Modifier.padding(
-                                            horizontal = Dimens.selectLocationTitlePadding
-                                        ),
-                                    horizontalAlignment = Alignment.CenterHorizontally
-                                ) {
-                                    Text(
-                                        text = firstRow,
-                                        style = MaterialTheme.typography.labelMedium,
-                                        textAlign = TextAlign.Center,
-                                        color = MaterialTheme.colorScheme.onSecondary,
-                                        maxLines = 2,
-                                        overflow = TextOverflow.Ellipsis
-                                    )
-                                    Text(
-                                        text = secondRow,
-                                        style = MaterialTheme.typography.labelMedium,
-                                        textAlign = TextAlign.Center,
-                                        color = MaterialTheme.colorScheme.onSecondary
-                                    )
-                                }
-                            }
-                        } else {
-                            items(
-                                count = uiState.countries.size,
-                                key = { index -> uiState.countries[index].hashCode() },
-                                contentType = { ContentType.ITEM }
-                            ) { index ->
-                                val country = uiState.countries[index]
-                                RelayLocationCell(
-                                    relay = country,
-                                    selectedItem = uiState.selectedRelay,
-                                    onSelectRelay = onSelectRelay,
-                                    modifier = Modifier.animateContentSize()
-                                )
-                            }
-                        }
+                    is SelectLocationUiState.Data -> {
+                        relayList(
+                            relayListState = uiState.relayListState,
+                            searchTerm = uiState.searchTerm,
+                            onSelectRelay = onSelectRelay
+                        )
                     }
                 }
             }
@@ -254,6 +210,71 @@ fun SelectLocationScreen(
     }
 }
 
+private fun LazyListScope.loading() {
+    item(contentType = ContentType.PROGRESS) {
+        MullvadCircularProgressIndicatorLarge(Modifier.testTag(CIRCULAR_PROGRESS_INDICATOR))
+    }
+}
+
+private fun LazyListScope.relayList(
+    relayListState: RelayListState,
+    searchTerm: String,
+    onSelectRelay: (item: RelayItem) -> Unit
+) {
+    when (relayListState) {
+        is RelayListState.RelayList -> {
+            items(
+                count = relayListState.countries.size,
+                key = { index -> relayListState.countries[index].hashCode() },
+                contentType = { ContentType.ITEM }
+            ) { index ->
+                val country = relayListState.countries[index]
+                RelayLocationCell(
+                    relay = country,
+                    selectedItem = relayListState.selectedRelay,
+                    onSelectRelay = onSelectRelay,
+                    modifier = Modifier.animateContentSize()
+                )
+            }
+        }
+        RelayListState.Empty -> {
+            if (searchTerm.isNotEmpty())
+                item(contentType = ContentType.EMPTY_TEXT) {
+                    val firstRow =
+                        HtmlCompat.fromHtml(
+                                textResource(
+                                    id = R.string.select_location_empty_text_first_row,
+                                    searchTerm
+                                ),
+                                HtmlCompat.FROM_HTML_MODE_COMPACT
+                            )
+                            .toAnnotatedString(boldFontWeight = FontWeight.ExtraBold)
+                    val secondRow =
+                        textResource(id = R.string.select_location_empty_text_second_row)
+                    Column(
+                        modifier = Modifier.padding(horizontal = Dimens.selectLocationTitlePadding),
+                        horizontalAlignment = Alignment.CenterHorizontally
+                    ) {
+                        Text(
+                            text = firstRow,
+                            style = MaterialTheme.typography.labelMedium,
+                            textAlign = TextAlign.Center,
+                            color = MaterialTheme.colorScheme.onSecondary,
+                            maxLines = 2,
+                            overflow = TextOverflow.Ellipsis
+                        )
+                        Text(
+                            text = secondRow,
+                            style = MaterialTheme.typography.labelMedium,
+                            textAlign = TextAlign.Center,
+                            color = MaterialTheme.colorScheme.onSecondary
+                        )
+                    }
+                }
+        }
+    }
+}
+
 suspend fun LazyListState.animateScrollAndCentralizeItem(index: Int) {
     val itemInfo = this.layoutInfo.visibleItemsInfo.firstOrNull { it.index == index }
     if (itemInfo != null) {
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/FilterConstrainExtensions.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/FilterConstrainExtensions.kt
index 8a65c64b0196..da46938e8427 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/FilterConstrainExtensions.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/FilterConstrainExtensions.kt
@@ -17,9 +17,9 @@ fun Ownership?.toOwnershipConstraint(): Constraint<Ownership> =
         else -> Constraint.Only(this)
     }
 
-fun Constraint<Providers>.toSelectedProviders(allProviders: List<Provider>): List<Provider> =
+fun Constraint<Providers>.toSelectedProviders(allProviders: List<Provider>): List<Provider>? =
     when (this) {
-        is Constraint.Any -> allProviders
+        is Constraint.Any -> null
         is Constraint.Only ->
             this.value.providers.toList().mapNotNull { providerName ->
                 allProviders.firstOrNull { it.name == providerName }
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/SelectLocationUiState.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/SelectLocationUiState.kt
index 123bf821e605..3152ed1a348d 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/SelectLocationUiState.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/compose/state/SelectLocationUiState.kt
@@ -8,13 +8,19 @@ sealed interface SelectLocationUiState {
 
     data object Loading : SelectLocationUiState
 
-    data class ShowData(
+    data class Data(
         val searchTerm: String,
-        val countries: List<RelayCountry>,
-        val selectedRelay: RelayItem?,
         val selectedOwnership: Ownership?,
-        val selectedProvidersCount: Int?
+        val selectedProvidersCount: Int?,
+        val relayListState: RelayListState
     ) : SelectLocationUiState {
         val hasFilter: Boolean = (selectedProvidersCount != null || selectedOwnership != null)
     }
 }
+
+sealed interface RelayListState {
+    data object Empty : RelayListState
+
+    data class RelayList(val countries: List<RelayCountry>, val selectedRelay: RelayItem?) :
+        RelayListState
+}
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/FilterViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/FilterViewModel.kt
index 3f95d7919399..bd0703e6ae6f 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/FilterViewModel.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/FilterViewModel.kt
@@ -40,6 +40,7 @@ class FilterViewModel(
                         selectedConstraintProviders.toSelectedProviders(allProviders)
                     }
                     .first()
+                    ?: emptyList()
 
             val ownershipConstraint = relayListFilterUseCase.selectedOwnership().first()
             selectedOwnership.value = ownershipConstraint.toNullableOwnership()
diff --git a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
index 70bd92fef77d..d3c5977e27bd 100644
--- a/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
+++ b/android/app/src/main/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModel.kt
@@ -11,6 +11,7 @@ import kotlinx.coroutines.flow.first
 import kotlinx.coroutines.flow.receiveAsFlow
 import kotlinx.coroutines.flow.stateIn
 import kotlinx.coroutines.launch
+import net.mullvad.mullvadvpn.compose.state.RelayListState
 import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState
 import net.mullvad.mullvadvpn.compose.state.toNullableOwnership
 import net.mullvad.mullvadvpn.compose.state.toSelectedProviders
@@ -44,35 +45,30 @@ class SelectLocationViewModel(
                 selectedOwnership,
                 allProviders,
                 selectedConstraintProviders ->
-                val selectedProviders =
-                    selectedConstraintProviders.toSelectedProviders(allProviders)
-
-                val selectedProvidersByOwnershipList =
+                val selectedOwnershipItem = selectedOwnership.toNullableOwnership()
+                val selectedProvidersCount =
                     filterSelectedProvidersByOwnership(
-                        selectedProviders,
-                        selectedOwnership.toNullableOwnership()
-                    )
-
-                val allProvidersByOwnershipListList =
-                    filterAllProvidersByOwnership(
-                        allProviders,
-                        selectedOwnership.toNullableOwnership()
-                    )
+                            selectedConstraintProviders.toSelectedProviders(allProviders),
+                            selectedOwnershipItem
+                        )
+                        ?.size
 
                 val filteredRelayCountries =
                     relayCountries.filterOnSearchTerm(searchTerm, relayItem)
-                SelectLocationUiState.ShowData(
+
+                SelectLocationUiState.Data(
                     searchTerm = searchTerm,
-                    countries = filteredRelayCountries,
-                    selectedRelay = relayItem,
-                    selectedOwnership = selectedOwnership.toNullableOwnership(),
-                    selectedProvidersCount =
-                        if (
-                            selectedProvidersByOwnershipList.size ==
-                                allProvidersByOwnershipListList.size
-                        )
-                            null
-                        else selectedProvidersByOwnershipList.size
+                    selectedOwnership = selectedOwnershipItem,
+                    selectedProvidersCount = selectedProvidersCount,
+                    relayListState =
+                        if (filteredRelayCountries.isNotEmpty()) {
+                            RelayListState.RelayList(
+                                countries = filteredRelayCountries,
+                                selectedRelay = relayItem
+                            )
+                        } else {
+                            RelayListState.Empty
+                        },
                 )
             }
             .stateIn(
@@ -99,26 +95,16 @@ class SelectLocationViewModel(
     }
 
     private fun filterSelectedProvidersByOwnership(
-        selectedProviders: List<Provider>,
+        selectedProviders: List<Provider>?,
         selectedOwnership: Ownership?
-    ): List<Provider> {
-        return when (selectedOwnership) {
-            Ownership.MullvadOwned -> selectedProviders.filter { it.mullvadOwned }
-            Ownership.Rented -> selectedProviders.filterNot { it.mullvadOwned }
-            else -> selectedProviders
-        }
-    }
-
-    private fun filterAllProvidersByOwnership(
-        allProviders: List<Provider>,
-        selectedOwnership: Ownership?
-    ): List<Provider> {
-        return when (selectedOwnership) {
-            Ownership.MullvadOwned -> allProviders.filter { it.mullvadOwned }
-            Ownership.Rented -> allProviders.filterNot { it.mullvadOwned }
-            else -> allProviders
+    ): List<Provider>? =
+        selectedProviders?.let {
+            when (selectedOwnership) {
+                Ownership.MullvadOwned -> selectedProviders.filter { it.mullvadOwned }
+                Ownership.Rented -> selectedProviders.filterNot { it.mullvadOwned }
+                else -> selectedProviders
+            }
         }
-    }
 
     fun removeOwnerFilter() {
         viewModelScope.launch {

From 7efda2d553370be150a0ccc789f76ca3c68bf554 Mon Sep 17 00:00:00 2001
From: Jonatan Rhodin <jonatan.rhodin@mullvad.net>
Date: Fri, 15 Dec 2023 10:55:52 +0100
Subject: [PATCH 4/4] Fix tests

---
 .../screen/SelectLocationScreenTest.kt        | 33 ++++++++-----
 .../viewmodel/SelectLocationViewModelTest.kt  | 49 ++++++++++++++-----
 2 files changed, 58 insertions(+), 24 deletions(-)

diff --git a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt
index fbc8b046fd77..bb64adc8a68f 100644
--- a/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt
+++ b/android/app/src/androidTest/kotlin/net/mullvad/mullvadvpn/compose/screen/SelectLocationScreenTest.kt
@@ -9,6 +9,7 @@ import io.mockk.MockKAnnotations
 import io.mockk.mockk
 import io.mockk.verify
 import net.mullvad.mullvadvpn.compose.setContentWithTheme
+import net.mullvad.mullvadvpn.compose.state.RelayListState
 import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState
 import net.mullvad.mullvadvpn.compose.test.CIRCULAR_PROGRESS_INDICATOR
 import net.mullvad.mullvadvpn.model.Constraint
@@ -54,9 +55,12 @@ class SelectLocationScreenTest {
             setContentWithTheme {
                 SelectLocationScreen(
                     uiState =
-                        SelectLocationUiState.ShowData(
-                            countries = DUMMY_RELAY_COUNTRIES,
-                            selectedRelay = null,
+                        SelectLocationUiState.Data(
+                            relayListState =
+                                RelayListState.RelayList(
+                                    countries = DUMMY_RELAY_COUNTRIES,
+                                    selectedRelay = null
+                                ),
                             selectedOwnership = null,
                             selectedProvidersCount = 0,
                             searchTerm = ""
@@ -91,9 +95,12 @@ class SelectLocationScreenTest {
             setContentWithTheme {
                 SelectLocationScreen(
                     uiState =
-                        SelectLocationUiState.ShowData(
-                            countries = updatedDummyList,
-                            selectedRelay = updatedDummyList[0].cities[0].relays[0],
+                        SelectLocationUiState.Data(
+                            relayListState =
+                                RelayListState.RelayList(
+                                    countries = updatedDummyList,
+                                    selectedRelay = updatedDummyList[0].cities[0].relays[0]
+                                ),
                             selectedOwnership = null,
                             selectedProvidersCount = 0,
                             searchTerm = ""
@@ -118,9 +125,12 @@ class SelectLocationScreenTest {
             setContentWithTheme {
                 SelectLocationScreen(
                     uiState =
-                        SelectLocationUiState.ShowData(
-                            countries = emptyList(),
-                            selectedRelay = null,
+                        SelectLocationUiState.Data(
+                            relayListState =
+                                RelayListState.RelayList(
+                                    countries = emptyList(),
+                                    selectedRelay = null
+                                ),
                             selectedOwnership = null,
                             selectedProvidersCount = 0,
                             searchTerm = ""
@@ -146,9 +156,8 @@ class SelectLocationScreenTest {
             setContentWithTheme {
                 SelectLocationScreen(
                     uiState =
-                        SelectLocationUiState.ShowData(
-                            countries = emptyList(),
-                            selectedRelay = null,
+                        SelectLocationUiState.Data(
+                            relayListState = RelayListState.Empty,
                             selectedOwnership = null,
                             selectedProvidersCount = 0,
                             searchTerm = mockSearchString
diff --git a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt
index 46ea0bf3fb68..fc6408d8ab7e 100644
--- a/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt
+++ b/android/app/src/test/kotlin/net/mullvad/mullvadvpn/viewmodel/SelectLocationViewModelTest.kt
@@ -3,8 +3,10 @@ package net.mullvad.mullvadvpn.viewmodel
 import androidx.lifecycle.viewModelScope
 import app.cash.turbine.test
 import io.mockk.every
+import io.mockk.just
 import io.mockk.mockk
 import io.mockk.mockkStatic
+import io.mockk.runs
 import io.mockk.unmockkAll
 import io.mockk.verify
 import kotlin.test.assertEquals
@@ -12,6 +14,7 @@ import kotlin.test.assertIs
 import kotlinx.coroutines.cancel
 import kotlinx.coroutines.flow.MutableStateFlow
 import kotlinx.coroutines.test.runTest
+import net.mullvad.mullvadvpn.compose.state.RelayListState
 import net.mullvad.mullvadvpn.compose.state.SelectLocationUiState
 import net.mullvad.mullvadvpn.lib.common.test.TestCoroutineRule
 import net.mullvad.mullvadvpn.lib.common.test.assertLists
@@ -53,6 +56,7 @@ class SelectLocationViewModelTest {
         every { mockRelayListFilterUseCase.selectedProviders() } returns selectedProvider
         every { mockRelayListFilterUseCase.availableProviders() } returns allProvider
         every { mockRelayListUseCase.relayListWithSelection() } returns relayListWithSelectionFlow
+        every { mockRelayListUseCase.fetchRelayList() } just runs
 
         mockkStatic(SERVICE_CONNECTION_MANAGER_EXTENSIONS)
         mockkStatic(RELAY_LIST_EXTENSIONS)
@@ -86,9 +90,16 @@ class SelectLocationViewModelTest {
         // Act, Assert
         viewModel.uiState.test {
             val actualState = awaitItem()
-            assertIs<SelectLocationUiState.ShowData>(actualState)
-            assertLists(mockCountries, actualState.countries)
-            assertEquals(selectedRelay, actualState.selectedRelay)
+            assertIs<SelectLocationUiState.Data>(actualState)
+            assertIs<RelayListState.RelayList>(actualState.relayListState)
+            assertLists(
+                mockCountries,
+                (actualState.relayListState as RelayListState.RelayList).countries
+            )
+            assertEquals(
+                selectedRelay,
+                (actualState.relayListState as RelayListState.RelayList).selectedRelay
+            )
         }
     }
 
@@ -103,9 +114,16 @@ class SelectLocationViewModelTest {
         // Act, Assert
         viewModel.uiState.test {
             val actualState = awaitItem()
-            assertIs<SelectLocationUiState.ShowData>(actualState)
-            assertLists(mockCountries, actualState.countries)
-            assertEquals(selectedRelay, actualState.selectedRelay)
+            assertIs<SelectLocationUiState.Data>(actualState)
+            assertIs<RelayListState.RelayList>(actualState.relayListState)
+            assertLists(
+                mockCountries,
+                (actualState.relayListState as RelayListState.RelayList).countries
+            )
+            assertEquals(
+                selectedRelay,
+                (actualState.relayListState as RelayListState.RelayList).selectedRelay
+            )
         }
     }
 
@@ -145,16 +163,23 @@ class SelectLocationViewModelTest {
         // Act, Assert
         viewModel.uiState.test {
             // Wait for first data
-            assertIs<SelectLocationUiState.ShowData>(awaitItem())
+            assertIs<SelectLocationUiState.Data>(awaitItem())
 
             // Update search string
             viewModel.onSearchTermInput(mockSearchString)
 
             // Assert
             val actualState = awaitItem()
-            assertIs<SelectLocationUiState.ShowData>(actualState)
-            assertLists(mockCountries, actualState.countries)
-            assertEquals(selectedRelay, actualState.selectedRelay)
+            assertIs<SelectLocationUiState.Data>(actualState)
+            assertIs<RelayListState.RelayList>(actualState.relayListState)
+            assertLists(
+                mockCountries,
+                (actualState.relayListState as RelayListState.RelayList).countries
+            )
+            assertEquals(
+                selectedRelay,
+                (actualState.relayListState as RelayListState.RelayList).selectedRelay
+            )
         }
     }
 
@@ -172,14 +197,14 @@ class SelectLocationViewModelTest {
         // Act, Assert
         viewModel.uiState.test {
             // Wait for first data
-            assertIs<SelectLocationUiState.ShowData>(awaitItem())
+            assertIs<SelectLocationUiState.Data>(awaitItem())
 
             // Update search string
             viewModel.onSearchTermInput(mockSearchString)
 
             // Assert
             val actualState = awaitItem()
-            assertIs<SelectLocationUiState.ShowData>(actualState)
+            assertIs<SelectLocationUiState.Data>(actualState)
             assertEquals(mockSearchString, actualState.searchTerm)
         }
     }