Skip to content

Commit

Permalink
Updated downloadhelper (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
msmygit authored Oct 28, 2023
1 parent 39094dc commit 08e79d3
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 26 deletions.
12 changes: 10 additions & 2 deletions jvector-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
</parent>
<artifactId>jvector-examples</artifactId>
<name>JVector Examples</name>
<properties>
<awssdk.version>2.21.10</awssdk.version>
</properties>
<build>
<plugins>
<plugin>
Expand Down Expand Up @@ -42,12 +45,17 @@
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3-transfer-manager</artifactId>
<version>2.21.2</version>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>aws-crt-client</artifactId>
<version>2.21.2</version>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>s3</artifactId>
<version>${awssdk.version}</version>
</dependency>
<dependency>
<groupId>com.kohlschutter.junixsocket</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3AsyncClientBuilder;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.S3Object;
import software.amazon.awssdk.transfer.s3.S3TransferManager;
import software.amazon.awssdk.transfer.s3.model.CompletedFileDownload;
import software.amazon.awssdk.transfer.s3.model.DownloadFileRequest;
Expand All @@ -19,31 +23,39 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

public class DownloadHelper {
private static String bucketName = "astra-vector";

public static void maybeDownloadFvecs() {
// TODO how to detect and recover from incomplete downloads?
String[] keys = {
"wikipedia_squad/100k/ada_002_100000_base_vectors.fvec",
"wikipedia_squad/100k/ada_002_100000_query_vectors_10000.fvec",
"wikipedia_squad/100k/ada_002_100000_indices_query_10000.ivec"
};

String bucketName = "astra-vector";

private static S3AsyncClientBuilder getS3AsyncClientBuilder() {
S3AsyncClientBuilder s3ClientBuilder = S3AsyncClient.builder()
.region(Region.of("us-east-1"))
.region(Region.US_EAST_1)
.httpClient(AwsCrtAsyncHttpClient.builder()
.maxConcurrency(1)
.build())
.maxConcurrency(1)
.build())
.credentialsProvider(AnonymousCredentialsProvider.create());
return s3ClientBuilder;
}

public static void maybeDownloadFvecs(List<String> files) {
List<String> keys;
if (null == files || files.isEmpty()) {
keys = Arrays.asList(new String[] {
"wikipedia_squad/100k/ada_002_100000_base_vectors.fvec",
"wikipedia_squad/100k/ada_002_100000_query_vectors_10000.fvec",
"wikipedia_squad/100k/ada_002_100000_indices_query_10000.ivec",
});
} else {
keys = files;
}
// TODO how to detect and recover from incomplete downloads?

// get directory from paths in keys
List<String> dirs = Arrays.stream(keys).map(key -> key.substring(0, key.lastIndexOf("/"))).distinct().collect(Collectors.toList());
List<String> dirs = keys.stream().map(key -> key.substring(0, key.lastIndexOf("/"))).distinct().collect(Collectors.toList());
for (String dir : dirs) {
try {
dir = "fvec/" + dir;
Expand All @@ -53,7 +65,7 @@ public static void maybeDownloadFvecs() {
}
}

try (S3AsyncClient s3Client = s3ClientBuilder.build()) {
try (S3AsyncClient s3Client = getS3AsyncClientBuilder().build()) {
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Client).build();
for (String key : keys) {
Path path = Paths.get("fvec", key);
Expand All @@ -69,11 +81,20 @@ public static void maybeDownloadFvecs() {
.destination(Paths.get(path.toString()))
.build();

FileDownload downloadFile = tm.downloadFile(downloadFileRequest);

CompletedFileDownload downloadResult = downloadFile.completionFuture().join();
System.out.println("Downloaded file of length " + downloadResult.response().contentLength());

// 3 retries
for (int i = 0; i < 3; i++) {
FileDownload downloadFile = tm.downloadFile(downloadFileRequest);
CompletedFileDownload downloadResult = downloadFile.completionFuture().join();
long downloadedSize = Files.size(path);

// Check if downloaded file size matches the expected size
if (downloadedSize == downloadResult.response().contentLength()) {
System.out.println("Downloaded file of length " + downloadResult.response().contentLength());
break; // Successfully downloaded
} else {
System.out.println("Incomplete download. Retrying...");
}
}
}
tm.close();
} catch (Exception e) {
Expand All @@ -82,14 +103,19 @@ public static void maybeDownloadFvecs() {
}
}

public static void maybeDownloadFvecs() {
maybeDownloadFvecs(null);
}

public static void maybeDownloadHdf5(String datasetName) {
var fullPath = Path.of(Hdf5Loader.HDF5_DIR).resolve(datasetName);
Path path = Path.of(Hdf5Loader.HDF5_DIR);
var fullPath = path.resolve(datasetName);
if (Files.exists(fullPath)) {
return;
}

// Download from http://ann-benchmarks.com/datasetName
var url = "http://ann-benchmarks.com/" + datasetName;
// Download from https://ann-benchmarks.com/datasetName
var url = "https://ann-benchmarks.com/" + datasetName;
System.out.println("Downloading: " + url);

HttpURLConnection connection = null;
Expand All @@ -111,11 +137,22 @@ public static void maybeDownloadHdf5(String datasetName) {
}

try (InputStream in = connection.getInputStream()) {
Files.createDirectories(Path.of(Hdf5Loader.HDF5_DIR));
Files.createDirectories(path);
Files.copy(in, fullPath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
System.out.println("Error downloading data: " + e.getMessage());
System.exit(1);
}
}

public static List<String> s3FileListing() {
S3Client s3 = S3Client.builder().region(Region.US_EAST_1).credentialsProvider(AnonymousCredentialsProvider.create()).build();
ListObjectsV2Request req = ListObjectsV2Request.builder().bucket(bucketName).build();
ListObjectsV2Response res = s3.listObjectsV2(req);
List<String> filenames = res.contents().stream().map(S3Object::key).collect(Collectors.toList());
/*for (String filename : filenames) {
System.out.println(filename);
}*/
return filenames;
}
}

0 comments on commit 08e79d3

Please sign in to comment.