diff --git a/ios/MullvadREST/Relay/RelaySelector.swift b/ios/MullvadREST/Relay/RelaySelector.swift index b156c83fa402..f45605410859 100644 --- a/ios/MullvadREST/Relay/RelaySelector.swift +++ b/ios/MullvadREST/Relay/RelaySelector.swift @@ -136,7 +136,7 @@ public enum RelaySelector { } /// Produce a list of `RelayWithLocation` items satisfying the given constraints - private static func applyConstraints<T: AnyRelay>( + static func applyConstraints<T: AnyRelay>( _ constraints: RelayConstraints, relays: [RelayWithLocation<T>] ) -> [RelayWithLocation<T>] { @@ -154,24 +154,10 @@ public enum RelaySelector { case .any: return true case let .only(relayConstraint): - for location in relayConstraint.locations { - switch location { - case let .country(countryCode): - return relayWithLocation.serverLocation.countryCode == countryCode && - relayWithLocation.relay.includeInCountry - - case let .city(countryCode, cityCode): - return relayWithLocation.serverLocation.countryCode == countryCode && - relayWithLocation.serverLocation.cityCode == cityCode - - case let .hostname(countryCode, cityCode, hostname): - return relayWithLocation.serverLocation.countryCode == countryCode && - relayWithLocation.serverLocation.cityCode == cityCode && - relayWithLocation.relay.hostname == hostname - } + // At least one location must match the relay under test. + return relayConstraint.locations.contains { location in + relayWithLocation.matches(location: location) } - - return false } }.filter { relayWithLocation -> Bool in relayWithLocation.relay.active @@ -310,9 +296,26 @@ public struct RelaySelectorResult: Codable, Equatable { public var location: Location } -private struct RelayWithLocation<T: AnyRelay> { +struct RelayWithLocation<T: AnyRelay> { let relay: T let serverLocation: Location + + func matches(location: RelayLocation) -> Bool { + switch location { + case let .country(countryCode): + serverLocation.countryCode == countryCode && + relay.includeInCountry + + case let .city(countryCode, cityCode): + serverLocation.countryCode == countryCode && + serverLocation.cityCode == cityCode + + case let .hostname(countryCode, cityCode, hostname): + serverLocation.countryCode == countryCode && + serverLocation.cityCode == cityCode && + relay.hostname == hostname + } + } } private struct RelayWithDistance<T: AnyRelay> { diff --git a/ios/MullvadVPNTests/RelaySelectorTests.swift b/ios/MullvadVPNTests/RelaySelectorTests.swift index 03ff9983d90f..34e1c2d1d49d 100644 --- a/ios/MullvadVPNTests/RelaySelectorTests.swift +++ b/ios/MullvadVPNTests/RelaySelectorTests.swift @@ -59,6 +59,45 @@ class RelaySelectorTests: XCTestCase { XCTAssertEqual(result.relay.hostname, "se6-wireguard") } + func testMultipleLocationsConstraint() throws { + let constraints = RelayConstraints( + locations: .only(RelayLocations(locations: [ + .city("se", "got"), + .hostname("se", "sto", "se6-wireguard"), + ])) + ) + + let relayWithLocations = sampleRelays.wireguard.relays.map { + let location = sampleRelays.locations[$0.location]! + let locationComponents = $0.location.split(separator: "-") + + return RelayWithLocation( + relay: $0, + serverLocation: Location( + country: location.country, + countryCode: String(locationComponents[0]), + city: location.city, + cityCode: String(locationComponents[1]), + latitude: location.latitude, + longitude: location.longitude + ) + ) + } + + let constrainedLocations = RelaySelector.applyConstraints(constraints, relays: relayWithLocations) + + XCTAssertTrue( + constrainedLocations.contains( + where: { $0.matches(location: .city("se", "got")) } + ) + ) + XCTAssertTrue( + constrainedLocations.contains( + where: { $0.matches(location: .hostname("se", "sto", "se6-wireguard")) } + ) + ) + } + func testSpecificPortConstraint() throws { let constraints = RelayConstraints( locations: .only(RelayLocations(locations: [.hostname("se", "sto", "se6-wireguard")])), diff --git a/ios/MullvadVPNTests/ServerRelaysResponse+Stubs.swift b/ios/MullvadVPNTests/ServerRelaysResponse+Stubs.swift index 14a614184c80..13b0d443e37b 100644 --- a/ios/MullvadVPNTests/ServerRelaysResponse+Stubs.swift +++ b/ios/MullvadVPNTests/ServerRelaysResponse+Stubs.swift @@ -63,6 +63,12 @@ enum ServerRelaysResponseStubs { latitude: 32.89748, longitude: -97.040443 ), + "us-nyc": REST.ServerLocation( + country: "USA", + city: "New York, NY", + latitude: 40.6963302, + longitude: -74.6034843 + ), ], wireguard: REST.ServerWireguardTunnels( ipv4Gateway: .loopback,