Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIK-4641 Add support for IP ranges in bypass list and in allowlist #127

Open
wants to merge 10 commits into
base: AIK-4490
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package dev.aikido.agent_api.background;

import dev.aikido.agent_api.helpers.net.IPList;

import java.util.List;

import static dev.aikido.agent_api.helpers.IPListBuilder.createIPList;

public class Endpoint {
public record RateLimitingConfig(long maxRequests, long windowSizeInMS, boolean enabled) {}
private final String method;
Expand All @@ -16,8 +20,8 @@ public Endpoint(
boolean forceProtectionOff, boolean rateLimitingEnabled) {
this.method = method;
this.route = route;
this.rateLimiting = new RateLimitingConfig(maxRequests, windowSizeMS, rateLimitingEnabled);
this.allowedIPAddresses = allowedIPAddresses;
this.rateLimiting = new RateLimitingConfig(maxRequests, windowSizeMS, rateLimitingEnabled);
this.graphql = graphql;
this.forceProtectionOff = forceProtectionOff;
}
Expand All @@ -32,13 +36,18 @@ public String getRoute() {
public RateLimitingConfig getRateLimiting() {
return rateLimiting;
}
public List<String> getAllowedIPAddresses() {
return allowedIPAddresses;
}
public boolean isGraphql() {
return graphql;
}
public boolean protectionForcedOff() {
return forceProtectionOff;
}

// allowed ip addresses :
public boolean allowedIpAddressesEmpty() {
return allowedIPAddresses == null || allowedIPAddresses.size() == 0;
}
public boolean isIpAllowed(String ip) {
return createIPList(allowedIPAddresses).matches(ip);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dev.aikido.agent_api.background.cloud.api.events.APIEvent;
import dev.aikido.agent_api.helpers.logging.LogManager;
import dev.aikido.agent_api.helpers.logging.Logger;
import dev.aikido.agent_api.helpers.net.IPList;
import dev.aikido.agent_api.storage.routes.RouteEntry;

import java.io.InputStream;
Expand Down Expand Up @@ -100,8 +101,12 @@ public APIResponse toApiResponse(HttpResponse<String> res) {
} else if (status == 401) {
return getUnsuccessfulAPIResponse("invalid_token");
} else if (status == 200) {
Gson gson = new Gson();
return gson.fromJson(res.body(), APIResponse.class);
try {
return new Gson().fromJson(res.body(), APIResponse.class);
} catch (Throwable e) {
logger.debug("json error: %s", e);
return getUnsuccessfulAPIResponse("json_deserialize");
}
}
return getUnsuccessfulAPIResponse("unknown_error");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@ public static boolean ipAllowedToAccessRoute(String remoteAddress, List<Endpoint
}

for (Endpoint endpoint : endpoints) {
if (endpoint.getAllowedIPAddresses() == null) {
if (endpoint.allowedIpAddressesEmpty()) {
// Feature might not be enabled
continue;
}
if (endpoint.getAllowedIPAddresses().isEmpty()) {
// We will continue to check all the other matches
continue;
}

if (remoteAddress == null) {
// We only check it here because if allowedIPAddresses isn't set
// We don't want to change any default behaviour
return false;
}

if (!endpoint.getAllowedIPAddresses().contains(remoteAddress)) {
if (!endpoint.isIpAllowed(remoteAddress)) {
// The IP is not in the allowlist, so block
return false;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.aikido.agent_api.helpers;

import dev.aikido.agent_api.helpers.net.IPList;

import java.util.Collection;
import java.util.List;
import java.util.Set;

public final class IPListBuilder {
private IPListBuilder() {}
public static IPList createIPList(Collection<String> ips) {
IPList ipList = new IPList();
if (ips == null) {
return ipList; // Don't iterate over null.
}
for (String ip: ips) {
// Add ip address or subnet to IP list :
ipList.add(ip);
}

return ipList;
}
}
Original file line number Diff line number Diff line change
@@ -1,46 +1,48 @@
package dev.aikido.agent_api.helpers.net;

import inet.ipaddr.AddressStringException;
import inet.ipaddr.IPAddress;
import inet.ipaddr.IPAddressString;
import inet.ipaddr.IPAddressNetwork;

import java.util.HashSet;
import java.util.List;
import java.util.ArrayList;
import java.util.Set;

public class BlockList {
//private HashSet<String> blockedAddresses; // Both IPv4 and IPv6
private List<IPAddress> blockedList;
public class IPList {
private Set<IPAddress> ipAddresses;

public BlockList() {
//this.blockedAddresses = new HashSet<>();
this.blockedList = new ArrayList<>();
public IPList() {
this.ipAddresses = new HashSet<>();
}

public void add(String ipOrCIDR) {
if (ipOrCIDR == null) {
return; // Don't add if IP is null
}
IPAddress ip = new IPAddressString(ipOrCIDR).getAddress();
if (ipOrCIDR.contains("/")) {
// CIDR :
ip = ip.toPrefixBlock();
}
if (ip != null) {
blockedList.add(ip);
ipAddresses.add(ip);
}
}

public boolean isBlocked(String ip) {
public boolean matches(String ip) {
IPAddressString ipAddressString = new IPAddressString(ip);
if (!ipAddressString.isValid()) {
return false; // Invalid IP address
}
IPAddress ipAddress = ipAddressString.getAddress();

// Check if the IP address is in any of the blocked subnets
for (IPAddress subnet : blockedList) {
for (IPAddress subnet : ipAddresses) {
if (subnet.contains(ipAddress)) {
return true;
}
}
return false;
}
public int length() {
return ipAddresses.size();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.background.cloud.api.ReportingApi;
import dev.aikido.agent_api.helpers.net.BlockList;
import dev.aikido.agent_api.helpers.net.IPList;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.storage.routes.Routes;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static dev.aikido.agent_api.helpers.IPListBuilder.createIPList;
import static dev.aikido.agent_api.helpers.UnixTimeMS.getUnixTimeMS;

public class ThreadCacheObject {
private final List<Endpoint> endpoints;
private final Set<String> blockedUserIds;
private final Set<String> bypassedIPs;
private final IPList bypassedIPs;
private final long lastRenewedAtMS;
private final Hostnames hostnames;
private final Routes routes;

// IP Blocking (e.g. Geo-IP Restrictions) :
public record BlockedIpEntry(BlockList blocklist, String description) {}
public record BlockedIpEntry(IPList blocklist, String description) {}
private List<BlockedIpEntry> blockedIps = new ArrayList<>();
// User-Agent Blocking (e.g. bot blocking) :
private Pattern blockedUserAgentRegex;
Expand All @@ -36,7 +36,7 @@ public ThreadCacheObject(List<Endpoint> endpoints, Set<String> blockedUserIDs, S
// Set endpoints :
this.endpoints = endpoints;
this.blockedUserIds = blockedUserIDs;
this.bypassedIPs = bypassedIPs;
this.bypassedIPs = createIPList(bypassedIPs);
this.routes = routes;
this.hostnames = new Hostnames(5000);
this.updateBlockedLists(blockedListsRes);
Expand Down Expand Up @@ -65,15 +65,15 @@ public Routes getRoutes() {
return routes;
}
public boolean isBypassedIP(String ip) {
return bypassedIPs.contains(ip);
return bypassedIPs.matches(ip);
}

/**
* Check if the IP is blocked (e.g. Geo IP Restrictions)
*/
public BlockedResult isIpBlocked(String ip) {
for (BlockedIpEntry entry: blockedIps) {
if (entry.blocklist.isBlocked(ip)) {
if (entry.blocklist.matches(ip)) {
return new BlockedResult(true, entry.description);
}
}
Expand All @@ -86,11 +86,8 @@ public void updateBlockedLists(Optional<ReportingApi.APIListsResponse> blockedLi
// Update blocked IP addresses (e.g. for geo restrictions) :
if (res.blockedIPAddresses() != null) {
for (ReportingApi.ListsResponseEntry entry : res.blockedIPAddresses()) {
BlockList blockList = new BlockList();
for (String ip : entry.ips()) {
blockList.add(ip);
}
blockedIps.add(new BlockedIpEntry(blockList, entry.description()));
IPList ipList = createIPList(entry.ips());
blockedIps.add(new BlockedIpEntry(ipList, entry.description()));
}
}
// Update Blocked User-Agents regex
Expand Down
113 changes: 0 additions & 113 deletions agent_api/src/test/java/helpers/BlockListTest.java

This file was deleted.

23 changes: 23 additions & 0 deletions agent_api/src/test/java/helpers/IPAccessControllerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,29 @@ public void testChecksEveryMatchingEndpoint() {
);
assertFalse(IPAccessController.ipAllowedToAccessRoute("3.4.5.6", endpoints));
}
@Test
public void testWithSubnet() {
List<Endpoint> endpoints = List.of(
genEndpoint(List.of("10.0.0.0/8"))
);
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.0.1", endpoints));
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.0.20", endpoints));
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.1.50", endpoints));
assertFalse(IPAccessController.ipAllowedToAccessRoute("1.1.1.1", endpoints));
}

@Test
public void testSubnetMultiple() {
List<Endpoint> endpoints = List.of(
genEndpoint(List.of("10.0.0.0/8")),
genEndpoint(List.of("10.0.0.0/24"))
);
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.0.1", endpoints));
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.0.20", endpoints));
assertTrue(IPAccessController.ipAllowedToAccessRoute("10.0.0.255", endpoints));
assertFalse(IPAccessController.ipAllowedToAccessRoute("10.0.1.1", endpoints));
assertFalse(IPAccessController.ipAllowedToAccessRoute("10.0.1.50", endpoints));
}

@Test
public void testIfAllowedIpsIsEmptyOrBroken() {
Expand Down
Loading