diff --git a/README.md b/README.md index 40bf25c..7c24125 100644 --- a/README.md +++ b/README.md @@ -4,14 +4,15 @@ This Plugin allows you to score Elasticsearch documents based on embedding-vecto ## General * Updated version for ES 6.1 of [this plugin](https://github.com/lior-k/fast-elasticsearch-vector-scoring) -* Cosine support removed. * This plugin was inspired from [This elasticsearch vector scoring plugin](https://github.com/MLnick/elasticsearch-vector-scoring) and [this discussion](https://discuss.elastic.co/t/vector-scoring/85227/6) to achieve 10 times faster processing over the original. * lior-k gained this substantial speed improvement by using the lucene index directly * lior-k developed it for their workplace which needs to pick KNN from a set of ~4M vectors. Their current ES setup is able to answer this in ~80ms +* Cosine support +* Vector norms can be stored with vector and can be used when cosine distance calculating, so that they are not recalculated for each request ## Elasticsearch version -* Currently designed for Elasticsearch 6.1.2. +* Currently designed for Elasticsearch 6.3.1. ## Setup @@ -42,14 +43,45 @@ Place this into an elasticsearch checkout, add the plugin to the projects list i ## Usage ### Documents -* Each document you score should have a field containing the base64 representation of your vector. for example: +* Each document you score should have a field containing the base64 representation of your vector. for example vector: ``` - { - "id": 1, - .... - "content_vector": "v7l48eAAAAA/s4VHwAAAAD+R7I5AAAAAv8MBMAAAAAA/yEI3AAAAAL/IWkeAAAAAv7s480AAAAC/v6DUgAAAAL+wJi0gAAAAP76VqUAAAAC/sL1ZYAAAAL/dyq/gAAAAP62FVcAAAAC/tQRvYAAAAL+j6ycAAAAAP6v1KcAAAAC/bN5hQAAAAL+u9ItAAAAAP4ckTsAAAAC/pmkjYAAAAD+cYpwAAAAAP5renEAAAAC/qY0HQAAAAD+wyYGgAAAAP5WrCcAAAAA/qzjTQAAAAD++LBzAAAAAP49wNKAAAAC/vu/aIAAAAD+hqXfAAAAAP4FfNCAAAAA/pjC64AAAAL+qwT2gAAAAv6S3OGAAAAC/gfMtgAAAAD/If5ZAAAAAP5mcXOAAAAC/xYAU4AAAAL+2nlfAAAAAP7sCXOAAAAA/petBIAAAAD9soYnAAAAAv5R7X+AAAAC/pgM/IAAAAL+ojI/gAAAAP2gPz2AAAAA/3FonoAAAAL/IHg1AAAAAv6p1SmAAAAA/tvKlQAAAAD/I2OMAAAAAP3FBiCAAAAA/wEd8IAAAAL94wI9AAAAAP2Y1IIAAAAA/rnS4wAAAAL9vriVgAAAAv1QxoCAAAAC/1/qu4AAAAL+inZFAAAAAv7aGA+AAAAA/lqYVYAAAAD+kNP0AAAAAP730BiAAAAA=" - } - ``` +[-0.09950172156095505, 0.07625244557857513, 0.017503950744867325, -0.14847373962402344, 0.1895207166671753, -0.19025510549545288, -0.10633774101734161, -0.12354782223701477, -0.06308252364397049, 0.11947114765644073, -0.0653892382979393, -0.4654960334300995, 0.057657890021800995, -0.08209892362356186, -0.03890344500541687, 0.054604820907115936, -0.0035240077413618565, -0.06045947223901749, 0.011299720034003258, -0.043770890682935715, 0.02771991491317749, 0.02623981609940529, -0.04990408569574356, 0.06557474285364151, 0.021160271018743515, 0.0531679168343544, 0.11786060035228729, 0.015350733883678913, -0.12084735184907913, 0.034496061503887177, 0.008482367731630802, 0.0433405302464962, -0.05225555971264839, -0.040460359305143356, -0.008764605969190598, 0.19139364361763, 0.02501053921878338, -0.16797123849391937, -0.08835361897945404, 0.10550480335950851, 0.04281047359108925, 0.0034949961118400097, -0.020001886412501335, -0.04299351945519447, -0.04794740304350853, 0.0029372263234108686, 0.4430026113986969, -0.18841710686683655, -0.051676105707883835, 0.08963997662067413, 0.19411885738372803, 0.004212886560708284, 0.1271815448999405, -0.006043014116585255, 0.0027108797803521156, 0.05948426574468613, -0.0038672189693897963, -0.0012325347634032369, -0.3746754825115204, -0.03635839372873306, -0.0879824087023735, 0.02211793325841427, 0.03946676850318909, 0.11700475960969925] + +``` +should be converted to: +``` +{ + "id": 1, + .... + "content_vector": "v7l48eAAAAA/s4VHwAAAAD+R7I5AAAAAv8MBMAAAAAA/yEI3AAAAAL/IWkeAAAAAv7s480AAAAC/v6DUgAAAAL+wJi0gAAAAP76VqUAAAAC/sL1ZYAAAAL/dyq/gAAAAP62FVcAAAAC/tQRvYAAAAL+j6ycAAAAAP6v1KcAAAAC/bN5hQAAAAL+u9ItAAAAAP4ckTsAAAAC/pmkjYAAAAD+cYpwAAAAAP5renEAAAAC/qY0HQAAAAD+wyYGgAAAAP5WrCcAAAAA/qzjTQAAAAD++LBzAAAAAP49wNKAAAAC/vu/aIAAAAD+hqXfAAAAAP4FfNCAAAAA/pjC64AAAAL+qwT2gAAAAv6S3OGAAAAC/gfMtgAAAAD/If5ZAAAAAP5mcXOAAAAC/xYAU4AAAAL+2nlfAAAAAP7sCXOAAAAA/petBIAAAAD9soYnAAAAAv5R7X+AAAAC/pgM/IAAAAL+ojI/gAAAAP2gPz2AAAAA/3FonoAAAAL/IHg1AAAAAv6p1SmAAAAA/tvKlQAAAAD/I2OMAAAAAP3FBiCAAAAA/wEd8IAAAAL94wI9AAAAAP2Y1IIAAAAA/rnS4wAAAAL9vriVgAAAAv1QxoCAAAAC/1/qu4AAAAL+inZFAAAAAv7aGA+AAAAA/lqYVYAAAAD+kNP0AAAAAP730BiAAAAA/8AAAAq9PCQ==" +} +``` + +Also in this example I saved vector norm as last element of array and convert vector and its norm to base64 representation + +### Converting a vector to Base64 + +**Python** +``` +import struct +import base64 +import math + + +def calc_vector_norm(vector): + norm = 0 + for elem in vector: + norm += elem**2 + return math.sqrt(norm) + + +def encode_v(vector): + vector.append(calc_vector_norm(vector)) + return base64.b64encode( + struct.pack('>%sd' % len(vector), *vector) + ).decode('utf-8') +``` + * Use this field mapping: ``` PUT my_index @@ -88,6 +120,8 @@ POST /_search "source": "vector_scoring", "lang": "binary_vector_score", "params": { + "use_stored_vector_norm": true, + "cosine": true, "vector_field": "content_vector", "vector": [ -0.09217305481433868, @@ -167,5 +201,7 @@ POST /_search * The example above shows a vector of 64 dimensions * Parameters: 1. `field_vector`: The field containing the base64 vector. - 3. `vector`: The vector (comma separated) to compare to. + 2. `vector`: The vector (comma separated) to compare to. + 3. `cosine`: (boolean) calculate cosine distance or dot product of vectors + 4. `use_stored_vector_norm`: (boolean) if `true` stored norm will be used when cosine distance will calculating diff --git a/src/main/java/com/gosololaw/elasticsearch/VectorScoringPlugin.java b/src/main/java/com/gosololaw/elasticsearch/VectorScoringPlugin.java index 3d2c8fe..e6c0f5e 100755 --- a/src/main/java/com/gosololaw/elasticsearch/VectorScoringPlugin.java +++ b/src/main/java/com/gosololaw/elasticsearch/VectorScoringPlugin.java @@ -60,11 +60,16 @@ public T compile(String scriptName, String scriptSource, ScriptContext co if (!context.equals(SearchScript.CONTEXT)) { throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } + // we use the script "source" as the script identifier if ("vector_scoring".equals(scriptSource)) { SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() { + private final double[] inputVector; + private final double inputVectorNorm; final String field; + final boolean useStoredVectorNorm; + final boolean cosine; { final Object field = p.get("vector_field"); if (field == null) @@ -74,9 +79,26 @@ public T compile(String scriptName, String scriptSource, ScriptContext co // get query inputVector - convert to primitive final ArrayList tmp = (ArrayList) p.get("vector"); this.inputVector = new double[tmp.size()]; + for (int i = 0; i < inputVector.length; i++) { inputVector[i] = tmp.get(i); } + + final Object cosine = p.get("cosine"); + this.cosine = cosine != null && (boolean)cosine; + + if (this.cosine) { + double norm = 0.0f; + for (int i = 0; i < inputVector.length; i++) { + norm += inputVector[i] * inputVector[i]; + } + this.inputVectorNorm = Math.sqrt(norm); + } else { + this.inputVectorNorm = 0; + } + + final Object useStoredVectorNorm = p.get("use_stored_vector_norm"); + this.useStoredVectorNorm = useStoredVectorNorm != null && (boolean)useStoredVectorNorm; } @Override @@ -101,17 +123,23 @@ public double runAsDouble() { if (!is_value) return 0; final byte[] bytes; try { - bytes = accessor.binaryValue().bytes; + bytes = accessor.binaryValue().bytes; } catch (IOException e) { - return 0; + return 0; } final int input_vector_size = inputVector.length; final ByteArrayDataInput doc_vector = new ByteArrayDataInput(bytes); + doc_vector.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls - final int doc_vector_length = doc_vector.readVInt(); // returns the number of bytes to read - if(doc_vector_length != input_vector_size * DOUBLE_SIZE) { + int doc_vector_length = doc_vector.readVInt(); // returns the number of bytes to read + + if (useStoredVectorNorm) { + doc_vector_length = doc_vector_length - DOUBLE_SIZE; + } + + if(doc_vector_length < input_vector_size * DOUBLE_SIZE) { return 0.0; } final int position = doc_vector.getPosition(); @@ -124,6 +152,28 @@ public double runAsDouble() { for (int i = 0; i < input_vector_size; i++) { score += docVector[i] * inputVector[i]; } + + if (!cosine) { + return score; + } + + + double docVectorNorm = 0.0f; + + if (useStoredVectorNorm) { + doc_vector.skipBytes(doc_vector_length); + docVectorNorm = Double.longBitsToDouble(doc_vector.readLong()); + } else { + for (int i = 0; i < input_vector_size; i++) { + docVectorNorm += docVector[i] * docVector[i]; + } + docVectorNorm = Math.sqrt(docVectorNorm); + } + + score /= inputVectorNorm; + score /= docVectorNorm; + + return score; } };