Skip to content

Commit

Permalink
work around lastfm dataset containing f64 query vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
jbellis committed Oct 26, 2023
1 parent 9e02dd5 commit 39094dc
Showing 1 changed file with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.jhdf.HdfFile;
import io.jhdf.api.Dataset;
import io.jhdf.object.datatype.DataType;
import io.jhdf.object.datatype.FloatingPoint;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.IntStream;

public class Hdf5Loader {
public static final String HDF5_DIR = "hdf5/";
Expand All @@ -47,7 +51,20 @@ else if (filename.contains("-euclidean")) {
Path path = Path.of(HDF5_DIR).resolve(filename);
try (HdfFile hdf = new HdfFile(path)) {
baseVectors = (float[][]) hdf.getDatasetByPath("train").getData();
queryVectors = (float[][]) hdf.getDatasetByPath("test").getData();
Dataset queryDataset = hdf.getDatasetByPath("test");
if (((FloatingPoint) queryDataset.getDataType()).getBitPrecision() == 64) {
// lastfm dataset contains f64 queries but f32 everything else
var doubles = ((double[][]) queryDataset.getData());
queryVectors = IntStream.range(0, doubles.length).parallel().mapToObj(i -> {
var a = new float[doubles[i].length];
for (int j = 0; j < doubles[i].length; j++) {
a[j] = (float) doubles[i][j];
}
return a;
}).toArray(float[][]::new);
} else {
queryVectors = (float[][]) queryDataset.getData();
}
groundTruth = (int[][]) hdf.getDatasetByPath("neighbors").getData();
}

Expand Down

0 comments on commit 39094dc

Please sign in to comment.