Skip to content

Commit

Permalink
Merge pull request #1164 from flightstats/fix-access-monitor
Browse files Browse the repository at this point in the history
ensure get object stream is closed
  • Loading branch information
Paul-Hess authored Apr 24, 2019
2 parents cf6c8ba + 60c9433 commit 486e6ef
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
18 changes: 13 additions & 5 deletions src/main/java/com/flightstats/hub/dao/aws/S3AccessMonitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.PutObjectResult;
import com.amazonaws.services.s3.model.S3Object;
import com.flightstats.hub.dao.Dao;
import com.flightstats.hub.model.ChannelConfig;
import com.flightstats.hub.model.Content;
import com.flightstats.hub.model.ContentKey;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.joda.time.DateTime;

import javax.inject.Inject;
import javax.inject.Named;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;

Expand Down Expand Up @@ -78,13 +81,18 @@ private CompletableFuture<PutObjectResult> waitForWrite() {
}
}

private CompletableFuture<String> waitForRead(String versionId) {
private CompletableFuture<String> waitForRead() {
try {
return CompletableFuture.supplyAsync(() -> {
hubS3Client.getObject(new GetObjectRequest(s3BucketName.getS3BucketName(), key(), versionId));
return versionId;
try (S3Object s3Object = hubS3Client
.getObject(new GetObjectRequest(s3BucketName.getS3BucketName(), key()))) {
return s3Object.getObjectMetadata().getVersionId();
} catch (IOException e) {
log.info("error closing connection to s3", e);
return StringUtils.EMPTY;
}
});
} catch(Exception e) {
} catch (Exception e) {
log.error("error getting object from s3", e);
throw e;
}
Expand All @@ -94,7 +102,7 @@ public boolean verifyReadWriteAccess() {
try {
createChannelIfNotExist();
waitForWrite()
.thenCompose(putObjectResult -> waitForRead(putObjectResult.getVersionId())).get();
.thenCompose(result -> waitForRead()).get();
} catch (Exception e) {
log.error("error reaching S3: ", e);
return false;
Expand Down
42 changes: 38 additions & 4 deletions src/test/java/com/flightstats/hub/dao/aws/S3AccessMonitorTest.java
Original file line number Diff line number Diff line change
@@ -1,41 +1,48 @@
package com.flightstats.hub.dao.aws;

import com.amazonaws.services.s3.model.DeleteObjectRequest;
import com.amazonaws.services.s3.model.DeleteVersionRequest;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PutObjectRequest;
import com.amazonaws.services.s3.model.PutObjectResult;
import com.flightstats.hub.dao.CachedLowerCaseDao;
import com.amazonaws.services.s3.model.S3Object;
import com.flightstats.hub.dao.Dao;
import com.flightstats.hub.model.ChannelConfig;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;

import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.junit.Assert.assertFalse;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class S3AccessMonitorTest {
private final S3BucketName s3BucketName = new S3BucketName("test", "bucket");
private HubS3Client s3Client;
private S3AccessMonitor monitor;
private PutObjectResult putObjectResult;
private S3Object s3Object;

@Before
public void setUpTest() {
s3Client = mock(HubS3Client.class);
Dao<ChannelConfig> channelConfigDao = mock(DynamoChannelConfigDao.class);
monitor = new S3AccessMonitor(channelConfigDao, s3Client, s3BucketName);
s3Object = new S3Object();
ObjectMetadata metadata = new ObjectMetadata();
metadata.addUserMetadata("versionId", "testVersionId");
s3Object.setObjectMetadata(metadata);
putObjectResult = new PutObjectResult();
putObjectResult.setVersionId("testVersionId");
}

@Test
public void testVerifyReadWriteAccess_errorPutObject_false() {
doThrow(new RuntimeException("public void testVerifyReadWriteAccess_errorPutObject_false"))
doThrow(new RuntimeException("testVerifyReadWriteAccess_errorPutObject_false"))
.when(s3Client).putObject(any(PutObjectRequest.class));
assertFalse(monitor.verifyReadWriteAccess());
}
Expand All @@ -52,7 +59,34 @@ public void testVerifyReadWriteAccess_errorGetObject_false() {

@Test
public void testVerifyReadWriteAccess_mockVersionId_true() {
when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(s3Object);
when(s3Client.putObject(any(PutObjectRequest.class))).thenReturn(putObjectResult);
assertTrue(monitor.verifyReadWriteAccess());
}

@Test
public void testVerifyReadWriteAccess_closeCalledGreenField_true() throws IOException {
S3Object mockS3Object = mock(S3Object.class);
when(s3Client.getObject(any(GetObjectRequest.class))).thenReturn(mockS3Object);
when(s3Client.putObject(any(PutObjectRequest.class))).thenReturn(putObjectResult);
monitor.verifyReadWriteAccess();
verify(mockS3Object).close();
}

@Test
public void testVerifyReadWriteAccess_closeCalledWithErrorOnGet_true() throws IOException {
S3Object mockS3Object = mock(S3Object.class);
when(s3Client.getObject(any(GetObjectRequest.class)))
.thenReturn(mockS3Object)
.thenThrow(new RuntimeException("testVerifyReadWriteAccess_closeCalledErrorOnGet_true"));
assertFalse(monitor.verifyReadWriteAccess());
verify(mockS3Object).close();
}

@Test
public void testVerifyReadWriteAccess_handlesNullSafely_false() throws IOException {
when(s3Client.getObject(any(GetObjectRequest.class)))
.thenReturn(null);
assertFalse(monitor.verifyReadWriteAccess());
}
}

0 comments on commit 486e6ef

Please sign in to comment.