Skip to content

Commit

Permalink
fix: split OSV batch queries for 1000 purls (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Ruben Romero Montes <[email protected]>
  • Loading branch information
ruromero authored Mar 12, 2024
1 parent 321fddf commit a70ef0f
Show file tree
Hide file tree
Showing 7 changed files with 359 additions and 117 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.redhat.ecosystemappeng.onguard.model.osv;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

public record Partition(List<String> items, int size) {

public Stream<List<String>> stream() {
List<List<String>> partitions = new ArrayList<>();
if(items == null) {
return Stream.empty();
}
int pos = 0;
while(pos < items.size()) {
var to = pos + size;
if(to > items.size()) {
to = items.size();
}
partitions.add(items.subList(pos, to));
pos = to;
}
return partitions.stream();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;

import org.eclipse.microprofile.rest.client.inject.RestClient;
import org.jboss.resteasy.reactive.ClientWebApplicationException;
Expand All @@ -34,6 +35,7 @@
import com.redhat.ecosystemappeng.onguard.model.VulnerabilityAlias;
import com.redhat.ecosystemappeng.onguard.model.osv.OsvVulnerability;
import com.redhat.ecosystemappeng.onguard.model.osv.PackageRef;
import com.redhat.ecosystemappeng.onguard.model.osv.Partition;
import com.redhat.ecosystemappeng.onguard.model.osv.QueryRequest;
import com.redhat.ecosystemappeng.onguard.model.osv.QueryRequestItem;
import com.redhat.ecosystemappeng.onguard.model.osv.VulnerabilityRef;
Expand All @@ -49,6 +51,7 @@
public class VulnerabilityServiceImpl implements VulnerabilityService {

private static final Logger LOGGER = LoggerFactory.getLogger(VulnerabilityServiceImpl.class);
private static final int MAX_OSV_BATCH = 1000;

@Inject
VulnerabilityRepository repository;
Expand Down Expand Up @@ -80,6 +83,15 @@ public List<Vulnerability> find(List<String> aliases, boolean reload) {

@Override
public Map<String, List<Vulnerability>> findByPurls(List<String> purls, boolean reload) {
return new Partition(purls, MAX_OSV_BATCH)
.stream()
.parallel()
.map(p -> this.processBatch(p, reload))
.flatMap(m -> m.entrySet().stream())
.collect(Collectors.toMap(a -> a.getKey(), a -> a.getValue()));
}

private Map<String, List<Vulnerability>> processBatch(List<String> purls, boolean reload) {
List<QueryRequestItem> queries = purls.stream().map(purl -> new QueryRequestItem(new PackageRef(purl))).toList();
var response = osvApi.queryBatch(new QueryRequest(queries));
Map<String, List<Vulnerability>> vulnerabilities = new HashMap<>();
Expand All @@ -89,7 +101,12 @@ public Map<String, List<Vulnerability>> findByPurls(List<String> purls, boolean
if (resultItem == null || resultItem.vulns() == null) {
vulnerabilities.put(purl, Collections.emptyList());
} else {
vulnerabilities.put(purl, find(resultItem.vulns().stream().map(VulnerabilityRef::id).toList(), reload));
var vulns = find(resultItem.vulns().stream().map(VulnerabilityRef::id).toList(), reload);
if(vulns == null) {
vulnerabilities.put(purl, Collections.emptyList());
} else {
vulnerabilities.put(purl, vulns);
}
}
}
return vulnerabilities;
Expand All @@ -109,7 +126,7 @@ public void ingestNvdVulnerability(Vulnerability vuln) {
builder = Vulnerability.builder(vuln);
}
var osvVuln = getOsvVulnerabilityData(vuln.cveId());
if(osvVuln != null) {
if (osvVuln != null) {
builder.affected(osvVuln.affected()).summary(osvVuln.summary()).description(osvVuln.description());
}
repository.save(builder.build());
Expand All @@ -120,16 +137,16 @@ private Vulnerability load(VulnerabilityAlias vulnAlias, boolean reload) {
return vulnAlias.vulnerability();
}
var vuln = getOsvVulnerabilityData(vulnAlias.alias());
if(vuln == null || vuln.cveId() == null) {
if (vuln == null || vuln.cveId() == null) {
return null;
}
var metrics = nvdService.getCveMetrics(vuln.cveId());
var updated = Vulnerability.builder(vuln);
if(metrics != null) {
if (metrics != null) {
updated.metrics(metrics);
}
var existing = repository.get(vuln.cveId());
if(existing != null && existing.hasData()) {
if (existing != null && existing.hasData()) {
updated.created(existing.created()).lastModified(new Date());
} else {
updated.created(new Date());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,58 +41,58 @@
@ApplicationScoped
public class NvdFallbackService implements FallbackHandler<NvdResponse> {

private static final Logger LOGGER = LoggerFactory.getLogger(NvdFallbackService.class);
private static final Logger LOGGER = LoggerFactory.getLogger(NvdFallbackService.class);

@Inject
VulnerabilityRepository repository;
@Inject
VulnerabilityRepository repository;

@Inject
ManagedExecutor executor;
@Inject
ManagedExecutor executor;

@Inject
NvdService nvdService;
@Inject
NvdService nvdService;

@ConfigProperty(name = "CircuitBreaker/delay", defaultValue = "30")
Long delay;
@ConfigProperty(name = "CircuitBreaker/delay", defaultValue = "30")
Long delay;

private void updateMetrics(String cveId) {
var metrics = nvdService.getCveMetrics(cveId);
if (metrics == null) {
LOGGER.debug("Unable to retrieve metrics from NVD for CVE {}", cveId);
return;
}
var vuln = repository.get(cveId);
if (vuln != null) {
var newVuln = Vulnerability.builder(vuln).metrics(metrics).lastModified(new Date()).build();
repository.save(newVuln);
repository.setAliases(List.of(cveId), cveId);
}
private void updateMetrics(String cveId) {
var metrics = nvdService.getCveMetrics(cveId);
if (metrics == null) {
LOGGER.debug("Unable to retrieve metrics from NVD for CVE {}", cveId);
return;
}
var vuln = repository.get(cveId);
if (vuln != null) {
var newVuln = Vulnerability.builder(vuln).metrics(metrics).lastModified(new Date()).build();
repository.save(newVuln);
repository.setAliases(List.of(cveId), cveId);
}
}

@Override
public NvdResponse handle(ExecutionContext context) {
if(shouldHandle(context.getFailure())) {
Uni.createFrom()
.item((String) context.getParameters()[0])
.onItem().delayIt().by(Duration.ofSeconds(delay)).invoke(cveId -> updateMetrics(cveId))
.runSubscriptionOn(executor).subscribeAsCompletionStage();
}
return new NvdResponse(0, 0, 0, null, null, new Date(), Collections.emptyList());
@Override
public NvdResponse handle(ExecutionContext context) {
if (shouldHandle(context.getFailure())) {
Uni.createFrom()
.item((String) context.getParameters()[0])
.onItem().delayIt().by(Duration.ofSeconds(delay)).invoke(cveId -> updateMetrics(cveId))
.runSubscriptionOn(executor).subscribeAsCompletionStage();
}
return new NvdResponse(0, 0, 0, null, null, new Date(), Collections.emptyList());
}

private boolean shouldHandle(Throwable failure) {
if(failure == null) {
return true;
}
var cause = failure;
if(failure.getCause() != null && !(failure instanceof WebApplicationException)) {
cause = failure.getCause();
}
if(cause instanceof WebApplicationException) {
var error = (WebApplicationException) cause;
var status = error.getResponse().getStatus();
return status != 404;
}
return true;
private boolean shouldHandle(Throwable failure) {
if (failure == null) {
return true;
}
var cause = failure;
if (failure.getCause() != null && !(failure instanceof WebApplicationException)) {
cause = failure.getCause();
}
if (cause instanceof WebApplicationException) {
var error = (WebApplicationException) cause;
var status = error.getResponse().getStatus();
return status != 404;
}
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.redhat.ecosystemappeng.onguard.service;

import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.resetAllRequests;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.Mockito.when;

import java.util.Collections;
import java.util.List;

import org.jboss.resteasy.reactive.ClientWebApplicationException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.redhat.ecosystemappeng.onguard.model.Vulnerability;
import com.redhat.ecosystemappeng.onguard.model.VulnerabilityAlias;
import com.redhat.ecosystemappeng.onguard.repository.VulnerabilityRepository;
import com.redhat.ecosystemappeng.onguard.test.InjectWireMock;
import com.redhat.ecosystemappeng.onguard.test.WireMockExtensions;

import io.quarkus.test.InjectMock;
import io.quarkus.test.common.QuarkusTestResource;
import io.quarkus.test.junit.QuarkusTest;
import io.restassured.RestAssured;
import jakarta.inject.Inject;

@QuarkusTest
@QuarkusTestResource(WireMockExtensions.class)
public class VulnerabilityServiceTest {

@InjectWireMock
WireMockServer server;

@InjectMock
VulnerabilityRepository repository;

static {
RestAssured.enableLoggingOfRequestAndResponseIfValidationFails();
}

@AfterEach
public void reset() {
server.resetRequests();
}

@Inject
VulnerabilityService vulnerabilityService;

@Test
void testFindByPurls_Empty() {
var result = vulnerabilityService.findByPurls(null, false);
assertTrue(result.isEmpty());

result = vulnerabilityService.findByPurls(Collections.emptyList(), false);
assertTrue(result.isEmpty());

server.verify(0, postRequestedFor(urlEqualTo(WireMockExtensions.OSV_API_PATH)));
}

@Test
void testFindByPurls_NoVulns() {
var purl = "pkg:maven/org.mvnpm.at.vaadin/[email protected]?type=jar";
var purls = List.of(purl);
var results = vulnerabilityService.findByPurls(purls, false);

assertFalse(results.isEmpty());
assertEquals(purls.size(), results.size());
var result = results.get(purl);
assertNotNull(result);
assertTrue(result.isEmpty());

server.verify(1, postRequestedFor(urlEqualTo(WireMockExtensions.OSV_API_PATH)));
}

@Test
void testFindByPurls_Success() {
var purls = List.of(WireMockExtensions.PURL_WITH_VULNS);
var alias = "GHSA-hr8g-6v94-x4m9";
var vulnerability = new Vulnerability("cve-001", null, null, null, null, null, null);
when(repository.listByAliases(anyList())).thenReturn(List.of(new VulnerabilityAlias(alias, vulnerability)));

var results = vulnerabilityService.findByPurls(purls, false);

assertFalse(results.isEmpty());
assertEquals(purls.size(), results.size());
var result = results.get(WireMockExtensions.PURL_WITH_VULNS);
assertNotNull(result);
assertEquals(1, result.size());
var found = result.get(0);
assertEquals(vulnerability, found);

server.verify(1, postRequestedFor(urlEqualTo(WireMockExtensions.OSV_API_PATH)));
}

@Test
void testFindByPurls_Error() {
var purls = List.of(WireMockExtensions.PURL_WITH_ERROR);

assertThrows(ClientWebApplicationException.class, () -> vulnerabilityService.findByPurls(purls, false));

server.verify(1, postRequestedFor(urlEqualTo(WireMockExtensions.OSV_API_PATH)));
}
}
Loading

0 comments on commit a70ef0f

Please sign in to comment.